Specify Training Options in Custom Training Loop
For most tasks, you can control the training algorithm details using thetrainingOptions
andtrainNetwork
functions. If thetrainingOptions
function does not provide the options you need for your task (for example, a custom learning rate schedule), then you can define your own custom training loop using adlnetwork
object. Adlnetwork
object allows you to train a network specified as a layer graph using automatic differentiation.
To specify the same options as thetrainingOptions
, use these examples as a guide:
Training Option | trainingOptions Argument |
Example |
---|---|---|
Adam solver | Adaptive Moment Estimation (ADAM) | |
RMSProp solver | Root Mean Square Propagation (RMSProp) | |
SGDM solver | Stochastic Gradient Descent with Momentum (SGDM) | |
Learn rate | “InitialLearnRate' |
Learn Rate |
Learn rate schedule | Piecewise Learn Rate Schedule | |
Training progress | “阴谋” |
Plots |
Verbose output | Verbose Output | |
Mini-batch size | 'MiniBatchSize' |
Mini-Batch Size |
Number of epochs | 'MaxEpochs' |
Number of Epochs |
Validation | Validation | |
L2regularization | 'L2Regularization' |
L2 Regularization |
Gradient clipping | Gradient Clipping | |
单CPU或GPU培训 | 'ExecutionEnvironment' |
Single CPU or GPU Training |
Checkpoints | 'CheckpointPath' |
Checkpoints |
Solver Options
To specify the solver, use theadamupdate
,rmspropupdate
, andsgdmupdate
functions for the update step in your training loop. To implement your own custom solver, update the learnable parameters using thedlupdate
函数。
Adaptive Moment Estimation (ADAM)
更新your network parameters using Adam, use theadamupdate
函数。Specify the gradient decay and the squared gradient decay factors using the corresponding input arguments.
Root Mean Square Propagation (RMSProp)
更新your network parameters using RMSProp, use thermspropupdate
函数。Specify the denominator offset (epsilon) value using the corresponding input argument.
Stochastic Gradient Descent with Momentum (SGDM)
更新your network parameters using SGDM, use thesgdmupdate
函数。Specify the momentum using the corresponding input argument.
Learn Rate
To specify the learn rate, use the learn rate input arguments of theadamupdate
,rmspropupdate
, andsgdmupdate
functions.
To easily adjust the learn rate or use it for custom learn rate schedules, set the initial learn rate before the custom training loop.
learnRate = 0.01;
Piecewise Learn Rate Schedule
To automatically drop the learn rate during training using a piecewise learn rate schedule, multiply the learn rate by a given drop factor after a specified interval.
To easily specify a piecewise learn rate schedule, create the variableslearnRate
,learnRateSchedule
,learnRateDropFactor
, andlearnRateDropPeriod
, wherelearnRate
is the initial learn rate,learnRateSchedule
contains either"piecewise"
or"none"
,learnRateDropFactor
is a scalar in the range [0, 1] that specifies the factor for dropping the learning rate, andlearnRateDropPeriod
is a positive integer that specifies how many epochs between dropping the learn rate.
learnRate = 0.01; learnRateSchedule ="piecewise"learnRateDropPeriod = 10; learnRateDropFactor = 0.1;
Inside the training loop, at the end of each epoch, drop the learn rate when thelearnRateSchedule
option is"piecewise"
and the current epoch number is a multiple oflearnRateDropPeriod
. Set the new learn rate to the product of the learn rate and the learn rate drop factor.
iflearnRateSchedule = ="piecewise"&& mod(epoch,learnRateDropPeriod) == 0 learnRate = learnRate * learnRateDropFactor;end
Plots
To plot the training loss and accuracy during training, calculate the mini-batch loss and either the accuracy or the root-mean-squared-error (RMSE) in the model loss function and plot them using aTrainingProgressMonitor
object.
To easily specify that the plot should be on or off, set theVisible
property of theTrainingProgressMonitor
object. By default,Visible
is set totrue
. WhenVisible
is set tofalse
,software logs the training metrics and information but does not display the Training Progress window. You can display the Training Progress window after training by changing theVisible
property. To also plot validation metrics, use the same optionsvalidationData
andvalidationFrequency
described inValidation.
validationData = {XValidation, YValidation}; validationFrequency = 50;
Before training, initialize aTrainingProgressMonitor
object. The monitor automatically tracks the elapsed time since the construction of the object. To use this elapsed time as a proxy for training time, make sure you create theTrainingProgressMonitor
object close to the start of the training loop.
For classification tasks, create a plot to track the loss and accuracy for the training and validation data. Also track the epoch number and the training progress percentage.
monitor = trainingProgressMonitor; monitor.Metrics = ["TrainingAccuracy","ValidationAccuracy","TrainingLoss","ValidationLoss"]; groupSubPlot(monitor,"Accuracy",["TrainingAccuracy","ValidationAccuracy"]); groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"]); monitor.Info ="Epoch"; monitor.XLabel ="Iteration"; monitor.Progress = 0;
For regression tasks, adjust the code by changing the variable names and labels so that it initializes plots for the training and validation RMSE instead of the training and validation accuracy.
Inside the training loop, at the end of an iteration, use therecordMetrics
andupdateInfo
functions to include the appropriate metrics and information for the training loop. For classification tasks, add points corresponding to the mini-batch accuracy and the mini-batch loss. If the current iteration is either 1 or a multiple of the validation frequency option, then also add points for the validation data.
recordMetrics(monitor,iteration,...TrainingLoss=lossTrain,...TrainingAccuracy=accuracyTrain); updateInfo(monitor,Epoch=string(epoch) +" of "+ string(numEpochs));ifiteration == 1 || mod(iteration,validationFrequency) == 0 recordMetrics(monitor,iteration,...ValidationLoss=lossValidation,...ValidationAccuracy=accuracyValidation);endmonitor.Progress = 100*iteration/numIterations;
accuracyTrain
andlossTrain
correspond to the mini-batch accuracy and loss calculated in the model loss function. For regression tasks, use the mini-batch RMSE losses instead of the mini-batch accuracies.
You can stop training using theStopbutton in the Training Progress window. When you clickStop,Stop
property of the monitor changes to1
(true
). Training stops if your training loop exits when theStop
property is1
.
whilenumEpochs < maxEpochs && ~monitor.Stop% Custom training loop code.end
For more information about plotting and recording metrics during training, seeMonitor Custom Training Loop Progress During Training.
To learn how to compute validation metrics, seeValidation.
Verbose Output
To display the training loss and accuracy during training in a verbose table, calculate the mini-batch loss and either the accuracy (for classification tasks) or the RMSE (for regression tasks) in the model loss function and display them using thedisp
函数。
To easily specify that the verbose table should be on or off, create the variablesverbose
andverboseFrequency
, whereverbose
istrue
orfalse
andverbosefrequency
specifies how many iterations between printing verbose output. To display validation metrics, use the same optionsvalidationData
andvalidationFrequency
described inValidation.
verbose = true verboseFrequency = 50; validationData = {XValidation, YValidation}; validationFrequency = 50;
Before training, display the verbose output table headings and initialize a timer using thetic
函数。
disp("|======================================================================================================================|") disp("| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |") disp("| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |") disp("|======================================================================================================================|") start = tic;
For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.
Inside the training loop, at the end of an iteration, print the verbose output when theverbose
option istrue
and it is either the first iteration or the iteration number is a multiple ofverboseFrequency
.
ifverbose && (iteration == 1 || mod(iteration,verboseFrequency) == 0 D = duration(0,0,toc(start),'Format','hh:mm:ss');ifisempty(validationData) || mod(iteration,validationFrequency) ~= 0 accuracyValidation =""; lossValidation ="";enddisp("| "+...pad(epoch,7,'left') +" | "+...pad(iteration,11,'left') +" | "+...pad(D,14,'left') +" | "+...pad(accuracyTrain,12,'left') +" | "+...pad(accuracyValidation,12,'left') +" | "+...pad(lossTrain,12,'left') +" | "+...pad(lossValidation,12,'left') +" | "+...pad(learnRate,15,'left') +" |")end
For regression tasks, adjust the code so that it displays the training and validation RMSE instead of the training and validation accuracy.
When training is finished, print the last border of the verbose table.
disp("|======================================================================================================================|")
To learn how to compute validation metrics, seeValidation.
Mini-Batch Size
Setting the mini-batch size depends on the format of data or type of datastore used.
To easily specify the mini-batch size, create a variableminiBatchSize
.
miniBatchSize = 128;
For data in an image datastore, before training, set theReadSize
property of the datastore to the mini-batch size.
imds.ReadSize = miniBatchSize;
For data in an augmented image datastore, before training, set theMiniBatchSize
property of the datastore to the mini-batch size.
augimds.MiniBatchSize = miniBatchSize;
For in-memory data, during training at the start of each iteration, read the observations directly from the array.
idx = ((iteration - 1)*miniBatchSize + 1):(iteration*miniBatchSize); X = XTrain(:,:,:,idx);
Number of Epochs
Specify the maximum number of epochs for training in the outer loop of the training loop.
To easily specify the maximum number of epochs, create the variablemaxEpochs
that contains the maximum number of epochs.
maxEpochs = 30;
In the outer loop of the training loop, specify to loop over the range 1, 2, …,maxEpochs
.
epoch = 0;whilenumEpochs < maxEpochs epoch = epoch + 1;...end
Validation
To validate your network during training, set aside a held-out validation set and evaluate how well the network performs on that data.
To easily specify validation options, create the variablesvalidationData
andvalidationFrequency
, wherevalidationData
contains the validation data or is empty andvalidationFrequency
specifies how many iterations between validating the network.
validationData = {XValidation,TValidation}; validationFrequency = 50;
During the training loop, after updating the network parameters, test how well the network performs on the held-out validation set using thepredict
函数。Validate the network only when validation data is specified and it is either the first iteration or the current iteration is a multiple of thevalidationFrequency
option.
ifiteration == 1 || mod(iteration,validationFrequency) == 0 YValidation = predict(net,XValidation); lossValidation = crossentropy(YValidation,TValidation); [~,idx] = max(YValidation); labelsPredValidation = classNames(idx); accuracyValidation = mean(labelsPredValidation == labelsValidation);end
TValidation
is a one-hot encoded array of labels overclassNames
. To calculate the accuracy, convertTValidation
to an array of labels.
For regression tasks, adjust the code so that it calculates the validation RMSE instead of the validation accuracy.
For an example showing how to calculate and plot validation metrics during training, seeMonitor Custom Training Loop Progress During Training.
Early Stopping
To stop training early when the loss on the held-out validation stops decreasing, use a flag to break out of the training loops.
To easily specify the validation patience (the number of times that the validation loss can be larger than or equal to the previously smallest loss before network training stops), create the variablevalidationPatience
.
validationPatience = 5;
Before training, initialize a variablesearlyStop
andvalidationLosses
, whereearlyStop
is a flag to stop training early andvalidationLosses
contains the losses to compare. Initialize the early stopping flag withfalse
and array of validation losses withinf
.
earlyStop = false;ifisfinite(validationPatience) validationLosses = inf(1,validationPatience);end
Inside the training loop, in the loop over mini-batches, add theearlyStop
旗帜the loop condition.
whilehasdata(ds) && ~earlyStop...end
During the validation step, append the new validation loss to the arrayvalidationLosses
. If the first element of the array is the smallest, then set theearlyStop
旗帜true
. Otherwise, remove the first element.
ifisfinite(validationPatience) validationLosses = [validationLosses validationLoss];ifmin(validationLosses) == validationLosses(1) earlyStop = true;elsevalidationLosses(1) = [];endend
L2Regularization
To apply L2regularization to the weights, use thedlupdate
函数。
To easily specify the L2regularization factor, create the variablel2Regularization
that contains the L2regularization factor.
l2Regularization = 0.0001;
During training, after computing the model loss and gradients, for each of the weight parameters, add the product of the L2regularization factor and the weights to the computed gradients using thedlupdate
函数。更新only the weight parameters, extract the parameters with name"Weights"
.
idx = net.Learnables.Parameter =="Weights"; gradients(idx,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idx,:), net.Learnables(idx,:));
After adding the L2regularization parameter to the gradients, update the network parameters.
Gradient Clipping
To clip the gradients, use thedlupdate
函数。
To easily specify gradient clipping options, create the variablesgradientThresholdMethod
andgradientThreshold
, wheregradientThresholdMethod
contains"global-l2norm"
,"l2norm"
, or"absolute-value"
, andgradientThreshold
is a positive scalar containing the threshold orinf
.
gradientThresholdMethod ="global-l2norm"; gradientThreshold = 2;
Create functions namedthresholdGlobalL2Norm
,thresholdL2Norm
, andthresholdAbsoluteValue
that apply the"global-l2norm"
,"l2norm"
, and"absolute-value"
threshold methods, respectively.
For the"global-l2norm"
option, the function operates on all gradients of the model.
functiongradients = thresholdGlobalL2Norm(gradients,gradientThreshold) globalL2Norm = 0;fori = 1:numel(gradients) globalL2Norm = globalL2Norm + sum(gradients{i}(:).^2);endglobalL2Norm = sqrt(globalL2Norm);ifglobalL2Norm > gradientThreshold normScale = gradientThreshold / globalL2Norm;fori = 1:numel(gradients) gradients{i} = gradients{i} * normScale;endendend
For the"l2norm"
and"absolute-value"
options, the functions operate on each gradient independently.
functiongradients = thresholdL2Norm(gradients,gradientThreshold) gradientNorm = sqrt(sum(gradients(:).^2));ifgradientNorm > gradientThreshold gradients = gradients * (gradientThreshold / gradientNorm);endend
functiongradients = thresholdAbsoluteValue(gradients,gradientThreshold) gradients(gradients > gradientThreshold) = gradientThreshold; gradients(gradients < -gradientThreshold) = -gradientThreshold;end
During training, after computing the model loss and gradients, apply the appropriate gradient clipping method to the gradients using thedlupdate
函数。因为"global-l2norm"
option requires all the gradient values, apply thethresholdGlobalL2Norm
function directly to the gradients. For the"l2norm"
and"absolute-value"
options, update the gradients independently using thedlupdate
函数。
switchgradientThresholdMethodcase"global-l2norm"gradients = thresholdGlobalL2Norm(gradients, gradientThreshold);case"l2norm"gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);case"absolute-value"gradients = dlupdate(@(g) thresholdAbsoluteValue(g, gradientThreshold),gradients);end
After applying the gradient threshold operation, update the network parameters.
Single CPU or GPU Training
The software, by default, performs calculations using only the CPU. To train on a single GPU, convert the data togpuArray
objects.Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, seeGPU Computing Requirements(Parallel Computing Toolbox).
To easily specify the execution environment, create the variableexecutionEnvironment
that contains either"cpu"
,"gpu"
, or"auto"
.
executionEnvironment ="auto"
During training, after reading a mini-batch, check the execution environment option and convert the data to agpuArray
if necessary. ThecanUseGPU
function checks for useable GPUs.
if(executionEnvironment =="auto"&& canUseGPU) || executionEnvironment =="gpu"X = gpuArray(X);end
Checkpoints
To save checkpoint networks during training save the network using thesave
函数。
To easily specify whether checkpoints should be switched on, create the variablecheckpointPath
contains the folder for the checkpoint networks or is empty.
checkpointPath = fullfile(tempdir,"checkpoints");
If the checkpoint folder does not exist, then before training, create the checkpoint folder.
if~exist(checkpointPath,"dir") mkdir(checkpointPath)end
During training, at the end of an epoch, save the network in a MAT file. Specify a file name containing the current iteration number, date, and time.
if~isempty(checkpointPath) D = string(datetime("now",Format="yyyy_MM_dd__HH_mm_ss")); filename ="net_checkpoint__"+ iteration +"__"+ D +".mat"; save(filename,"net")end
net
is thedlnetwork
object to be saved.
See Also
adamupdate
|rmspropupdate
|sgdmupdate
|dlupdate
|dlarray
|dlgradient
|dlfeval
|dlnetwork
Related Topics
- Define Custom Training Loops, Loss Functions, and Networks
- Define Model Loss Function for Custom Training Loop
- Train Network Using Custom Training Loop
- Train Network Using Model Function
- Make Predictions Using dlnetwork Object
- Make Predictions Using Model Function
- Initialize Learnable Parameters for Model Function
- Update Batch Normalization Statistics in Custom Training Loop
- Update Batch Normalization Statistics Using Model Function
- Train Generative Adversarial Network (GAN)
- List of Functions with dlarray Support