主要内容

使用模型函数进行预测

这个例子展示了如何通过将数据分割成小批量来使用模型函数进行预测。

对于大型数据集,或者在内存有限的硬件上进行预测时,可以通过将数据分割为小批量来进行预测。当用SeriesNetworkDAGNetwork对象时,预测函数自动将输入数据分割为小批量。对于模型函数,必须手动将数据分割为小批量。

创建模型函数和负载参数

从MAT文件加载模型参数digitsMIMO.mat.MAT文件将模型参数包含在名为参数的结构中的模型状态状态的类名一会

S =负载(“digitsMIMO.mat”);参数= s.parameters;State = s.state;classNames = s.classNames;

模型函数模型,在示例的末尾列出,它定义了给定模型参数和状态的模型。

预测负荷数据

加载数字数据进行预测。

digitDatasetPath = fullfile(matlabroot,“工具箱”“nnet”“nndemos”...“nndatasets”“DigitDataset”);imds = imageDatastore(digitDatasetPath,...“IncludeSubfolders”,真的,...“LabelSource”“foldernames”);numObservations = numel(imds.Files);

作出预测

遍历测试数据的小批,并使用自定义预测循环进行预测。

使用minibatchqueue处理和管理小批量的图像。指定迷你批处理大小为128。将映像数据存储的read size属性设置为迷你批处理大小。

对于每个小批量:

  • 使用自定义小批量预处理功能preprocessMiniBatch(在本例结束时定义)将数据连接到一个批处理中并规范化图像。

  • 用尺寸格式化图像“SSCB”(空间,空间,通道,批次)。默认情况下,minibatchqueue对象将数据转换为dlarray具有基础类型的对象

  • 在可用的GPU上进行预测。默认情况下,minibatchqueue对象将输出转换为gpuArray如果GPU可用。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱)

miniBatchSize = 128;洛桑国际管理发展学院。ReadSize = miniBatchSize;MBQ = minibatchqueue(imds,...“MiniBatchSize”miniBatchSize,...“MiniBatchFcn”@preprocessMiniBatch,...“MiniBatchFormat”“SSCB”);

对小批数据进行循环,并使用预测函数。使用onehotdecode函数确定类标签。存储预测的类标签。

doTraining = false;y1forecasts = [];y2forecasts = [];在小批上循环。hasdata(兆贝可)读取小批数据。dlX = next(mbq);使用预测函数进行预测。[dlY1Pred,dlY2Pred] = model(parameters,dlX,doTraining,state);确定相应的类。Y1PredBatch = onehotdecode(dlY1Pred,classNames,1);y1forecasts = [y1forecasts Y1PredBatch];Y2PredBatch = extractdata(dlY2Pred);y2forecasts = [y2forecasts Y2PredBatch];结束

查看一些带有预测的图片。

idx = randperm(numObservations,9);数字i = 1:9 subplot(3,3,i) i = imread(imds.Files{idx(i)});imshow (I)sz = size(I,1);Offset = sz/2;thetaPred = Y2Predictions(idx(i));plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],“r——”)举行label = string(y1forecasts (idx(i)));标题(”的标签:“+标签)结束

模型函数

这个函数模型取模型参数参数,输入数据dlX,旗帜doTraining它指定了模型是否应该返回用于训练或预测的输出,以及网络状态状态.网络输出标签的预测、角度的预测和更新的网络状态。

函数[dlY1,dlY2,state] = model(parameters,dlX,doTraining,state)%卷积weights = parameters.conv1.Weights;bias = parameters.conv1.Bias;dlY = dlconv(dlX,权重,偏差,“填充”“相同”);批处理归一化,ReLUoffset = parameters.batchnorm1.Offset;scale = parameters.batchnorm1.Scale;trainedMean = state.batchnorm1.TrainedMean;trainedVariance = state.batchnorm1.TrainedVariance;如果doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm1。受过训练的人;state.batchnorm1。trained方差= trained方差;其他的dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);结束dlY = relu(dlY);%卷积,批量归一化(跳过连接)weights = parameters.convSkip.Weights;bias = parameters.convSkip.Bias;dlYSkip = dlconv(dlY,权重,偏差,“步”2);offset = parameters.batchnormSkip.Offset;scale = parameters.batchnormSkip.Scale;trainedMean = state.batchnormSkip.TrainedMean;trainedVariance = state.batchnormSkip.TrainedVariance;如果doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnormSkip.TrainedMean = trainedMean;state.batchnormSkip.TrainedVariance = trainedVariance;其他的dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance);结束%卷积weights = parameters.conv2.Weights;bias = parameters.conv2.Bias;dlY = dlconv(dlY,权重,偏差,“填充”“相同”“步”2);批处理归一化,ReLUoffset = parameters.batchnorm2.Offset;scale = parameters.batchnorm2.Scale;trainedMean = state.batchnorm2.TrainedMean;trainedVariance = state.batchnorm2.TrainedVariance;如果doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm2。受过训练的人;state.batchnorm2。trained方差= trained方差;其他的dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);结束dlY = relu(dlY);%卷积weights = parameters.conv3.Weights;bias = parameters.conv3.Bias;dlY = dlconv(dlY,权重,偏差,“填充”“相同”);批归一化offset = parameters.batchnorm3.Offset;scale = parameters.batchnorm3.Scale;trainedMean = state.batchnorm3.TrainedMean;trainedVariance = state.batchnorm3.TrainedVariance;如果doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);%更新状态state.batchnorm3。受过训练的人;state.batchnorm3。trained方差= trained方差;其他的dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance);结束%加法,ReLUdlY = dlYSkip + dlY;dlY = relu(dlY);%完全连接,softmax(标签)weights = parameters.fc1.Weights;bias = parameters.fc1.Bias;dlY1 =完全连接(dlY,权重,偏差);dlY1 = softmax(dlY1);%完全连接(角度)weights = parameters.fc2.Weights;bias = parameters.fc2.Bias;dlY2 =完全连接(dlY,权重,偏差);结束

小批量预处理功能

preprocessMiniBatch函数按照以下步骤对数据进行预处理:

  1. 从传入单元格数组中提取数据并连接到数值数组中。在第四个维度上的连接为每个图像添加了第三个维度,用作单通道维度。

  2. 规范化之间的像素值0而且1

函数X = preprocessMiniBatch(数据)从单元格和拼接中提取图像数据X = cat(4,data{:});将图像规范化。X = X/255;结束

另请参阅

||||||||||

相关的话题