forward
Syntax
Description
Some deep learning layers behave differently during training and inference (prediction). For example, during training, dropout layers randomly set input elements to zero to help prevent overfitting, but during inference, dropout layers do not change the input.
To compute network outputs for training, use theforward
function. To compute network outputs for inference, use thepredict
function.
(日元…向前,YN) = (___)
returns theN
outputsY1
, …,YN
during training for networks that haveN
outputs using any of the previous syntaxes.
(日元…,YK] = forward(___,'Outputs',
returns the outputslayerNames
)Y1
, …,YK
during training for the specified layers using any of the previous syntaxes.
[___] = forward(___,'Acceleration',
also specifies performance optimization to use during training, in addition to the input arguments in previous syntaxes.acceleration
)
[___,
also returns the updated network state.state
] = forward(___)
[___,
also returns a cell array of activations of the pruning layers. This syntax is applicable only ifstate
,pruningActivations
] = forward(___)net
是一个TaylorPrunableNetwork
object.
To prune a deep neural network, you require theDeep Learning Toolbox™ Model Quantization Librarysupport package. This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, seeDeep Learning ToolboxModel Quantization Library.
Examples
Train Network Using Custom Training Loop
This example shows how to train a network that classifies handwritten digits with a custom learning rate schedule.
You can train most types of neural networks using thetrainNetwork
andtrainingOptions
functions. If thetrainingOptions
function does not provide the options you need (for example, a custom learning rate schedule), then you can define your own custom training loop usingdlarray
anddlnetwork
objects for automatic differentiation. For an example showing how to retrain a pretrained deep learning network using thetrainNetwork
function, seeTransfer Learning Using Pretrained Network.
Training a deep neural network is an optimization task. By considering a neural network as a function , where is the network input, and is the set of learnable parameters, you can optimize so that it minimizes some loss value based on the training data. For example, optimize the learnable parameters such that for a given inputs with a corresponding targets , they minimize the error between the predictions and .
取决于使用的损失函数类型的任务. For example:
For classification tasks, you can minimize the cross entropy error between the predictions and targets.
For regression tasks, you can minimize the mean squared error between the predictions and targets.
You can optimize the objective using gradient descent: minimize the loss by iteratively updating the learnable parameters 通过步骤使用毕业生的最低ients of the loss with respect to the learnable parameters. Gradient descent algorithms typically update the learnable parameters by using a variant of an update step of the form , where is the iteration number, is the learning rate, and denotes the gradients (the derivatives of the loss with respect to the learnable parameters).
This example trains a network to classify handwritten digits with thetime-based decaylearning rate schedule: for each iteration, the solver uses the learning rate given by , wheretis the iteration number, is the initial learning rate, andkis the decay.
Load Training Data
Load the digits data as an image datastore using theimageDatastore
function and specify the folder containing the image data.
dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder,...IncludeSubfolders=true,....LabelSource="foldernames");
Partition the data into training and validation sets. Set aside 10% of the data for validation using thesplitEachLabel
function.
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,"randomize");
The network used in this example requires input images of size 28-by-28-by-1. To automatically resize the training images, use an augmented image datastore. Specify additional augmentation operations to perform on the training images: randomly translate the images up to 5 pixels in the horizontal and vertical axes. Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.
inputSize = [28 28 1]; pixelRange = [-5 5]; imageAugmenter = imageDataAugmenter(...RandXTranslation=pixelRange,...RandYTranslation=pixelRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=imageAugmenter);
To automatically resize the validation images without performing further data augmentation, use an augmented image datastore without specifying any additional preprocessing operations.
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
Determine the number of classes in the training data.
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
Define Network
Define the network for image classification.
For image input, specify an image input layer with input size matching the training data.
Do not normalize the image input, set the
Normalization
option of the input layer to"none"
.Specify three convolution-batchnorm-ReLU blocks.
Pad the input to the convolution layers such that the output has the same size by setting the
Padding
option to"same"
.For the first convolution layer specify 20 filters of size 5. For the remaining convolution layers specify 20 filters of size 3.
For classification, specify a fully connected layer with size matching the number of classes
To map the output to probabilities, include a softmax layer.
When training a network using a custom training loop, do not include an output layer.
layers = [ imageInputLayer(inputSize,Normalization="none") convolution2dLayer(5,20,Padding="same") batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding="same") batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding="same") batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
Create adlnetwork
object from the layer array.
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
Define Model Loss Function
Training a deep neural network is an optimization task. By considering a neural network as a function , where is the network input, and is the set of learnable parameters, you can optimize so that it minimizes some loss value based on the training data. For example, optimize the learnable parameters such that for a given inputs with a corresponding targets , they minimize the error between the predictions and .
Create the functionmodelLoss
, listed in theModel Loss Functionsection of the example, that takes as input thedlnetwork
object, a mini-batch of input data with corresponding targets, and returns the loss, the gradients of the loss with respect to the learnable parameters, and the network state.
Specify Training Options
Train for ten epochs with a mini-batch size of 128.
numEpochs = 10; miniBatchSize = 128;
Specify the options for SGDM optimization. Specify an initial learn rate of 0.01 with a decay of 0.01, and momentum 0.9.
initialLearnRate = 0.01; decay = 0.01; momentum = 0.9;
Train Model
Create aminibatchqueue
object that processes and manages mini-batches of images during training. For each mini-batch:
Use the custom mini-batch preprocessing function
preprocessMiniBatch
(defined at the end of this example) to convert the labels to one-hot encoded variables.Format the image data with the dimension labels
"SSCB"
(spatial, spatial, channel, batch). By default, theminibatchqueue
object converts the data todlarray
objects with underlying typesingle
. Do not format the class labels.Train on a GPU if one is available. By default, the
minibatchqueue
object converts each output to agpuArray
如果availabl GPUe. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, seeGPU Computing Requirements(Parallel Computing Toolbox).
mbq = minibatchqueue(augimdsTrain,...MiniBatchSize=miniBatchSize,...MiniBatchFcn=@preprocessMiniBatch,...MiniBatchFormat=["SSCB"""]);
Initialize the velocity parameter for the SGDM solver.
velocity = [];
Calculate the total number of iterations for the training progress monitor.
numObservationsTrain = numel(imdsTrain.Files); numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize); numIterations = numEpochs * numIterationsPerEpoch;
Initialize theTrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info=["Epoch","LearnRate"],XLabel="Iteration");
Train the network using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. For each mini-batch:
Evaluate the model loss, gradients, and state using the
dlfeval
andmodelLoss
functions and update the network state.Determine the learning rate for the time-based decay learning rate schedule.
更新网络parameters using the
sgdmupdate
function.Update the loss, learn rate, and epoch values in the training progress monitor.
Stop if the Stop property is true. The Stop property value of the
TrainingProgressMonitor
object changes to true when you click the Stop button.
epoch = 0; iteration = 0;% Loop over epochs.whileepoch < numEpochs && ~monitor.Stop epoch = epoch + 1;% Shuffle data.shuffle(mbq);% Loop over mini-batches.whilehasdata(mbq) && ~monitor.Stop iteration = iteration + 1;% Read mini-batch of data.[X,T] = next(mbq);% Evaluate the model gradients, state, and loss using dlfeval and the% modelLoss function and update the network state.[loss,gradients,state] = dlfeval(@modelLoss,net,X,T); net.State = state;% Determine learning rate for time-based decay learning rate schedule.learnRate = initialLearnRate/(1 + decay*iteration);% Update the network parameters using the SGDM optimizer.[net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);% Update the training progress monitor.recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch,LearnRate=learnRate); monitor.Progress = 100 * iteration/numIterations;endend
Test Model
Test the classification accuracy of the model by comparing the predictions on the validation set with the true labels.
After training, making predictions on new data does not require the labels. Createminibatchqueue
object containing only the predictors of the test data:
To ignore the labels for testing, set the number of outputs of the mini-batch queue to 1.
Specify the same mini-batch size used for training.
Preprocess the predictors using the
preprocessMiniBatchPredictors
function, listed at the end of the example.For the single output of the datastore, specify the mini-batch format
"SSCB"
(spatial, spatial, channel, batch).
numOutputs = 1; mbqTest = minibatchqueue(augimdsValidation,numOutputs,...MiniBatchSize=miniBatchSize,...MiniBatchFcn=@preprocessMiniBatchPredictors,...MiniBatchFormat="SSCB");
Loop over the mini-batches and classify the images usingmodelPredictions
function, listed at the end of the example.
YTest = modelPredictions(net,mbqTest,classes);
Evaluate the classification accuracy.
TTest = imdsValidation.Labels; accuracy = mean(TTest == YTest)
accuracy = 0.9750
Visualize the predictions in a confusion chart.
figure confusionchart(TTest,YTest)
Large values on the diagonal indicate accurate predictions for the corresponding class. Large values on the off-diagonal indicate strong confusion between the corresponding classes.
Supporting Functions
Model Loss Function
ThemodelLoss
function takes adlnetwork
objectnet
, a mini-batch of input dataX
with corresponding targetsT
and returns the loss, the gradients of the loss with respect to the learnable parameters innet
, and the network state. To compute the gradients automatically, use thedlgradient
function.
function[loss,gradients,state] = modelLoss(net,X,T)% Forward data through network.[Y,state] = forward(net,X);% Calculate cross-entropy loss.loss = crossentropy(Y,T);% Calculate gradients of loss with respect to learnable parameters.gradients = dlgradient(loss,net.Learnables);end
Model Predictions Function
ThemodelPredictions
function takes adlnetwork
objectnet
, aminibatchqueue
of input datambq
, and the network classes, and computes the model predictions by iterating over all data in theminibatchqueue
object. The function uses theonehotdecode
function to find the predicted class with the highest score.
functionY = modelPredictions(net,mbq,classes) Y = [];% Loop over mini-batches.whilehasdata(mbq) X = next(mbq);% Make prediction.scores = predict(net,X);% Decode labels and append to output.labels = onehotdecode(scores,classes,1)'; Y = [Y; labels];endend
Mini Batch Preprocessing Function
ThepreprocessMiniBatch
function preprocesses a mini-batch of predictors and labels using the following steps:
Preprocess the images using the
preprocessMiniBatchPredictors
function.Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
One-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
function[X,T] = preprocessMiniBatch(dataX,dataT)% Preprocess predictors.X = preprocessMiniBatchPredictors(dataX);% Extract label data from cell and concatenate.T = cat(2,dataT{1:end});% One-hot encode labels.T = onehotencode(T,1);end
Mini-Batch Predictors Preprocessing Function
ThepreprocessMiniBatchPredictors
function preprocesses a mini-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension.
functionX = preprocessMiniBatchPredictors(dataX)% Concatenate.X = cat(4,dataX{1:end});end
Input Arguments
net
—Network for custom training loops or custom pruning loops
dlnetwork
object|TaylorPrunableNetwork
object
This argument can represent either of these:
Network for custom training loops, specified as a
dlnetwork
object.Network for custom pruning loops, specified as a
TaylorPrunableNetwork
object.
To prune a deep neural network, you require theDeep Learning Toolbox Model Quantization Librarysupport package. This support package is a free add-on that you can download using the Add-On Explorer. Alternatively, seeDeep Learning ToolboxModel Quantization Library.
layerNames
—Layers to extract outputs from
string array|单元阵列的生产er vectors
Layers to extract outputs from, specified as a string array or a cell array of character vectors containing the layer names.
If
layerNames(i)
corresponds to a layer with a single output, thenlayerNames(i)
is the name of the layer.If
layerNames(i)
corresponds to a layer with multiple outputs, thenlayerNames(i)
is the layer name followed by the character "/
" and the name of the layer output:'layerName/outputName'
.
acceleration
—Performance optimization
'auto'
(default) |“没有”
Performance optimization, specified as one of the following:
'auto'
— Automatically apply a number of optimizations suitable for the input network and hardware resources.“没有”
— Disable all acceleration.
The default option is'auto'
.
Using the'auto'
acceleration option can offer performance benefits, but at the expense of an increased initial run time. Subsequent calls with compatible parameters are faster. Use performance optimization when you plan to call the function multiple times using different input data with the same size and shape.
Output Arguments
state
——更新网络状态
table
Updated network state, returned as a table.
The network state is a table with three columns:
Layer
– Layer name, specified as a string scalar.Parameter
– State parameter name, specified as a string scalar.Value
– Value of state parameter, specified as adlarray
object.
Layer states contain information calculated during the layer operation to be retained for use in subsequent forward passes of the layer. For example, the cell state and hidden state of LSTM layers, or running statistics in batch normalization layers.
For recurrent layers, such as LSTM layers, with theHasStateInputs
property set to1
(true), the state table does not contain entries for the states of that layer.
pruningActivations
— Activations of the pruning layers
cell array containingdlarray
objects
Cell array of activations of the pruning layers, if the input network is aTaylorPrunableNetwork
object.
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
This function runs on the GPU if either or both of the following conditions are met:
Any of the values of the network learnable parameters inside
net.Learnables.Value
aredlarray
objects with underlying data of typegpuArray
The input argument
X
是一个dlarray
with underlying data of typegpuArray
For more information, seeRun MATLAB Functions on a GPU(Parallel Computing Toolbox).
Version History
Introduced in R2019bR2021a:forward
returns state values asdlarray
objects
Fordlnetwork
objects, thestate
output argument returned by theforward
function is a table containing the state parameter names and values for each layer in the network.
Starting in R2021a, the state values aredlarray
objects. This change enables better support when usingAcceleratedFunction
objects. To accelerate deep learning functions that have frequently changing input values, for example, an input containing the network state, the frequently changing values must be specified asdlarray
objects.
In previous versions, the state values are numeric arrays.
In most cases, you will not need to update your code. If you have code that requires the state values to be numeric arrays, then to reproduce the previous behavior, extract the data from the state values manually using theextractdata
function with thedlupdate
function.
state = dlupdate(@extractdata,net.State);
Abrir ejemplo
Tiene una versión modificada de este ejemplo. ¿Desea abrir este ejemplo con sus modificaciones?
Comando de MATLAB
Ha hecho clic en un enlace que corresponde a este comando de MATLAB:
Ejecute el comando introduciéndolo en la ventana de comandos de MATLAB. Los navegadores web no admiten comandos de MATLAB.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select:.
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina(Español)
- Canada(English)
- United States(English)
Europe
- Belgium(English)
- Denmark(English)
- Deutschland(Deutsch)
- España(Español)
- Finland(English)
- France(Français)
- Ireland(English)
- Italia(Italiano)
- Luxembourg(English)
- Netherlands(English)
- Norway(English)
- Österreich(Deutsch)
- Portugal(English)
- Sweden(English)
- Switzerland
- United Kingdom(English)