How to plot the loss function on the overall dataset in Training Progress

26 views (last 30 days)
I am writing a Convolutional Neural Network for regression in MATLAB R2021b. I'm using the trainNetwork function in Deep Learning Toolbox and in options I have 'Plots','training-progress'.
I have understood that in the Training Progress, for each iteration, I have the MSE value computed on the mini-batch.
I wuold like to know whether I can plot in the training progress the MSE value on the overall Training Set and Validation Set respectively and not on the mini-batch.
Thanks in advance.

Answers (1)

Aneela
Aneela on 13 Sep 2024
Edited: Aneela on 22 Sep 2024
Hi Maria,
The trainNetwork function's default training progress plot displays the mini-batch loss (such as MSE) and accuracy for each iteration.
  • However, it does not directly provide options to display metrics computed over the entire training or validation set during training.
  • To achieve this, the training loop should be customised using a custom training loop approach.
Here’s a possible workaround:
  • Set hyperparameters like learning rate, number of epochs, and mini-batch size.
  • Iterate over number of epochs.
  • Within each epoch, iterate over mini-batches and compute predictions for mini-batch.
  • Compute the gradients of the loss with respect to model parameters.
  • Improve the model's performance by minimizing the loss using an optimization algorithm.
  • Compute the MSE over the training and validation datasets after each mini-batch update. Here’s a sample code snippet:
%net -Network, (XTrain,YTrain)-Training data,
% (XValidation, YValidation)-Validation data
YPredTrain = predict(net, XTrain);
trainMSE = mean((YPredTrain - YTrain).^2);
YPredValidation = predict(net, XValidation);
validationMSE = mean((YPredValidation - YValidation).^2);
  • Plot the MSE for both training and validation datasets throughout the training process using “addpoints and “drawnow”.
trainingPlot = animatedline('Color','r');
validationPlot = animatedline('Color','b');
addpoints(trainingPlot, iteration, trainMSE);
addpoints(validationPlot, iteration, validationMSE);
drawnow;
Refer to the following MathWorks documentation for more information on:

Categories

Find more on Image Data Workflows in Help Center and File Exchange

Products


Release

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!