sgdmupdate
使用随机动量梯度下降(SGDM)更新参数
语法
描述
使用随机动量梯度下降(SGDM)算法更新自定义训练循环中的网络可学习参数。
请注意
此函数应用SGDM优化算法来更新自定义训练循环中的网络参数,该循环使用定义为的网络dlnetwork
对象或模型函数。如果你想训练一个定义为a的网络层
数组或作为LayerGraph
,使用以下函数:
创建一个
TrainingOptionsSGDM
对象使用trainingOptions
函数。使用
TrainingOptionsSGDM
对象的trainNetwork
函数。
例子
使用更新可学习参数sgdmupdate
执行全局学习率为的单个SGDM更新步骤0.05
和动量0.95
.
将参数和参数梯度创建为数值数组。
Params = rand(3,3,4);Grad = ones(3,3,4);
初始化第一次迭代的参数速度。
Vel = [];
指定全局学习率和动量的自定义值。
learnRate = 0.05;动量= 0.95;
使用更新可学习参数sgdmupdate
.
[params,vel] = sgdmupdate(params,grad,vel,learnRate,动量);
列车网络使用sgdmupdate
使用sgdmupdate
使用SGDM算法训练网络。
负荷训练数据
加载数字训练数据。
[XTrain,TTrain] = digitTrain4DArrayData;类=类别(TTrain);numClasses = nummel(类);
定义网络
属性定义网络体系结构并指定平均图像值的意思是
选项在图像输入层。
图层= [imageInputLayer([28 28 1],“的意思是”reluLayer卷积2dlayer (3,20),“填充”,1) relullayer卷积2dlayer (3,20,“填充”,1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
创建一个dlnetwork
对象。
Net = dlnetwork(layers);
定义模型损失函数
创建helper函数modelLoss
,在示例的末尾列出。函数的参数为dlnetwork
对象和带有相应标签的小批输入数据,并返回损失和损失相对于可学习参数的梯度。
指定培训项目
指定在培训期间使用的选项。
miniBatchSize = 128;numEpochs = 20;numObservations = numel(TTrain);numIterationsPerEpoch = floor(numObservations./miniBatchSize);
列车网络的
初始化velocity参数。
Vel = [];
计算训练进度监控器的总迭代次数。
numIterations = nummepochs * numIterationsPerEpoch;
初始化TrainingProgressMonitor
对象。因为计时器在创建监视器对象时开始,所以请确保创建的对象接近训练循环。
monitor = trainingProgressMonitor(指标=“损失”信息=“时代”包含=“迭代”);
使用自定义训练循环训练模型。对于每个纪元,洗牌数据并在小批量数据上循环。方法更新网络参数sgdmupdate
函数。在每次迭代结束时,显示训练进度。
在GPU上训练(如果有的话)。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱).
迭代= 0;Epoch = 0;而epoch < numEpochs && ~monitor。停止epoch = epoch + 1;% Shuffle数据。idx = randperm(数字(TTrain));XTrain = XTrain(:,:,:,idx);TTrain = TTrain(idx);I = 0;而i < numIterationsPerEpoch && ~monitor。Stop i = i + 1;迭代=迭代+ 1;读取小批数据并将标签转换为虚拟标签%变量。idx = (i-1)*miniBatchSize+1:i*miniBatchSize;X = XTrain(:,:,:,idx);T = 0 (numClasses, miniBatchSize,“单身”);为T(c,TTrain(idx)==classes(c)) = 1;结束将小批数据转换为大数组。X = dlarray(single(X),“SSCB”);如果在GPU上训练,则将数据转换为gpuArray。如果canUseGPU X = gpuArray(X);结束计算模型损失和梯度使用dlfeval和% modelLoss函数。[loss,gradients] = dlfeval(@modelLoss,net,X,T);使用SGDM优化器更新网络参数。[net,vel] = sgdmupdate(net,gradients,vel);更新培训进度监视器。recordMetrics(监控、迭代损失=损失);updateInfo(监视、时代=时代+“的”+ numEpochs);班长。进度= 100 * iteration/numIterations;结束结束
测试网络
通过比较测试集上的预测与真实标签来测试模型的分类准确性。
[XTest,TTest] = digitTest4DArrayData;
将数据转换为adlarray
使用维度格式“SSCB”
(空间,空间,通道,批次)。对于GPU预测,也将数据转换为agpuArray
.
XTest = dlarray(XTest,“SSCB”);如果canUseGPU XTest = gpuArray(XTest);结束
对图像进行分类dlnetwork
对象时,使用预测
计算并找出得分最高的课程。
YTest =预测(net,XTest);[~,idx] = max(extractdata(YTest),[],1);YTest = classes(idx);
评估分类准确率。
精度=平均值(YTest==TTest)
准确度= 0.9916
模型损失函数
的modelLoss
函数的参数为dlnetwork
对象网
和一小批输入数据X
有相应的标签T
,并返回损失以及损失相对于中可学习参数的梯度网
.要自动计算梯度,请使用dlgradient
函数。
函数[loss,gradients] = modelLoss(net,X,T) Y = forward(net,X);损失=交叉熵(Y,T);gradients = dlgradient(loss,net.Learnables);结束
输入参数
网
- - - - - -网络
dlnetwork
对象
网络,指定为dlnetwork
对象。
函数更新可学的
的属性dlnetwork
对象。网可学的
是一个包含三个变量的表:
层
-层名,指定为字符串标量。参数
—参数名称,指定为字符串标量。价值
参数的值,指定为包含dlarray
.
输入参数研究生
一定是和?一样形式的表网可学的
.
参数个数
- - - - - -网络可学习参数
dlarray
|数字数组|单元阵列|结构|表格
网络可学习参数,指定为dlarray
、数字数组、单元格数组、结构体或表。
如果你指定参数个数
作为一个表,它必须包含以下三个变量。
层
-层名,指定为字符串标量。参数
—参数名称,指定为字符串标量。价值
参数的值,指定为包含dlarray
.
你可以指定参数个数
作为使用单元格数组、结构、表或嵌套单元格数组或结构的网络可学习参数的容器。单元格数组、结构或表中的可学习参数必须为dlarray
或数据类型的数值双
或单
.
输入参数研究生
必须提供与?完全相同的数据类型、顺序和字段(用于结构)或变量(用于表)参数个数
.
数据类型:单
|双
|结构体
|表格
|细胞
研究生
- - - - - -损失的梯度
dlarray
|数字数组|单元阵列|结构|表格
损耗的梯度,指定为adlarray
、数字数组、单元格数组、结构体或表。
确切的形式研究生
取决于输入网络或可学习参数。下表显示了所需的格式研究生
可能的输入sgdmupdate
.
输入 | 可学的参数 | 梯度 |
---|---|---|
网 |
表格网可学的 包含层 ,参数 ,价值 变量。的价值 变量由单元格数组组成,单元格数组包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序网可学的 .研究生 必须有一个价值 由包含每个可学习参数梯度的单元格数组组成的变量。 |
参数个数 |
dlarray |
dlarray 使用相同的数据类型和顺序参数个数 |
数字数组 | 具有相同数据类型和顺序的数值数组参数个数 |
|
单元阵列 | 单元格数组,具有相同的数据类型、结构和顺序参数个数 |
|
结构 | 结构,具有相同的数据类型、字段和排序参数个数 |
|
表层 ,参数 ,价值 变量。的价值 变量必须由单元格数组组成,其中包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序参数个数 .研究生 必须有一个价值 由包含每个可学习参数梯度的单元格数组组成的变量。 |
你可以获得研究生
从电话到dlfeval
对包含调用的函数求值dlgradient
.有关更多信息,请参见在深度学习工具箱中使用自动区分.
韦尔
- - - - - -速度参数
[]
|dlarray
|数字数组|单元阵列|结构|表格
参数velocity(指定为空数组)dlarray
、数字数组、单元格数组、结构体或表。
确切的形式韦尔
取决于输入网络或可学习参数。下表显示了所需的格式韦尔
可能的输入sgdmpdate
.
输入 | 可学的参数 | 速度 |
---|---|---|
网 |
表格网可学的 包含层 ,参数 ,价值 变量。的价值 变量由单元格数组组成,单元格数组包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序网可学的 .韦尔 必须有一个价值 由包含每个可学习参数的速度的单元格数组组成的变量。 |
参数个数 |
dlarray |
dlarray 使用相同的数据类型和顺序参数个数 |
数字数组 | 具有相同数据类型和顺序的数值数组参数个数 |
|
单元阵列 | 单元格数组,具有相同的数据类型、结构和顺序参数个数 |
|
结构 | 结构,具有相同的数据类型、字段和排序参数个数 |
|
表层 ,参数 ,价值 变量。的价值 变量必须由单元格数组组成,其中包含每个可学习参数dlarray . |
表具有相同的数据类型、变量和排序参数个数 .韦尔 必须有一个价值 由包含每个可学习参数的速度的单元格数组组成的变量。 |
如果你指定韦尔
作为一个空数组,函数假设没有之前的速度,并以与一系列迭代中的第一次更新相同的方式运行。要迭代地更新可学习参数,请使用韦尔
的前一次调用的输出sgdmupdate
随着韦尔
输入。
learnRate
- - - - - -全球学习率
0.01
(默认)|积极的标量
学习率,指定为正标量。的默认值learnRate
是0.01
.
如果指定网络参数为adlnetwork
对象时,每个参数的学习率是全局学习率乘以网络层中定义的相应学习率因子属性。
动力
- - - - - -动力
0.9
(默认)|之间的正标量0
而且1
动量,指之间的正标量0
而且1
.的默认值动力
是0.9
.
输出参数
netUpdated
-更新网络
dlnetwork
对象
网络,返回为adlnetwork
对象。
函数更新可学的
的属性dlnetwork
对象。
参数个数
—更新网络可学习参数
dlarray
|数字数组|单元数组|结构|表
更新网络可学习参数,返回为dlarray
类型的数字数组、单元格数组、结构体或表价值
变量,包含网络更新后的可学习参数。
韦尔
-更新参数速度
dlarray
|数字数组|单元数组|结构|表
更新的参数velocity,返回为dlarray
、数字数组、单元格数组、结构体或表。
更多关于
带动量的随机梯度下降
该函数采用动量随机梯度下降算法来更新可学习参数。有关更多信息,请参阅下面的随机梯度下降算法的定义随机梯度下降在trainingOptions
参考页面。
扩展功能
GPU数组
通过使用并行计算工具箱™在图形处理单元(GPU)上运行来加速代码。
使用注意事项和限制:
当以下输入参数中至少有一个是
gpuArray
或者一个dlarray
类型的底层数据gpuArray
,该函数运行在GPU上。研究生
参数个数
有关更多信息,请参见在图形处理器上运行MATLAB函数(并行计算工具箱).
版本历史
R2019b引入
MATLAB命令
你点击了一个对应于这个MATLAB命令的链接:
在MATLAB命令窗口中输入该命令来运行该命令。Web浏览器不支持MATLAB命令。金宝app
您也可以从以下列表中选择一个网站:
如何获得最佳的网站性能
选择中国站点(中文或英文)以获得最佳站点性能。其他MathWorks国家站点没有针对您所在位置的访问进行优化。