Code covered by the BSD License  

Highlights from
Mass Spectrometry Bayesian Network Analysis Tool

from Mass Spectrometry Bayesian Network Analysis Tool by Karl Kuschner
Finds diagnostic features in the spectra of biologic samples by using a Bayesian Network approach

[IntOut IDOut PredClass Class2Vars Var2Vars MetaVars TrialErr]...
function [IntOut IDOut PredClass Class2Vars Var2Vars MetaVars TrialErr]...
    = WMBAT (Intensities, Class, ID, MZ, Options, nfold, repeats, threshold)
% The William and Mary Bayesian Analysis Tool
%           (c) 2009 Karl Kuschner, College of William and Mary
% 
% DESCRIPTION
%       WMBAT takes an array of mass spec peak intensities, a vector
%       describing which of two classes each sample belongs to, 
%       and other information and builds and assesses a Bayesian network
%       after selecting features (peaks) from within the data array that
%       are diagnostic of the class. The primary output is an adjacency
%       matrix describing the resulting Bayesian network.
% 
% USAGE
%       IntOut IDOut PredClass Class2Vars Var2Vars MetaVars TrialErr] = WMBAT (Intensities,
%                       Class, ID, MZ, Options, nfold, repeats, threshold)
% 
% INPUTS
%       Intensities: Double array of intensity values of size #cases x #variables.
%           Each row is a case (spectrum), each column a specific global 
%           m/z (mass) position. Each entry in the row is the intensity
%           value of that case's spectrum at that specific m/z position.
%       Class: Integer vector of length "#cases", values 1 or 2 identifying the 
%           class of each case, such as "disease, non-disease"
%       ID: Double one or two column array of length #cases containg the sample ID
%           for each case. Second column is optional and would identify
%           replicates of the same sample. 
%       MZ: Double vector of length "#variables" holding m/z labels for peaks
%       Options: Logical 6x1 array. Options are:
%          1. Normalize on population total ion count (sum across rows)
%          2. Remove negative data values by setting them to zero
%          3. After normalizing, before binning, average cases with same ID
%          4. NOT USED - SET TO FALSE
%          5. Take log(data) prior to binning.  Negative values set to 1.
%          6. NOT USED - SET TO FALSE 
%       nfold: the "n" in n-fold cross validation (integer 4-10). 10 is
%           recommended.
%       repeats: Integer, times to repeat the whole process (e.g.
%           re-crossvalidate). 100 is recommended.
%       threshold: Factor by which the maximum "random" MI is multiplied to
%           find the minimum "significant" MI (double, 1.0-5.0). We
%           recommend starting with 1 and increasing until a "reasonable"
%           number of diagnostic peaks is reached and error rates are
%           minimized. This setting is dependent on the data and the
%           correlations between variables.

% 
% OUTPUTS
%       IntOut: The Intensities input array, after processing by the
%           various options selected by the logical Options above.
%       IDOut: The ID number of each row in the IntOut array. With no replicate
%           averaging, each ID will be preserved (but reformatted) from the
%           input.  With replicate averaging, only the primary ID number
%           remains.
%       PredClass: The predicted class of each case, during each of the
%           "repeats" number of trials 
%       Class2Vars: A vector whose ith value is the fraction of times peak i
%           (from the vector MZ) was selected as being connected to the
%           class. The maximum times it could have been selected was
%           nfold*repeats.
%       Var2Vars: An integer array whose (i,j) entry is the fraction of times a
%           second level link was found from peak i to peak j, once peak i
%           was connected to the class, as found in Class2Vars. 
%       MetaVars: An integer array whose (i,j) entry is the fraction of times a
%           metavariable was created using peak i and peak j and stored in
%           the level 1 variable peak i, once peak i was found connected to
%           the class.
%       TrialErr: The error rate for each of the "repeats" possible
%           trials. Records the percentage of cases where PredClass was not
%           equal to the input Class.
% 
% CALLED FUNCTIONS
% 
%       DoTheMath: Learns a Bayesian Network from the data

%% Initialize
% Package the inputs into a data structure, as needed by DoTheMath

In.Intensities=Intensities;
In.MZ=MZ;
In.ID=ID;
In.Class=Class;
In.Options=Options;
In.n=nfold;
In.Repeats=repeats;
%       drop: MI loss pecentage threshold for testing independance. Set to
%           .75 (75%) and adjust to filter too few/too many variable-to-
%           variable connections.
In.Drop=0.75;
In.Threshold=threshold;

%% Call the main function

Out=DoTheMath(In);

%% Format the results
numvars=max(size(MZ));
numtrials=nfold*repeats;
IntOut = Out.Intensities;
IDOut = Out.ID;
PredClass = Out.PredictedClass;
FirstLevel=Out.SumAdj(numvars+1,:);
Class2Vars=FirstLevel/numtrials;

Var2Vars=zeros(numvars);
MetaVars=zeros(numvars);

Chances=repmat(FirstLevel',1, numvars); % How often a variable could be linked
VtoV=Out.SumAdj(1:numvars,1:numvars); % Variable to variable connections
t=VtoV==0; % Places where there are no V to V conn
Var2Vars(~t)=VtoV(~t)./Chances(~t); % scale the V to V
MetaVars(Out.MetaVars~=0)=Out.MetaVars(Out.MetaVars~=0)./Chances(Out.MetaVars~=0);
TrialErr = Out.ErrorRate;

end

function OutputStructure = DoTheMath (InputStructure)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% DoTheMath takes a data set and performs feature selection 
% 
% DESCRIPTION
%       DoTheMath takes a data array, class vector, and other information
%       and builds and assesses a Bayesian network after selecting features
%       from within the data array.  It is called from the user interface
%       "orca.m." 
%
%       This is the umbrella script that loops a specified number of times
%       (see "repeats" below), each time doing a full n-fold cross
%       validation and recording the results.  All input and output data
%       are stored in a single data structure, described below.
% 
% USAGE
%       OutputDataStructure = DoTheMath (InputStructure)
% 
% INPUTS
%       InputStructure: Data repository with fields: 
%       Intensities: Array of intensity values of size #cases x #variables
%       Class: Vector of length "#cases", with discrete values identifying
%          class of each case (may be integer)
%       ID: Patient ID array of length #cases, with one or more cols
%       MZ: Vector of length "#variables" holding labels for variables
%       Options: Logical 6x1 array. Options are:
%          1. Normalize on population total ion count (sum across rows)
%          2. Remove negative data values by setting them to zero
%          3. After normalizing, before binning, average cases with same ID
%          4. Find the MI threshold by randomization
%          5. Take log(data) prior to binning.  Negative values set to 1.
%          6. Remove Low Signal cases
%              NOT DONE: 3 Bin (2 Bin if False)
%       n: the "n" in n-fold cross validation
%       repeats: Times to repeat the whole process (e.g. re-crossvalidate)
%       threshold: Factor by which the maximum "random" MI us multiplied to
%           find the minimum "significant" MI (double, 1.0-5.0).
% 
% OUTPUTS
%       OutputDataStructure: all the fields of InputStructure, plus: 
%       ErrorRate: Vector containing misclassification rate for each repeat
%       KeyFeatures: Index to vector MZ that identifies features selected
% 
% CALLED FUNCTIONS
% 
%       InitialProcessing: Applies the options listed above
%       BuildBayesNet: Learns a Bayesian Network from the training data
%       ChooseMetaVars: Combines variables that may not be physically
%           separate molecules.
%       TestCases: Given the BayesNet, tests the "test group" to determine
%           the probability of being in each class.
%       opt3bin: Discretizes continuous data into 3 bins, optimizing MI
%       FindProbTables: Learns the values P(C,V) for each variable
%       cvpartition and training are MATLAB Statistics toolbox functions.


%% Initialize

tic % start a timer
                
%% Initial Processing
% According to options, remove negative values, normalize and/or take
% logarithm of data, replicate average. Store in output data structure.

%  display('Starting Initial Processing of Data');
OutputStructure = InitialProcessing( InputStructure);
display('Initial processing complete.')
display (' ');


% Get values out of Data structure to be used later
drop=InputStructure.Drop; % MI loss pecentage threshold for testing 
                % independance, see clipclassconnections
ff=OutputStructure.Threshold;
n=double(OutputStructure.n); % for n-fold cross validation; default is 10
repeats=OutputStructure.Repeats; % Number of times to repeat CV, default 30
numtrials=repeats*n;
cverrorrate=zeros(numtrials,1);
errorrate=zeros(repeats,1);
data=OutputStructure.Intensities;
class=OutputStructure.Class;

% Find some sizes and initialize variables
[rows cols]=size(data);
% OutputStructure.varlist=zeros(cols,1);
class_predict=zeros(rows,repeats);
class_prob=zeros(rows,repeats);
trial=0; % counter of how many times we perform Bayes Analysis (n*repeats)

%% "Repeat Entire Process" Loop

% Repeat all processes the number of times requested
for r=1:repeats
    display (' '); 
    display(['Working on repetition number ', num2str(r),' at ',...
        num2str(toc/60),' mins']);
     
    %% Cross Validation Loop
    % This section selects a training and testing group out of the data by
    % dividing it into n groups, and using n-1 of those for training and 1
    % for testing. MATLAB (ver. 2008a or later) has a built in class for
    % this. See MATLAB documentation for "cvpartition" and "training."  
    cvgroups = cvpartition ( class, 'kfold', n ); 
    
    for cv = 1:n % for each of n test groups, together spanning all cases
        trial=trial+1; % Keep track of each trial
        display(['     Working on cross-validation number ',num2str(cv),...
            ' of ',num2str(n)])
        
        % The next line uses a function inside "cvpartition" called
        % "training" that returns a logical vector identifying which cases
        % to use as the training group in cross validation.
        traingrpindex=training(cvgroups,cv);
        
        % Use the vec to extract tng data and  class of the tng cases
        traingrp=data(traingrpindex,:);  
        traingrpclass=class(traingrpindex,:);
        
        % The test cases are cases NOT in the training group
        testgrp=data(~traingrpindex,:); 
        testgrpclass= class(~traingrpindex,:);
        
        %% Discretize the groups into hi-med-low
        % by optimizing MI(V,C) for each V (feature) in the training data.
        
        [leftbndry,rightbndry,traingrpbin, maxMI]=opt3bin(traingrp,...
            traingrpclass); %#ok<NASGU>
 
        %% Build an augmented Naive Bayesian Network with the training data
        % The adjacency matrix is a logical with true values meaning "there
        % is an arc from row index to column index." The last row
        % represents the class variable.
        
        adjmat = BuildBayesNet( traingrpbin, traingrpclass, ff, drop ); %adjacency matrix

        
        %% Find MetaVariables, rebuild data
        % Depending on the option set, reduce the V->V links by removing
        % them, or combining them into a single variable. The result is a
        % naive Bayesian network with only connections C->V.
 
        meta_option=1; % Hard coded for now
        classrow=cols+1;
        listvec=1:cols; % just a list of numbers
        varlist=unique(listvec(adjmat(classrow,:))); % top level vars
        
        if meta_option==1
            [finaldata metas leftbndry rightbndry] = ...
                ChooseMetaVars (traingrp, traingrpclass, adjmat);
        end
        
        % Bin up the test group using these final results, combining
        % variables per the instructions encoded in the "metas" logical
        % matrix.
        
        testdata=zeros(size(testgrp));
        
        if isempty(varlist) % in case no links are found
            disp ('Not finding any links yet...');
            errorrate(trial) = 1;
        else % if we do find links
            for var = varlist; % each of the parents of metavariables
                metavar=[var listvec(metas(var,:))]; % concatenate children
                testdata(:,var)=sum(testgrp(:,metavar),2); % sum parent/child
            end
            
            
            % Now remove empty rows
            finaltestdata=testdata(:,varlist);
            
            % And bin the result
            testgrpbin=zeros(size(finaltestdata)); %will be stored here
            % Build boundary arrays to test against
            testcases=size(testgrp,1);
            lb=repmat(leftbndry,testcases,1);
            rb=repmat(rightbndry,testcases,1);
            %  test each value and record the bin
            testgrpbin(finaltestdata<lb)=1;
            testgrpbin(finaltestdata>=lb)=2;
            testgrpbin(finaltestdata>rb)=3;
            
            %% Populate Bayesian Network
            
            % With the final set of data and the adjacency matrix, build the
            % probability tables and test each of the test group cases, to see
            % if we can determine the class.
            
            % Build the probability tables empirically with the training group
            % results
            ptable=FindProbTables(finaldata, traingrpclass);
            prior=histc(class, unique(traingrpclass))/max(size(traingrpclass));

            % find out the probability of each cases bing in class 1,2,etc.
            % Cases are in rows, class in columns.
            classprobtable = TestCases (ptable, prior, testgrpbin);
            [P_C predclass]=max(classprobtable,[],2);
            class_prob(~traingrpindex,r)=P_C;
            class_predict(~traingrpindex,r)=predclass;
            
            %Get the per trial error rate
            cverrorrate(trial)= sum(predclass==testgrpclass)/testcases;
            
            %Store some "per trial" data
            OutputStructure.Adjacency(trial,:,:)=adjmat;
            OutputStructure.MetaVariablesFound(trial,1:cols,1:cols)=metas;
            ProbTables(trial).TrialTable=ptable; %#ok<AGROW>
            
        end % of finding metavariables
        
    end % of Cross Validation loop
   
    wrong=sum(~(class==class_predict(:,r)));
    errorrate(r)=wrong/rows;
    
end % of repeating entire process loop

% Record the results in the output structure
OutputStructure.ErrorRate=errorrate; % one for each repeat
OutputStructure.CvErrorRate=cverrorrate; % one for each of n*repeats trials
OutputStructure.PredictedClass=class_predict;

% Find out the error for each case
classrep=repmat(class,1,r);
WasIright=classrep==OutputStructure.PredictedClass;
OutputStructure.CasePredictionRate=sum(WasIright, 2)/double(r);

OutputStructure.ClassProbability=class_prob;
OutputStructure.ProbTables=ProbTables;
OutputStructure.SumAdj=squeeze(sum(OutputStructure.Adjacency,1));
OutputStructure.MetaVars=squeeze(sum(OutputStructure.MetaVariablesFound,1));
% Save the results as a .mat data file and alert the user.
save results -struct  OutputStructure
disp('Additional Results are saved in the current directory as results.mat')

end % of the function DoTheMath

function StructOut = InitialProcessing( StructIn)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% INITIALPROCESSING Inital Prep of Data from Signal Pr0cessing 
% 
% DESCRIPTION
%       Takes peaklists that have been imported into MATLAB and prepares 
%       them for Bayesian Analysis.
% 
% USAGE
%       StructOut = InitialProcessing( StructIn)
% 
% INPUTS
%       Structure with the following double-typed arrays
%       Intensities: n x m real-valued array with variables (peaks) in
%           columns, cases (samples) in rows.
%       MZ: List of the labels (m/z value) for each of the variables.
%           Must be the same size as the number of variables in Intensities
%       Class: Classification of each sample (disease state)-- 1 or 2--must
%       be the same size as the number of cases in Intensities
%       ID: Case or patient ID number, same size as class.  May have second
%           column, so each row is [ID1 ID2} where ID2 is replicate number.
%       Options (logical):  Array of processing options with elements:
%           1. Normalize
%           2. Clip Data (remove negatives)
%           3. Replicate Average
%           4. Auto threshold MI
%           5. Use Log of Data
%           6. Remove Low Signal cases
%           NOT DONE: 3 Bin (2 Bin if False)
% 
% OUTPUTS
%
%       DataStructure: MATLAB data structure with the following components:
%           RawData: Intensities as input
%           ClipData: RawData where all values less than 1 are set to 1
%           NormData: ClipData normalized by total ion count, i.e.
%               divided by the sum of all variables for each case
%           LogData: Natural logarithm of NormData
%           Class, MZ: Same as input
%           ID: SIngle column. If replicates are not averaged, the entries
%               are now ID1.ID2. If replicates averaged, then just ID1 
%           DeltaMZ: difference in peak m/z values to look for adducts
%           RatioMZ: ratios of m/z values ot look for satellites
% 
% CALLED FUNCTIONS
% 
%       None. (cdfplot is MATLAB "stat" toolbox)


%% Initialize  Data
%  find the size, create the output structure,and transfer info

[rows cols]=size (StructIn.Intensities);
StructOut = StructIn;
StructOut.RawData = StructIn.Intensities;

%% Option 2: Clip Negatives from data
%  set values below 0 to be 1 because negative
%   molecule counts are not physically reasonable
% 1 is chosen rather than 0 in case log(data) is used
% Note: the decision to do this before normalization was based on
% discussions with Dr. William Cooke, who created the data set.

if StructOut.Options(2)
    StructOut.Intensities(find(StructOut.Intensities<1))=1; %#ok<FNDSB>
end

%%  Option 6: Removal of Cases with Low Signal
%   find the sum of all values for eah row, then normalize each row to
%   account for the effects of signal strenght over time and other
%   instrumental variations in total strength of the signal

% Find the total ion count for each case, then the global average.
% Determine a correction factor for each case (NormFactor)
if StructOut.Options(1) ||  StructOut.Options(6)
    RowTotalIonCount=sum(StructOut.Intensities, 2);
    AvgTotalIonCount=mean(RowTotalIonCount); %Population average
    NormFactor=AvgTotalIonCount./RowTotalIonCount; %Vector of norm factors
    StructOut.NormFactor=NormFactor;  %save this in the structure
end
% If Remove Low Signal is desired, interact with user to determine
% threshold, then remove all cases that are below the threshold.

if StructOut.Options(6)
    figure(999);
    cdfplot(NormFactor);
    title('Cumulative Distribution of Normalization Factors');
    
    % Request cutoff
    
    text(1.3,0.5,['Click on the graph where you want';...
          'the normalization threshold      ';...
          'Cases with high norm factor (or  ';...
          'low signal) will be removed.     ']);
    [NormThreshold, Fraction] = ginput(1);
    display([num2str(floor((1-Fraction)*100)),'% of cases removed']);
    close(999);
    TossMe=find (NormFactor>NormThreshold); %Low signal cases
    
    % Now record, then remove, those cases with low signal
    
    StructOut.LowSignalRemovedCases=StructOut.ID(TossMe,:);
    StructOut.LowSignalRemovedCasesNormFactors=NormFactor(TossMe);
    StructOut.Intensities(TossMe,:)=[];
    StructOut.ID(TossMe,:)=[];
    StructOut.Class(TossMe,:)=[];
    
end


%% Option 3: Replicate Average
% This option causes cases with same ID numbers to be averaged, peak by
% peak.

if StructOut.Options(3) %Replicate Average
    % Collapse to unique IDs only, throw out replicate ID column
    StructOut.Replicate_ID=StructOut.ID; %Save old data
    StructOut.Replicate_Class=StructOut.Class;
    
    newID=unique(StructOut.ID(:,1)); % List of unique IDs
    num=size(newID,1); %how many are there?
    newClass=zeros(num,1); % Holders for extracted class, data
    newData=zeros(num,cols);
    for i=1:num % for each unique ID
        id=newID(i); % work on this one
        cases=find(StructOut.ID(:,1)==id); % Get a list of cases with this ID
        newClass(i)=StructOut.Class(cases(1)); % save their class
        casedata=StructOut.Intensities(cases, :); % get their data
        newData(i,:)=mean(casedata, 1); % and save the average
    end
    StructOut.Intensities=newData;
    StructOut.Class=newClass;
    StructOut.ID=newID;
    clear newID newClass newData
else % If replicates exist, combine the 2 column ID into a single ID
    ID= StructOut.ID;
    if min(size(ID))==2
        shortID=ID(:,1)+(ID(:,2)*.001); % Now single entry is ID1.ID2
        StructOut.OldID=StructOut.ID;
        StructOut.ID=shortID;
        clear ID shortID
    end
    
end

%% Option 1: Normalize total ion count
% Apply the normalization factor to each row to normalize total ion count.
% We'll recalc norm factors in case data was replicate averaged.
if StructOut.Options(1)
    RowTotalIonCount=sum(StructOut.Intensities, 2);
    AvgTotalIonCount=mean(RowTotalIonCount); %Population average
    NormFactor=AvgTotalIonCount./RowTotalIonCount; %Vector of norm factors
    StructOut.NormFactor=NormFactor;  %save this in the structure
    NFmat=repmat(NormFactor, 1, cols); % match size of Intensities
    StructOut.Intensities=StructOut.Intensities.*NFmat;
    clear NFmat RowTotalIonCount AvgTotalIonCount NormFactor;
end


%%  Option 5: Work with log (data)

if StructOut.Options(5)
    StructOut.Intensities=log(StructOut.Intensities);
end


%% end function

end %of the function InitialProcessing

function adjacency = BuildBayesNet( data, class, ffactor, drop )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% BuildBayesNet selects features and metafeatures based on mutual info.
%  
% 
% DESCRIPTION
%       This function takes a set of training data and an additional
%       variable called "class" and tries to learn a Bayesian Network
%       Structure by examining Mutual Information.  The class variable C is
%       assumed to be the ancestor of all other variables V.  Arcs from C
%       to V are declared if MI(C;V)>>z, where z is a maximum expected MI
%       of similar, but random data...multiplied by a "fudge factor."  Arcs
%       from Vi to Vj are similarly declared. Then various tests are
%       performed to prune the network structure and combine variables that
%       exhibit high correlations. Finally the network is pruned to be a
%       Naive Bayesian Classifier, with only C->V arcs remaining.
% 
% USAGE
%       network_structure = BuildBayesNet( training_data, class )
% 
% INPUTS
%       training_data: cases in rows, variables in cols, integer array
%               containing the data used to build the Bayes net
%       class: the known class variable for each case (1:c col vector)
%       ffactor: multiple of auto MI to use to threshold C->V connections
%       drop: 
% 
% OUTPUTS
%
%       adjmatrix: a matrix of zeros and ones, where one in row i, column j
%               denotes a directed link in a Bayesian network between 
%               variable i and variable j. The class variable is the last
%               row/column.
% 
% CALLED FUNCTIONS
% 
%       automi: finds an MI threshold based on data
%       findmutualinfos: finds all values MI(V;C), MI(V;V) and MI(V;C|V)

%% Initialize

% Initialize the network object and some constants
network.data=data;
network.class=class;

automireps=10; %times to repeat the auto MI thresholding to find avg.

% Check the sizes of various things
[rows cols]=size(data); %#ok<NASGU>
cases=max(size(class));
if rows==cases
    clear cases
else
    disp('# of rows in the data and class must be equal.')
    return
end

% network.adjmat=zeros(cols+1); % all variables plus class as last row/col
dataalphabet=max(size(unique(data))); % number of possible values of data
classalphabet=max(size(unique(class))); % Number of values of class

%% Step 0: Find all the necessary mutual information values, thresholds
% The function below finds all values MI(V;C|V) and other combos needed and
% stores them in the network structure.

[ network.mi_vc, network.mi_vv, network.mi_vc_v ]...
                                = findmutualinfos( data, class );
                            
% Find a threshold MI by examining MI under randomization
% ******************************
% Come back to the next line
% ****************************
network.vcthreshold = automi( data, class, automireps )*ffactor ; %scalar MI threshold, 10 repetitions
network.vvthreshold = network.vcthreshold * log(dataalphabet)/log(classalphabet);


%% Step 1: Find all the possible arcs.
% Find the variables with high MI with the class, i.e. MI(V,C)>>0 and
% connect a link in the adjacency matrix C->V.  Also connect variable Vi,Vj
% if MI(Vi;Vj)>>0

network.adjmat1=getarcs(network.mi_vc, network.vcthreshold, network.mi_vv,...
                network.vvthreshold);

%% Step 2: Prune the variable set by clearing irrelevant features
% If there is no path from V to the class, clear all entries V<->Vi (all i) 
network.adjmat2 = clearirrarcs( network.adjmat1 );

%% Step 3: Cut connections to class
% Where two variables are connected to each other and also to the class,
% attempt to select one as the child of the other amd disconnect it from
% the class. Use MI(Vi;C|Vj)<<MI(Vi;C) as a test.

temp = clipclassconnections (network.adjmat2, ...
    network.mi_vc,network.mi_vc_v, drop);

% and once again clear features no longer near class and end function
adjacency= clearirrarcs( temp ); 

end % of the function BuildBayesNet
function [ mi_vc, mi_vv, mi_vc_v ] = findmutualinfos( data, class )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FunctionName short description 
% 
% DESCRIPTION
%       Input a training group of data arranged with cases in rows and 
% 
% USAGE
%       probtable = FindProbTables(data, class)
% 
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows 
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
% 
% OUTPUTS
%
%       probtable: 3-D array whose (c,d,v) value is P(class=c|data=p) for
%           variable v.
% 
% CALLED FUNCTIONS
% 
%       MIarray - finds the MI of each column in an array with a class vector

%% Intialize
% Find the sizes of the inputs and the number of possible values% 
% THIS VERSION USES SOMEONE ELSE'S CODE, BUT ITS 10x FASTER THAN MINE.
% 
% FINDMUTUALINFOS finds the various mutual info combos among variables.
% Given a set of data (many cases, each with values for many variables) and
% an additional value stored in the vector class, it finds MI described
% below in "OUTPUTS."

% 
% INPUTS:
% 
% data: A number of cases (in rows), each with a measurement for a group of
%    variables (in columns). The data should be discretized into integers 1
%    through k. The columns are considered variables V1, V2, ...
% class: an additional measurement of class C. A column vector of length 
%    "cases" with integer values 1,2...
% 
% OUTPUTS (all type double, >0)
% 
% mi_vc: a row vector whose ith value is MI(Vi,C).
% mi_vv: Symmetric matrix with values MI(Vi,Vj).
% mi_vc_v: Non-sym matrix with values MI(Vi;C|Vj).
% 
% CALLED FUNCTIONS
% 
% mutualinfo and condmutualinfo are from the mutualinfo package (c) 2002 by
% Hanchuan Peng, <penghanchuan@yahoo.com>.

% Calculate the value MI(Vi,C)
[rows cols]=size(data);
mi_vc = zeros(1,cols);

for v=1:cols
    mi_vc(v)=mutualinfo(data(:,v),class); % Fast using Pengs DLLS
end

% For each variable Vj, calculate MI(Vi,Vj) and MI(Vi;C|Vj)

        mi_vv=zeros(cols,cols);
        mi_vc_v=zeros(cols,cols);
        
for i=1:cols
    for j=1:cols
        mi_vv(i,j)=mutualinfo(data(:,i),data(:,j));
        mi_vc_v(i,j)=condmutualinfo(data(:,i),class,data(:,j));
    end
end

end %of the function findmutualinfos

function [l, r, binned, mi] = opt3bin (data, class)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FunctionName short description 
% 
% DESCRIPTION
%       This function takes an array of continuous sample data of size
%       cases (rows) by variables (columns), along with a class vector of
%       integers 1:c, each integer specifying the class. The class vector 
%       has the same number of cases as the data.  The function outputs the
%       position of the 2 bin boundaries (3 bins) that optimize the mutual
%       information of each variable's data vector with the class vector.    
% 
% USAGE
%       [l,r,binned, mi]=opt3bin(data,class)
% 
% INPUTS
%       data: double array of continuous values, cases in rows and 
%           variables in columns. Distribution is unknown.
%       class: double column vector, values 1:c representing classification
%           of each case. 
% 
% OUTPUTS
%
%       l     - row vector of left boundary position for each var.
%       r     - row vector of right boundary position for each var.
%       binned- data array discretized using boundaries in l and r
%       mi    - row vector of mutual info between each discr. variable 
%                  and class 
% 
% CALLED FUNCTIONS
% 
%       opt2bin: Similar function that finds a single boundary. This is
%           used as a seed for the 3 bin optimization.
%       looklr: See below.


%% Intialize
% 
%  Variable Prep : find sizes of arrays and create placeholders for locals

steps=150;
[rows cols]=size(data);
boundary=zeros(2,cols);

%% Method
% Find starting point by finding the maximum value of a 2 bin mi. Next, go
% left and right from that position, finding the position of the
% next boundary that maximizes MI.

[mi boundary(1,:)] = opt2bin (data, class, steps, 2);

% We've located a good starting (center) bin boundary.  Search L/R for a
% second boundary to do a 3 bin discretization.
[mi boundary(2,:)] = looklr (data, class, boundary(1,:), steps);

% We've now found the optimum SECOND boundary position given the best 2 bin
% center boundary.  Now re-search using that SECOND boaundary position,
% dropping the original (2 bin).  The result should be at, or near, the
% optimal 3 bin position.
[mi boundary(1,:) binned] = looklr (data, class, boundary(2,:), steps);

% from the two boundaries found above, sort the left and right
r=max(boundary);
l=min(boundary);

% Now retutn the vector of left and right boundaries, the disc. data, and
% max MI found.
end % of function

function [miout nextboundary binned] = looklr (data, class, startbd, steps)
% given a start position, finds another boundary (to create 3 bins) that
% maximizes MI with the class
[rows cols]=size(data);
farleft=min(data,[],1);
farright=max(data,[],1);
miout=zeros(1,cols);
binned=zeros(rows,cols);
nextboundary=zeros(1,cols);

for peak=1:cols % for each peak/variable separately...

    % discretize this variables' values. Sweep through the possible
    % bin boundaries from the startbd to the furthest value of the
    % data, creating 2 boundaries for 3 bins. Record the binned values in
    % a "cases x steps" array, where "steps" is the granularity of the
    % sweep. The data vector starts off as a column...

    testmat=repmat(data(:,peak),1,steps); % and is replicated to an array.
    
    % Create same size array of bin boundaries. Each row is the same.
    checkptsL=repmat(linspace(farleft(peak),startbd(peak),steps),rows,1); 
    checkptsR=repmat(linspace(startbd(peak),farright(peak),steps),rows,1);
    
    % Create a place to hold the discrete info, starting with all ones. The
    % "left" array will represent data binned holding the center boundary
    % fixed and sweeping out a second boundary to the left; similarly the
    % right boundary starts at "startbd" and sweeps higher.
    binarrayL=ones(rows,steps); 
    binarrayR=ones(rows,steps); 
    
    % Those in the L test array that are higher than the left boundary -> 2
    binarrayL(testmat>checkptsL)=2;
    binarrayL(testmat>startbd(peak))=3; % >center boundary -> 3
    
    % Similarly using center and right boundaries
    binarrayR(testmat>startbd(peak))=2;
    binarrayR(testmat>checkptsR)=3;

    % Now at each of those step positions, check MI (var;class).
    miout(peak) = 0;
    for j=1:steps
        miL = mutualinfo(binarrayL(:,j),class);% MI(V;C) using left/center
        miR = mutualinfo(binarrayR(:,j),class);% MI(V;C) using center/right
        if miL>miout(peak) % check if that steps MI is the highest yet
            miout(peak)=miL; % if so, record it
            newboundary=checkptsL(1,j); % and record the boundary
            binned(:,peak)=binarrayL(:,j); % and record the discrete data
        end
        if miR>miout(peak) % and check the center/right combo similarly
            miout(peak)=miR;
            newboundary=checkptsR(1,j);
            binned(:,peak)=binarrayR(:,j);
        end
            
    end % checking each possible boundary position.
    
    % we should now know the best possible place to put a second boundary,
    % either left or right of a given starting position. Record that.
    nextboundary(peak)=newboundary;
    
end % of that variable's search.  Go to next variable.

end % of function opt3bin. Return the best boundary and associated MI and data

function  [finaldata metamatrix leftbound rightbound] =...
    ChooseMetaVars ( data, class, adj)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% ChooseMetaVars attempts to combine variables into better variables 
% 
% DESCRIPTION
%       Finds the V-V pairs in the adjacency matrix, and attempts
%       to combine them into a metavariable with a higher mutual
%       information than either variable alone. If it is possible to do
%       this, it returns a new data matrix with the variables combined. 
% 
% USAGE
%       [finaldata metamatrix leftbound rightbound] =
%                        ChooseMetaVars ( data, class, adj)
% 
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows 
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%       adj: Adjacency matrix, #variables+1 by #variables. Last row is
%           class node. Logical meaning "there is an arc from i to j."
% 
% OUTPUTS
%       metamatrix: logical whose (i,j) means "variable j was combined into
%           variable i (and erased)"
%       finaldata: The data matrix with the variable combined and rebinned
%       leftbound: The new left boundary (vector) for binning.
%       rightbound: The new right boundary (vector) for binning.    
% 
% CALLED FUNCTIONS
%       opt3bin: rebins combined variables to determine highest MI.

%% Intialize
[rows cols]=size(data);
[classrow numvars]=size(adj);
bindata=zeros(rows,cols);
metamatrix=false(cols);

% Create a list of all the variables V to check by examining the adjacency
% matrix's last row, i.e. those with C->V connections
listvec=1:numvars;
varstocheck=unique(listvec(adj(classrow,:))); 
l=zeros(1,numvars);
r=zeros(1,numvars);

% Now go through that list, testing each V->W connection to see if adding V
% and W creates a new variable Z that has a higher MI with the class than V
% alone.  V is the list above, W is the list of variables connected to a V.

for v=varstocheck % Pull out the W variables connected to V and test
    wlist=unique(listvec(adj(v,:)));
    [l(v), r(v), binned, mitobeat] = opt3bin(data(:,v), class);
    bindata(:,v)=binned;
    if ~isempty(wlist)
        for w=wlist
            newdata=data(:,v)+data(:,w);
            [left, right, binned, newmi] = opt3bin(newdata, class);
            if newmi>mitobeat
                mitobeat=newmi;
                data(:,v)=data(:,v)+data(:,w);
                metamatrix(v,w)=true; % record the combination
                bindata(:,v)=binned; 
                l(v)=left;
                r(v)=right;
            end
        end
    end
    
                
                % ********************************************************
                %              Too simple - should check all combos
                %    May be that a later variable is better
                % ***************************************************
end

%pull out just the V->C columns from the data matrix.
finaldata=bindata(:,adj(classrow,:));
leftbound=l(adj(classrow,:));
rightbound=r(adj(classrow,:));
end %of function ChooseMetaVars

function adjout = clearirrarcs( adjin )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% CLEARIRRARCS clears arcs that are not C->V or C->V<->V 
% 
% DESCRIPTION
%       Given an adjacency matrix with V<->V arcs in a square matrix and an
%       additional row representing C->V (class to variable), this function
%       clears out all V1->V2 arcs where V1 is not a member of the set of
%       V's that are class-connected, i.e. have arcs in the final row.
% 
% USAGE
%       adjout = clearirrarcs( adjin )
% 
% INPUTS
%       adjin: a logical array where a true value at position (i,j) means
%           that there is an arc in a directed acyclic graph between
%           (variable) i and variable j.
% 
% OUTPUTS
%       adjout: copy of adjin with unneeded arcs cleared
% 
% CALLED FUNCTIONS
%       None.

%% Intialize
% Find the sizes of the input
[classrow, numvars]=size(adjin);

%% Main processing
% Find out which variables are connected to class
conntocls=(adjin(classrow,:));

% Remove all arcs that don't have at least one variable in this list,
% e.g. all Vi<->Vj such that ~(Vi->C or Vj->C). These are all the entries 
% in the adjacency matrix whose i and j are NOT in the list above.

% Make a matrix with ones where neither variable is in the list above
noconnmat=repmat(~conntocls,numvars,1) & repmat(~conntocls',1,numvars);

% Use that to erase all the irrelevant entries in the square adj matrix, at
% the same time remove the diagonal (arcs Vi<->Vi)
adjout=adjin (1:numvars, 1:numvars)& ~noconnmat & ~eye(numvars);

% Bidirectional arcs are temporarily permitted between nodes connected
% directly to the class, but not between nodes where only one is connected
% to the class- those are assumed to flow C->V1->V2 only.  Remove V2->V1.

% Get a matrix of ones in rows that are class connected. V->V arcs are only
% allowed to be in these rows:
parents=repmat(conntocls',1,numvars);
% Remove anything else
adjout=adjout & parents;

% Now add back in the class row at the bottom of the square matrix
adjout(classrow,:)=adjin(classrow,:);

end % of function clearirrrarcs

function adjout = clipclassconnections( adj, mivc_vec, mivcv, dropthreshold )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% clipclassconnections delinks variables from class 
% 
% DESCRIPTION
%        Where two variables are connected to each other and also 
%        to the class, attempt to select one as the child of the other and
%        disconnect it from the class. Use MI(Vi;C|Vj)<<MI(Vi;C) as a test.
% 
% USAGE
%       probtable = FindProbTables(data, class)
% 
% INPUTS
%       adj: (logical) matrix where "true" entries at (i,j) mean "an arc
%            exists from the Bayesian network node Vi to Vj." The class 
%            variable C is added at row (number of V's + 1). "0" values
%            mean no arc.
%       mivc_vec: (double) row vector containing MI(C;Vi) for each variable
%       mivcv: (double) array whose (i,j) entry is MI(Vi,C|Vj).
%       dropthreshold: percentage drop from MI(Vj;C) to MI(Vj;C|Vi) before
%           declaring that Vi is between C and Vj. 
% 
% OUTPUTS
%
%       adjout: copy of adj with the appropriate arcs removed.
% 
% CALLED FUNCTIONS
% 
%       None.


%% Intialize

[classrow, numvars]=size(adj);
classconnect=adj(classrow, :); % the last row of adj stores arcs C->V
adjout=false(classrow, numvars); % placeholder for output array

%% Identify triply connected arcs

% First look for pairs that are connected to each other and connected to
% the class. 

% Connected to each other: build logical array with (i,j) true if Vi<->Vj
vv_conn=adj(1:numvars, 1:numvars);

% Connected to the class: logical array with (i,j) true if C->Vi and C->Vj
vcv_conn=repmat(classconnect, numvars,1) & repmat(classconnect',1,numvars);

% Find all (i,j) with both true
triple_conn = vv_conn & vcv_conn;

%% Determine preferred direction on V<->V arcs

% Determine the Vi<->Vj direction by finding the greater of MI(C;i|j) or
% (C;j|i).  Greater MI means less effect of the instantiation of i or j.
arcdirection=mivcv > mivcv'; %Only the larger survive
dag_triple_conn=arcdirection & triple_conn; % Wipes out the smaller ->

% find links should NOT be kept under the test above,
linkstoremove=(~arcdirection) & triple_conn;
% and if they are in the connection list, remove them
adjout(1:numvars, 1:numvars)=xor(vv_conn,linkstoremove);

% Now we need to test whether we can remove the link between C and which
% ever V (i or j) is the child of the other. We look for a "significant"
% drop in MI(Vj;C) when instantiating Vi, e.g. MI(Vj;C|Vi)<<MI(Vj;C).
%
% dropthreshold of .7, for example, means link breaks if 1st term is less
% than 30% of the second term.
%
% If there is a big drop in MI(C;Vj) when Vi is given, and Vi->Vj exists in
% the DAG, then we can remove the link C->Vj and leave C->Vi->Vj.

% Build an array out of the mivc_vec vector
mivc=repmat(mivc_vec',1,numvars);
% Test for the large drop described above
bigdrop=((mivc-mivcv)./mivc) > dropthreshold;
% Test for the big drop and the V-V connection
breakconn = bigdrop' & dag_triple_conn;
% If any of the elements in a column of the result are true, remove that
% variable's C->V link, since it is a child.
linkstokeep=~any(breakconn);
adjout(classrow,:)= adj(classrow,:) & linkstokeep;

% With V->V links now only one way, and C->V removed where needed, we can
end % of function clipclassconnections

function p=FindProbTables(data, class)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FindProbTables estimates the probabilities P(class=c|data=D)  
% 
% DESCRIPTION
%       Input a training group of data arranged with cases in rows and 
%       variables in columns, as well as the class value c for that vector. 
%       Each case represents a data vector V.  For each possible data value 
%       vi, and each variable Vi, it calculates P(C=c|Vi=vi) and stores 
%       that result in a 3-D table.  The table is arranged with the 
%       dimensions (class value, data value, variable number).
% 
% USAGE
%       probtable = FindProbTables(data, class)
% 
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows 
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
% 
% OUTPUTS
%
%       probtable: 3-D array whose (c,d,v) value is P(class=c|data=p) for
%           variable v.
% 
% CALLED FUNCTIONS
% 
%       None.

%% Intialize
% Find the sizes of the inputs and the number of possible values
[cases numvars]=size(data);
datavals=max(size(unique(data)));
classvals=max(size(unique(class)));
% Build some placeholders and loop indices
p=zeros(classvals, datavals, numvars ); % triplet: (class, value, variable#) 
databins=1:datavals;
classbins=1:classvals;

%% Find Probabilities
% For each classification value, extract the data with that class
for c=classbins
    datainthatclass=data(class==c,:); % array of just cases with class=c
    % find the percentage of data with each possible data value
    p(c,:,:)=histc(datainthatclass,databins)/cases;
end

end % of function FindProbTables

function adjacency = getarcs( mvc, vcthreshold, mvv, vvthreshold )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% GETARCS builds the adjacency matrix for a set of variables 
% 
% DESCRIPTION
%       By comparing mutual information between two variables to thresholds
%       determined seperately, this function declares there to be an arc in
%       a Bayesian network. Arcs are stored in an adjacency matrix,
%       described below.
% 
%       The primary tests are:
%       MI(Vi;Cj)>>vcthreshold : tests for links between Vi and the class
%       MI(Vi;Vj)>>vvthreshold : tests the links between variables
% 
% USAGE
%       adjacency = getarcs( mvc, vcthreshold, mvv, vvthreshold )
% 
% INPUTS
%       mvc [mvv]: double vector [array] with mutual information between
%           variables and the class [variables and other variables]. The
%           (i,j) entries of mvv are MI(Vi,Vj).
%       vc/vvthreshold: scalar threshold used to test for existence linkz
% 
% OUTPUTS
%
%       adjacency: logical matrix whose entries "1" at (i,j) mean "an arc
%            exists from the Bayesian network node Vi to Vj." The class 
%            variable C is added at row (number of V's + 1). "0" values
%            mean no arc.
% 
% CALLED FUNCTIONS
% 
%       None.
% 
% For more information on the tests and the links, see my dissertation.


%% Initialize
numvars=max(size(mvc)); %the number of variables
classrow=numvars+1; %row to store links C->V
adjacency= false(classrow,numvars); %the blank adjacency matrix

%% Test for adjacency to class
adjacency ( classrow , : )= mvc > vcthreshold;

%% Test for links between variables
% This test results in a symmetric logical matrix since MI (X;Y) is
% symetric. To create a directed graph, these arcs will need to be pruned.
adjacency ( 1:numvars, 1:numvars ) = mvv > vvthreshold;

end %of function getarcs

function [mi boundary binneddata] = opt2bin (rawdata, class, steps,...
    typesearch, minint, maxint)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% opt2bin finds the best single boundary for each variable to maximize MI 
% 
% DESCRIPTION
%       This function takes an array of continuous data, with cases in rows
%       and variables in columns, along with a vector "class" which holds
%       the known class of each of the cases, and returns an array
%       "binneddata" that holds the 2 bin discretized data.  The
%       discretization bin boundary is found by maximizing the mutual
%       information with the class; the resulting MI and boundary are also
%       returned. The starting boundaries for the search can be given in
%       the vectors min and max, or either one, or neither, in which case
%       the data values determine the search boundaries.% 
%
% USAGE
%       [mi boundary binneddata] = maxMIbin(rawdata, class, typesearch [,
%           min, max])
% 
% INPUTS
%       rawdata: double array of continuous values, cases in rows and 
%           variables in columns. Distribution is unknown.
%       class: double column vector, values 1:c representing classification
%           of each case. 
%       steps: Number of steps to test at while finding maximum MI
%       typesearch =0: starting bndry based on data's actual max/min values
%                  =1: use the value passed in max as maximum (right) value
%                  =-1: use the value passed in min as minimum (left) value
%                  =2: used values passed via max, min
%       the two optional arguments are vectors whose values limit the range
%       of search for each variables boundaries.
% 
% OUTPUTS
%
%       mi: row vector holding the maximum values of MI(C;Vi) found
%       boundary: The location used to bin the data to get max MI
%       binneddata: The resulting data binned into "1" (low) or "2" (hi)
% 
% CALLED FUNCTIONS
% 
%       MIarray: Finds the MI of each col in an array with a separate
%           vector (the class in this case)

%% Intialize
[rows cols]=size(rawdata);
mi=zeros(1,cols);
boundary=zeros(1,cols);
binneddata=zeros(rows,cols);
currentmi=zeros(steps,cols);

% if not passed, find the left and rightmost possible bin boundaries from
% data

if nargin~=6
    minint=min(rawdata,[],1);
    maxint=max(rawdata,[],1);
elseif typesearch==1
    minint=min(rawdata,[],1);
elseif typesearch==-1
    maxint=max(rawdata,[],1);
elseif typesearch==2
    disp('using passed values')
else
    disp('typesearch must = 0,1,-1,2')
    return
end

%% Find best boundary

for peak=1:cols %look at each variable separately

    % Create an array of bin boundary's possible locations min->max
    checkpoints=repmat(linspace(minint(peak),maxint(peak),steps),rows,1); 
    
    % discretize the variable's values at each of these possible
    % boundaries, putting 2's everywhere (value > boundary), 1 elsewhere 
    binarray=(repmat(rawdata(:,peak), 1, steps)>checkpoints)+1; 
    
    % find the MI(C,V) for each possible binning  

    [rbin cbin]=size(binarray);
    mivec=zeros(1,cbin);
    for v=1:cbin
        mivec(v)=mutualinfo(binarray(:,v),class); % Fast using Pengs DLLS
    end
    currentmi(1:steps,peak)=mivec;
    
    % Now pick out the highest MI, i.e. best bin boundary 
    [mi(peak) atstep]=max(currentmi(:,peak));
    boundary(peak)=checkpoints(1,atstep);
    
    % and record the binned data using that boundary.
    binneddata(:,peak)=binarray(:,atstep);
end

end % of function opt2bin

function classprobs = TestCases( p, prior, data)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% classprobs uses Bayes rule to classify a case
% 
% DESCRIPTION
%       Tests each of a set of data vectors by looking up P(data|class) in
%       a probability table, then finding P(case|class) by multiplying each
%       of those values in a product.  Then uses Bayes' rule to calculate
%       P(class|data) for each possible value of class.  Reports this as an
%       array of class probabilities for each case.
% 
% USAGE
%       classprobs = TestCases( p, prior, data)
% 
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows 
%           and variables in columns.
%       p: 3-D double array of probabilities (c,d,v).  The first dimension 
%           is the class, the second is the data value, the third is the 
%           variable number. The entry is P(var v=value d | class=value c).
% 
% OUTPUTS
%
%       classprobs: 2-D double array whose value is P(class=c|data) for
%           each case. Cases are in rows, class in cols.
% 
% CALLED FUNCTIONS
% 
%       None.

%% Intialize

% Find the sizes of the inputs and the number of possible values
[cases numvars]=size(data);
classvals=size(p,1);
pvec=zeros(classvals,numvars);
classprobs = zeros(cases, classvals); % holds the classification results

%% Find the probabilities

% Create pvec, an array whose first row is P(V=v|c=1) for each V
for casenum=1:cases
    casedata=data(casenum,:); % The case to be checked
    for c=1:classvals
        for v=1:numvars
            pvec(c,v)=max(p(c,casedata(v),v),.01); % Don't want any zeros
        end
    end
    % Now find P(case|class) for each class by multiplying each individual
    % P(V|C) together, assuming they are independant.
    
    Pdc=prod(pvec,2);
    
    % Use Bayes' Rule
    
    classprobs(casenum,:) =(Pdc.*prior)/sum(Pdc.*prior); 

end

end % of function TestCases

function threshold = automi( data, class, repeats )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% automi finds a threshold for randomized MI(V; C) 
% 
% DESCRIPTION
%       Finds the threshold of a data set's mutual information with a class
%       vector, above which a variable's MI(class, variable) can be
%       expected  to be significant. The threshold for mi (significance 
%       level) is found by taking the data set and randoomizing the class
%       vector, then calculating MI(C;V) for all the variables. This is
%       repeated a number of times. The resulting list of length (#repeats
%       * #variables) is sorted,  and the 99th percentile max MI is taken
%       as the threshold.

% USAGE
%       threshold = automi( data, class )
% 
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows 
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%       repeats: the number of times to repeat the randomization
% 
% OUTPUTS
%
%       threshold: the significance level for MI(C;V)
% 
% CALLED FUNCTIONS
% 
%       MIarray(data,class): returns a vector with MI(Vi;Class) for each V
%           in the data set

%% Intialize

% Find the size of the data (cases x variables) and check against class
[rows cols]=size(data);
cases=max(size(class));
if rows==cases
    clear cases
else
    disp('# of rows in the data and class must be equal.')
    return
end


%% Repeat a number of times

mifound=zeros(cols,repeats); % stores the results of the randomized MI
for i=1:repeats
    c=class(randperm(rows)); % creates a randomized class vector
    
    [rbin cbin]=size(data);
    mivec=zeros(1,cbin);
    for v=1:cbin
        mivec(v)=mutualinfo(data(:,v),c); % Fast using Pengs DLLS
    end
    mifound(:,i)=mivec;
end

% pull off the 99th percentile highest MI
mi_in_a_vector=reshape(mifound,repeats*cols,1); % prctile needs vector
threshold=prctile(mi_in_a_vector,99);

end % of function automi


Contact us at files@mathworks.com