dlupdate
Syntax
Description
updates the learnable parameters of thenetUpdated
= dlupdate(fun
,net
)dlnetwork
objectnet
by evaluating the functionfun
with each learnable parameter as an input.fun
is a function handle to a function that takes one parameter array as an input argument and returns an updated parameter array.
Examples
L1 Regularization withdlupdate
Perform L1 regularization on a structure of parameter gradients.
Create the sample input data.
dlX = dlarray(rand(100,100,3),'SSC');
Initialize the learnable parameters for the convolution operation.
params.Weights = dlarray(rand(10,10,3,50)); params.Bias = dlarray(rand(50,1));
Calculate the gradients for the convolution operation using the helper functionconvGradients
,defined at the end of this example.
gradients = dlfeval(@convGradients,dlX,params);
Define the regularization factor.
L1Factor = 0.001;
Create an anonymous function that regularizes the gradients. By using an anonymous function to pass a scalar constant to the function, you can avoid having to expand the constant value to the same size and structure as the parameter variable.
L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);
Usedlupdate
to apply the regularization function to each of the gradients.
gradients = dlupdate(L1Regularizer,gradients,params);
The gradients ingrads
are now regularized according to the functionL1Regularizer
.
convGradients
Function
TheconvGradients
helper function takes the learnable parameters of the convolution operation and a mini-batch of input datadlX
,and returns the gradients with respect to the learnable parameters.
functiongradients = convGradients(dlX,params) dlY = dlconv(dlX,params.Weights,params.Bias); dlY = sum(dlY,'all'); gradients = dlgradient(dlY,params);end
Usedlupdate
to Train Network Using Custom Update Function
Usedlupdate
to train a network using a custom update function that implements the stochastic gradient descent algorithm (without momentum).
Load Training Data
Load the digits training data.
[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);
Define the Network
Define the network architecture and specify the average image value using theMean
option in the image input layer.
layers = [ imageInputLayer([28 28 1],'Mean',mean(XTrain,4)) convolution2dLayer(5,20) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
Create adlnetwork
object from the layer array.
net = dlnetwork(layers);
Define Model Loss Function
Create the helper functionmodelLoss
,listed at the end of this example. The function takes adlnetwork
object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.
Define Stochastic Gradient Descent Function
Create the helper functionsgdFunction
,listed at the end of this example. The function takes the parameters and the gradients of the loss with respect to the parameters, and returns the updated parameters using the stochastic gradient descent algorithm, expressed as
where 是迭代数, is the learning rate, is the parameter vector, and is the loss function.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 30; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Specify the learning rate.
learnRate = 0.01;
Train Network
Calculate the total number of iterations for the training progress monitor.
numIterations = numEpochs * numIterationsPerEpoch;
Initialize theTrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters by callingdlupdate
with the functionsgdFunction
defined at the end of this example. At the end of each epoch, display the training progress.
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, seeGPU Computing Requirements(Parallel Computing Toolbox).
iteration = 0; epoch = 0;whileepoch < numEpochs && ~monitor.Stop epoch = epoch + 1;% Shuffle data.idx = randperm(numel(TTrain)); XTrain = XTrain(:,:,:,idx); TTrain = TTrain(idx); i = 0;whilei < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1;% Read mini-batch of data and convert the labels to dummy% variables.idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); T = zeros(numClasses, miniBatchSize,"single");forc = 1:numClasses T(c,TTrain(idx)==classes(c)) = 1;end% Convert mini-batch of data to dlarray.X = dlarray(single(X),"SSCB");% If training on a GPU, then convert data to a gpuArray.ifcanUseGPU X = gpuArray(X);end% Evaluate the model loss and gradients using dlfeval and the% modelLoss function.[loss,gradients] = dlfeval(@modelLoss,net,X,T);% Update the network parameters using the SGD algorithm defined in% the sgdFunction helper function.updateFcn = @(net,gradients) sgdFunction(net,gradients,learnRate); net = dlupdate(updateFcn,net,gradients);% Update the training progress monitor.recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch +" of "+ numEpochs); monitor.Progress = 100 * iteration/numIterations;endend
Test Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest,TTest] = digitTest4DArrayData;
Convert the data to adlarray
with the dimension format"SSCB"
(spatial, spatial, channel, batch). For GPU prediction, also convert the data to agpuArray
.
XTest = dlarray(XTest,"SSCB");ifcanUseGPU XTest = gpuArray(XTest);end
To classify images using adlnetwork
object, use thepredict
function and find the classes with the highest scores.
YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YTest==TTest)
accuracy = 0.9040
Model Loss Function
The helper functionmodelLoss
takes adlnetwork
objectnet
and a mini-batch of input dataX
with corresponding labelsT
,and returns the loss and the gradients of the loss with respect to the learnable parameters innet
. To compute the gradients automatically, use thedlgradient
function.
function[loss,gradients] = modelLoss(net,X,T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables);end
Stochastic Gradient Descent Function
The helper functionsgdFunction
takes the learnable parametersparameters
,the gradients of the loss with with respect to the learnable parameters, and the learning ratelearnRate
,and returns the updated parameters using the stochastic gradient descent algorithm, expressed as
where 是迭代数, is the learning rate, is the parameter vector, and is the loss function.
functionparameters = sgdFunction(parameters,gradients,learnRate) parameters = parameters - learnRate .* gradients;end
Input Arguments
net
—Network
dlnetwork
object
Network, specified as adlnetwork
object.
The function updates theLearnables
property of thedlnetwork
object.net.Learnables
is a table with three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
params
—Network learnable parameters
dlarray
|numeric array|cell array|structure|table
Network learnable parameters, specified as adlarray
,a numeric array, a cell array, a structure, or a table.
If you specifyparams
as a table, it must contain the following three variables.
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
You can specifyparams
东北的容器可学的参数twork using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must bedlarray
or numeric values of data typedouble
orsingle
.
The input argumentA1,...,An
must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) asparams
.
Data Types:single
|double
|struct
|table
|cell
A1,...,An
—Additional input arguments
dlarray
|numeric array|cell array|structure|table
Additional input arguments tofun
,specified asdlarray
objects, numeric arrays, cell arrays, structures, or tables with aValue
variable.
The exact form ofA1,...,An
depends on the input network or learnable parameters. The following table shows the required format forA1,...,An
for possible inputs todlupdate
.
Input | Learnable Parameters | A1,...,An |
---|---|---|
net |
Tablenet.Learnables containingLayer ,Parameter ,andValue variables. TheValue variable consists of cell arrays that contain each learnable parameter as adlarray . |
Table with the same data type, variables, and ordering asnet.Learnables .A1,...,An must have aValue variable consisting of cell arrays that contain the additional input arguments for the functionfun to apply to each learnable parameter. |
params |
dlarray |
dlarray with the same data type and ordering asparams . |
Numeric array | Numeric array with the same data type and ordering asparams . |
|
Cell array | Cell array with the same data types, structure, and ordering asparams . |
|
Structure | Structure with the same data types, fields, and ordering asparams . |
|
Table withLayer ,Parameter ,andValue variables. TheValue variable must consist of cell arrays that contain each learnable parameter as adlarray . |
Table with the same data types, variables and ordering asparams .A1,...,An must have aValue variable consisting of cell arrays that contain the additional input argument for the functionfun to apply to each learnable parameter. |
Output Arguments
netUpdated
— Updated network
dlnetwork
object
Network, returned as adlnetwork
object.
The function updates theLearnables
property of thedlnetwork
object.
params
— Updated network learnable parameters
dlarray
| numeric array | cell array | structure | table
Updated network learnable parameters, returned as adlarray
,a numeric array, a cell array, a structure, or a table with aValue
variable containing the updated learnable parameters of the network.
X1,...,Xm
— Additional output arguments
dlarray
| numeric array | cell array | structure | table
Additional output arguments from the functionfun
,wherefun
is a function handle to a function that returns multiple outputs, returned asdlarray
objects, numeric arrays, cell arrays, structures, or tables with aValue
variable.
The exact form ofX1,...,Xm
depends on the input network or learnable parameters. The following table shows the returned format ofX1,...,Xm
for possible inputs todlupdate
.
Input | Learnable parameters | X1,...,Xm |
---|---|---|
net |
Tablenet.Learnables containingLayer ,Parameter ,andValue variables. TheValue variable consists of cell arrays that contain each learnable parameter as adlarray . |
Table with the same data type, variables, and ordering asnet.Learnables .X1,...,Xm has aValue variable consisting of cell arrays that contain the additional output arguments of the functionfun applied to each learnable parameter. |
params |
dlarray |
dlarray with the same data type and ordering asparams . |
Numeric array | Numeric array with the same data type and ordering asparams . |
|
Cell array | Cell array with the same data types, structure, and ordering asparams . |
|
Structure | Structure with the same data types, fields, and ordering asparams . |
|
Table withLayer ,Parameter ,andValue variables. TheValue variable must consist of cell arrays that contain each learnable parameter as adlarray . |
Table with the same data types, variables. and ordering asparams .X1,...,Xm has aValue variable consisting of cell arrays that contain the additional output argument of the functionfun applied to each learnable parameter. |
Extended Capabilities
GPU Arrays
Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.
Usage notes and limitations:
When at least one of the following input arguments is a
gpuArray
or adlarray
with underlying data of typegpuArray
,this function runs on the GPU.params
A1,...,An
For more information, seeRun MATLAB Functions on a GPU(Parallel Computing Toolbox).
Version History
Introduced in R2019b
Abrir ejemplo
Tiene una versión modificada de este ejemplo. ¿Desea abrir este ejemplo con sus modificaciones?
Comando de MATLAB
Ha hecho clic en un enlace que corresponde a este comando de MATLAB:
Ejecute el comando introduciéndolo en la ventana de comandos de MATLAB. Los navegadores web no admiten comandos de MATLAB.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select:.
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina(Español)
- Canada(English)
- United States(English)
Europe
- Belgium(English)
- Denmark(English)
- Deutschland(Deutsch)
- España(Español)
- Finland(English)
- France(Français)
- Ireland(English)
- Italia(Italiano)
- Luxembourg(English)
- Netherlands(English)
- Norway(English)
- Österreich(Deutsch)
- Portugal(English)
- Sweden(English)
- Switzerland
- United Kingdom(English)