function [h] = MLE_SummaryPlot(results)
% -------------------------------------------------------------------------
%   MLE_SummaryPlot(results)
%
%   Plot a summary of the MLE results in terms of both prediction and
%   emperical cue weights and the sigmas obtained from the fits.
%
%   Input parameters:
%   results = the results structure as obtained from either MLE_Analysis
%   (for results across participants), MLE_SingleSubject or MLE_SingleSet
%   (for single participant data).
%
%   output:
%   h = the figure handle
%
%   by Loes van Dam, 2014
%   Loes.van_Dam@uni-bielefeld.de
%
% -------------------------------------------------------------------------


    colstring = 'brmgcyk';

    nrCombos = length(results.cueCombos);

    nrSub = results.nrSub(1);    
    
    figure(); hold on;

    % get the stuff in an easy to plot structure
    for cueC = 1:nrCombos,
        cueID = results.cueCombos{cueC}.cueID;
        nrCueInCombo = length(cueID);
        noiselevels = results.cueCombos{cueC}.noiselevels;
        noiselevelIDs = results.cueCombos{cueC}.noiselevelIDs;
        nrNoises = length(noiselevels(:,1));

        nrO = nrCueInCombo; % the number of cues
        jitterAmp = 1/20;
        cueoffsetlst = (2*(1:nrO) - (nrO+1))/(nrO-1); % the jitter of points along the x-axis to more clearly separate conditions
        if nrNoises > 1
            noises = noiselevels(:,find(sum(abs(diff(noiselevels)))~= 0));
            if size(noises,2)>1,                      % if there are different noiselevels in more than one cue
                noises = (1:nrNoises)';               % number them, otherwise plots will be a mess
            end
            jitterAmp = min(abs(diff(noises)))*jitterAmp;
            cueoffsetlst = cueoffsetlst*jitterAmp;
        else
            noises = [1];
            cueoffsetlst = cueoffsetlst*jitterAmp;
        end
        plotrange = [min(noises)-4*jitterAmp,max(noises)+4*jitterAmp];
        
        % first do weight plot for weights
        if isnan(results.cueCombos{cueC}.weights(1,1)), % plot PSEs
            Cconf = squeeze(results.cueCombos{cueC}.conflicts(1,:,:));
            nrcols = length(Cconf)+1;
            
            
            for CF = 1:length(Cconf(:,1)),
                subplot(nrCombos,nrcols,(cueC-1)*nrcols+CF); hold on;
                curve_count = 0;
                for j = 1:length(Cconf(CF,:)),
                    curve_count = curve_count+1;
                    if curve_count > length(colstring),
                        curve_count = 1;
                    end
                    levels = results.cueCombos{cueC}.conflicts(:,CF,j);
                    if nrNoises > 1,
                        h(j) = plot( noises,levels, ['--',colstring(curve_count)]);
                    else
                        h(j) = plot( (1:2)-0.5,levels*ones(1,2), ['--',colstring(curve_count)]);
                    end
                    if CF == 1,
                        legstr{j} = ['C: ', num2str(cueID(j))];
                    end
                end
                curve_count = curve_count+1;
                predictions = squeeze(results.cueCombos{cueC}.prediction(:,CF,:));
                predictionsCI = squeeze(results.cueCombos{cueC}.prediction_CI(:,CF,:,:));
                if nrNoises > 1,
                    shadedErrorBar_CI(noises,predictions(:,1),predictionsCI(:,1,:),['--',colstring(curve_count)]);
                else
                    shadedErrorBar_CI((1:2)-0.5,predictions(1,1)*ones(1,2),squeeze(predictions_CI(1,1,:))*ones(1,2),['--',colstring(curve_count)]);
                end
                h(j+1) = errorbar_CI(noises,...
                        results.cueCombos{cueC}.params(:,CF,1),...
                        results.cueCombos{cueC}.params_CI(:,CF,1,:),['.-',colstring(curve_count)]);

                if CF == 1,
                    legend(h,legstr);
                end
                title('summary PSE');
                xlabel('noise condition')
                ylabel('PSE')
            end
        else
            %plot weights
            if length(cueID) >2,
                nrO = length(cueID);
                transparancy = 1;
                cueoffsetlst2 = cueoffsetlst;
            else
                nrO = 1;    % if 2 cue plot only the weight w1 since the other is 1 - w1
                transparancy = 0;
                cueoffsetlst2 = zeros(1,nrCueInCombo);
            end
            nrcols = 2;
            subplot(nrCombos,nrcols,(cueC-1)*nrcols+1); hold on;
            curve_count = 0;
            for cue = 1:nrO,
                curve_count = curve_count+1;
                if curve_count > length(colstring),
                    curve_count = 1;
                end
                weights = results.cueCombos{cueC}.weights(:,cue);
                weightsstd = results.cueCombos{cueC}.weights_STD(:,cue);
                weightsCI = results.cueCombos{cueC}.weights_CI(:,cue,:);
                predictions = results.cueCombos{cueC}.predicted_weights(:,cue);
                predictionsCI = squeeze(results.cueCombos{cueC}.predicted_weights_CI(:,cue,:));
                if nrNoises > 1,
                    shadedErrorBar_CI(noises+cueoffsetlst2(cue),predictions(:,1),predictionsCI,['--',colstring(curve_count)],transparancy);
                else
                    shadedErrorBar_CI((1:2)'-0.5+cueoffsetlst2(cue),(predictions(1,1)*ones(1,2))',(predictionsCI*ones(1,2))',['--',colstring(curve_count)],transparancy);
                end
                h(cue) = errorbar_CI(noises+cueoffsetlst2(cue),...
                            weights,...
                            weightsCI,['.-',colstring(curve_count)]);
                legstr{cue} = ['wC: ', num2str(cueID(cue))];
            end
            plot(plotrange,[0,0],'k--');
            plot(plotrange,[1,1],'k--');
            legend(h,legstr);
            title('summary weights');
            xlabel('noise condition');
            ylabel('weight');
            xlim(plotrange);
            ylim([min([min(results.cueCombos{cueC}.predicted_weights_CI(:)),-0.2]),...
                  max([max(results.cueCombos{cueC}.predicted_weights_CI(:)),1.2])]);
            text(min(noises),0.5,sprintf('N = %d/%d',results.nrSub));
        end

        

        % get the stuff in an easy to plot structure
        subplot(nrCombos,nrcols,cueC*nrcols); hold on;

        % -------------------------------------------------------------------------
        % mostly for plotting: compute average JND across conflicts but
        % within the same noise level. The JND should be the same for all
        % conflicts and this makes it easier for plotting later.
        % -------------------------------------------------------------------------
        nrconflicts = size(results.cueCombos{cueC}.params_STD(:,:,2),2);
        tempJND = results.cueCombos{cueC}.sigma;
        tempJND_std = results.cueCombos{cueC}.sigma_STD;
        tempJND_CI = results.cueCombos{cueC}.sigma_CI;

        % same for predictions
        predJND = results.cueCombos{cueC}.predicted_sigma;
        predJND_std = results.cueCombos{cueC}.predicted_sigma_STD;
        predJND_CI = results.cueCombos{cueC}.predicted_sigma_CI;
        
        if nrNoises > 1,
            curve_count = 0;
            
            shadedErrorBar_CI(noises,predJND,predJND_CI,['--',colstring(nrCueInCombo+1)]);
            
            for cue = 1:length(results.singleCues),
                curve_count = curve_count+1;
                levels = results.singleCues{cue}.params(:,2);
                                        
                if length(levels) == 1,
                    h(cue)= errorbar_CI(noises+cueoffsetlst(cue),...
                                levels.*ones(nrNoises,1),...
                                (squeeze(results.singleCues{cue}.params_CI(:,2,:))*ones(1,nrNoises))',...
                                ['.-',colstring(curve_count)]);
                else
                    h(cue)= errorbar_CI(noises+cueoffsetlst(cue),...
                                levels(noiselevelIDs(:,cue)),...
                                squeeze(results.singleCues{cue}.params_CI(noiselevelIDs(:,cue),2,:)),...
                                ['.-',colstring(curve_count)]);
                end                    

                legstr{cue} = ['C: ', num2str(cueID(cue))];
            end
            curve_count = curve_count+1;
            

            h(cue+1)= errorbar_CI(noises,...
                        tempJND,...
                        tempJND_CI,['.-',colstring(curve_count)]);

                    
                    
            legstr{cue+1} = 'Combined';
            legend(h,legstr);
            title('summary jnd');
            xlabel('noise condition');
            ylabel('sigma')
            xlim(plotrange);
        else            
            shadedErrorBar_CI(length(results.singleCues)+1+[-0.5;0.5],(predJND*[1,1])',(predJND_CI'*ones(1,2))','r-');
            for cue = 1:length(results.singleCues),
                levels(cue,1) = results.singleCues{cue}.params(1,2);
                levels(cue,2:3) = results.singleCues{cue}.params_CI(1,2,:);
            end
            levels(cue+1,1:3) = [tempJND,tempJND_CI];
            bar(levels(:,1))
            h(cue) = errorbar_CI(   1:(cue+1),...
                        levels(:,1),...
                        levels(:,2:3),...
                        '.');
            plot(cue+1+[-0.5;0.5],predJND*[1;1],'r-');
            errorbar_CI(cue+1+0.5,predJND,predJND_CI,'r.');
            
            for cue = 1:length(results.singleCues)
                TickLabels{cue} = num2str(cue);
            end
            TickLabels{cue+1} = 'Comb';

            set(gca,'XTick',1:(length(results.singleCues)+1));
            set(gca,'XTickLabel',TickLabels);
        end
        title('summary jnd');
        ylabel('sigma')
        
        
    end
end