主要内容

训练卷积神经网络用于回归

这个例子展示了如何使用卷积神经网络拟合回归模型来预测手写数字的旋转角度。

卷积神经网络(cnn或ConvNets)是深度学习的重要工具,特别适合于分析图像数据。例如,您可以使用cnn对图像进行分类。要预测连续数据(如角度和距离),可以在网络的末端包含一个回归层。

该实例构造了一个卷积神经网络结构,训练了一个网络,并使用训练的网络来预测旋转后的手写数字的角度。这些预测对光学字符识别是有用的。

您可以选择使用imrotate(图像处理工具箱™)以旋转图像箱线图(Statistics and Machine Learning Toolbox™)创建残差箱图。

加载数据

数据集包含手写数字的合成图像以及相应的旋转角度(以角度表示)。

将训练和验证图像加载为4-D数组digitTrain4DArrayDatadigitTest4DArrayData.输出YTrainYValidation是旋转角度的度数。每个训练和验证数据集包含5000张图像。

[XTrain ~, YTrain] = digitTrain4DArrayData;[XValidation ~, YValidation] = digitTest4DArrayData;

显示20个随机训练图像使用imshow

numTrainImages =元素个数(YTrain);图idx = randperm(numTrainImages,20);i = 1:元素个数(idx)次要情节(4、5、i) imshow (XTrain (:,:,:, idx(我)))结束

图中包含20个轴对象。axis对象1包含一个image类型的对象。axis对象2包含一个image类型的对象。axis对象3包含一个image类型的对象。axis对象4包含一个image类型的对象。axis对象5包含一个类型为image的对象。axis对象6包含一个image类型的对象。axis对象7包含一个image类型的对象。axis对象8包含一个image类型的对象。axis对象9包含一个image类型的对象。 Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

检查数据归一化

在训练神经网络时,确保你的数据在网络的所有阶段都是标准化的通常是有帮助的。归一化有助于使用梯度下降来稳定和加速网络训练。如果您的数据缩放不当,则可能会造成损失网络参数在训练过程中会出现发散。标准化数据的常用方法包括对数据进行缩放,使其范围变为[0,1],或者使其均值为0,标准差为1。可以对以下数据进行标准化处理:

  • 输入数据。在将预测器输入网络之前将其归一化。在本例中,输入图像已经归一化到范围[0,1]。

  • 层输出。您可以使用批处理归一化层对每个卷积和完全连接层的输出进行归一化。

  • 响应。如果使用批处理归一化层对网络末端的层输出进行归一化,则在训练开始时对网络的预测进行归一化。如果响应的规模与这些预测非常不同,那么网络训练可能无法收敛。如果你的反应规模不大,那么试着将其正常化,看看网络训练是否有所改善。如果在训练前对响应进行归一化,则必须转换训练网络的预测,以获得原始响应的预测。

绘制响应的分布。响应(旋转角度)在-45和45之间近似均匀分布,不需要标准化就能很好地工作。在分类问题中,输出是类的概率,总是被归一化。

图直方图(YTrain)轴ylabel (“计数”)包含(旋转角度的

图中包含一个轴对象。坐标轴对象包含一个直方图类型的对象。

通常,数据不必完全规范化。但是,如果你训练这个例子中的网络去预测100 * YTrainYTrain + 500而不是YTrain,那么损失就变成了训练开始时,网络参数出现发散。这些结果即使发生的唯一不同的网络预测aY + b还有一个预测网络Y是最终完全连接层的权重和偏差的简单缩放。

如果输入或响应的分布非常不均匀或歪斜,您还可以在训练网络之前对数据进行非线性转换(例如,取对数)。

创建网络层

为了解决回归问题,创建网络的层,并在网络的末端包括一个回归层。

第一层定义输入数据的大小和类型。输入的图像是28 × 28 × 1。创建与训练图像大小相同的图像输入层。

网络的中间层定义了网络的核心架构,大部分的计算和学习都发生在中间层。

最后一层定义输出数据的大小和类型。对于回归问题,一个完全连接的层必须先于网络末端的回归层。创建一个大小为1的全连接输出层和一个回归层。

将所有的层合并在一起数组中。

layers = [imageInputLayer([28 28 1])]“填充”“相同”reluLayer averageepooling2dlayer (2,“步”2) convolution2dLayer(16日“填充”“相同”reluLayer averageepooling2dlayer (2,“步”32岁的,2)convolution2dLayer (3“填充”“相同”) batchNormalizationLayer reluLayer卷积2dlayer (3,32,“填充”“相同”) batchNormalizationLayer relullayer dropoutLayer(0.2) fulllyconnectedlayer (1) regressionLayer];

列车网络的

创建网络培训选项。训练30个时代。初始学习率设置为0.001,20个周期后降低学习率。在培训期间,通过指定验证数据和验证频率来监控网络的准确性。该软件在训练数据上对网络进行训练,并在训练期间定期计算验证数据的准确性。验证数据不用于更新网络权重。打开训练进度图,并关闭命令窗口输出。

miniBatchSize = 128;validationFrequency =地板(元素个数(YTrain) / miniBatchSize);选择= trainingOptions (“个”...“MiniBatchSize”miniBatchSize,...“MaxEpochs”30岁的...“InitialLearnRate”1 e - 3,...“LearnRateSchedule”“分段”...“LearnRateDropFactor”, 0.1,...“LearnRateDropPeriod”, 20岁,...“洗牌”“every-epoch”...“ValidationData”{XValidation, YValidation},...“ValidationFrequency”validationFrequency,...“阴谋”“训练进步”...“详细”、假);

使用以下命令创建网络trainNetwork.该命令使用兼容的GPU(如果可用)。使用GPU需要并行计算工具箱™和支持的GPU设备。金宝app有关支持的设备的信息,请参见金宝appGPU支金宝app持情况(并行计算工具箱).否则,trainNetwork使用CPU。

网= trainNetwork (XTrain、YTrain层,选择);

Figure Training Progress (01-Sep-2021 08:26:26)包含2个轴对象和另一个类型为uigridlayout的对象。axis对象1包含10个类型为patch, text, line的对象。axis对象2包含10个类型为patch, text, line的对象。

中包含的网络架构的详细信息的属性

网。层
ans = 18x1图层数组:1“imageinput”28 x28x1图像输入图像与“zerocenter”正常化2 conv_1的卷积8 3 x3x1旋转步[1]和填充“相同”3“batchnorm_1”批量标准化批量标准化8通道4的relu_1 ReLU ReLU 5“avgpool2d_1”平均池2 x2平均池步(2 - 2)和填充[0 0 0 0]6“conv_2”Convolution 16 3x3x8 convolutions with stride [1 1] and padding 'same' 7 'batchnorm_2' Batch Normalization Batch normalization with 16 channels 8 'relu_2' ReLU ReLU 9 'avgpool2d_2' Average Pooling 2x2 average pooling with stride [2 2] and padding [0 0 0 0] 10 'conv_3' Convolution 32 3x3x16 convolutions with stride [1 1] and padding 'same' 11 'batchnorm_3' Batch Normalization Batch normalization with 32 channels 12 'relu_3' ReLU ReLU 13 'conv_4' Convolution 32 3x3x32 convolutions with stride [1 1] and padding 'same' 14 'batchnorm_4' Batch Normalization Batch normalization with 32 channels 15 'relu_4' ReLU ReLU 16 'dropout' Dropout 20% dropout 17 'fc' Fully Connected 1 fully connected layer 18 'regressionoutput' Regression Output mean-squared-error with response 'Response'

测试网络

通过评估验证数据的准确性来测试网络的性能。

使用预测来预测验证图像的旋转角度。

YPredicted =预测(净,XValidation);

评估性能

通过计算来评估模型的性能:

  1. 在可接受的误差范围内预测的百分比

  2. 预测和实际旋转角度的均方根误差(RMSE)

计算预测与实际转角之间的预测误差。

predictionError = YValidation - YPredicted;

从真实的角度计算在可接受的误差范围内的预测数量。设置阈值为10度。计算在这个阈值内的预测的百分比。

用力推= 10;numCorrect = sum(abs(predictionError) < thr);numValidationImages =元素个数(YValidation);= numCorrect / numValidationImages准确性
精度= 0.9716

使用均方根误差(RMSE)来测量预测的和实际的旋转角度之间的差异。

广场= predictionError。^ 2;rmse =√意味着(广场))
rmse =4.5505

可视化预测

在散点图中可视化预测。将预测值与真实值作对比。

图散射(YPredicted YValidation,“+”)包含(“预测价值”) ylabel (“真正的价值”)举行Plot ([-60 60], [-60 60],“r——”

图中包含一个轴对象。坐标轴对象包含两个散点和直线类型的对象。

正确的数字旋转

您可以使用图像处理工具箱中的函数来调整数字并将它们一起显示。旋转49个样本数字根据他们的预测旋转角度使用imrotate(图像处理工具箱)。

idx = randperm (numValidationImages, 49);i = 1:元素个数(idx)图像= XValidation (:,:,:, idx(我));predictedAngle = YPredicted (idx (i));imagesRotated (::,:, i) = imrotate(形象,predictedAngle,“双三次的”“作物”);结束

显示原始数字及其校正旋转。您可以使用蒙太奇(图像处理工具箱)以在单一图像中同时显示数字。

figure subplot(1,2,1) montage(XValidation(:,:,:,idx)) title(“原始”) subplot(1,2,2)蒙太奇(imagesrotate) title(“纠正”

图中包含2个轴对象。标题为Original的轴对象1包含一个类型为image的对象。标题为Corrected的轴对象2包含一个类型为image的对象。

另请参阅

|

相关的话题