Code covered by the BSD License  

Highlights from
Fuzzy ARTMAP

image thumbnail
from Fuzzy ARTMAP by Massimiliano Versace
This package contains an implementation of Fuzzy ARTMAP.

DEFARTMAP_biasedLearn(artmap, inputVectors, classes)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% This is an implementation of biased ARTMAP,
% as described in CAS/CNS Technical Report TR-2009-xxx
% Biased ART: A neural architecture that shifts attention towards previously disregarded features
% following an incorrect prediction.
% Boston, MA: Boston University.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Programmed by Sai Chaitanya Gaddam(May 2007-09)
%%%%%%%%%%%%%%%%%%%%%%%%%
%
% The notation follows the article's notation,
% as does the headers for each step of the
% algorithm

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


function updated_artmap = DEFARTMAP_biasedLearn(artmap, inputVectors, classes)

ON=1; OFF=0;

LEARN_TRAIL=OFF;



% Make sure that the data is appropriate for the given network.
[numFeatures, n] = size(inputVectors);
if (numFeatures ~= 2*artmap.M)
    error('The input vector does not contain the same number of features as the network.');
end

if (LEARN_TRAIL==ON)
    [artmap.learn_trail(1:n).A]=deal([]);
    [artmap.learn_trail(1:n).K]=deal([]);
    [artmap.learn_trail(1:n).winning_node]=deal([]);
    [artmap.learn_trail(1:n).weight_before]=deal([]);
    [artmap.learn_trail(1:n).weight_after]=deal([]);
    [artmap.learn_trail(1:n).MT_nodes]=deal([]);
    [artmap.learn_trail(1:n).MT_mismatch]=deal([]);
    [artmap.learn_trail(1:n).MT_x_before]=deal([]);
    [artmap.learn_trail(1:n).MT_e_before]=deal([]);
    [artmap.learn_trail(1:n).MT_e_after]=deal([]);
    [artmap.learn_trail(1:n).MT_Adash_before]=deal([]);
    [artmap.learn_trail(1:n).MT_Adash_after]=deal([]);
end
artmap.compare_rho=zeros(n,3);
comp_rho=1;
% Initialize the return artmap network

%Initializing e before every input presentation e should start at 0


e_AttentionDefault=zeros(2*artmap.M,1);


search_cycles = 0;
artmap.e_mod_dist=zeros(n,1);
artmap.A_dep_dist=zeros(n,1);
% Classify and learn on each sample.
for sampleNumber = 1:n


    if (mod(sampleNumber,400)==0)
        disp('Sample Number:')
        disp(sampleNumber')
    end


    %%%%%%%%%%%%%%%

    % Reset the activated node (NOTE: this is Moshe Bar addition)
    activated_node=[];

    % Get the next input vector.
    % B.3 / B.6
    A = inputVectors(:, sampleNumber);

    MT_nodes=0;
    MT_ALL_nodes=0;

    if LEARN_TRAIL==ON
        artmap.learn_trail(sampleNumber).A=A;
    end

    % Get the next class
    % B.3 / B.6
    K = classes(sampleNumber);
    if LEARN_TRAIL==ON
        artmap.learn_trail(sampleNumber).K=K;
    end

    % B.5
    artmap.rho = artmap.rho_bar;
    artmap.rho_dep = artmap.rho_bar;

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %Try dynamically changing alpha

    %     artmap.alpha=10/(artmap.C+10);
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


    if  (artmap.start_conservative==1)
        %artmap.rho = .98;
        %artmap.rho_dep = .98;
        artmap.alpha=10/(artmap.C+10);
    end

    if (artmap.C==0)
        artmap.C=1;
        artmap.w(:,1) = A;
        artmap.W = zeros(1,artmap.numClasses);
        artmap.W(1,K)=1;

        if LEARN_TRAIL==ON
            artmap.learn_trail(sampleNumber).winning_node=artmap.C;
            artmap.learn_trail(sampleNumber).weight_before = A';
            artmap.learn_trail(sampleNumber).weight_after = A';
        end

        y = zeros(1,artmap.C); % the code at F2. At each node j, activity reset to 0.
        activated_node=1; % The first and only node is activated.
    else
        T = DEFARTMAP_choice_function(artmap,A);
        [T_sorted, T_j] = sort(-T);
        T_j = T_j(logical(-T_sorted>artmap.alpha.*artmap.M*(3.5)));
        MATCH=0;

        e_AttentionBefore=e_AttentionDefault;
        A_dashBefore=A;

        for j=T_j

            x_sum=sum(min(max(artmap.w(:,j)-artmap.lambda_Attention*e_AttentionBefore,0),...
                max(A-artmap.lambda_Attention*e_AttentionBefore,0)));
            z_sum=sum(min(max(1-artmap.lambda_Attention*e_AttentionBefore,0),...
                max(A-artmap.lambda_Attention*e_AttentionBefore,0)));

            MATCH_TRACK_TEST=(((x_sum+artmap.gamma)/(sum(max(A-artmap.lambda_Attention*e_AttentionBefore,0))+eps+artmap.gamma)) >= artmap.rho);

            if LEARN_TRAIL==ON
                MT_ALL_nodes=MT_ALL_nodes+1;
                artmap.learn_trail(sampleNumber).rho_calc{MT_ALL_nodes}=[(x_sum+artmap.gamma)./(sum(A_dashBefore)+eps+artmap.gamma) ...
                    x_sum./artmap.M ];
                artmap.learn_trail(sampleNumber).rho{MT_ALL_nodes} = [artmap.rho_dep artmap.rho];
            end


            if MATCH_TRACK_TEST
                search_cycles = search_cycles+1;
                J=find(artmap.W(j,:));
                if J==K
                    MATCH=1;
                    artmap.win_sequence(length(artmap.win_sequence)+1)=j;
                    if LEARN_TRAIL==ON
                        artmap.learn_trail(sampleNumber).winning_node=j;
                        artmap.learn_trail(sampleNumber).weight_before = artmap.w(:,j);
                    end

                    storeWeights=artmap.w;
                    artmap.w=DEFARTMAP_Update_Weights(artmap,j,A);
                    artmap.e_mod_dist(sampleNumber)=artmap.lambda_Attention*sum(e_AttentionBefore)/2;
                    artmap.A_dep_dist(sampleNumber)=artmap.lambda_Attention*sum(A_dashBefore);



                    if LEARN_TRAIL==ON
                        artmap.learn_trail(sampleNumber).weight_after = artmap.w(:,j);
                    end

                    activated_node=j;
                    break; % Don't check any more T_j
                else % if J~=K

                    artmap.rho = ((x_sum+artmap.gamma)./(sum(max(A-artmap.lambda_Attention*e_AttentionBefore,0))+eps+artmap.gamma)) + artmap.epsilon;
                    artmap.rho_uc = ((z_sum+artmap.gamma)./(sum(max(A-artmap.lambda_Attention*e_AttentionBefore,0))+eps+artmap.gamma));


                    [e_AttentionAfter,A_dashAfter,x_Attention] =...
                        ART_depleter_new(sampleNumber,artmap.w(:,j),A,e_AttentionBefore,artmap.lambda_Attention,artmap.M,artmap.model_num,artmap.deplete_fast);


                    if sum(e_AttentionAfter)<.0001
                        artmap.e_realLow=artmap.e_realLow+1;
                    end
                    if sum(e_AttentionAfter)<2*eps
                        artmap.e_zero=artmap.e_zero+1;
                    end

                    if LEARN_TRAIL==ON
                        MT_nodes=MT_nodes+1;
                        artmap.learn_trail(sampleNumber).MT_nodes{MT_nodes}=j;
                        artmap.learn_trail(sampleNumber).MT_rho{MT_nodes} =artmap.rho;
                        artmap.learn_trail(sampleNumber).MT_x_before{MT_nodes} =x_Attention';
                        artmap.learn_trail(sampleNumber).MT_e_before{MT_nodes} =e_AttentionBefore';
                        artmap.learn_trail(sampleNumber).MT_e_after{MT_nodes} =e_AttentionAfter';
                        artmap.learn_trail(sampleNumber).MT_Adash_before{MT_nodes} =A_dashBefore';
                        artmap.learn_trail(sampleNumber).MT_Adash_after{MT_nodes} =A_dashAfter';
                    end

                    A_dashBefore=A_dashAfter;
                    e_AttentionBefore=e_AttentionAfter;


                end % if J==K b
            else % if (sum(min(A,artmap.w(:,j)))./artmap.M >= artmap.rho)  %RESET, NO MATCH TRACKING
            end % if (sum(min(A,artmap.w(:,j)))./artmap.M >= artmap.rho)
        end % for J=T_j

        if (MATCH==0) %
            %sampleNumber
            artmap.C=artmap.C+1;
            artmap.w(:,artmap.C)=A;
            artmap.node_created=[artmap.node_created sampleNumber];
            artmap.e_mod_dist(sampleNumber)=artmap.lambda_Attention*sum(e_AttentionBefore)/2;
            artmap.A_dep_dist(sampleNumber)=artmap.lambda_Attention*sum(A_dashBefore);

            z_sum=sum(min(max(1-artmap.lambda_Attention*e_AttentionBefore,0),...
                max(A-artmap.lambda_Attention*e_AttentionBefore,0)));
            rho_uc = ((z_sum+artmap.gamma)./(sum(max(A-artmap.lambda_Attention*e_AttentionBefore,0))+eps+artmap.gamma));


            if artmap.rho>(rho_uc+.00001)
                artmap.uncommitted_fail=artmap.uncommitted_fail+1;
                disp('It is happening!');
            end

            if LEARN_TRAIL==ON
                artmap.learn_trail(sampleNumber).winning_node=artmap.C;
                artmap.learn_trail(sampleNumber).weight_before = A';
                artmap.learn_trail(sampleNumber).weight_after = A';
            end


            artmap.W=[artmap.W; zeros(1,artmap.numClasses)];
            artmap.W(artmap.C,K)=1;

            activated_node=artmap.C; % Activated node is new node
        end % (MATCH==0)

    end      %  if (artmap.C==0)
    artmap.NodeList = [artmap.NodeList, activated_node];

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%THIS LOOP CHECKS IF THE NEXT_INPUT TEST IS MET
    y = zeros(artmap.C,1);
    T = DEFARTMAP_choice_function(artmap,A);
    Lambda = find(T>artmap.alpha.*artmap.M);
    Lambda_prime = find(T==artmap.M);

    if ~isempty(Lambda_prime)  % Point Box Case -- w_j = A
        y(Lambda_prime)=(1./length(Lambda_prime));
    else
        sumL = sum(1./(artmap.M-T(Lambda)));
        y(Lambda)=(1./(artmap.M-T(Lambda)))./sumL;
    end

    class_val = find(artmap.W(y==max(y),:),1,'first');

    if K~=class_val
        artmap.NI_fail(end+1)=sampleNumber;

    end


    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

end      % for sampleNumber = 1:numSamples

artmap.search_cycles = search_cycles;
updated_artmap = artmap;

return

Contact us at files@mathworks.com