This example shows how to train an Inflated 3-D (I3D) two-stream convolutional neural network for activity recognition using RGB and optical flow data from videos [1].
Vision-based activity recognition involves predicting the action of an object, such as walking, swimming, or sitting, using a set of video frames. Activity recognition from video has many applications, such as human-computer interaction, robot learning, anomaly detection, surveillance, and object detection. For example, online prediction of multiple actions for incoming videos from multiple cameras can be important for robot learning. Compared to image classification, action recognition using videos is challenging to model because of the noisy labels in video data sets, the variety of actions that actors in a video can perform that are heavily class imbalanced, and the compute inefficiency in pretraining on large video data sets. Some deep learning techniques, such as I3D two-stream convolutional networks [1], have shown improved performance by leveraging pretraining on large image classification data sets.
This example trains an I3D network using the HMDB51 data set. Use the downloadHMDB51 supporting function, listed at the end of this example, to download the HMDB51 data set to a folder named hmdb51.
downloadFolder = fullfile(tempdir,"hmdb51");
downloadHMDB51(downloadFolder);After the download is complete, extract the RAR file hmdb51_org.rar to the hmdb51 folder. Next, use the checkForHMDB51Folder supporting function, listed at the end of this example, to confirm that the downloaded and extracted files are in place.
allClasses = checkForHMDB51Folder(downloadFolder);
The data set contains about 2 GB of video data for 7000 clips over 51 classes, such as drink, run, and shake hands. Each video frame has a height of 240 pixels and a minimum width of 176 pixels. The number of frames ranges from 18 to approximately 1000.
To reduce training time, this example trains an activity recognition network to classify 5 action classes instead of all 51 classes in the data set. Set useAllData to true to train with all 51 classes.
useAllData = false; if useAllData classes = allClasses; else classes = ["kiss","laugh","pick","pour","pushup"]; end dataFolder = fullfile(downloadFolder, "hmdb51_org");
Split the data set into a training set for training the network, and a test set for evaluating the network. Use 80% of the data for the training set and the rest for the test set. Use imageDatastore to split the data based on each label into training and test data sets by randomly selecting a proportion of files from each label.
imds = imageDatastore(fullfile(dataFolder,classes),... 'IncludeSubfolders', true,... 'LabelSource', 'foldernames',... 'FileExtensions', '.avi'); [trainImds,testImds] = splitEachLabel(imds,0.8,'randomized'); trainFilenames = trainImds.Files; testFilenames = testImds.Files;
To normalize the input data for the network, the minimum and maximum values for the data set are provided in the MAT file inputStatistics.mat, attached to this example. To find the minimum and maximum values for a different data set, use the inputStatistics supporting function, listed at the end of this example.
inputStatsFilename = 'inputStatistics.mat'; if ~exist(inputStatsFilename, 'file') disp("Reading all the training data for input statistics...") inputStats = inputStatistics(dataFolder); else d = load(inputStatsFilename); inputStats = d.inputStats; end
Create two FileDatastore objects for training and validation by using the createFileDatastore supporting function, defined at the end of this example. Each datastore reads a video file to provide RGB data, optical flow data, and the corresponding label information.
Specify the number of frames for each read by the datastore. Typical values are 16, 32, 64, or 128. Using more frames helps capture more temporal information, but requires more memory for training and prediction. Set the number of frames to 64 to balance memory usage against performance. You might need to lower this value depending on your system resources.
numFrames = 64;
Specify the height and width of the frames for the datastore to read. Fixing the height and width to the same value makes batching data for the network easier. Typical values are [112, 112], [224, 224], and [256, 256]. The minimum height and width of the video frames in the HMDB51 data set are 240 and 176, respectively. Specify [112, 112] to capture a larger number of frames at the cost of spatial information. If you want to specify a frame size for the datastore to read that is larger than the minimum values, such as [256, 256], first resize the frames using imresize.
frameSize = [112,112];
Set inputSize to the inputStats structure so the read function of fileDatastore can read the specified input size.
inputSize = [frameSize, numFrames]; inputStats.inputSize = inputSize; inputStats.Classes = classes;
Create two FileDatastore objects, one for training and another for validation.
isDataForValidation = false;
dsTrain = createFileDatastore(trainFilenames,inputStats,isDataForValidation);
isDataForValidation = true;
dsVal = createFileDatastore(testFilenames,inputStats,isDataForValidation);
disp("Training data size: " + string(numel(dsTrain.Files)))Training data size: 436
disp("Validation data size: " + string(numel(dsVal.Files)))Validation data size: 109
Using a 3-D CNN is a natural approach to extracting spatio-temporal features from videos. You can create an I3D network from a pretrained 2-D image classification network such as Inception v1 or ResNet-50 by expanding 2-D filters and pooling kernels into 3-D. This procedure reuses the weights learned from the image classification task to bootstrap the video recognition task.
The following figure is a sample showing how to inflate a 2-D convolution layer to a 3-D convolution layer. The inflation involves expanding the filter size, weights, and bias by adding a third dimension (the temporal dimension).

Video data can be considered to have two parts: a spatial component and a temporal component.
The spatial component comprises information about the shape, texture, and color of objects in video. RGB data contains this information.
The temporal component comprises information about the motion of objects across the frames and depicts important movements between the camera and the objects in a scene. Computing optical flow is a common technique for extracting temporal information from video.
A two-stream CNN incorporates a spatial subnetwork and a temporal subnetwork [2]. A convolutional neural network trained on dense optical flow and a video data stream can achieve better performance with limited training data than with raw stacked RGB frames. The following illustration shows a typical two-stream I3D network.

In this example, you create an I3D network using GoogLeNet, a network pretrained on the ImageNet database.
Specify the number of channels as 3 for the RGB subnetwork, and 2 for the optical flow subnetwork. The two channels for optical flow data are the and components of velocity, and , respectively.
rgbChannels = 3; flowChannels = 2;
Obtain the minimum and maximum values for the RGB and optical flow data from the inputStats structure loaded from the inputStatistics.mat file. These values are needed for the image3dInputLayer of the I3D networks to normalize the input data.
rgbInputSize = [frameSize, numFrames, rgbChannels]; flowInputSize = [frameSize, numFrames, flowChannels]; rgbMin = inputStats.rgbMin; rgbMax = inputStats.rgbMax; oflowMin = inputStats.oflowMin(:,:,1:2); oflowMax = inputStats.oflowMax(:,:,1:2); rgbMin = reshape(rgbMin,[1,size(rgbMin)]); rgbMax = reshape(rgbMax,[1,size(rgbMax)]); oflowMin = reshape(oflowMin,[1,size(oflowMin)]); oflowMax = reshape(oflowMax,[1,size(oflowMax)]);
Specify the number of classes for training the network.
numClasses = numel(classes);
Create the I3D RGB and optical flow subnetworks by using the Inflated3D supporting function, which is attached to this example. The subnetworks are created from GoogLeNet.
cnnNet = googlenet; netRGB = Inflated3D(numClasses,rgbInputSize,rgbMin,rgbMax,cnnNet); netFlow = Inflated3D(numClasses,flowInputSize,oflowMin,oflowMax,cnnNet);
Create a dlnetwork object from the layer graph of each of the I3D networks.
dlnetRGB = dlnetwork(netRGB); dlnetFlow = dlnetwork(netFlow);
Create the supporting function modelGradients, listed at the end of this example. The modelGradients function takes as input the RGB subnetwork dlnetRGB, the optical flow subnetwork dlnetFlow, a mini-batch of input data dlRGB and dlFlow, and a mini-batch of ground truth label data dlY. The function returns the training loss value, the gradients of the loss with respect to the learnable parameters of the respective subnetworks, and the mini-batch accuracy of the subnetworks.
The loss is calculated by computing the average of the cross-entropy losses of the predictions from each of the subnetworks. The output predictions of the network are probabilities between 0 and 1 for each of the classes.
The accuracy of each of the subnetworks is calculated by taking the average of the RGB and optical flow predictions, and comparing it to the ground truth label of the inputs.
Train with a mini-batch size of 20 for 1500 iterations. Specify the iteration after which to save the model with the best validation accuracy by using the SaveBestAfterIteration parameter.
Specify the cosine-annealing learning rate schedule [3] parameters. For both networks, use:
A minimum learning rate of 1e-4.
A maximum learning rate of 1e-3.
Cosine number of iterations of 300, 500, and 700, after which the learning rate schedule cycle restarts. The option CosineNumIterations defines the width of each cosine cycle.
Specify the parameters for SGDM optimization. Initialize the SGDM optimization parameters at the beginning of the training for each of the RGB and optical flow networks. For both networks, use:
A momentum of 0.9.
An initial velocity parameter initialized as [].
An L2 regularization factor of 0.0005.
Specify to dispatch the data in the background using a parallel pool. If DispatchInBackground is set to true, open a parallel pool with the specified number of parallel workers, and create a DispatchInBackgroundDatastore, provided as part of this example, that dispatches the data in the background to speed up training using asynchronous data loading and preprocessing. By default, this example uses a GPU if one is available. Otherwise, it uses a CPU. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU. For information about the supported compute capabilities, see GPU Support by Release (Parallel Computing Toolbox).
params.Classes = classes; params.MiniBatchSize = 20; params.NumIterations = 1500; params.SaveBestAfterIteration = 900; params.CosineNumIterations = [300, 500, 700]; params.MinLearningRate = 1e-4; params.MaxLearningRate = 1e-3; params.Momentum = 0.9; params.VelocityRGB = []; params.VelocityFlow = []; params.L2Regularization = 0.0005; params.ProgressPlot = false; params.Verbose = true; params.ValidationData = dsVal; params.DispatchInBackground = false; params.NumWorkers = 4;
Train the subnetworks using the RGB data and optical flow data. Set the doTraining variable to false to download the pretrained subnetworks without having to wait for training to complete. Alternatively, if you want to train the subnetworks, set the doTraining variable to true.
doTraining = false;
For each epoch:
Shuffle the data before looping over mini-batches of data.
Use minibatchqueue to loop over the mini-batches. The supporting function createMiniBatchQueue, listed at the end of this example, uses the given training datastore to create a minibatchqueue.
Use the validation data dsVal to validate the networks.
Display the loss and accuracy results for each epoch using the supporting function displayVerboseOutputEveryEpoch, listed at the end of this example.
For each mini-batch:
Convert the image data or optical flow data and the labels to dlarray objects with the underlying type single.
Treat the temporal dimension of the the video and optical flow data as one of the spatial dimensions to enable processing using a 3-D CNN. Specify the dimension labels "SSSCB" (spatial, spatial, spatial, channel, batch) for the RGB or optical flow data, and "CB" for the label data.
The minibatchqueue object uses the supporting function batchRGBAndFlow, listed at the end of this example, to batch the RGB and optical flow data.
modelFilename = "I3D-RGBFlow-" + numClasses + "Classes-hmdb51.mat"; if doTraining epoch = 1; bestValAccuracy = 0; accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; iteration = 1; shuffled = shuffleTrainDs(dsTrain); % Number of outputs is three: One for RGB frames, one for optical flow % data, and one for ground truth labels. numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); start = tic; trainTime = start; % Use the initializeTrainingProgressPlot and initializeVerboseOutput % supporting functions, listed at the end of the example, to initialize % the training progress plot and verbose output to display the training % loss, training accuracy, and validation accuracy. plotters = initializeTrainingProgressPlot(params); initializeVerboseOutput(params); while iteration <= params.NumIterations % Iterate through the data set. [dlX1,dlX2,dlY] = next(mbq); % Evaluate the model gradients and loss using dlfeval. [gradRGB,gradFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = ... dlfeval(@modelGradients,dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); % Accumulate the loss and accuracies. lossTrain = [lossTrain, loss]; accTrain = [accTrain, acc]; accTrainRGB = [accTrainRGB, accRGB]; accTrainFlow = [accTrainFlow, accFlow]; % Update the network state. dlnetRGB.State = stateRGB; dlnetFlow.State = stateFlow; % Update the gradients and parameters for the RGB and optical flow % subnetworks using the SGDM optimizer. [dlnetRGB,gradRGB,params.VelocityRGB,learnRate] = ... updateDlNetwork(dlnetRGB,gradRGB,params,params.VelocityRGB,iteration); [dlnetFlow,gradFlow,params.VelocityFlow] = ... updateDlNetwork(dlnetFlow,gradFlow,params,params.VelocityFlow,iteration); if ~hasdata(mbq) || iteration == params.NumIterations % Current epoch is complete. Do validation and update progress. trainTime = toc(trainTime); [validationTime,cmat,lossValidation,accValidation,accValidationRGB,accValidationFlow] = ... doValidation(params, dlnetRGB, dlnetFlow); % Update the training progress. displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... mean(accTrain),mean(accTrainRGB),mean(accTrainFlow),... accValidation,accValidationRGB,accValidationFlow,... mean(lossTrain),lossValidation,trainTime,validationTime); updateProgressPlot(params,plotters,epoch,iteration,start,mean(lossTrain),mean(accTrain),accValidation); % Save model with the trained dlnetwork and accuracy values. % Use the saveData supporting function, listed at the % end of this example. if iteration >= params.SaveBestAfterIteration if accValidation > bestValAccuracy bestValAccuracy = accValidation; saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation); end end end if ~hasdata(mbq) && iteration < params.NumIterations % Current epoch is complete. Initialize the training loss, accuracy % values, and minibatchqueue for the next epoch. accTrain = []; accTrainRGB = []; accTrainFlow = []; lossTrain = []; trainTime = tic; epoch = epoch + 1; shuffled = shuffleTrainDs(dsTrain); numOutputs = 3; mbq = createMiniBatchQueue(shuffled, numOutputs, params); end iteration = iteration + 1; end % Display a message when training is complete. endVerboseOutput(params); disp("Model saved to: " + modelFilename); end % Download the pretrained model and video file for prediction. filename = "activityRecognition-I3D-HMDB51.zip"; downloadURL = "https://ssd.mathworks.com/supportfiles/vision/data/" + filename; filename = fullfile(downloadFolder,filename); if ~exist(filename,'file') disp('Downloading the pretrained network...'); websave(filename,downloadURL); end % Unzip the contents to the download folder. unzip(filename,downloadFolder); if ~doTraining modelFilename = fullfile(downloadFolder, modelFilename); end
Use the test data set to evaluate the accuracy of the trained subnetworks.
Load the best model saved during training.
d = load(modelFilename); dlnetRGB = d.data.dlnetRGB; dlnetFlow = d.data.dlnetFlow;
Create a minibatchqueue object to load batches of the test data.
numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params);
For each batch of test data, make predictions using the RGB and optical flow networks, take the average of the predictions, and compute the prediction accuracy using a confusion matrix.
cmat = sparse(numClasses,numClasses); while hasdata(mbq) [dlRGB, dlFlow, dlY] = next(mbq); % Pass the video input as RGB and optical flow data through the % two-stream subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(dlY,[],1); [~,YPred] = max(dlYPred,[],1); cmat = aggregateConfusionMetric(cmat,YTest,YPred); end
Compute the average classification accuracy for the trained networks.
accuracyEval = sum(diag(cmat))./sum(cmat,"all")accuracyEval =
0.60909
Display the confusion matrix.
figure chart = confusionchart(cmat,classes);

Due to the limited number of training samples, increasing the accuracy beyond 61% is challenging. To improve the robustness of the network, additional training with a large data set is required. In addition, pretraining on a larger data set, such as Kinetics [1], can help improve results.
You can now use the trained networks to predict actions in new videos. Read and display the video pour.avi using VideoReader and vision.VideoPlayer.
videoFilename = fullfile(downloadFolder, "pour.avi"); videoReader = VideoReader(videoFilename); videoPlayer = vision.VideoPlayer; videoPlayer.Name = "pour"; while hasFrame(videoReader) frame = readFrame(videoReader); step(videoPlayer,frame); end release(videoPlayer);

Use the readRGBAndFlow supporting function, listed at the end of this example, to read the RGB and optical flow data.
isDataForValidation = true; readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation);
The read function returns a logical isDone value that indicates whether there is more data to read from the file. Use the batchRGBAndFlow supporting function, defined at the end of this example, to batch the data to pass through the two-stream subnetworks to obtain the predictions.
hasdata = true; userdata = []; YPred = []; while hasdata [data,userdata,isDone] = readFcn(videoFilename,userdata); [dlRGB, dlFlow] = batchRGBAndFlow(data(:,1),data(:,2),data(:,3)); % Pass video input as RGB and optical flow data through the two-stream % subnetworks to get the separate predictions. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; [~,YPredCurr] = max(dlYPred,[],1); YPred = horzcat(YPred,YPredCurr); hasdata = ~isDone; end YPred = extractdata(YPred);
Count the number of correct predictions using histcounts, and obtain the predicted action using the maximum number of correct predictions.
classes = params.Classes; counts = histcounts(YPred,1:numel(classes)); [~,clsIdx] = max(counts); action = classes(clsIdx)
action = "pour"
inputStatisticsThe inputStatistics function takes as input the name of the folder containing the HMDB51 data, and calculates the minimum and maximum values for the RGB data and the optical flow data. The minimum and maximum values are used as normalization inputs to the input layer of the networks. This function also obtains the number of frames in each of the video files to use later during training and testing the network. In order to find the minimum and maximum values for a different data set, use this function with a folder name containing the data set.
function inputStats = inputStatistics(dataFolder) ds = createDatastore(dataFolder); ds.ReadFcn = @getMinMax; tic; tt = tall(ds); varnames = {'rgbMax','rgbMin','oflowMax','oflowMin'}; stats = gather(groupsummary(tt,[],{'max','min'}, varnames)); inputStats.Filename = gather(tt.Filename); inputStats.NumFrames = gather(tt.NumFrames); inputStats.rgbMax = stats.max_rgbMax; inputStats.rgbMin = stats.min_rgbMin; inputStats.oflowMax = stats.max_oflowMax; inputStats.oflowMin = stats.min_oflowMin; save('inputStatistics.mat','inputStats'); toc; end function data = getMinMax(filename) reader = VideoReader(filename); opticFlow = opticalFlowFarneback; data = []; while hasFrame(reader) frame = readFrame(reader); [rgb,oflow] = findMinMax(frame,opticFlow); data = assignMinMax(data, rgb, oflow); end totalFrames = floor(reader.Duration * reader.FrameRate); totalFrames = min(totalFrames, reader.NumFrames); [labelName, filename] = getLabelFilename(filename); data.Filename = fullfile(labelName, filename); data.NumFrames = totalFrames; data = struct2table(data,'AsArray',true); end function data = assignMinMax(data, rgb, oflow) if isempty(data) data.rgbMax = rgb.Max; data.rgbMin = rgb.Min; data.oflowMax = oflow.Max; data.oflowMin = oflow.Min; return; end data.rgbMax = max(data.rgbMax, rgb.Max); data.rgbMin = min(data.rgbMin, rgb.Min); data.oflowMax = max(data.oflowMax, oflow.Max); data.oflowMin = min(data.oflowMin, oflow.Min); end function [rgbMinMax,oflowMinMax] = findMinMax(rgb, opticFlow) rgbMinMax.Max = max(rgb,[],[1,2]); rgbMinMax.Min = min(rgb,[],[1,2]); gray = rgb2gray(rgb); flow = estimateFlow(opticFlow,gray); oflow = cat(3,flow.Vx,flow.Vy,flow.Magnitude); oflowMinMax.Max = max(oflow,[],[1,2]); oflowMinMax.Min = min(oflow,[],[1,2]); end function ds = createDatastore(folder) ds = fileDatastore(folder,... 'IncludeSubfolders', true,... 'FileExtensions', '.avi',... 'UniformRead', true,... 'ReadFcn', @getMinMax); disp("NumFiles: " + numel(ds.Files)); end
createFileDatastoreThe createFileDatastore function creates a FileDatastore object using the given file names. The FileDatastore object reads the data in 'partialfile' mode, so every read can return partially read frames from videos. This feature helps with reading large video files, if all of the frames do not fit in memory.
function datastore = createFileDatastore(filenames,inputStats,isDataForValidation) readFcn = @(f,u)readRGBAndFlow(f,u,inputStats,isDataForValidation); datastore = fileDatastore(filenames,... 'ReadFcn',readFcn,... 'ReadMode','partialfile'); end
readRGBAndFlowThe readRGBAndFlow function reads RGB frames, the corresponding optical flow data, and the label values for a given video file. During training, the read function reads the specific number of frames as per the network input size, with a randomly chosen starting frame. Optical flow data is calculated from the beginning of the video file, but skipped until the starting frame is reached. During testing, all the frames are sequentially read, and corresponding optical flow data is calculated. The RGB frames and optical flow data are randomly cropped to the required network input size for training, and center cropped for testing and validation.
function [data,userdata,done] = readRGBAndFlow(filename,userdata,inputStats,isDataForValidation) if isempty(userdata) userdata.reader = VideoReader(filename); userdata.batchesRead = 0; userdata.opticalFlow = opticalFlowFarneback; [totalFrames,userdata.label] = getTotalFramesAndLabel(inputStats,filename); if isempty(totalFrames) totalFrames = floor(userdata.reader.Duration * userdata.reader.FrameRate); totalFrames = min(totalFrames, userdata.reader.NumFrames); end userdata.totalFrames = totalFrames; end reader = userdata.reader; totalFrames = userdata.totalFrames; label = userdata.label; batchesRead = userdata.batchesRead; opticalFlow = userdata.opticalFlow; inputSize = inputStats.inputSize; H = inputSize(1); W = inputSize(2); rgbC = 3; flowC = 2; numFrames = inputSize(3); if numFrames > totalFrames numBatches = 1; else numBatches = floor(totalFrames/numFrames); end imH = userdata.reader.Height; imW = userdata.reader.Width; imsz = [imH,imW]; if ~isDataForValidation augmentFcn = augmentTransform([imsz,3]); cropWindow = randomCropWindow2d(imsz, inputSize(1:2)); % 1. Randomly select required number of frames, % starting randomly at a specific frame. if numFrames >= totalFrames idx = 1:totalFrames; % Add more frames to fill in the network input size. additional = ceil(numFrames/totalFrames); idx = repmat(idx,1,additional); idx = idx(1:numFrames); else startIdx = randperm(totalFrames - numFrames); startIdx = startIdx(1); endIdx = startIdx + numFrames - 1; idx = startIdx:endIdx; end video = zeros(H,W,rgbC,numFrames); oflow = zeros(H,W,flowC,numFrames); i = 1; % Discard the first set of frames to initialize the optical flow. for ii = 1:idx(1)-1 frame = read(reader,ii); getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); end % Read the next set of required number of frames for training. for ii = idx frame = read(reader,ii); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end else augmentFcn = @(data)(data); cropWindow = centerCropWindow2d(imsz, inputSize(1:2)); toRead = min([numFrames,totalFrames]); video = zeros(H,W,rgbC,toRead); oflow = zeros(H,W,flowC,toRead); i = 1; while hasFrame(reader) && i <= numFrames frame = readFrame(reader); [rgb,vxvy] = getRGBAndFlow(frame,opticalFlow,augmentFcn,cropWindow); video(:,:,:,i) = rgb; oflow(:,:,:,i) = vxvy; i = i + 1; end if numFrames > totalFrames additional = ceil(numFrames/totalFrames); video = repmat(video,1,1,1,additional); oflow = repmat(oflow,1,1,1,additional); video = video(:,:,:,1:numFrames); oflow = oflow(:,:,:,1:numFrames); end end % The network expects the video and optical flow input in % the following dlarray format: % "SSSCB" ==> Height x Width x Frames x Channels x Batch % % Permute the data % from % Height x Width x Channels x Frames % to % Height x Width x Frames x Channels video = permute(video, [1,2,4,3]); oflow = permute(oflow, [1,2,4,3]); data = {video, oflow, label}; batchesRead = batchesRead + 1; userdata.batchesRead = batchesRead; % Set the done flag to true, if the reader has read all the frames or % if it is training. done = batchesRead == numBatches || ~isDataForValidation; end function [rgb,vxvy] = getRGBAndFlow(rgb,opticalFlow,augmentFcn,cropWindow) rgb = augmentFcn(rgb); gray = rgb2gray(rgb); flow = estimateFlow(opticalFlow,gray); vxvy = cat(3,flow.Vx,flow.Vy,flow.Vy); rgb = imcrop(rgb, cropWindow); vxvy = imcrop(vxvy, cropWindow); vxvy = vxvy(:,:,1:2); end function [label,fname] = getLabelFilename(filename) [folder,name,ext] = fileparts(string(filename)); [~,label] = fileparts(folder); fname = name + ext; label = string(label); fname = string(fname); end function [totalFrames,label] = getTotalFramesAndLabel(info, filename) filenames = info.Filename; frames = info.NumFrames; [labelName, fname] = getLabelFilename(filename); idx = strcmp(filenames, fullfile(labelName,fname)); totalFrames = frames(idx); label = categorical(string(labelName), string(info.Classes)); end
augmentTransformThe augmentTransform function creates an augmentation method with random left-right flipping and scaling factors.
function augmentFcn = augmentTransform(sz) % Randomly flip and scale the image. tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); rout = affineOutputView(sz,tform,'BoundsStyle','CenterOutput'); augmentFcn = @(data)augmentData(data,tform,rout); function data = augmentData(data,tform,rout) data = imwarp(data,tform,'OutputView',rout); end end
modelGradientsThe modelGradients function takes as input a mini-batch of RGB data dlRGB, the corresponding optical flow data dlFlow, and the corresponding target dlY, and returns the corresponding loss, the gradients of the loss with respect to the learnable parameters, and the training accuracy. To compute the gradients, evaluate the modelGradients function using the dlfeval function in the training loop.
function [gradientsRGB,gradientsFlow,loss,acc,accRGB,accFlow,stateRGB,stateFlow] = modelGradients(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass video input as RGB and optical flow data through the two-stream % network. [dlYPredRGB,stateRGB] = forward(dlnetRGB,dlRGB); [dlYPredFlow,stateFlow] = forward(dlnetFlow,dlFlow); % Calculate fused loss, gradients, and accuracy for the two-stream % predictions. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); gradientsRGB = dlgradient(loss,dlnetRGB.Learnables); gradientsFlow = dlgradient(loss,dlnetFlow.Learnables); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); acc = gather(extractdata(sum(YTest == YPred)./numel(YTest))); % Calculate the accuracy of the RGB and flow predictions. [~,YTest] = max(Y,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); accRGB = gather(extractdata(sum(YTest == YPredRGB)./numel(YTest))); accFlow = gather(extractdata(sum(YTest == YPredFlow)./numel(YTest))); end
doValidationThe doValidation function validates the network using the validation data.
function [validationTime, cmat, lossValidation, accValidation, accValidationRGB, accValidationFlow] = doValidation(params, dlnetRGB, dlnetFlow) validationTime = tic; numOutputs = 3; mbq = createMiniBatchQueue(params.ValidationData, numOutputs, params); lossValidation = []; numClasses = numel(params.Classes); cmat = sparse(numClasses,numClasses); cmatRGB = sparse(numClasses,numClasses); cmatFlow = sparse(numClasses,numClasses); while hasdata(mbq) [dlX1,dlX2,dlY] = next(mbq); [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlX1,dlX2,dlY); lossValidation = [lossValidation,loss]; cmat = aggregateConfusionMetric(cmat,YTest,YPred); cmatRGB = aggregateConfusionMetric(cmatRGB,YTest,YPredRGB); cmatFlow = aggregateConfusionMetric(cmatFlow,YTest,YPredFlow); end lossValidation = mean(lossValidation); accValidation = sum(diag(cmat))./sum(cmat,"all"); accValidationRGB = sum(diag(cmatRGB))./sum(cmatRGB,"all"); accValidationFlow = sum(diag(cmatFlow))./sum(cmatFlow,"all"); validationTime = toc(validationTime); end
predictValidationThe predictValidation function calculates the loss and prediction values using the provided dlnetwork objects for RGB and optical flow data.
function [loss,YTest,YPred,YPredRGB,YPredFlow] = predictValidation(dlnetRGB,dlnetFlow,dlRGB,dlFlow,Y) % Pass the video input through the two-stream % network. dlYPredRGB = predict(dlnetRGB,dlRGB); dlYPredFlow = predict(dlnetFlow,dlFlow); % Calculate the cross-entropy separately for the two-stream % outputs. rgbLoss = crossentropy(dlYPredRGB,Y); flowLoss = crossentropy(dlYPredFlow,Y); % Fuse the losses. loss = mean([rgbLoss,flowLoss]); % Fuse the predictions by calculating the average of the predictions. dlYPred = (dlYPredRGB + dlYPredFlow)/2; % Calculate the accuracy of the predictions. [~,YTest] = max(Y,[],1); [~,YPred] = max(dlYPred,[],1); [~,YPredRGB] = max(dlYPredRGB,[],1); [~,YPredFlow] = max(dlYPredFlow,[],1); end
updateDlnetworkThe updateDlnetwork function updates the provided dlnetwork object with gradients and other parameters using SGDM optimization function sgdmupdate.
function [dlnet,gradients,velocity,learnRate] = updateDlNetwork(dlnet,gradients,params,velocity,iteration) % Determine the learning rate using the cosine-annealing learning rate schedule. learnRate = cosineAnnealingLearnRate(iteration, params); % Apply L2 regularization to the weights. idx = dlnet.Learnables.Parameter == "Weights"; gradients(idx,:) = dlupdate(@(g,w) g + params.L2Regularization*w, gradients(idx,:), dlnet.Learnables(idx,:)); % Update the network parameters using the SGDM optimizer. [dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, params.Momentum); end
cosineAnnealingLearnRateThe cosineAnnealingLearnRate function computes the learning rate based on the current iteration number, minimum learning rate, maximum learning rate, and number of iterations for annealing [3].
function lr = cosineAnnealingLearnRate(iteration, params) if iteration == params.NumIterations lr = params.MinLearningRate; return; end cosineNumIter = [0, params.CosineNumIterations]; csum = cumsum(cosineNumIter); block = find(csum >= iteration, 1,'first'); cosineIter = iteration - csum(block - 1); annealingIteration = mod(cosineIter, cosineNumIter(block)); cosineIteration = cosineNumIter(block); minR = params.MinLearningRate; maxR = params.MaxLearningRate; cosMult = 1 + cos(pi * annealingIteration / cosineIteration); lr = minR + ((maxR - minR) * cosMult / 2); end
aggregateConfusionMetricThe aggregateConfusionMetric function incrementally fills a confusion matrix based on the predicted results YPred and the expected results YTest.
function cmat = aggregateConfusionMetric(cmat,YTest,YPred) YTest = gather(extractdata(YTest)); YPred = gather(extractdata(YPred)); [m,n] = size(cmat); cmat = cmat + full(sparse(YTest,YPred,1,m,n)); end
createMiniBatchQueueThe createMiniBatchQueue function creates a minibatchqueue object that provides miniBatchSize amount of data from the given datastore. It also creates a DispatchInBackgroundDatastore if a parallel pool is open.
function mbq = createMiniBatchQueue(datastore, numOutputs, params) if params.DispatchInBackground && isempty(gcp('nocreate')) % Start a parallel pool, if DispatchInBackground is true, to dispatch % data in the background using the parallel pool. c = parcluster('local'); c.NumWorkers = params.NumWorkers; parpool('local',params.NumWorkers); end p = gcp('nocreate'); if ~isempty(p) datastore = DispatchInBackgroundDatastore(datastore, p.NumWorkers); end inputFormat(1:numOutputs-1) = "SSSCB"; outputFormat = "CB"; mbq = minibatchqueue(datastore, numOutputs, ... "MiniBatchSize", params.MiniBatchSize, ... "MiniBatchFcn", @batchRGBAndFlow, ... "MiniBatchFormat", [inputFormat,outputFormat]); end
batchRGBAndFlowThe batchRGBAndFlow function batches the image, flow, and label data into corresponding dlarray values in the data formats "SSSCB", "SSSCB", and "CB", respectively.
function [dlX1,dlX2,dlY] = batchRGBAndFlow(images, flows, labels) % Batch dimension: 5 X1 = cat(5,images{:}); X2 = cat(5,flows{:}); % Batch dimension: 2 labels = cat(2,labels{:}); % Feature dimension: 1 Y = onehotencode(labels,1); % Cast data to single for processing. X1 = single(X1); X2 = single(X2); Y = single(Y); % Move data to the GPU if possible. if canUseGPU X1 = gpuArray(X1); X2 = gpuArray(X2); Y = gpuArray(Y); end % Return X and Y as dlarray objects. dlX1 = dlarray(X1,"SSSCB"); dlX2 = dlarray(X2,"SSSCB"); dlY = dlarray(Y,"CB"); end
shuffleTrainDsThe shuffleTrainDs function shuffles the files present in the training datastore dsTrain.
function shuffled = shuffleTrainDs(dsTrain) shuffled = copy(dsTrain); n = numel(shuffled.Files); shuffledIndices = randperm(n); shuffled.Files = shuffled.Files(shuffledIndices); reset(shuffled); end
saveDataThe saveData function saves the given dlnetwork objects and accuracy values to a MAT file.
function saveData(modelFilename, dlnetRGB, dlnetFlow, cmat, accValidation) dlnetRGB = gatherFromGPUToSave(dlnetRGB); dlnetFlow = gatherFromGPUToSave(dlnetFlow); data.ValidationAccuracy = accValidation; data.cmat = cmat; data.dlnetRGB = dlnetRGB; data.dlnetFlow = dlnetFlow; save(modelFilename, 'data'); end
gatherFromGPUToSaveThe gatherFromGPUToSave function gathers data from the GPU in order to save the model to disk.
function dlnet = gatherFromGPUToSave(dlnet) if ~canUseGPU return; end dlnet.Learnables = gatherValues(dlnet.Learnables); dlnet.State = gatherValues(dlnet.State); function tbl = gatherValues(tbl) for ii = 1:height(tbl) tbl.Value{ii} = gather(tbl.Value{ii}); end end end
checkForHMDB51FolderThe checkForHMDB51Folder function checks for the downloaded data in the download folder.
function classes = checkForHMDB51Folder(dataLoc) hmdbFolder = fullfile(dataLoc, "hmdb51_org"); if ~exist(hmdbFolder, "dir") error("Download 'hmdb51_org.rar' file using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end classes = ["brush_hair","cartwheel","catch","chew","clap","climb","climb_stairs",... "dive","draw_sword","dribble","drink","eat","fall_floor","fencing",... "flic_flac","golf","handstand","hit","hug","jump","kick","kick_ball",... "kiss","laugh","pick","pour","pullup","punch","push","pushup","ride_bike",... "ride_horse","run","shake_hands","shoot_ball","shoot_bow","shoot_gun",... "sit","situp","smile","smoke","somersault","stand","swing_baseball","sword",... "sword_exercise","talk","throw","turn","walk","wave"]; expectFolders = fullfile(hmdbFolder, classes); if ~all(arrayfun(@(x)exist(x,'dir'),expectFolders)) error("Download hmdb51_org.rar using the supporting function 'downloadHMDB51' before running the example and extract the RAR file."); end end
downloadHMDB51The downloadHMDB51 function downloads the data set and saves it to a directory.
function downloadHMDB51(dataLoc) if nargin == 0 dataLoc = pwd; end dataLoc = string(dataLoc); if ~exist(dataLoc,"dir") mkdir(dataLoc); end dataUrl = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar"; options = weboptions('Timeout', Inf); rarFileName = fullfile(dataLoc, 'hmdb51_org.rar'); fileExists = exist(rarFileName, 'file'); % Download the RAR file and save it to the download folder. if ~fileExists disp("Downloading hmdb51_org.rar (2 GB) to the folder:") disp(dataLoc) disp("This download can take a few minutes...") websave(rarFileName, dataUrl, options); disp("Download complete.") disp("Extract the hmdb51_org.rar file contents to the folder: ") disp(dataLoc) end end
initializeTrainingProgressPlotThe initializeTrainingProgressPlot function configures two plots for displaying the training loss, training accuracy, and validation accuracy.
function plotters = initializeTrainingProgressPlot(params) if params.ProgressPlot % Plot the loss, training accuracy, and validation accuracy. figure % Loss plot subplot(2,1,1) plotters.LossPlotter = animatedline; xlabel("Iteration") ylabel("Loss") % Accuracy plot subplot(2,1,2) plotters.TrainAccPlotter = animatedline('Color','b'); plotters.ValAccPlotter = animatedline('Color','g'); legend('Training Accuracy','Validation Accuracy','Location','northwest'); xlabel("Iteration") ylabel("Accuracy") else plotters = []; end end
initializeVerboseOutputThe initializeVerboseOutput function displays the column headings for the table of training values, which shows the epoch, mini-batch accuracy, and other training values.
function initializeVerboseOutput(params) if params.Verbose disp(" ") if canUseGPU disp("Training on GPU.") else disp("Training on CPU.") end p = gcp('nocreate'); if ~isempty(p) disp("Training on parallel cluster '" + p.Cluster.Profile + "'. ") end disp("NumIterations:" + string(params.NumIterations)); disp("MiniBatchSize:" + string(params.MiniBatchSize)); disp("Classes:" + join(string(params.Classes), ",")); disp("|=======================================================================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-Batch Accuracy | Validation Accuracy | Mini-Batch | Validation | Base Learning | Train Time | Validation Time |") disp("| | | (hh:mm:ss) | (Avg:RGB:Flow) | (Avg:RGB:Flow) | Loss | Loss | Rate | (hh:mm:ss) | (hh:mm:ss) |") disp("|=======================================================================================================================================================================|") end end
displayVerboseOutputEveryEpochThe displayVerboseOutputEveryEpoch function displays the verbose output of the training values, such as the epoch, mini-batch accuracy, validation accuracy, and mini-batch loss.
function displayVerboseOutputEveryEpoch(params,start,learnRate,epoch,iteration,... accTrain,accTrainRGB,accTrainFlow,accValidation,accValidationRGB,accValidationFlow,lossTrain,lossValidation,trainTime,validationTime) if params.Verbose D = duration(0,0,toc(start),'Format','hh:mm:ss'); trainTime = duration(0,0,trainTime,'Format','hh:mm:ss'); validationTime = duration(0,0,validationTime,'Format','hh:mm:ss'); lossValidation = gather(extractdata(lossValidation)); lossValidation = compose('%.4f',lossValidation); accValidation = composePadAccuracy(accValidation); accValidationRGB = composePadAccuracy(accValidationRGB); accValidationFlow = composePadAccuracy(accValidationFlow); accVal = join([accValidation,accValidationRGB,accValidationFlow], " : "); lossTrain = gather(extractdata(lossTrain)); lossTrain = compose('%.4f',lossTrain); accTrain = composePadAccuracy(accTrain); accTrainRGB = composePadAccuracy(accTrainRGB); accTrainFlow = composePadAccuracy(accTrainFlow); accTrain = join([accTrain,accTrainRGB,accTrainFlow], " : "); learnRate = compose('%.13f',learnRate); disp("| " + ... pad(string(epoch),5,'both') + " | " + ... pad(string(iteration),9,'both') + " | " + ... pad(string(D),12,'both') + " | " + ... pad(string(accTrain),26,'both') + " | " + ... pad(string(accVal),26,'both') + " | " + ... pad(string(lossTrain),10,'both') + " | " + ... pad(string(lossValidation),10,'both') + " | " + ... pad(string(learnRate),13,'both') + " | " + ... pad(string(trainTime),10,'both') + " | " + ... pad(string(validationTime),15,'both') + " |") end end function acc = composePadAccuracy(acc) acc = compose('%.2f',acc*100) + "%"; acc = pad(string(acc),6,'left'); end
endVerboseOutputThe endVerboseOutput function displays the end of verbose output during training.
function endVerboseOutput(params) if params.Verbose disp("|=======================================================================================================================================================================|") end end
updateProgressPlotThe updateProgressPlot function updates the progress plot with loss and accuracy information during training.
function updateProgressPlot(params,plotters,epoch,iteration,start,lossTrain,accuracyTrain,accuracyValidation) if params.ProgressPlot % Update the training progress. D = duration(0,0,toc(start),"Format","hh:mm:ss"); title(plotters.LossPlotter.Parent,"Epoch: " + epoch + ", Elapsed: " + string(D)); addpoints(plotters.LossPlotter,iteration,double(gather(extractdata(lossTrain)))); addpoints(plotters.TrainAccPlotter,iteration,accuracyTrain); addpoints(plotters.ValAccPlotter,iteration,accuracyValidation); drawnow end end
[1] Carreira, Joao, and Andrew Zisserman. "Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR): 6299??6308. Honolulu, HI: IEEE, 2017.
[2] Simonyan, Karen, and Andrew Zisserman. "Two-Stream Convolutional Networks for Action Recognition in Videos." Advances in Neural Information Processing Systems 27, Long Beach, CA: NIPS, 2017.
[3] Loshchilov, Ilya, and Frank Hutter. "SGDR: Stochastic Gradient Descent with Warm Restarts." International Conferencee on Learning Representations 2017. Toulon, France: ICLR, 2017.