Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム学習ループでのバッチ正規化統計量の更新

この例では、カスタム学習ループでネットワークの状態を更新する方法を示します。

バッチ正規化層は、ミニバッチ全体で各入力チャネルを正規化します。畳み込みニューラル ネットワークの学習速度を上げ、ネットワークの初期化に対する感度を下げるには、畳み込み層の間にあるバッチ正規化層と、ReLU 層などの非線形性を使用します。

学習中、バッチ正規化層は、まず、ミニバッチの平均を減算し、ミニバッチの標準偏差で除算することにより、各チャネルの活性化を正規化します。その後、この層は、学習可能なオフセットβだけ入力をシフトし、それを学習可能なスケール係数γだけスケーリングします。

ネットワークの学習が終了したら、バッチ正規化層は学習セット全体の平均と分散を計算し、その値をTrainedMeanプロパティおよびTrainedVarianceプロパティに格納します。学習済みネットワークを使用して新しいイメージについて予測を実行する場合、バッチ正規化層はミニバッチの平均と分散ではなく、学習済みの平均と分散を使用して活性化を正規化します。

データセットの統計量を計算するために、バッチ正規化層は継続的に更新される状態を使用してミニバッチの統計量を追跡します。カスタム学習ループを実装している場合、ミニバッチ間でネットワークの状態を更新しなければなりません。

学習データの読み込み

関数digitTrain4DArrayDataは、手書き数字のイメージとその数字ラベルを読み込みます。イメージと角度についてarrayDatastoreオブジェクトを作成してから、関数combineを使用してすべての学習データを含む単一のデータストアを作成します。クラス名を抽出します。

[XTrain,YTrain] = digitTrain4DArrayData; dsXTrain = arrayDatastore(XTrain,'IterationDimension',4); dsYTrain = arrayDatastore(YTrain); dsTrain = combine(dsXTrain,dsYTrain); classNames = categories(YTrain); numClasses = numel(classNames);

ネットワークの定義

ネットワークを定義し、イメージ入力層で'Mean'オプションを使用して平均イメージを指定します。

layers = [ imageInputLayer([28 28 1],'Name','input','Mean', mean(XTrain,4)) convolution2dLayer(5, 20,'Name','conv1') batchNormalizationLayer('Name','bn1') reluLayer('Name','relu1') convolution2dLayer(3, 20,“爸爸dding', 1,'Name','conv2') batchNormalizationLayer('Name','bn2') reluLayer('Name','relu2') convolution2dLayer(3, 20,“爸爸dding', 1,'Name','conv3') batchNormalizationLayer('Name','bn3') reluLayer('Name','relu3') fullyConnectedLayer(numClasses,'Name','fc') softmaxLayer('Name','softmax')]; lgraph = layerGraph(layers);

層グラフからdlnetworkオブジェクトを作成します。

dlnet = dlnetwork(lgraph)
dlnet = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'input'} OutputNames: {'softmax'}

ネットワークの状態を表示します。各バッチ正規化層は、データセットの平均と分散をぞれぞれ含む、TrainedMeanパラメーターとTrainedVarianceパラメーターをもちます。

dlnet.State
ans=6×3 tableLayer Parameter Value _____ _________________ _______________ "bn1" "TrainedMean" {1×1×20 single} "bn1" "TrainedVariance" {1×1×20 single} "bn2" "TrainedMean" {1×1×20 single} "bn2" "TrainedVariance" {1×1×20 single} "bn3" "TrainedMean" {1×1×20 single} "bn3" "TrainedVariance" {1×1×20 single}

モデル勾配関数の定義

この例の最後にリストされている関数modelGradientsを作成します。この関数はdlnetworkオブジェクトdlnet、入力データdlXのミニバッチとそれに対応するラベルYを入力として受け取り、dlnetにおける学習可能なパラメーターについての損失の勾配、および対応する損失を返します。

学習オプションの指定

ミニバッチ サイズを 128 として、学習を 5 エポック行います。SGDM 最適化では、学習率に 0.01、モーメンタムに 0.9 を指定します。

numEpochs = 5;miniBatchSize = 128;learnRate = 0。01; momentum = 0.9;

学習の進行状況をプロットに可視化します。

plots ="training-progress";

モデルの学習

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数preprocessMiniBatch(この例の最後に定義) を使用して、クラス ラベルを one-hot 符号化します。

  • イメージ データを次元ラベル'SSCB'(spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueueオブジェクトは、基となる型がsingledlarrayオブジェクトにデータを変換します。書式をクラス ラベルに追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueueオブジェクトは、GPU が利用可能な場合、各出力をgpuArrayに変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート(Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain,...'MiniBatchSize',miniBatchSize,...'MiniBatchFcn', @preprocessMiniBatch,...'MiniBatchFormat',{'SSCB',''});

カスタム学習ループを使用してモデルに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各ミニバッチで次を行います。

  • dlfevalと関数modelGradientsを使用してモデルの勾配、状態、および損失を評価し、ネットワークの状態を更新。

  • 関数sgdmupdateを使用してネットワーク パラメーターを更新。

学習の進行状況プロットを初期化します。

ifplots =="training-progress"figure lineLossTrain = animatedline('Color',[0.85 0.325 0.098]); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") gridonend

SGDM ソルバーの速度パラメーターを初期化します。

velocity = [];

ネットワークに学習をさせます。

iteration = 0; start = tic;% Loop over epochs.forepoch = 1:numEpochs% Shuffle data.shuffle(mbq)% Loop over mini-batches.whilehasdata(mbq) iteration = iteration + 1;% Read mini-batch of data and convert the labels to dummy% variables.[dlX,dlY] = next(mbq);% Evaluate the model gradients, state, and loss using dlfeval and the% modelGradients function and update the network state.[gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY); dlnet.State = state;% Update the network parameters using the SGDM optimizer.[dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum);% Display the training progress.ifplots =="training-progress"D = duration(0,0,toc(start),'Format','hh:mm:ss'); addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) title("Epoch: "+ epoch +", Elapsed: "+ string(D)) drawnowendendend

モデルのテスト

真のラベルと角度をもつテスト セットで予測を比較して、モデルの分類精度をテストします。学習データと同じ設定のminibatchqueueオブジェクトを使用して、テスト データ セットを管理します。

[XTest,YTest] = digitTest4DArrayData; dsXTest = arrayDatastore(XTest,'IterationDimension',4); dsYTest = arrayDatastore(YTest); dsTest = combine(dsXTest,dsYTest); mbqTest = minibatchqueue(dsTest,...'MiniBatchSize',miniBatchSize,...'MiniBatchFcn', @preprocessMiniBatch,...'MiniBatchFormat',{'SSCB',''});

例の最後にリストされている関数modelPredictionsを使用してイメージを分類します。この関数は、予測されたクラス、および真の値との比較を返します。

[classesPredictions,classCorr] = modelPredictions(dlnet,mbqTest,classNames);

分類精度を評価します。

accuracy = mean(classCorr)
accuracy = 0.9946

モデル勾配関数

関数modelGradientsは、dlnetworkオブジェクトdlnet、入力データdlXのミニバッチとそれに対応するラベルYを入力として受け取り、dlnetにおける学習可能なパラメーターについての損失の勾配、ネットワークの状態、および損失を返します。勾配を自動的に計算するには、関数dlgradientを使用します。

function[gradients,state,loss] = modelGradients(dlnet,dlX,Y) [dlYPred,state] = forward(dlnet,dlX); loss = crossentropy(dlYPred,Y); gradients = dlgradient(loss,dlnet.Learnables);end

モデル予測関数

関数modelPredictionsは、入力としてdlnetworkオブジェクトdlnet、入力データmbqminibatchqueueを受け取り、minibatchqueueのすべてのデータを反復処理することでモデル予測を計算します。この関数は、関数onehotdecodeを使用して、スコアが最も高い予測されたクラスを見つけ、その予測を真のクラスと比較します。この関数は、予測および予測の正誤を表す 0 と 1 のベクトルを返します。

function[classesPredictions,classCorr] = modelPredictions(dlnet,mbq,classes) classesPredictions = []; classCorr = [];whilehasdata(mbq) [dlX,dlY] = next(mbq);% Make predictions using the model function.dlYPred = predict(dlnet,dlX);% Determine predicted classes.YPredBatch = onehotdecode(dlYPred,classes,1); classesPredictions = [classesPredictions YPredBatch];% Compare predicted and true classesY = onehotdecode(dlY,classes,1); classCorr = [classCorr YPredBatch == Y];endend

ミニ バッチ前処理関数

関数preprocessMiniBatchは、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

  2. 入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function[X,Y] = preprocessMiniBatch(XCell,YCell)% Extract image data from cell and concatenateX = cat(4,XCell{:});%从细胞中提取标签数据和连接Y = cat(2,YCell{:});% One-hot encode labelsY = onehotencode(Y,1);end

参考

|||||||||

関連するトピック