How can I plot a confusion matrix for a multi-class or non-binary classification problem?

39 views (last 30 days)
I want to make a plot similar to the confusion matrix created in the Classification Learner app. This can make a confusion matrix for a multi-class or non-binary classification problem. In addition, it can plot things such as a True Positive or False Negative rates.
How can I do this?

Accepted Answer

MathWorks Support Team
MathWorks Support Team on 5 Jul 2017
Similar to the binary or two-class problem, this can be done using the "plotconfusion" function. By default, this command will also plot the True Positive, False Negative, Positive Predictive, and False Discovery rates in they grey-colored boxes. Please refer to the following example:
targetsVector = [1 2 1 1 3 2]; % True classes
outputsVector = [1 3 1 2 3 1]; % Predicted classes
% Convert this data to a [numClasses x 6] matrix
targets = zeros(3,6);
outputs = zeros(3,6);
targetsIdx = sub2ind(size(targets), targetsVector, 1:6);
outputsIdx = sub2ind(size(outputs), outputsVector, 1:6);
targets(targetsIdx) = 1;
outputs(outputsIdx) = 1;
% Plot the confusion matrix for a 3-class problem
The class labels can be customized by setting that 'XTickLabel' and 'YTickLabel' properties of the axis:
h = gca;
h.XTickLabel = {'Class A','Class B','Class C',''};
h.YTickLabel = {'Class A','Class B','Class C',''};
h.YTickLabelRotation = 90;
Munshida P
Munshida P on 13 Feb 2020
% Load Image dataset
faceDatabase = imageSet('facedatabaseatt','recursive');
%splitting into training and testing sets
[training,test] = partition(faceDatabase,[0.8 0.2]);
% Extract HOG Features for training set
featureCount = 1;
for i=1:size(training,2)
for j = 1:training(i).Count
trainingFeatures(featureCount,:) = extractHOGFeatures(read(training(i),j));
% imshow(read(training(i),j));
trainingLabel{featureCount} = training(i).Description;
featureCount = featureCount + 1;
personIndex{i} = training(i).Description;
% Create 40 class classifier
faceClassifier = fitcknn(trainingFeatures,trainingLabel);
for person=1:40
for j = 1:test(person).Count
queryImage = read(test(person),j);
queryFeatures = extractHOGFeatures(queryImage);
actualLabel = predict(faceClassifier,queryFeatures);
% Map back to training set to find identity
%booleanIndex = strcmp(actualLabel, personIndex);
%integerIndex = find(booleanIndex);
if isempty(al)==0
Sir, how to plot the confusion matrix of the below code?

Sign in to comment.

More Answers (2)

David Franco
David Franco on 23 Jan 2018
Edited: MathWorks Support Team on 16 Mar 2018
Implementation code:
Confusion Matrix
function [] = confusion_matrix(T,Y)
M = size(unique(T),2);
N = size(T,2);
targets = zeros(M,N);
outputs = zeros(M,N);
targetsIdx = sub2ind(size(targets), T, 1:N);
outputsIdx = sub2ind(size(outputs), Y, 1:N);
targets(targetsIdx) = 1;
outputs(outputsIdx) = 1;
% Plot the confusion matrix

Sign in to comment.

Fatai Anifowose
Fatai Anifowose on 28 Aug 2019
I am trying to use the "plotconfusion" function in my code but it took a very long time until MATLAB crashed.
What could be the reason for this? I have a high end workstation so I would not expect it to be a memory issue.
Thanks for your help.
Martin Jendryka
Martin Jendryka on 25 Nov 2020
transposing the input variables might help so that the column vectors are row vectors (see link below)

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!