将分类网络转换为回归网络
这个例子展示了如何将一个训练好的分类网络转换成一个回归网络。
预先训练的图像分类网络已经在超过100万张图像上进行了训练,可以将图像分为1000个对象类别,例如键盘、咖啡杯、铅笔和许多动物。该网络已经学习了广泛图像的丰富特征表示。该网络将图像作为输入,然后输出图像中对象的标签以及每个对象类别的概率。
迁移学习是深度学习应用中常用的一种方法。你可以使用预先训练好的网络,并将其作为学习新任务的起点。这个例子展示了如何使用预训练的分类网络,并为回归任务重新训练它。
该示例加载了一个预先训练好的卷积神经网络架构进行分类,替换了分类的层,并重新训练网络来预测旋转手写数字的角度。您可以选择使用imrotate
(图像处理工具箱™)校正图像旋转使用的预测值。
负荷预训练网络
从支持文件中加载预训练的网络金宝appdigitsNet.mat
.该文件包含一个分类网络,用于对手写数字进行分类。
负载digitsNet层=净。层
图层数组= 15x14 'relu_1' ReLU ReLU 5 'maxpool_1' 2- d Max Pooling 2x2 Max Pooling with stride [2 2] and padding [0 0 0 0 0] 6 'conv_2' 2- d Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch normalization Batch normalization with 16 channel 8'relu_2' ReLU ReLU 9 'maxpool_2'二维最大池化2x2最大池化与stride[2 2]和填充[0 0 0 0 0]10 'conv_3'二维卷积32 3x3x16卷积与stride[1 1]和填充'相同' 11 'batchnorm_3'批归一化批归一化32通道12 'relu_3' ReLU ReLU 13 'fc'全连接10全连接层14 'softmax' softmax softmax 15 'classoutput'分类输出crossentropyex与'0'和9个其他类
加载数据
数据集包含手写数字的合成图像以及每个图像旋转的相应角度(以度为单位)。
将训练图像和验证图像加载为4-D数组digitTrain4DArrayData
而且digitTest4DArrayData
.输出YTrain
而且YValidation
是旋转角度,单位是度。训练和验证数据集各包含5000张图像。
[XTrain,~,YTrain] = digitTrain4DArrayData;[XValidation,~,YValidation] = digitTest4DArrayData;
显示20个随机训练图像使用imshow
.
numTrainImages =数字(YTrain);figure idx = randperm(numTrainImages,20);为i = 1:元素个数(idx)次要情节(4、5、i) imshow (XTrain (:,:,:, idx(我)))结束
更换最终图层
网络的卷积层提取图像特征,最后的可学习层和最终的分类层使用这些特征对输入图像进行分类。这两层,“俱乐部”
而且“classoutput”
在digitsNet
,包含了如何将网络提取的特征组合成类概率、损失值和预测标签的信息。为了重新训练一个预先训练好的网络进行回归,用适应任务的新层替换这两个层。
将最后的全连接层、softmax层和分类输出层替换为大小为1(响应数量)的全连接层和回归层。
numResponses = 1;layers = [layers(1:12) fullyConnectedLayer(numResponses) regressionLayer];
冻结初始层
网络现在已经准备好接受新数据的重新训练。可以选择,通过将这些层的学习率设置为零,可以“冻结”网络中较早层的权重。在培训期间,trainNetwork
不更新冻结层的参数。由于冻结层的梯度不需要计算,冻结许多初始层的权值可以显著加快网络训练。如果新的数据集很小,那么冻结早期的网络层也可以防止这些层过度拟合到新的数据集。
使用支持函数金宝appfreezeWeights
在前12层中设置学习率为零。
layers(1:12) = freezeWeights(layers(1:12));
列车网络的
创建网络培训选项。将初始学习速率设置为0.001。在培训过程中通过指定验证数据监控网络的准确性。打开训练进度图,并关闭命令窗口输出。
选项= trainingOptions(“个”,...“InitialLearnRate”, 0.001,...“ValidationData”{XValidation, YValidation},...“阴谋”,“训练进步”,...“详细”、假);
使用以下命令创建网络trainNetwork
.如果可用,该命令使用兼容的图形处理器。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱).否则,trainNetwork
占用CPU。
net = trainNetwork(XTrain,YTrain,图层,选项);
测试网络
通过评估验证数据的准确性来测试网络的性能。
使用预测
预测验证图像的旋转角度。
YPred = predict(net,XValidation);
通过计算来评估模型的性能:
在可接受的误差范围内预测的百分比
预测和实际旋转角度的均方根误差(RMSE)
计算预测角度与实际角度之间的预测误差。
predictionError = YValidation - YPred;
从真实角度计算在可接受的误差范围内的预测数。设置阈值为10度。计算该阈值内预测的百分比。
THR = 10;numCorrect = sum(abs(predictionError) < thr);numImagesValidation = numel(YValidation);accuracy = numCorrect/numImagesValidation
准确度= 0.7532
使用均方根误差(RMSE)来测量预测和实际旋转角度之间的差异。
rmse =√(mean(predictionError.^2))
rmse =单9.0270
正确的数字旋转
可以使用“图像处理工具箱”中的函数将数字拉直并一起显示。旋转49个样本数字根据他们的预测旋转角度使用imrotate
(图像处理工具箱)。
idx = randperm(numImagesValidation,49);为i = 1:numel(idx) i = XValidation(:,:,:,idx(i));Y = YPred(idx(i));XValidationCorrected(:,:,:,i) = imrotate(i,Y,“双三次的”,“作物”);结束
显示原始数字及其正确的旋转。使用蒙太奇
(图像处理工具箱)显示数字一起在一个单一的图像。
图subplot(1,2,1)蒙太奇(XValidation(:,:,:,idx))标题(“原始”) subplot(1,2,2)蒙太奇(XValidationCorrected)“纠正”)
另请参阅
regressionLayer
|classificationLayer