Code covered by the BSD License  

Highlights from
Computing the posterior balanced accuracy

image thumbnail

Computing the posterior balanced accuracy

by

 

A set of MATLAB functions for evaluating generalization performance in binary classification.

bacc_demo
% Simple demo to compare accuracies and balanced accuracies.
% 
% Usage:
%     bacc_demo
% 
% Note that, throughout the code, confusion matrices are assumed to be 2x2
% matrices of the form: ACTUAL x PREDICTED. Confusion matrices with this
% convention can be generated using the Matlab function CONFUSIONMAT.
% 
% Literature:
%     K.H. Brodersen, C.S. Ong, K.E. Stephan, J.M. Buhmann (2010).
%     The balanced accuracy and its posterior distribution. In: Proceedings
%     of the 20th International Conference on Pattern Recognition.

% Kay H. Brodersen, ETH Zurich, Switzerland
% http://people.inf.ethz.ch/bkay/
% $Id: bacc_demo.m 5522 2010-04-22 18:33:11Z bkay $
% -------------------------------------------------------------------------
function bacc_demo
    
    % Specify confusion matrix (outcome of classification)
    C1 = [69,1;1,69];
    C2 = [40,5;8,2];
    
    % Coverage of posterior probability intervals
    alpha = 0.05;
    
    % Plot confusion matrices
    figure;
    subplot(1,3,1); plotMatrix(C1); title('C_1');
    subplot(1,3,3); plotMatrix(C2); title('C_2');
    
    % Plot mean and '1-alpha' posterior probability interval
    figure;
    subplot(1,2,1); plotBars(C1, alpha);
    subplot(1,2,2); plotBars(C2, alpha);
    
    % Plot full distribution
    figure; plotDistr(C1, alpha);
    figure; plotDistr(C2, alpha);
end

% -------------------------------------------------------------------------
function plotMatrix(C)
    
    hold on;
    imagesc(C);
    axis ij;
    colormap gray;
    %colorbar
    for y=1:2
        for x=1:2
            if C(y,x)>max(max(C))/2
                col = [0,0,0];
            else
                col = [1,1,1];
            end
            text(x,y,num2str(C(y,x)), 'HorizontalAlignment', 'center', ...
                'color', col);
        end
    end
    set(gca, 'Box', 'on');
    set(gca, 'XTick', [1 2]);
    set(gca, 'YTick', [1 2]);
    set(gca, 'XTickLabel', {'predicted +', 'predicted -'});
    set(gca, 'YTickLabel', {'actual +', 'actual -'});
    axis square
    axis tight;
end

% -------------------------------------------------------------------------
function plotBars(C, alpha)
    hold on;
    
    % Draw chance bar
    x = 1;
    line([x-1 x+1], [0.5 0.5], 'Color', [0.5 0.5 0.5], 'LineWidth', 2);
    
    % Draw accuracy mode (and naive standard error of the mean)
    a = acc_mode(C);
    sem = acc_sem(C);
    adjustErrorBarWidth(errorbar(x-0.5, a, 2*sem, 2*sem, ...
                'LineWidth', 2, 'color', [192 0 0]/255));
    
    % Draw accuracy mean (and posterior probability interval)
    a = acc_mean(C);
    [a_lower,a_upper] = acc_ppi(C,alpha);
    adjustErrorBarWidth(errorbar(x, a, a-a_lower, a_upper-a, ...
                'LineWidth', 2, 'color', [37 64 97]/255));
    
    % Draw balanced accuracy mean (and posterior probability interval)
    b = bacc_mean(C);
    [b_lower,b_upper] = bacc_ppi(C,alpha);
    adjustErrorBarWidth(errorbar(x+0.5, b, b-b_lower, b_upper-b, ...
                'LineWidth', 2, 'color', [0 176 80]/255));
    
    % Finalise figure
    v = axis;
    v(3:4)=[0.4,1];
    axis(v);
    strCI = [num2str(round((1-alpha)*100)), '% PPI'];
end

% -------------------------------------------------------------------------
function plotDistr(C, alpha)
    hold on;
    
    % Compute Beta distribution parameters for accuracy
    A = C(1,1)+C(2,2) + 1;  % alpha (corrects)
    B = C(1,2)+C(2,1) + 1;  % beta (incorrects)
    
    % Compute Beta distribution parameters for balanced accuracy
    A1 = C(1,1) + 1; % alpha for positives (corrects)
    B1 = C(1,2) + 1; % beta for positives (incorrects)
    A2 = C(2,2) + 1; % alpha for negatives (corrects)
    B2 = C(2,1) + 1; % beta for negatives (incorrects)
    
    % Get means, medians, PPIs
    % - of accuracy
    a = acc_mean(C);
    amode = acc_mode(C);
    amed = acc_med(C);
    [a_lower,a_upper] = acc_ppi(C,alpha);
    %
    % - of balanced accuracy
    b = bacc_mean(C);
    bmode = bacc_mode(C);
    bmed = bacc_med(C);
    [b_lower,b_upper] = bacc_ppi(C,alpha);
    bnaive = bacc_naive(C);
    
    % Compute full distributions
    res = 0.001;
    x = [0:res:1];
    y_acc = betapdf(x, A, B);
    y_bacc = betaavgpdf(x, A1, B1, A2, B2);
    
    % Plot accuracy
    subplot(2,1,1); hold on;
    plot([0.5 0.5], [0 betapdf(amode, A, B)], '--', 'color', [0 0 0], 'linewidth', 2);
    plot(x,y_acc,'k');
    inner = (a_lower <= x) & (x <= a_upper);
    %plotfill(x(inner), y_acc(inner), zeros(1,sum(inner)), 'color', [0.6 0.6 0.6]);
    plot([amode amode], [0 betapdf(amode, A, B)], '--', 'color', [0,176,80]/255, 'linewidth', 2);
    plot([a a], [0 betapdf(a, A, B)], 'color', [192,0,0]/255, 'linewidth', 2);
    plot([amed amed], [0 betapdf(amed, A, B)], '-.', 'color', [0,112,192]/255, 'linewidth', 2);
    title('Posterior accuracy');
    
    % Plot balanced accuracy
    subplot(2,1,2); hold on;
    plot([0.5 0.5], [0 betaavgpdf(bmode, A1, B1, A2, B2)], '--', 'color', [0 0 0], 'linewidth', 2);
    plot(x,y_bacc,'k');
    inner = (b_lower <= x) & (x <= b_upper);
    %plotfill(x(inner), y_bacc(inner), zeros(1,sum(inner)), 'color', [0.6 0.6 0.6]);
    plot([b b], [0 betaavgpdf(b, A1, B1, A2, B2)], 'color', [192,0,0]/255, 'linewidth', 2);
    plot([bmode bmode], [0 betaavgpdf(bmode, A1, B1, A2, B2)], '--', 'color', [0,176,80]/255, 'linewidth', 2);
    plot([bmed bmed], [0 betaavgpdf(bmed, A1, B1, A2, B2)], '-.', 'color', [0,112,192]/255, 'linewidth', 2);
    plot([bnaive bnaive], [0 betaavgpdf(bnaive, A1, B1, A2, B2)], ':', 'color', [0 0 0], 'linewidth', 2);
    title('Posterior balanced accuracy');
    %legend('chance', 'pdf', [num2str(round((1-alpha)*100)), '% PPI'], 'mean', 'mode', 'median', 'naive', 'Location', 'NorthWest');
    legend('chance', 'pdf', 'mean', 'mode', 'median', 'naive', 'Location', 'NorthWest');
    
end

Contact us