Clear accelerated deep learning function trace cache
dlnetwork object and class names from the MAT file
s = load("dlnetDigits.mat"); dlnet = s.dlnet; classNames = s.classNames;
Accelerate the model gradients function
modelGradients listed at the end of the example.
fun = @modelGradients; accfun = dlaccelerate(fun);
Clear any previously cached traces of the accelerated function using the
View the properties of the accelerated function. Because the cache is empty, the
Occupancy property is 0.
accfun = AcceleratedFunction with properties: Function: @modelGradients Enabled: 1 CacheSize: 50 HitRate: 0 Occupancy: 0 CheckMode: 'none' CheckTolerance: 1.0000e-04
AcceleratedFunction object stores the traces of underlying function calls and reuses the cached result when the same input pattern reoccurs. To use the accelerated function in a custom training loop, replace calls to the model gradients function with calls to the accelerated function. You can invoke the accelerated function as you would invoke the underlying function. Note that the accelerated function is not a function handle.
Evaluate the accelerated model gradients function with random data using the
X = rand(28,28,1,128,'single'); dlX = dlarray(X,'SSCB'); T = categorical(classNames(randi(10,[128 1]))); T = onehotencode(T,2)'; dlT = dlarray(T,'CB'); [gradients,state,loss] = dlfeval(accfun,dlnet,dlX,dlT);
Occupancy property of the accelerated function. Because the function has been evaluated, the cache is nonempty.
ans = 2
Clear the cache using the
Occupancy property of the accelerated function. Because the cache has been cleared, the cache is empty.
ans = 0
Model Gradients Function
modelGradients function takes a
dlnet, a mini-batch of input data
dlX with corresponding target labels
dlT and returns the gradients of the loss with respect to the learnable parameters in
dlnet, the network state, and the loss. To compute the gradients, use the
function [gradients,state,loss] = modelGradients(dlnet,dlX,flT) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,flT); gradients = dlgradient(loss,dlnet.Learnables); end