Main Content

このページは前リリースの情報です。該当の英語のページはこのリリースで削除されています。

parfor を使用した複数の深層学習ネットワークの学習

この例では、parforループを使用して学習オプションについてパラメーター スイープを実行する方法を説明します。

深層学習には多くの場合、数時間または数日を要し、良好な学習オプションの探索は困難なことがあります。並列計算を使用して、良好なモデルの探索の高速化および自動化ができます。複数のグラフィックス処理装置 (GPU) を搭載したマシンにアクセスできる場合は、ローカルの parpool を使用して、データセットのローカル コピーでこの例を実行できます。さらに多くのリソースが必要な場合は、深層学習をクラウドにスケール アップできます。この例では、parfor ループを使用してクラウド上のクラスターにある学習オプションMiniBatchSizeについてパラメーター スイープを実行する方法を説明します。スクリプトを編集して、その他すべての学習オプションについてパラメーター スイープを実行できます。また、この例ではDataQueueを使用して計算中にワーカーからフィードバックを取得する方法も説明します。さらに、このスクリプトをバッチ ジョブとしてクラスターに投入すると、MATLAB での作業を続行したり、MATLAB を終了して後で結果を取得したりできます。詳細については、深層学習バッチ ジョブのクラスターへの送信(Deep Learning Toolbox)を参照してください。

要件

この例を実行するには、クラスターを構成し、データをクラウドにアップロードしなければなりません。MATLAB では、MATLAB デスクトップから直接クラウドにクラスターを作成できます。[ホーム]タブの[並列]メニューで[クラスターの作成と管理]を選択します。クラスター プロファイル マネージャーで、[クラウド クラスターの作成]をクリックします。あるいは、MathWorks Cloud Center を使用して、計算クラスターの作成およびアクセスができます。詳細については、Getting Started with Cloud Centerを参照してください。この例では、MATLAB の[ホーム]タブの[並列][既定のクラスターの選択]で、使用するクラスターを確実に既定として設定します。その後、Amazon S3 バケットにデータをアップロードし、MATLAB から直接使用します。この例では、既に Amazon S3 に保存されている CIFAR-10 データセットのコピーを使用します。手順については、クラウドへの深層学習データのアップロード(Deep Learning Toolbox)を参照してください。

クラウドからのデータセットのロード

imageDatastoreを使用して,クラウドから学習データセットおよびテスト データセットを読み込みます。学習データセットを学習セットと検証セットに分割し、パラメーター スイープから最良のネットワークをテストするためにテスト データセットを保持しておきます。この例では、Amazon S3 に保存されている CIFAR-10 データセットのコピーを使用します。クラウド内のデータ ストアへのアクセス権をワーカーが確実にもつように、AWS 認証情報の環境変数が正しく設定されていることを確認してください。クラウドへの深層学習データのアップロード(Deep Learning Toolbox)を参照してください。

imds = imageDatastore('s3://cifar10cloud/cifar10/train',...'IncludeSubfolders',true,...'LabelSource','foldernames'); imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test',...'IncludeSubfolders',true,...'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

augmentedImageDatastoreオブジェクトを作成し、拡張イメージ データを使用してネットワークに学習させます。ランダムな平行移動と水平方向の反射パターンを使用します。データ拡張は、ネットワークによる過適合と、学習イメージそのものの細部の記憶を防ぐ上で役立ちます。

imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter(...'RandXReflection',true,...'RandXTranslation',pixelRange,...'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain,...'DataAugmentation',imageAugmenter,...'OutputSizeMode','randcrop');

ネットワーク アーキテクチャの定義

CIFAR-10 データセットのネットワーク アーキテクチャを定義します。コードを簡略化するために、入力を畳み込む畳み込みブロックを使用します。プーリング層は空間次元をダウンサンプリングします。

imageSize = [32 32 3]; netDepth = 2;% netDepth controls the depth of a convolutional blocknetWidth = 16;% netWidth controls the number of filters in a convolutional blocklayers = [ imageInputLayer(imageSize) convolutionalBlock(netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(2*netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(4*netWidth,netDepth) averagePooling2dLayer(8) fullyConnectedLayer(10) softmaxLayer classificationLayer ];

複数のネットワークで同時に学習

パラメーター スイープの対象とするミニバッチ サイズを指定します。得られるネットワークと精度の変数を割り当てます。

miniBatchSizes = [64 128 256 512]; numMiniBatchSizes = numel(miniBatchSizes); trainedNetworks = cell(numMiniBatchSizes,1); accuracies = zeros(numMiniBatchSizes,1);

parforループ内でミニバッチ サイズを変化させて並列パラメーター スイープを実行し、複数のネットワークに学習させます。クラスター内のワーカーは同時に複数のネットワークに学習させ、学習の完了時に学習済みのネットワークと精度を返します。学習が機能していることを確認するには、学習オプションのVerbosetrueに設定します。ワーカーは独立して計算を行うため、コマンド ライン出力は反復の順番と同じにならないことに注意してください。

parforidx = 1:numMiniBatchSizes miniBatchSize = miniBatchSizes(idx); initialLearnRate = 1e-1 * miniBatchSize/256;% Scale the learning rate according to the mini-batch size.%定义培训选项。Set the mini-batch size.options = trainingOptions('sgdm',...'MiniBatchSize',miniBatchSize,...% Set the corresponding MiniBatchSize in the sweep.'Verbose',false,...% Do not send command line output.'InitialLearnRate',initialLearnRate,...% Set the scaled learning rate.'L2Regularization',1e-10,...“MaxEpochs”,30,...'Shuffle','every-epoch',...'ValidationData',imdsValidation,...'LearnRateSchedule','piecewise',...'LearnRateDropFactor',0.1,...'LearnRateDropPeriod',25);% Train the network in a worker in the cluster.net = trainNetwork(augmentedImdsTrain,layers,options);% To obtain the accuracy of this network, use the trained network to% classify the validation images on the worker and compare the predicted labels to the% actual labels.YPredicted = classify(net,imdsValidation); accuracies(idx) = sum(YPredicted == imdsValidation.Labels)/numel(imdsValidation.Labels);% Send the trained network back to the client.trainedNetworks{idx} = net;end
Starting parallel pool (parpool) using the 'MyClusterInTheCloud' profile ... Connected to the parallel pool (number of workers: 4).

parforの完了後のtrainedNetworksは、ワーカーによって学習させた結果のネットワークを含みます。学習済みネットワークおよびその精度を表示します。

trainedNetworks
trainedNetworks =4×1 cell array{1×1 SeriesNetwork} {1×1 SeriesNetwork} {1×1 SeriesNetwork} {1×1 SeriesNetwork}
accuracies
accuracies =4×10.8188 0.8232 0.8162 0.8050

精度が最良のネットワークを選択します。テスト データセットに対するそのパフォーマンスをテストします。

[~, I] = max(accuracies); bestNetwork = trainedNetworks{I(1)}; YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8173

学習中のフィードバック データの送信

各ワーカーの学習の進行状況を示すプロットを準備して初期化します。変化するデータを表示する便利な方法であるanimatedLineを使用します。

f = figure; f.Visible = true;fori=1:4 subplot(2,2,i) xlabel('Iteration'); ylabel('Training accuracy'); lines(i) = animatedline;end

DataQueueを使用して、ワーカーからクライアントに学習の進行状況データを送信してから、データをプロットします。afterEachを使用して、ワーカーから学習の進行状況のフィードバックが送信されるたびにプロットを更新します。パラメーターoptsはワーカー,学習反復,学習精度に関する情報を含みます。

D = parallel.pool.DataQueue; afterEach(D, @(opts) updatePlot(lines, opts{:}));

parfor ループ内で異なるミニバッチ サイズを使用して並列パラメーター スイープを実行し、複数のネットワークに学習させます。学習オプションでOutputFcnを使用すると、各反復でクライアントに学習の進行状況が送信されます。次の図は、以下のコードの実行中における 4 つの異なるワーカーの学習の進行状況を示します。

parforidx = 1:numel(miniBatchSizes) miniBatchSize = miniBatchSizes(idx); initialLearnRate = 1e-1 * miniBatchSize/256;% Scale the learning rate according to the miniBatchSize.%定义培训选项。设置一个输出功能ion to send data back% to the client each iteration.options = trainingOptions('sgdm',...'MiniBatchSize',miniBatchSize,...% Set the corresponding MiniBatchSize in the sweep.'Verbose',false,...% Do not send command line output.'InitialLearnRate',initialLearnRate,...% Set the scaled learning rate.'OutputFcn',@(state) sendTrainingProgress(D,idx,state),...% Set an output function to send intermediate results to the client.'L2Regularization',1e-10,...“MaxEpochs”,30,...'Shuffle','every-epoch',...'ValidationData',imdsValidation,...'LearnRateSchedule','piecewise',...'LearnRateDropFactor',0.1,...'LearnRateDropPeriod',25);% Train the network in a worker in the cluster. The workers send% training progress information during training to the client.net = trainNetwork(augmentedImdsTrain,layers,options);% To obtain the accuracy of this network, use the trained network to% classify the validation images on the worker and compare the predicted labels to the% actual labels.YPredicted = classify(net,imdsValidation); accuracies(idx) = sum(YPredicted == imdsValidation.Labels)/numel(imdsValidation.Labels);% Send the trained network back to the client.trainedNetworks{idx} = net;end
Analyzing and transferring files to the workers ...done.

parforの完了後のtrainedNetworksは、ワーカーによって学習させた結果のネットワークを含みます。学習済みネットワークおよびその精度を表示します。

trainedNetworks
trainedNetworks =4×1 cell array{1×1 SeriesNetwork} {1×1 SeriesNetwork} {1×1 SeriesNetwork} {1×1 SeriesNetwork}
accuracies
accuracies =4×10.8214 0.8172 0.8132 0.8084

精度が最良のネットワークを選択します。テスト データセットに対するそのパフォーマンスをテストします。

[~, I] = max(accuracies); bestNetwork = trainedNetworks{I(1)}; YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8187

補助関数

ネットワーク アーキテクチャ内で畳み込みブロックを作成する関数を定義します。

functionlayers = convolutionalBlock(numFilters,numConvLayers) layers = [ convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer ]; layers = repmat(layers,numConvLayers,1);end

DataQueueを介して学習の進行状況をクライアントに送信する関数を定義します。

functionsendTrainingProgress(D,idx,info)ifinfo.State =="iteration"send(D,{idx,info.Iteration,info.TrainingAccuracy});endend

ワーカーが中間結果を送信したときにプロットを更新する、更新関数を定義します。

functionupdatePlot(lines,idx,iter,acc) addpoints(lines(idx),iter,acc); drawnowlimitratenocallbacksend

参考

(Deep Learning Toolbox)||

関連する例

詳細