function [results, resultsSub] = MLE_Analysis(A, nrStim, fitlapses, force_standards, ploton)

% -------------------------------------------------------------------------
%   MLE_Analysis(A, nrStim, fitlapses, force_standards, ploton)
%
%   Perform the MLE analysis by separating single cue data from multicue
%   data and then fitting a cumulative Gaussian (see FitCumulativeGauss) to
%   each part.
%
%   Input parameters:
%   A = The general data matrix. Rows denote data for different comparison
%       stimuli for an individual subject.
%       Each row is of the following format:
%       [C-S, P(C>S), N, S, [cues on/off], [Conf added], [noise level], subject]
%       where  C: comparison
%              S: standard
%              P(): proportion
%              N: number of data points (repetitions)
%              Conf: inter-sensory conflict.
%               
%       For instance, for a two-cue integration data set, a row like this:
%       [  -50, 0.2, 10, 0, 1, 1, -10, 10, 0, 2, 1]
%
%       indicates a difference standard/comparison of -50 (1st col), 
%       20% 'yes' responses (2nd col) for this comparison 
%       in ten repetitions (3rd col),
%       a standard value of 0 (4th col), 
%       a bimodal presentation (5th and 6th col are both set to 1),
%       an intersensory conflict of 20: 10 subtracted from cue 1 (7th col)
%       and 10 added to cue 2 (8th col),
%       no noise on cue 1 (9th col) and 
%       the 2nd noise level on cue 2 (10th col),
%       for the first participant (11th col)
%
%       A row like this:
%       [  20, 0.9, 10, 10, 0, 1, 0, 0, 0, 0, 10]
%
%       indicates a difference standard/comparison of 20 (1st col), 
%       90% 'yes' responses (2nd col) for this comparison 
%       in ten repetitions (3rd col),
%       a standard value of 10 (4th col), 
%       a unimodal presentation of cue 2 (5th col is 0, 6th col is 1),
%       with no conflict and no noise added,
%       for the 10th participant included in the dataset (11th col).
%
%       Note: the analysis is not necessarily restricted to designs with
%       only two cues, but also works for multiple cue combinations within
%       one dataset. The nr of columns will then be increased. For example,
%       when 3 cues are involved [cues on/off], [Conf added], [noise level]
%       will each cover three columns instead of two.
%
%   nrStim = the number of stimuli presented in one trial. For instance if
%       2 interval forced choice (2IFC) was used the JND represents sqrt(2)
%       times the sigma of a single stimulus. To take this correction into
%       account, the system needs to know the number of stimuli used in one
%       trial. I.e. nrStim = 1 for task that use only one stimulus (e.g.
%       for slant this could be which side of the single stimulus is more
%       in front). nrStim = 2 for tasks like the 2IFC in which a comparison
%       stimulus is compared to a standard (which one is more slanted).
%       Since most MLE studies involve the 2AFC/2IFC type of task the
%       default is nrStim = 2 (thus taking sigma = JND/sqrt(2)). For a
%       refence see Rohde, van Dam & Ernst (hopefully 2015).
%
%   fitlapses =  Whether or not to fit an additional parameter, the 
%       lapserate, to each curve. 1 = yes; 0 = no; default is 0.
%
%   force_standards =  whether or not to force combining the standards in 
%       the in the analyis. 1 = yes; 0 is no. If data for different 
%       standards are not forcefully combined the analysis is performed for
%       each standard separately and a test is performed to check for 
%       differences between standards. Only if no such differences are
%       found the standards are combined, and otherwise an error message is
%       produced. force_standards = 1 skips this test and thus, forces the 
%       combination of the standards without testing (neither on the
%       participant level nor across participants).
%
%   ploton = is 3 value list (0 = off; 1 = on). Respectively: plot single
%       subject curves, plot single subject summary plot, plot multisubject
%       summary. default is [0,0,1]
%
%
%   The output is a structure with the following fields:
%   singleCues = a substructure in which each field refers to one of the
%       cues used. For the structure of this substruct see below.
%   cueCombos = a substructure in which each field refers to one of the cue
%       combination conditions. See below.
%
%
%   Structure subfields for singleCues:
%   params:  stores the fitted mu,sigma and lapserate. If there are
%       multiple levels of noise in the design each row corresponds to a
%       different noise level.
%   params_STD: standard deviations for the fitted parameters (mu and 
%       sigma only). Each row corresponds to a different noise level for 
%       this cue.
%   params_CI: confidence intervals for each parameter.
%   noiselevels: The noise levels for this cue.
%            
%   Structure subfields for cueCombos:
%   params: stores the fitted mu,sigma and lapserate. For cueCombos params
%       is a 3D array with the first dimension representing the different
%       noise levels involved, the second dimension the levels of conflict
%       and the third the three different parameters.
%   params_STD: The standard deviation of the fitted parameters (mu and 
%       sigma only). The 3D structure is the same as for params.
%   params_CI: confidence intervals for each parameter.
%   prediction: The predicted mu and sigmas. The 3D array has the same
%       format as params and se.
%   prediction_STD: standard deviations for the predictions.
%   prediction_CI: confidence intervals for the predictions.
%   noiselevels: the noise levels involded. This is a 2D array with
%       dimensions noiselevels, cues.
%   noiselevelIDs: the indeces for the noise levels involded. This is a 2D 
%       array with dimensions noiselevels, cues. This array is useful for
%       finding the corresponding unisensory cues more easily
%   conflicts: the partial conflicts added to the standard for each cue.
%       This is a 3D array with dimensions noiselevels, conflicts, cues.
%   cueID: the cues-numbers included in this cue combination. This is of
%       particular use when more than 2 cues are included in the design.        
%   weights: the empirical weights for each cue in the combined conditions
%   weights_STD: standard deviations of the empirical weights
%   weights_CI: Confidence intervals of the empirical weights
%   predicted_weights: the predicted weights for each cue in the combined
%       conditions
%   predicted_weights_STD: standard deviations of the predicted weights
%   predicted_weights_CI: Confidence intervals of the predicted weights
%   sigma: the empirical sigma for the combined conditions across conflicts
%   sigma_STD: standard deviations of the empirical sigma across conflicts
%   sigma_CI: Confidence intervals of the empirical sigma across conflicts
%   predicted_sigma: the predicted sigma for the combined conditions
%   predicted_sigma_STD: standard deviations of the predicted sigma
%   predicted_sigma_CI: Confidence intervals of the predicted sigma
%
%
%   Note:
%   To measure optimal integration there are two things of importance:
%   1: The PSE (Point of Subjective Equality) for the combined percept
%   should be a weighted average of the single cues alone. The weights in
%   this case should correspond to the relative reliability.
%   2: The JND (Just Noticable Difference) should be better than either cue
%   alone according to
%   sigma_c^2 = sigma_1^2 * sigma_2^2 / ( sigma_1^2 + sigma_2^2 )
%   or since sigma relates to the JND by a constant also
%   JNDc^2 = JND1^2 * JND2^2 / ( JND1^2 + JND2^2 ).
%
%
%   by Loes van Dam, 2014
%   Loes.van_Dam@uni-bielefeld.de
%
% -------------------------------------------------------------------------
 
sizeA = size(A);
fprintf(...
    ['Will now infer number of cues from the size of A\n'...
     '...\n']);
nrcues = (sizeA(2)-5)/3;
if nrcues == round(nrcues),
    % succes
    fprintf('Seems like nr of cues = %d\n', nrcues);
    fprintf('Will continue with this value\n\n');
else
    % fail
    fprintf(...
        ['The matrix A does not have the appropriate format.\n',...
         'Did you include all the necessary variables?\n'...
         'Type "help MLE_Analysis" for more information\n']);
    error([ 'Error: ',...
        'The matrix A does not seem to have the appropriate format.\n']);

end

if nargin < 2 || isempty(nrStim),
    nrStim = 2;
end
if nrStim ~= 1 && nrStim ~= 2,
    error('nrStim does not seem to have an appriate value. Should be either 1 or 2.');
end

if nargin < 3 || isempty(fitlapses),
    fitlapses = 0;
end

if nargin < 4 || isempty(force_standards),
    force_standards = 0;
end

if nargin < 5 || isempty(ploton),
    ploton = [0,0,1];
end
if length(ploton)<2,
    ploton = [ploton,1];
end
if length(ploton)<3,
    ploton = [ploton,1];
end

% -------------------------------------------------------------------------
%   Set how to compute confidence intervals:
%   using norminv (assuming normal distribution): CI_NormT = 0
%   using tinv (assuming student t distribution): CI_NormT = 1
% -------------------------------------------------------------------------

CI_NormT = 0;

% -------------------------------------------------------------------------
%   Extract nr of participants, conditions etc from A
% -------------------------------------------------------------------------

nrSub = max(A(:,sizeA(2)));
fprintf('Nr of Participants: %d\n\n', nrSub);
results.nrSub = [nrSub,nrSub];

uniquesubs = length(unique(A(:,sizeA(2))));

if uniquesubs ~= nrSub,
    error(['Participants indexes are higher than the number of unique participants. ',...
        'Please fix the format of input matrix A such that participants are numbered consecutively.']);
end

standard_conditions = unique(A(:,4));  % how many baseslant conditions were there
nrStandards = length(standard_conditions);

cueID = 4+(1:nrcues);
conflictID = max(cueID)+(1:nrcues);
noiseID = max(conflictID)+(1:nrcues);

NoiseLevels = unique(A(:,noiseID),'rows');

assumption_test = 1;

if nrSub > 1,
    % -------------------------------------------------------------------------
    %   Do MLE analysis for each participant for each standard
    % -------------------------------------------------------------------------
    
    if force_standards == 0 && nrStandards > 1,
        for sub = 1:nrSub,
            dataSub = A(A(:,sizeA(2)) == sub,1:end-1);
            
            % separate for the different standards
            for st = 1:nrStandards,
                dataSubB = dataSub(dataSub(:,4)==standard_conditions(st),:);
                % skip the check for combination of standards within participants
                % also do not plot in this case
                tempSub = MLE_SingleSubject(dataSubB,nrStim,fitlapses,1,[0,0],0);
                
                % we are only interested in the single cue results in this
                % case
                
                for cue = 1:nrcues,
                    singleCue{cue}.musig(sub,st,:,:) = tempSub.singleCues{cue}.params(:,1:2); 
                end
            end

        end
        
        
        % -------------------------------------------------------------------------
        %   Check if standards are the same across participants
        % -------------------------------------------------------------------------
        SCnoiselevels = 0;
        for cue = 1:nrcues,
            SCnoiselevels = SCnoiselevels + length(squeeze(singleCue{cue}.musig(sub,st,:,1)));
        end
    
        alpha = 0.05;
        nrtests = sum(1:(nrStandards-1)) * 2 * SCnoiselevels; % for Bonferroni correction
        alpha = alpha/nrtests; % Bonferroni corrected alpha
        
        
        for cue = 1:nrcues,
            for j = 1:nrStandards,
                for k = (j+1):nrStandards,
                    for m = 1:length(squeeze(singleCue{cue}.musig(sub,st,:,1)));
                        % do t-test for each pair of mu and sigma
                        temp1 = singleCue{cue}.musig(:,j,m,1);
                        temp2 = singleCue{cue}.musig(:,k,m,1);
                        
                        h = ttest(temp1,temp2,alpha);
                        
                        if h == 1, % the standards are different and thus our assumption is invalid
                            assumption_test = 0;
                        end
                    end
                end
            end
        end
        
    end    
    % -------------------------------------------------------------------------
    %   Do MLE analysis with standards pooled for each participant
    % -------------------------------------------------------------------------
    if assumption_test == 1,
        subsub = 0;
        for sub = 1:nrSub,
            dataSub = A(A(:,sizeA(2)) == sub,1:end-1);
            [resultsSub{sub},flagLargeJND] = MLE_SingleSubject(dataSub,nrStim,fitlapses,1,ploton(1:2),0);

            if sub == 1,
                nrCombos = length(resultsSub{sub}.cueCombos);
            end

            if flagLargeJND == 1,
                fprintf(['WARNING: for subject %d the JND is larger than the measurement range used in one or more conditions.\n'...
                    'This participant will be excluded from the average results\n'],sub);
                results.nrSub(1) = results.nrSub(1)-1;
            end

            if flagLargeJND == 0;
                subsub = subsub+1;
                for j = 1:nrcues,
                    singleCues{j}(subsub,:,:) = resultsSub{sub}.singleCues{j}.params;
                    if subsub == 1,
                        results.singleCues{j}.noiselevels = resultsSub{sub}.singleCues{j}.noiselevels;
                    end
                end

                for j = 1:nrCombos,
                    cueCombos{j}(subsub,:,:,:) = resultsSub{sub}.cueCombos{j}.params;
                    cuePredic{j}(subsub,:,:,:) = resultsSub{sub}.cueCombos{j}.prediction;
                    weightPredic{j}(subsub,:,:) = resultsSub{sub}.cueCombos{j}.predicted_weights;
                    cueWeights{j}(subsub,:,:) = resultsSub{sub}.cueCombos{j}.weights;
                    
                    cueSigma{j}(subsub,:) = resultsSub{sub}.cueCombos{j}.sigma;
                    cueSigmaPred{j}(subsub,:) = resultsSub{sub}.cueCombos{j}.predicted_sigma;

                    if subsub == 1,
                        results.cueCombos{j}.noiselevels = resultsSub{sub}.cueCombos{j}.noiselevels;
                        results.cueCombos{j}.noiselevelIDs = resultsSub{sub}.cueCombos{j}.noiselevelIDs;
                        results.cueCombos{j}.conflicts = resultsSub{sub}.cueCombos{j}.conflicts;
                        results.cueCombos{j}.cueID = resultsSub{sub}.cueCombos{j}.cueID;
                    end
                end
            end
        end
        
        % -------------------------------------------------------------------------
        %   Order the results to do analysis across participant analysis
        % -------------------------------------------------------------------------

        for j = 1:nrcues,
            CNoise = results.singleCues{j}.noiselevels;
            for NL = 1:length(CNoise),
                results.singleCues{j}.params(NL,:) = mean(squeeze(singleCues{j}(:,NL,:)));
                results.singleCues{j}.params_STD(NL,:) = std(squeeze(singleCues{j}(:,NL,:)));%/sqrt(subsub);
                if CI_NormT == 0,
                    results.singleCues{j}.params_CI(NL,:,:) = squeeze(results.singleCues{j}.params(NL,:))'*ones(1,2)+...
                        norminv(0.975)*results.singleCues{j}.params_STD(NL,:)'*[-1,1]/sqrt(subsub);                    
                else
                    results.singleCues{j}.params_CI(NL,:,:) = squeeze(results.singleCues{j}.params(NL,:))'*ones(1,2)+...
                        tinv(0.975,subsub-1)*results.singleCues{j}.params_STD(NL,:)'*[-1,1]/sqrt(subsub);
                end
            end
        end
        for j = 1:nrCombos,
            CNoise = results.cueCombos{j}.noiselevels(:,1,:);
            for NL = 1:length(CNoise(:,1)),    
                Cconf = squeeze(results.cueCombos{j}.conflicts(NL,:,:));
                for CF = 1:length(Cconf(:,1)),
                    results.cueCombos{j}.params(NL,CF,:) = mean(squeeze(cueCombos{j}(:,NL,CF,:)));
                    results.cueCombos{j}.params_STD(NL,CF,:) = std(squeeze(cueCombos{j}(:,NL,CF,:)));%/sqrt(subsub);
                                        
                    if CI_NormT == 0,
                        results.cueCombos{j}.params_CI(NL,CF,:,:) = squeeze(results.cueCombos{j}.params(NL,CF,:))*ones(1,2)+...
                            norminv(0.975)*squeeze(results.cueCombos{j}.params_STD(NL,CF,:))*[-1,1]/sqrt(subsub);
                    else
                        results.cueCombos{j}.params_CI(NL,CF,:,:) = squeeze(results.cueCombos{j}.params(NL,CF,:))*ones(1,2)+...
                            tinv(0.975,subsub-1)*squeeze(results.cueCombos{j}.params_STD(NL,CF,:))*[-1,1]/sqrt(subsub);
                    end                        
                    
                    results.cueCombos{j}.prediction(NL,CF,:) = mean(squeeze(cuePredic{j}(:,NL,CF,:)));
                    results.cueCombos{j}.prediction_STD(NL,CF,:) = std(squeeze(cuePredic{j}(:,NL,CF,:)));%/sqrt(subsub);
                    if CI_NormT == 0,
                        results.cueCombos{j}.prediction_CI(NL,CF,:,:) = squeeze(results.cueCombos{j}.prediction(NL,CF,:))*ones(1,2)+...
                            norminv(0.975,subsub-1)*squeeze(results.cueCombos{j}.prediction_STD(NL,CF,:))*[-1,1]/sqrt(subsub);
                    else
                        results.cueCombos{j}.prediction_CI(NL,CF,:,:) = squeeze(results.cueCombos{j}.prediction(NL,CF,:))*ones(1,2)+...
                            tinv(0.975,subsub-1)*squeeze(results.cueCombos{j}.prediction_STD(NL,CF,:))*[-1,1]/sqrt(subsub);
                    end
                end
                results.cueCombos{j}.predicted_weights(NL,:) = mean(squeeze(weightPredic{j}(:,NL,:)));
                results.cueCombos{j}.predicted_weights_STD(NL,:) = std(squeeze(weightPredic{j}(:,NL,:)));%/sqrt(subsub);
                if CI_NormT == 0,
                    results.cueCombos{j}.predicted_weights_CI(NL,:,:) = results.cueCombos{j}.predicted_weights(NL,:)'*ones(1,2)+...
                        norminv(0.975)*results.cueCombos{j}.predicted_weights_STD(NL,:)'*[-1,1]/sqrt(subsub);
                else
                    results.cueCombos{j}.predicted_weights_CI(NL,:,:) = results.cueCombos{j}.predicted_weights(NL,:)'*ones(1,2)+...
                        tinv(0.975,subsub-1)*results.cueCombos{j}.predicted_weights_STD(NL,:)'*[-1,1]/sqrt(subsub);
                end

                results.cueCombos{j}.weights(NL,:) = mean(squeeze(cueWeights{j}(:,NL,:)));
                results.cueCombos{j}.weights_STD(NL,:) = std(squeeze(cueWeights{j}(:,NL,:)));%/sqrt(subsub);
                if CI_NormT == 0,
                    results.cueCombos{j}.weights_CI(NL,:,:) = results.cueCombos{j}.weights(NL,:)'*ones(1,2)+...
                        norminv(0.975)*results.cueCombos{j}.weights_STD(NL,:)'*[-1,1]/sqrt(subsub);
                else
                    results.cueCombos{j}.weights_CI(NL,:,:) = results.cueCombos{j}.weights(NL,:)'*ones(1,2)+...
                        tinv(0.975,subsub-1)*results.cueCombos{j}.weights_STD(NL,:)'*[-1,1]/sqrt(subsub);
                end

        
                % -------------------------------------------------------------------------
                % mostly for plotting: compute average JND across conflicts but
                % within the same noise level. The JND should be the same for all
                % conflicts and this makes it easier for plotting later.
                % -------------------------------------------------------------------------
                results.cueCombos{j}.sigma(NL,1) = mean(squeeze(cueSigma{j}(:,NL)));
                results.cueCombos{j}.sigma_STD(NL,1) = std(squeeze(cueSigma{j}(:,NL)));%/sqrt(subsub);
                if CI_NormT == 0,
                    results.cueCombos{j}.sigma_CI(NL,:) = results.cueCombos{j}.sigma(NL)'*ones(1,2)+...
                        norminv(0.975)*results.cueCombos{j}.sigma_STD(NL)*[-1,1]/sqrt(subsub);
                else
                    results.cueCombos{j}.sigma_CI(NL,:) = results.cueCombos{j}.sigma(NL)'*ones(1,2)+...
                        tinv(0.975,subsub-1)*results.cueCombos{j}.sigma_STD(NL)*[-1,1]/sqrt(subsub);
                end

                % same for predictions although here they are already the same so
                results.cueCombos{j}.predicted_sigma(NL,1) = mean(squeeze(cueSigmaPred{j}(:,NL)));
                results.cueCombos{j}.predicted_sigma_STD(NL,1) = std(squeeze(cueSigmaPred{j}(:,NL)));%/sqrt(subsub);
                if CI_NormT == 0,
                    results.cueCombos{j}.predicted_sigma_CI(NL,:) = results.cueCombos{j}.predicted_sigma(NL)'*ones(1,2)+...
                        norminv(0.975)*results.cueCombos{j}.predicted_sigma_STD(NL)*[-1,1]/sqrt(subsub);
                else
                    results.cueCombos{j}.predicted_sigma_CI(NL,:) = results.cueCombos{j}.predicted_sigma(NL)'*ones(1,2)+...
                        tinv(0.975,subsub-1)*results.cueCombos{j}.predicted_sigma_STD(NL)*[-1,1]/sqrt(subsub);
                end
                    
                    
            end

        end        
        
        
    else
        error(['Results for combined standards violates same parameters assumption. ',...
        'Change the analysis to compute the results for each standard separately. ',...
        'For instance, treat the standards as additional noise levels.']);
    end
else
    
    dataSub = A(A(:,sizeA(2)) == 1,1:end-1);
    [resultsSub{1},flagLargeJND] = MLE_SingleSubject(dataSub,nrStim,fitlapses,0,ploton(1:2),0);
    if flagLargeJND,
        fprintf(['WARNING: for subject %d the JND is larger than the measurement range used in one or more conditions.\n'...
            'This participant should be excluded from the average results\n'],sub);
    end    
    results = resultsSub{1};
    nrCombos = length(results.cueCombos);

end

% -------------------------------------------------------------------------
%   if plotopts(3) == 1, make summary plot across participants
% -------------------------------------------------------------------------
if ploton(3) == 1,
    MLE_SummaryPlot(results);
end