image thumbnail

3D visualization of GMM learning via the EM algorithm

by

 

10 Jan 2012 (Updated )

The evolution of a GMM in the EM algorithm is visualized by interpolating between iterations.

EM_GMM_3d(c,wk,N,movie,Y,D,Cv)
function X = EM_GMM_3d(c,wk,N,movie,Y,D,Cv)
%% 3D visualization of EM for GMMs
% function X = EM_GMM_3d(c,wk,N,movie,Y,D,Cv)
%
% Inputs: c      - # clusters (default: 5)
%         wk     - array of # gaussians to fit (default: 1:10)
%         N      - # GMM samples (default: 200)
%         movie  - string: writes .avi for each k (default: [])
%         Y      - NxD data matrix (default: generate new data)
%         D      - data dimensionality (default: 3)
%         Cv     - covariance type (0: diag, 1: full) (default: full)
% Output: X      - NxD data matrix
%
% How it works:
%   For each value of k in 'wk':
%    - Fit GMM with EM
%    - Interpolate model params between EM iterations
%    - Coloring: Gaussians - transparency = GMM mixing weights
%                data      - color blend = posterior probs
%    - Play movie of EM learning
%    - (optional) Save movie to .avi file
%
% Author: Johannes Traa, UIUC, Nov-Dec '11
% 
% Revised Mar '13 (cleaned up code, added full covariance capability)


%% check input args
if nargin < 1 || isempty(c);  c  = 5;    end
if nargin < 2 || isempty(wk); wk = 1:10; end
if nargin < 3 || isempty(N);  N  = 200;  end
if nargin < 4 || isempty(movie)
  m_val = 0;
  set(0,'DefaultFigureWindowStyle','docked')
else
  m_val = 1;
  set(0,'DefaultFigureWindowStyle','normal')
  if ~ischar(movie)
    movie = 'EM_k';
  end
end
if nargin < 6 || isempty(D);  D  = 3; end
if nargin < 7 || isempty(Cv); Cv = 1; end

%% get data
if ~(nargin<5 || isempty(Y))
  X = Y;
  [N,D] = size(X);
else
  % GMM params
  m = 10*randn(c,D); % means
  C = zeros(D,D,c); % covariance matrices
  for j=1:c
    Cj = randn(D,10);
    C(:,:,j) = Cj*Cj';
  end
  w = dirrand(3*ones(1,c),1); % mixing weights
  
  % sample from GMM
  X = zeros(N,D);
  for i=1:N
    j = find(mnrnd(1,w));
    X(i,:) = m(j,:) + randn(1,D)*sqrtm(C(:,:,j));
  end
end

%% get GMM and show movie for different k
% initialize figure window
fh = figure;

for k=wk
  % ----- fit k-gaussian GMM -----
  [M,C,P,ii] = EM_GMM_movie(X,k,Cv);
  
  % ----- interpolate params between iterations -----
  fps = 35; % frames per second
  f = fps*ii+1; % # frames
  
  m = zeros(k,D,f); % interpolated means
  Cm = zeros(D,D,k,f); % interpolated covars
  Pi = zeros(N,k,f); % interpolated posteriors
  
  % set anchor frames (true values of model at each iteration)
  for i=1:ii
    m(:,:,(i-1)*fps+1) = M{i};
    Cm(:,:,:,(i-1)*fps+1) = C{i};
    Pi(:,:,(i-1)*fps+1) = P{i};
  end
  
  % correct last frame
  m(:,:,ii*fps+1) = m(:,:,(ii-1)*fps+1);
  Cm(:,:,:,ii*fps+1) = Cm(:,:,:,(ii-1)*fps+1);
  Pi(:,:,ii*fps+1) = Pi(:,:,(ii-1)*fps+1);
  
  % interpolate missing frames for continuity
  mix = linspace(0,1,fps); % mixing weights for interpolation
  for i=1:ii
    % get anchor values
    m1 = m(:,:,(i-1)*fps+1);
    m2 = m(:,:,i*fps+1);
    c1 = Cm(:,:,:,(i-1)*fps+1);
    c2 = Cm(:,:,:,i*fps+1);
    p1 = Pi(:,:,(i-1)*fps+1);
    p2 = Pi(:,:,i*fps+1);
    
    % interpolate
    for j=2:fps
      m(:,:,(i-1)*fps+j) = m1*mix(fps-j+1) + m2*mix(j);
      Cm(:,:,:,(i-1)*fps+j) = c1*mix(fps-j+1) + c2*mix(j);
      Pi(:,:,(i-1)*fps+j) = p1*mix(fps-j+1) + p2*mix(j);
    end
  end
    
  % fills (2D) or ellipsoids (3D)
  rr = 20; % resolution
  theta = linspace(0,2*pi,rr);
  if D == 2
    E = zeros(rr,D,k,f); % fill boundaries
    v = [cos(theta); sin(theta)]';
  else
    E = zeros(rr,rr,D,k,f); % surf contours
    phi = linspace(0,pi,rr);
    v = [kron(cos(theta),sin(phi)); ...
      kron(sin(theta),sin(phi)); ...
      repmat(cos(phi),[1,rr])]';
  end
  for i=1:f
    for j=1:k
      E_ij = bsxfun(@plus,m(j,:,i),v*sqrtm(Cm(:,:,j,i)));
      if D == 2
        E(:,:,j,i) = E_ij;
      else
        for d=1:3
          E(:,:,d,j,i) = reshape(E_ij(:,d),[rr,rr]);
        end
      end
    end
  end
  
  % ----- coloring -----
  % posterior = color
  cc = jet(k);
  C = zeros(N,3,f);
  for i=1:f
    C(:,:,i) = Pi(:,:,i)*cc;
  end
  
  % mixing weight = transparency
  w = sum(Pi,1)/N; % mixing weights
  w = reshape(w,[k,f])'; % (frames x Gaussians)
  w = bsxfun(@rdivide,w,max(w,[],2)); % normalize rows
  w = bsxfun(@times,w,linspace(0.2,0.6,f)'); % fade in ellipsoids
  
  
  % ----- set up for movie -----
  % set up to write frames to avi object/set frame skip rate
  if ~m_val % don't write movie, skip frames for speed
    fi = 5;
  else % write frames for movie
    aviobj = avifile([movie num2str(k)]);
    fi = 1;
  end
  
  % axis handles (means, fills/ellipsoids, data)
  hold on
  eh = zeros(k,1); % fills/ellipsoids
  if D == 2
    mmh = scatter(m(:,1,1),m(:,2,1),200,'+k'); % means
    for g=1:k % fills
      eh(g) = fill(E(:,1,g,1),E(:,2,g,1),cc(g,:),...
        'FaceAlpha',w(1,g),'EdgeColor','none');
    end
    Xh = scatter(X(:,1),X(:,2),30,C(:,:,1),'filled'); % data
  else
    mmh = scatter3(m(:,1,1),m(:,2,1),m(:,3,1),200,'+k'); % means
    for g=1:k % ellipsoids
      eh(g) = surf(E(:,:,1,g,1),E(:,:,2,g,1),E(:,:,3,g,1),...
        'FaceColor',cc(g,:),'FaceAlpha',w(1,g),'EdgeColor','none');
    end
    Xh = scatter3(X(:,1),X(:,2),X(:,3),30,C(:,:,1),'filled'); % data
  end
  
  % figure details
  axis tight square
  if D == 3
    axis vis3d
  end
  grid on
  rot = 0;
  set(gcf,'Color','w')
  
  
  % ----- play movie showing evolution of GMM -----
  for i=[1:fi:f-1 f]
    it = 1+floor((i-1)/fps); % iteration #
    
    % update data coloring
    set(Xh,'CData',C(:,:,i));
    
    % update interpolated means ('+')
    set(mmh,'XData',m(:,1,i));
    set(mmh,'YData',m(:,2,i));
    if D == 3
      set(mmh,'ZData',m(:,3,i));
    end
        
    % update ellipsoids
    for g=1:k
      if D == 2
        set(eh(g),'XData',E(:,1,g,i))
        set(eh(g),'YData',E(:,2,g,i))
      else
        set(eh(g),'XData',E(:,:,1,g,i))
        set(eh(g),'YData',E(:,:,2,g,i))
        set(eh(g),'ZData',E(:,:,3,g,i))
      end
      set(eh(g),'FaceAlpha',w(i,g))
    end
    
    % figure details
    xlabel(['iteration: ' num2str(it) '  '])
    if D == 3
      view(-50+rot,30)
      rot = rot + 0.4;
    end
    
    % add frame to .avi object
    if m_val
      F = getframe(fh);
      aviobj = addframe(aviobj,F);
    end
    
    pause(1/fps);
  end
  
  
  % ----- clean up for next k -----
  xlabel(['iteration: ' num2str(ii) '  '])
  
  % write out video
  if m_val
    for i=1:fps
      aviobj = addframe(aviobj,F);
    end
    aviobj = close(aviobj);
    clf
  elseif k ~= wk(end)
    pause
  end
  
  % new figure
  if k < wk(end) && ~m_val
    figure
  end
end




%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% dirrand

function X = dirrand(a,N)
%% Sample from Dirichlet distribution
% function X = dirrand(a,N)
%
% Inputs: a - parameters of Dirichlet distribution
%         N - number of samples
% Output: X - Nxd matrix of Dirichlet samples

%% 
X = gamrnd(repmat(a,N,1),1);
X = bsxfun(@rdivide,X,sum(X,2));



%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% EM_GMM_movie

function [M,C,P,i] = EM_GMM_movie(X,k,Cv)
%% EM for GMMs
% [M,C,P,iters] = EM_GMM(X,k)
%
% Inputs: X - data (points x dims)
%         k - # clusters
% Outputs: M - cell array of means at each iteration
%          C - cell array of covars at each iteration
%          P - cell array of posteriors at each iteration
%          i - iterations to convergence

[N,d] = size(X);
A = zeros(N,k); % posteriors

% initialize
Cx = sqrtm(cov(X));
m = bsxfun(@plus,sum(X,1)/N,randn(k,d)*Cx);
w = ones(1,k)/k; % mixing weights
if Cv; Ca = repmat(Cx,[1,1,k]); % full cov
else   Ca = ones(k,d); % diagonal cov
end

% save params from each iteration
M = cell(1,50);
C = cell(1,50);
P = cell(1,50);

M{1} = m;
if Cv
  C{1} = Ca;
else
  C{1} = zeros(d,d,k);
  for j=1:k
    C{1}(:,:,j) = diag(Ca(j,:));
  end
end
P{1} = ones(N,k)/k;

% iterate
ll = -Inf;
for i=2:50
  ll_old = ll;
  
  % -- E step --
  if Cv
    for j=1:k
      v = bsxfun(@minus,X,m(j,:));
      A(:,j) = -1/2*log(det(2*pi*Ca(:,:,j))) ...
        - 1/2*sum((v/(Ca(:,:,j)+10^-3)).*v,2) + log(w(j));
    end
  else
    for j=1:k
      v = bsxfun(@minus,X,m(j,:));
      A(:,j) = sum(bsxfun(@minus,-1/2*log(2*pi*Ca(j,:)), ...
        bsxfun(@rdivide,v.^2,2*Ca(j,:)+10^-3)),2) ...
        + log(w(j));
    end
  end
  As = logsum(A,2);
  p = exp(bsxfun(@minus,A,As)); % Nxk posteriors
  
  % -- M step --
  ps = sum(p,1) + eps;
  m = bsxfun(@rdivide,p'*X,ps');
  if Cv
    for j=1:k
      Xm = bsxfun(@minus,X,m(j,:));
      Ca(:,:,j) = bsxfun(@times,p(:,j),Xm)'*Xm/ps(j);
    end
  else
    for j=1:k
      Ca(j,:) = p(:,j)'*bsxfun(@minus,X,m(j,:)).^2/ps(j);
    end
  end
  w = ps/N;
  
  % -- save means, covars, posteriors --
  M{i} = m;
  if Cv
    C{i} = Ca;
  else
    C{i} = zeros(d,d,k);
    for j=1:k
      C{i}(:,:,j) = diag(Ca(j,:));
    end
  end
  P{i} = p;
  
  % -- convergence --
  ll = sum(As);
  if abs((ll-ll_old)/ll_old) < 10^-4
    break
  end
end



%% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% logsum

function y = logsum(x,d)
%% Log-sum-exp for sumation in the log domain
% function y = logsum(x,d)
%
% Inputs: x - matrix or vector
%         d - dimension to sum over
% Output: y - logsum output

m = max(x,[],d);
y = log(sum(exp(bsxfun(@minus,x,m)),d)) + m;

Contact us