Code covered by the BSD License  

Highlights from
EM for HMM Multivariate Gaussian processes

image thumbnail

EM for HMM Multivariate Gaussian processes

by

 

15 Jul 2008 (Updated )

A fast implementation of the EM Algorithm for HMM Multivariate Gaussian Mixture

test_em_ghmm.m
%%%%%% Example of Training/Testing a 2d-mixture of 2 gaussians driven by
%%%%%% HMM



d                                   = 2;
m                                   = 2;
L                                   = 1;
R                                   = 1;
Ntrain                              = 3000;
Ntest                               = 10000;
options.nb_ite                      = 30;

PI                                  = [0.5 ; 0.5];
A                                   = [0.95 0.05 ; 0.05 0.95];
M                                   = cat(3 , [-1 ; -1] , [2 ; 2]);
S                                   = cat(3 , [1 0.3 ; 0.3 0.8] , [0.7 0.6; 0.6 1]);

[Ztrain , Xtrain]                   = sample_ghmm(Ntrain , PI , A , M , S , L);
Xtrain                              = Xtrain - 1;

%%%%% initial parameters %%%%

PI0                                 = rand(d , 1 , R);
sumPI                               = sum(PI0);
PI0                                 = PI0./sumPI(ones(d , 1) , : , :);

A0                                  = rand(d , d , R);
sumA                                = sum(A0);
A0                                  = A0./sumA(ones(d , 1) , : , :);

M0                                  = randn(m , 1 , d , R);
S0                                  = repmat(cat(3 , [2 0 ; 0 2] , [3 0; 0 2]) , [1 , 1 , 1, R]);

%%%%% EM algorithm %%%%

[logl , PIest , Aest , Mest , Sest] = em_ghmm(Ztrain , PI0 , A0 , M0 , S0 , options);


[x , y]                             = ndellipse(M , S);
[xest , yest]                       = ndellipse(Mest , Sest);

Ltrain_est                          = likelihood_mvgm(Ztrain , Mest , Sest);
Xtrain_est                          = forward_backward(PIest , Aest , Ltrain_est);
Xtrain_est                          = Xtrain_est - 1;

ind1                                = (Xtrain_est == 0);
ind2                                = (Xtrain_est == 1);

Err_train                           = min(sum(Xtrain ~= Xtrain_est , 2)/Ntrain , sum(Xtrain ~= ~Xtrain_est , 2)/Ntrain);

figure(1) ,
h                                   = plot(Ztrain(1 , ind1) , Ztrain(2 , ind1) , 'k+' , Ztrain(1 , ind2) , Ztrain(2 , ind2) , 'g+' , x , y , 'b' , xest  , yest ,'r', 'linewidth' , 2);
legend([h(1) ; h(3:m:end)] , 'Train data' , 'True'  , 'Estimated' , 'location' , 'best')
title(sprintf('Train data, Error rate = %4.2f%%' , Err_train*100))

%%%%% Test data  %%%%


[Ztest , Xtest]                     = sample_ghmm(Ntest , PI , A , M , S , L);
Xtest                               = Xtest - 1;


Ltest_est                           = likelihood_mvgm(Ztest , Mest , Sest);
Xtest_est                           = forward_backward(PIest , Aest , Ltest_est);
Xtest_est                           = Xtest_est - 1;


ind1                                = (Xtest_est == 0);
ind2                                = (Xtest_est == 1);

Err_test                            = min(sum(Xtest ~= Xtest_est , 2)/Ntest , sum(Xtest ~= ~Xtest_est , 2)/Ntest);

figure(2),
h                                   = plot(Ztest(1 , ind1) , Ztest(2 , ind1) , 'k+' , Ztest(1 , ind2) , Ztest(2 , ind2) , 'g+' , x , y , 'b' , xest  , yest ,'r', 'linewidth' , 2);
legend([h(1) ; h(3:m:end)] , 'Test data' , 'True'  , 'Estimated' , 'location' , 'best')
title(sprintf('Test data, Error rate = %4.2f%%' , Err_test*100))

Contact us