主要内容

将分类网络转换为回归网络

这个例子展示了如何将一个训练好的分类网络转换成一个回归网络。

预先训练的图像分类网络已经在超过100万张图像上进行了训练,可以将图像分为1000个对象类别,例如键盘、咖啡杯、铅笔和许多动物。该网络已经学习了广泛图像的丰富特征表示。该网络将图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。

迁移学习是深度学习应用中常用的一种方法。你可以使用预先训练好的网络,并将其作为学习新任务的起点。这个例子展示了如何使用预训练的分类网络,并为回归任务重新训练它。

该示例加载了一个预先训练好的卷积神经网络架构进行分类,替换了分类的层,并重新训练网络来预测旋转手写数字的角度。您可以选择使用imrotate(图像处理工具箱™)校正图像旋转使用的预测值。

负荷预训练网络

从支持文件中加载预训练的网络金宝appdigitsNet.mat.该文件包含一个分类网络,用于对手写数字进行分类。

负载digitsNet层=净。层
图层数组= 15x14 'relu_1' ReLU ReLU 5 'maxpool_1' 2- d Max Pooling 2x2 Max Pooling with stride [2 2] and padding [0 0 0 0 0] 6 'conv_2' 2- d Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch normalization Batch normalization with 16 channel 8'relu_2' ReLU ReLU 9 'maxpool_2'二维最大池化2x2最大池化与stride[2 2]和填充[0 0 0 0 0]10 'conv_3'二维卷积32 3x3x16卷积与stride[1 1]和填充'相同' 11 'batchnorm_3'批归一化批归一化32通道12 'relu_3' ReLU ReLU 13 'fc'全连接10全连接层14 'softmax' softmax softmax 15 'classoutput'分类输出crossentropyex与'0'和9个其他类

加载数据

数据集包含手写数字的合成图像以及每个图像旋转的相应角度(以度为单位)。

将训练图像和验证图像加载为4-D数组digitTrain4DArrayData而且digitTest4DArrayData.输出YTrain而且YValidation是旋转角度,单位是度。训练和验证数据集各包含5000张图像。

[XTrain,~,YTrain] = digitTrain4DArrayData;[XValidation,~,YValidation] = digitTest4DArrayData;

显示20个随机训练图像使用imshow

numTrainImages =数字(YTrain);figure idx = randperm(numTrainImages,20);i = 1:元素个数(idx)次要情节(4、5、i) imshow (XTrain (:,:,:, idx(我)))结束

图中包含20个轴对象。坐标轴对象1包含一个image类型的对象。坐标轴对象2包含一个image类型的对象。坐标轴对象3包含一个image类型的对象。Axes对象4包含一个image类型的对象。Axes对象5包含一个image类型的对象。Axes对象6包含一个image类型的对象。Axes对象7包含一个image类型的对象。Axes对象8包含一个image类型的对象。Axes对象9包含一个image类型的对象。 Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

更换最终图层

网络的卷积层提取图像特征,最后的可学习层和最终的分类层使用这些特征对输入图像进行分类。这两层,“俱乐部”而且“classoutput”digitsNet,包含了如何将网络提取的特征组合成类概率、损失值和预测标签的信息。为了重新训练一个预先训练好的网络进行回归,用适应任务的新层替换这两个层。

将最后的全连接层、softmax层和分类输出层替换为大小为1(响应数量)的全连接层和回归层。

numResponses = 1;layers = [layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];

冻结初始层

网络现在已经准备好接受新数据的重新训练。可以选择,通过将这些层的学习率设置为零,可以“冻结”网络中较早层的权重。在培训期间,trainNetwork不更新冻结层的参数。由于冻结层的梯度不需要计算,冻结许多初始层的权值可以显著加快网络训练。如果新的数据集很小,那么冻结早期的网络层也可以防止这些层过度拟合到新的数据集。

使用支持函数金宝appfreezeWeights在前12层中设置学习率为零。

layers(1:12) = freezeWeights(layers(1:12));

列车网络的

创建网络培训选项。将初始学习速率设置为0.001。在培训过程中通过指定验证数据监控网络的准确性。打开训练进度图,并关闭命令窗口输出。

选项= trainingOptions(“个”...“InitialLearnRate”, 0.001,...“ValidationData”{XValidation, YValidation},...“阴谋”“训练进步”...“详细”、假);

使用以下命令创建网络trainNetwork.如果可用,该命令使用兼容的图形处理器。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱).否则,trainNetwork占用CPU。

net = trainNetwork(XTrain,YTrain,图层,选项);

{

测试网络

通过评估验证数据的准确性来测试网络的性能。

使用预测预测验证图像的旋转角度。

YPred = predict(net,XValidation);

通过计算来评估模型的性能:

  1. 在可接受的误差范围内预测的百分比

  2. 预测和实际旋转角度的均方根误差(RMSE)

计算预测角度与实际角度之间的预测误差。

predictionError = YValidation - YPred;

从真实角度计算在可接受的误差范围内的预测数。设置阈值为10度。计算该阈值内预测的百分比。

THR = 10;numCorrect = sum(abs(predictionError) < thr);numImagesValidation = numel(YValidation);accuracy = numCorrect/numImagesValidation
准确度= 0.7532

使用均方根误差(RMSE)来测量预测和实际旋转角度之间的差异。

rmse =√(mean(predictionError.^2))
rmse =9.0270

正确的数字旋转

可以使用“图像处理工具箱”中的函数将数字拉直并一起显示。旋转49个样本数字根据他们的预测旋转角度使用imrotate(图像处理工具箱)。

idx = randperm(numImagesValidation,49);i = 1:numel(idx) i = XValidation(:,:,:,idx(i));Y = YPred(idx(i));XValidationCorrected(:,:,:,i) = imrotate(i,Y,“双三次的”“作物”);结束

显示原始数字及其正确的旋转。使用蒙太奇(图像处理工具箱)显示数字一起在一个单一的图像。

图subplot(1,2,1)蒙太奇(XValidation(:,:,:,idx))标题(“原始”) subplot(1,2,2)蒙太奇(XValidationCorrected)“纠正”

图中包含2个轴对象。标题为Original的坐标轴对象1包含一个image类型的对象。标题为Corrected的坐标轴对象2包含一个image类型的对象。

另请参阅

|

相关的话题