主要内容

训练多输出网络

这个例子展示了如何训练一个具有多个输出的深度学习网络,这些输出预测手写数字的标签和旋转角度。

要训练具有多个输出的网络,必须使用自定义训练循环来训练网络。

负荷训练数据

digitTrain4DArrayData函数加载图像、它们的数字标签以及它们从垂直方向旋转的角度。创建一个arrayDatastore对象获取图像、标签和角度,然后使用结合函数创建一个包含所有训练数据的单个数据存储。提取类名和非离散响应的数量。

[XTrain,T1Train,T2Train] = digitTrain4DArrayData;dsXTrain = arrayDatastore(XTrain,IterationDimension=4);dsT1Train = arrayDatastore(T1Train);dsT2Train = arrayDatastore(T2Train);dsTrain = combine(dsXTrain,dsT1Train,dsT2Train);classNames =类别(T1Train);numClasses = numel(classNames);numObservations = numel(T1Train);

查看训练数据中的一些图像。

idx = randperm(numObservations,64);I = imtile(XTrain(:,:,:,idx));图imshow(我)

图中包含一个轴对象。axis对象包含一个image类型的对象。

定义深度学习模型

定义以下预测标签和旋转角度的网络。

  • 带有16个5 × 5滤波器的卷积-batchnorm- relu块。

  • 两个卷积-batchnorm- relu块,每个块有32个3 × 3滤波器。

  • 围绕前两个块的跳过连接,其中包含32个1 × 1卷积的卷积-batchnorm- relu块。

  • 使用加法合并跳过连接。

  • 对于分类输出,一个具有大小为10(类的数量)的全连接操作和一个softmax操作的分支。

  • 对于回归输出,具有大小为1(响应的数量)的完全连接操作的分支。

将层的主块定义为层图。

layers = [imageInputLayer([28 28 1],归一化=“没有”) convolution2dLayer(填充= 5,16日“相同”batchNormalizationLayer reluLayer(Name= .“relu_1”32岁的)convolution2dLayer(3填充=“相同”batchNormalizationLayer reluLayer convolution2dLayer(3,32,Padding=“相同”batchNormalizationLayer reluLayer additionLayer(2,Name=“添加”) fullyConnectedLayer(numClasses) softmaxLayer(Name=“softmax”));lgraph = layerGraph(图层);

添加跳转连接。

layers = [convolution2dLayer(1,32,Stride=2,Name= .“conv_skip”batchNormalizationLayer reluLayer(Name= .“relu_skip”));lgraph = addLayers(lgraph,layers);lgraph = connectLayers(“relu_1”“conv_skip”);lgraph = connectLayers(“relu_skip”“添加/ in2”);

为回归添加全连接层。

layers = fullyConnectedLayer(1,Name=“fc_2”);lgraph = addLayers(lgraph,layers);lgraph = connectLayers(“添加”“fc_2”);

在图中查看层图。

图绘制(lgraph)

图中包含一个轴对象。axis对象包含一个graphplot类型的对象。

创建一个dlnetwork对象从图层图。

Net = dlnetwork(lgraph)
net = dlnetwork with properties: Layers: [17×1 nnet.cnn.layer.Layer] Connections: [17×2 table] Learnables: [20×3 table] State: [8×3 table] InputNames: {'imageinput'} OutputNames: {'softmax' 'fc_2'} Initialized: 1使用summary查看summary。

定义模型损失函数

创建函数modelLoss,在示例末尾列出,它将dlnetwork对象,一个小批量的输入数据,其对应的目标包含标签和角度,并返回损失、损失相对于可学习参数的梯度以及更新的网络状态。

指定培训项目

指定培训选项。使用128个小批量训练30个epoch。

numEpochs = 30;miniBatchSize = 128;

火车模型

使用minibatchqueue处理和管理小批量的图像。对于每个小批量:

  • 使用自定义小批量预处理功能preprocessMiniBatch(在本例末尾定义)来对类标签进行一次性编码。

  • 用尺寸标签格式化图像数据“SSCB”(空间,空间,通道,批次)。默认情况下,minibatchqueue对象将数据转换为dlarray具有基础类型的对象.不要向类标签或角度添加格式。

  • 如果有GPU,可以在GPU上进行训练。默认情况下,minibatchqueue对象将每个输出转换为gpuArray如果GPU可用。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱)

mbq = minibatchqueue(dsTrain,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessData,...MiniBatchFormat = [“SSCB”""""]);

使用自定义训练循环训练模型。对于每个纪元,洗牌数据并在小批量数据上循环。在每次迭代结束时,显示训练进度。对于每个小批量:

  • 评估模型损失和梯度使用dlfevalmodelLoss函数。

  • 方法更新网络参数adamupdate函数。

初始化Adam的参数。

trailingAvg = [];trailingAvgSq = [];

计算训练进度监控器的总迭代次数

numIterationsPerEpoch = ceil(numObservations / miniBatchSize);numIterations = nummepochs * numIterationsPerEpoch;

初始化TrainingProgressMonitor对象。因为计时器在创建监视器对象时开始,所以请确保创建的对象接近训练循环。

monitor = trainingProgressMonitor(...指标=“损失”...信息=“时代”...包含=“迭代”);

训练模型。

Epoch = 0;迭代= 0;epoch < numEpochs && ~monitor。停止epoch = epoch + 1;% Shuffle数据。洗牌(兆贝可)在小批上循环。Hasdata (mbq) && ~monitor。停止迭代=迭代+ 1;[X,T1,T2] = next(mbq);评估模型损耗、梯度和状态使用% dlfeval和modelLoss函数。[loss,gradients,state] = dlfeval(@modelLoss,net,X,T1,T2);网状态=状态;使用Adam优化器更新网络参数。。[net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients,...trailingAvg trailingAvgSq,迭代);更新培训进度监视器。recordMetrics(监控、迭代损失=损失);updateInfo(监视、时代=时代+“的”+ numEpochs);班长。进度= 100*iteration/numIterations;结束结束

测试模型

通过将测试集上的预测结果与真实标签和角度进行比较,测试模型的分类精度。方法管理测试数据集minibatchqueue对象使用与训练数据相同的设置。

[XTest,T1Test,T2Test] = digitTest4DArrayData;dsXTest = arrayDatastore(XTest,IterationDimension=4);dsT1Test = arrayDatastore(T1Test);dsT2Test = arrayDatastore(T2Test);dsTest = combine(dsXTest,dsT1Test,dsT2Test);mbqTest = minibatchqueue(dsTest...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessData,...MiniBatchFormat = [“SSCB”""""]);

要预测验证数据的标签和角度,请遍历小批并使用预测函数。存储预测的类和角度。比较预测和真实的类和角度,并存储结果。

classesforecasts = [];anglesforecasts = [];classCorr = [];angleDiff = [];在小批上循环。hasdata (mbqTest)读取小批数据。[X,T1,T2] = next(mbqTest);使用预测函数进行预测。[Y1,Y2] = predict(net,X,Outputs=[“softmax”“fc_2”]);确定预测的类。Y1 = onehotdecode(Y1,classNames,1);classesforecasts = [classesforecasts Y1];% Dermine预测角度Y2 = extractdata(Y2);anglesforecasts = [anglesforecasts Y2];比较预测的和真实的类。T1 = onehotdecode(T1,classNames,1);classCorr = [classCorr Y1 == T1];比较预测角度和真实角度。angleDiffBatch = Y2 - T2;angleDiffBatch = extractdata(gather(angleDiffBatch));angleDiff = [angleDiff angleDiffBatch];结束

评估分类准确率。

精确度=平均值(classCorr)
准确度= 0.9882

评估回归精度。

angleRMSE =√(mean(angleDiff.^2))
angleRMSE =6.3569

查看一些带有预测的图片。红色显示预测角度,绿色显示正确标签。

idx = randperm(size(XTest,4),9);数字i = 1:9 subplot(3,3,i) i = XTest(:,:,:,idx(i));imshow (I)sz = size(I,1);Offset = sz/2;thetaPred =角预测(idx(i));plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],“r——”) thetaValidation = T2Test(idx(i));plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],,“g——”)举行label = string(classesforecasts (idx(i)));标题(”的标签:“+标签)结束

图中包含9个轴对象。轴对象1的标题标签:8包含3个对象类型的图像,线。轴对象2带有标题标签:2包含3个类型为image, line的对象。轴对象3带有标题标签:7包含3个类型为image, line的对象。轴对象4的标题标签:1包含3个对象类型的图像,线。轴对象5带有标题标签:3包含3个类型为image, line的对象。轴对象6的标题标签:0包含3个类型为image, line的对象。轴对象7与标题标签:6包含3个对象类型的图像,线。轴对象8与标题标签:2包含3个对象类型的图像,线。轴对象9与标题标签:2包含3个对象类型的图像,线。

模型损失函数

modelLoss函数作为输入dlnetwork对象,一小批输入数据X有相应的目标T1而且T2分别包含标签和角度,并返回损失、损失相对于可学习参数的梯度以及更新的网络状态。

函数[loss,gradients,state] = modelLoss(net,X,T1,T2) [Y1,Y2,state] = forward(net,X,Outputs=[“softmax”“fc_2”]);lossLabels = crossentropy(Y1,T1);lossAngles = mse(Y2,T2);loss = lossLabels + 0.1*lossAngles;gradients = dlgradient(loss,net.Learnables);结束

小批量预处理功能

preprocessMiniBatch函数按照以下步骤对数据进行预处理:

  1. 从传入单元格数组中提取图像数据并连接到数值数组中。将图像数据连接到第四个维度将为每个图像添加第三个维度,用作单通道维度。

  2. 从传入单元格数组中提取标签和角度数据,并沿着第二维分别连接到分类数组和数值数组。

  3. One-hot将分类标签编码为数字数组。编码到第一个维度会产生一个与网络输出形状匹配的编码数组。

函数[X,T1,T2] = preprocessData(dataX,dataT1,dataT2)从单元格和拼接中提取图像数据X = cat(4,dataX{:});从单元格和级联中提取标签数据T1 = cat(2,dataT1{:});从单元格和拼接中提取角度数据T2 = cat(2,dataT2{:});单热编码标签T1 = onehotencode(T1,1);结束

另请参阅

|||||||||||

相关的话题