主要内容

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

敵対的生成ネットワーク(GAN)の学習

この例では,敵対的生成ネットワーク(GAN)に学習させてイメージを生成する方法を説明します。

敵対的生成ネットワーク (甘)は深層学習ネットワークの一種で、入力された実データに類似した特性をもつデータを生成できます。

赣は一緒に学習を行う 2.つのネットワークで構成されています。

  1. ジェネレーター — このネットワークは、乱数値 (潜在入力) のベクトルを入力として与えられ、学習データと同じ構造のデータを生成します。

  2. ディスクリミネーター——このネットワークは,学習データとジェネレーターにより生成されたデータの両方からの観測値を含むデータのバッチを与えられ,その観測値が”実データ”か”生成データ”かの分類を試みます。

赣に学習させる場合は、両方のネットワークの学習を同時に行うことで両者の性能を最大化します。

  • ジェネレーターに学習させて,ディスクリミネーターを”騙す”データを生成。

  • ディスクリミネーターに学習させて,実データと生成データを区別。

ジェネレーターの性能を最適化するには、生成データが与えられたときのディスクリミネーターの損失を最大化します。つまり、ジェネレーターの目的はディスクリミネーターが "実データ" と分類するようなデータを生成することです。

ディスクリミネーターの性能を最適化するには,実データと生成データの両方のバッチが与えられたときのディスクリミネーターの損失を最小化します。つまり,ディスクリミネーターの目的はジェネレーターに”騙されない”ことです。

これらの方法によって,十分に現実的なデータを生成するジェネレーターと,学習データの特性である強い特徴表現を学習したディスクリミネーターを得ることが理想的な結果です。

学習データの読み込み

のデータセット [1] をダウンロードし、解凍します。

网址=“http://download.tensorflow.org/example_images/flower_photos.tgz”;downloadFolder = tempdir;文件名= fullfile (downloadFolder,“flower_dataset.tgz”);imageFolder=fullfile(下载文件夹,‘花卉照片’);如果~存在(imageFolder“dir”) disp (下载花卉数据集(218mb)…) websave(文件名,url);解压(文件名,downloadFolder)结束

花の写真のイメージデータストアを作成します。

datasetFolder = fullfile (imageFolder);imd = imageDatastore (datasetFolder,...“IncludeSubfolders”,真正的);

データを拡張して水平方向にランダムに反転させ,イメージのサイズを64 x 64に変更します。

增量= imageDataAugmenter (“RandXReflection”,true);augimds=augmentedImageDatastore([64],imds,“DataAugmentation”增强器);

ジェネレーター ネットワークの定義

乱数値の1 x 1 x 100の配列からイメージを生成するネットワークアーキテクチャを以下のように定義します。

このネットワークは,次を行います。

  • "投影形状変更" 層を使用して、ノイズで構成される 1 x 1 x 100の配列を 7 x 7 x 128の配列に変換。

  • バッチ正規化とReLU層を用いた一連の転置畳み込み層を使用して,結果の配列を64 x 64 x 3の配列にスケールアップ。

このネットワークアーキテクチャを層グラフとして定義し,次のネットワークプロパティを指定します。

  • 転置畳み込み層では,5 x 5のフィルターを指定し,各層でフィルター数を減らし,ストライドを2にし,各エッジの出力をトリミングするように設定。

  • 最後の転置畳み込み層では、生成されたイメージの 3.つの RGBチャネルに対応する 3.つの 5 x 5のフィルターと、前の層の出力サイズを設定。

  • ネットワークの最後に,双曲正切層を追加。

ノイズ入力を投影して形状変更するには,この例にサポートファイルとして添付されている,カスタム層投影仪和Shapelayerを使用します。投影仪和Shapelayer層は,全結合演算を使用して入力をスケールアップし,出力を指定サイズに形状変更します。

filterSize=5;numFilters=64;numLatentInputs=100;projectionSize=[4 4 512];LayerGenerator=[imageInputLayer([1 1 numLatentInputs],“正常化”“没有”“名字”“在”) projectAndReshapeLayer (projectionSize numLatentInputs,“项目”);transposedConv2dLayer (filterSize 4 * numFilters,“名字”“tconv1”)批处理规范化层(“名字”“bnorm1”) reluLayer (“名字”“relu1”)transposedConv2dLayer(过滤器化,2*numFilters,“步”2.“种植”“相同”“名字”“tconv2”)批处理规范化层(“名字”“bnorm2”) reluLayer (“名字”“relu2”)transposedConv2dLayer(过滤器化、numFilters、,“步”2.“种植”“相同”“名字”“tconv3”)批处理规范化层(“名字”“bnorm3”) reluLayer (“名字”“relu3”) transposedConv2dLayer (filterSize 3“步”2.“种植”“相同”“名字”“tconv4”) tanhLayer (“名字”的双曲正切)];lgraphGenerator=layerGraph(LayerGenerator);

カスタム学習ループを使用してネットワークに学習させ,自動微分を有効にするには,層グラフを数据链路网络オブジェクトに変換します。

dlnetGenerator = dlnetwork (lgraphGenerator);

ディスクリミネーターネットワークの定義

64 x 64の実イメージと生成イメージを分類する,次のネットワークを定義します。

64 x 64 x 3のイメージを受け取り、バッチ正規化と 漏泄雷卢層のある一連の畳み込み層を使用してスカラーの予測スコアを返すネットワークを作成します。ドロップアウトを使用して、入力イメージにノイズを追加します。

  • ドロップアウト層で,ドロップアウトの確率0.5をに設定。

  • 畳み込み層で、5 x 5のフィルターを指定し、各層でフィルター数を増やす。また、ストライド 2.で出力をパディングするように指定。

  • 漏泄雷卢層で、スケールを 0.2に設定。

  • 最後の層で,4 x 4のフィルターを1つもつ畳み込み層を設定。

範囲[0,1]の確率を出力するには,モデル勾配関数の関数乙状结肠を使用します

dropoutProb = 0.5;numFilters = 64;规模= 0.2;inputSize = [64 64 3];filterSize = 5;layersDiscriminator = [imageInputLayer(inputSize,“正常化”“没有”“名字”“在”)dropoutLayer(0.5,“名字”“辍学”) convolution2dLayer (filterSize numFilters,“步”2.“填充”“相同”“名字”“conv1”) leakyReluLayer(规模、“名字”“lrelu1”)卷积2dlayer(filterSize,2*numFilters,“步”2.“填充”“相同”“名字”“conv2”)批处理规范化层(“名字”“bn2”) leakyReluLayer(规模、“名字”“lrelu2”) convolution2dLayer (filterSize 4 * numFilters,“步”2.“填充”“相同”“名字”“conv3”)批处理规范化层(“名字”“bn3”) leakyReluLayer(规模、“名字”“lrelu3”) convolution2dLayer (filterSize 8 * numFilters,“步”2.“填充”“相同”“名字”“conv4”)批处理规范化层(“名字”“bn4”) leakyReluLayer(规模、“名字”“lrelu4”1) convolution2dLayer(4日,“名字”“conv5”)];lgraphDiscriminator=layerGraph(layerDiscriminator);

カスタム学習ループを使用してネットワークに学習させ,自動微分を有効にするには,層グラフを数据链路网络オブジェクトに変換します。

dlnetDiscriminator = dlnetwork (lgraphDiscriminator);

モデル勾配,損失関数,およびスコアの定義

例のモデル勾配関数の節にリストされている関数modelGradientsを作成します。この関数は,ジェネレーターネットワークおよびディスクリミネーターネットワーク,入力データのミニバッチ,および乱数値と反転係数の配列を入力として受け取り,ネットワーク内の学習可能なパラメーターについての損失の勾配と,2つのネットワークのスコアを返します。

学習オプションの指定

ミニバッチ サイズを 128として 500エポック学習させます。大きなデータセットでは、学習させるエポック数をこれより少なくできる場合があります。

numEpochs = 500;miniBatchSize = 128;

亚当最適化のオプションを指定します。両方のネットワークで次のように設定します。

  • 学習率0.0002

  • 勾配の減衰係数0.5

  • 2乗勾配の減衰係数0.999

learnRate=0.0002;gradientDecayFactor=0.5;squaredGradientDecayFactor=0.999;

実イメージと生成イメージとを区別するディスクリミネーターの学習速度が速すぎる場合、ジェネレーターの学習に失敗する可能性があります。ディスクリミネーターとジェネレーターの学習バランスを改善するために、ラベルをランダムに反転させて実データにノイズを加えます。

実ラベルの30%を反転するように指定します。これは,ラベルの総数の15%が学習中に反転することを意味します。生成されたイメージにはすべて正しいラベルが付いているので,これがジェネレーターに損失を与えることはない点に注意してください。

flipFactor = 0.3;

生成された検証イメージを100回の反復ごとに表示します。

validationFrequency = 100;

モデルの学習

minibatchqueueを使用して,イメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数preprocessMiniBatch(この例の最後に定義)を使用して,イメージを範囲[1]で再スケーリングします。

  • 128年観測値が個未満の部分的なミニバッチは破棄します。

  • イメージデータを次元ラベル“SSCB”(空间、空间、通道、批处理)で書式設定します。既定では,minibatchqueueオブジェクトは,基となる型がdlarrayオブジェクトにデータを変換します。

  • GPUが利用できる場合,GPUで学習を行います。minibatchqueue“外部环境”オプションが“自动”のとき,GPUが利用可能であれば,minibatchqueueは各出力をgpuArrayに変換します。GPU を使用するには、Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

augimds。MiniBatchSize = MiniBatchSize;executionEnvironment =“自动”;兆贝可= minibatchqueue (augimds,...“MiniBatchSize”,小批量,...“PartialMiniBatch”“丢弃”...“MiniBatchFcn”,@minibatch,...“MiniBatchFormat”“SSCB”...“外部环境”, executionEnvironment);

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し,各反復でネットワークパラメーターを更新します。学習の進行状況を監視するには,ホールドアウトされた乱数値の配列をジェネレーターに入力して得られた生成イメージのバッチと,スコアのプロットを表示します。

亚当のパラメーターを初期化します。

trailingAvgGenerator = [];trailingAvgSqGenerator = [];trailingAvgDiscriminator = [];trailingAvgSqDiscriminator = [];

学習の進行状況を監視するには、ホールドアウトされた乱数値の固定配列のバッチをジェネレーターに渡して得られた生成イメージのバッチを表示し、ネットワークのスコアをプロットします。

ホールドアウトされた乱数値の配列を作成します。

numValidationImages = 25;ZValidation = randn (1, 1, numLatentInputs numValidationImages,“单身”);

データをdlarrayオブジェクトに変換し,次元ラベル“SSCB”(空间、空间、通道、批处理)を指定します。

dlZValidation = dlarray (ZValidation,“SSCB”);

GPUで学習する場合,データをgpuArrayオブジェクトに変換します。

如果(b)执行环境==“自动”&& canUseGPU) || executionEnvironment ==“图形”dlZValidation = gpuArray (dlZValidation);结束

学習の進行状況プロットを初期化します。图を作成して幅が2倍になるようサイズ変更します。

f =图;f.Position (3) = 2 * f.Position (3);

生成イメージとネットワーク スコアのサブプロットを作成します。

imageAxes=子批次(1,2,1);scoreAxes=子批次(1,2,2);

スコアのプロット用にアニメーションの線を初期化します。

lineScoreGenerator = animatedline (scoreAxes,“颜色”0.447 - 0.741 [0]);lineScoreDiscriminator = animatedline (scoreAxes,“颜色”, [0.85 0.325 0.098]);传奇(“发电机”“鉴别器”);ylim([0 1])包含(“迭代”) ylabel (“得分”)网格

氮化镓に学習させます。各エポックで,データストアをシャッフルしてデータのミニバッチについてループします。

各ミニバッチで次を行います。

  • 関数dlfevalおよびmodelGradientsを使用してモデルの勾配を評価します。

  • 関数adamupdateを使用してネットワーク パラメーターを更新。

  • 2つのネットワークのスコアをプロット。

  • validationFrequencyの反復がすべて終了した後で、ホールドアウトされた固定ジェネレーター入力の生成イメージのバッチを表示。

学習を行うのに時間がかかる場合があります。

迭代=0;开始=tic;%环游各个时代。时代= 1:numEpochs重置和洗牌数据存储。洗牌(mbq);%在小批量上循环。hasdata(mbq)迭代=迭代+1;%读取小批量数据。dlX =下一个(兆贝可);为发电机网络产生潜在的输入。转换为% dlarray,并指定尺寸标签'SSCB'(空间,%空间,通道,批处理)。如果训练在GPU上,然后转换%潜在输入到gpuArray。Z = randn (1, - 1, numLatentInputs大小(dlX, 4),“单身”);dlZ = dlarray (Z,“SSCB”);如果(b)执行环境==“自动”&& canUseGPU) || executionEnvironment ==“图形”dlZ = gpuArray (dlZ);结束%评估模型梯度和生成器状态使用的% dlfeval和模型梯度函数%例如。[gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] =...dlfeval(@modelGradients, dlnetGenerator, dlnetDiscriminator, dlX, dlZ, flipFactor);dlnetGenerator。状态= stateGenerator;%更新鉴别器网络参数。[dlnetDiscriminator,trailingAvgDiscriminator,trailingAvgSqDiscriminator]=...adamupdate (dlnetDiscriminator gradientsDiscriminator,...trailingAvgDiscriminator trailingAvgSqDiscriminator,迭代,...learnRate、gradientDecayFactor、squaredGradientDecayFactor);%更新生成器网络参数。[dlnetGenerator, trailingAvgGenerator trailingAvgSqGenerator] =...adamupdate (dlnetGenerator gradientsGenerator,...trailingAvgGenerator trailingAvgSqGenerator,迭代,...learnRate、gradientDecayFactor、squaredGradientDecayFactor);%每次validationFrequency迭代,使用%保持发电机输入如果mod(iteration,validationFrequency) == 0 || iteration == 1%使用保留的生成器输入生成图像。dlXGeneratedValidation=预测(dlnetGenerator、dlZValidation);在[0 1]范围内平铺并重新缩放图像。我= imtile (extractdata (dlXGeneratedValidation));I =重新调节(我);%显示图像。次要情节(1、2、1);图像(imageAxes,我)xticklabels ([]);yticklabels ([]);标题(“生成的图像”);结束%更新分数图次要情节(1、2、2)addpoints (lineScoreGenerator,迭代,...double(收集(提取数据(scoreGenerator));addpoints(lineScoreDiscriminator,迭代,...双(收集(提取数据(记分鉴别器));%用培训进度信息更新标题。D =持续时间(0,0,toc(开始),“格式”“hh:mm:ss”);标题(...”时代:“+时代+", "+...“迭代:”+迭代+", "+...“已过:”+ drawnow字符串(D))结束结束

ここでは、ディスクリミネーターは生成イメージの中から実イメージを識別する強い特徴表現を学習しました。それに対し、ジェネレーターは実データのように見えるデータを生成できるように、同様に強い特徴表現を学習しました。

学習プロットは,ジェネレーターおよびディスクリミネーターのネットワークのスコアを示しています。ネットワークのスコアを解釈する方法の詳細については,氮化镓の学習過程の監視と一般的な故障モードの識別を参照してください。

新しいイメージの生成

新しいイメージを生成するには、ジェネレーターに対して関数预测を使用して,乱数値の1 x 1 x 100の配列のバッチを含むdlarrayオブジェクトを指定します。イメージを並べて表示するには関数imtileを使用し、関数重新缩放を使ってイメージを再スケーリングします。

乱数値の1 x 1 x 100の配列25個のバッチを含むdlarrayオブジェクトを作成します。

ZNew = randn (1, 1, numLatentInputs, 25岁,“单身”);dlZNew = dlarray (ZNew,“SSCB”);

GPUを使用してイメージを生成するには,データをgpuArrayオブジェクトにも変換します。

如果(b)执行环境==“自动”&& canUseGPU) || executionEnvironment ==“图形”dlZNew = gpuArray (dlZNew);结束

関数预测をジェネレーターと入力データと共に使用して,新しいイメージを生成します。

dlXGeneratedNew =预测(dlnetGenerator dlZNew);

イメージを表示します。

I=imtile(extracteddata(dlXGeneratedNew));I=rescale(I);图形图像(I)轴标题(“生成的图像”

モデル勾配関数

関数modelGradientsは,ジェネレーターおよびディスクリミネーターの数据链路网络オブジェクトであるdlnetGeneratordlnetDiscriminator,入力データのミニバッチdlX、乱数値の配列dlZ、および実ラベルの反転する割合flipFactor,を入力として受け取り、ネットワーク内の学習可能なパラメーターについての損失の勾配、ジェネレーターの状態、および 2.つのネットワークのスコアを返します。ディスクリミネーターの出力は範囲 [0,1] に含まれないため、modelGradientsはシグモイド関数を適用してこれを確率に変換します。

函数[gradientsGenerator, gradientsDiscriminator, stateGenerator, scoreGenerator, scoreDiscriminator] =...模型渐变(dlnetGenerator、dlnetDiscriminator、dlX、dlZ、flipFactor)%用鉴别器网络计算真实数据的预测。dlYPred = forward(dlnetDiscriminator, dlX);%使用鉴别器网络计算生成数据的预测。向前(dlXGenerated stateGenerator] = (dlnetGenerator, dlZ);dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);%将鉴别器输出转换为概率。probGenerated =乙状结肠(dlYPredGenerated);probReal =乙状结肠(dlYPred);%计算鉴别器的分数。scoreDiscriminator =((意思(probReal) + (1-probGenerated)) / 2);%计算生成器的得分。scoreGenerator=平均值(生成的概率);%随机翻转真实图像的一部分标签。numObservations =大小(probReal 4);idx = randperm(numObservations,floor(flipFactor * numObservations));%翻转标签probReal(:,:,:,idx)=1-probReal(:,:,:,idx);%计算GAN损耗。[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);%对于每个网络,计算相对于损失的梯度。gradientsGenerator = dlgradient(lossGenerator, dlnetGenerator。可学的,“RetainData”,真正的);gradientsDiscriminator = dlgradient(lossDiscriminator, dlnetDiscriminator.Learnables);结束

氮化镓の損失関数とスコア

ジェネレーターの目的はディスクリミネーターが”実データ”に分類するようなデータを生成することです。ジェネレーターが生成したイメージをディスクリミネーターが実データとして分類する確率を最大化するには,負の対数尤度関数を最小化します。

ディスクリミネーターの出力 Y が与えられた場合、次のようになります。

  • Y ˆ σ Y は、入力イメージが "実" クラスに属する確率です。

  • 1 - Y ˆ は,入力イメージが”生成“クラスに属している確率です。

シグモイド演算 σ は関数modelGradientsで行われる点に注意してください。ジェネレーターの損失関数は次の式で表されます。

损耗发生器 - 的意思是 日志 Y ˆ 生成的

ここで, Y ˆ G e n e r 一个 t e d は生成イメージに対するディスクリミネーターの出力確率を表しています。

ディスクリミネーターの目的はジェネレーターに”騙されない”ことです。ディスクリミネーターが実イメージと生成イメージを正しく区別する確率を最大化するには,対応する負の対数尤度関数の和を最小化します。

ディスクリミネーターの損失関数は次の式で表されます。

lossDiscriminator - 的意思是 日志 Y ˆ 真正的 - 的意思是 日志 1 - Y ˆ 生成的

ここで, Y ˆ R e 一个 l は実イメージに対するディスクリミネーターの出力確率を表しています。

ジェネレーターとディスクリミネーターがそれぞれの目標をどれだけ達成するかを0から1のスケールで測定するには,スコアの概念を使用できます。

ジェネレーターのスコアは,生成イメージに対するディスクリミネーターの出力に対応する確率の平均です。

scoreGenerator 的意思是 Y ˆ 生成的

ディスクリミネーターのスコアは,実イメージと生成イメージの両方に対するディスクリミネーターの出力に対応する確率の平均です。

scoreDiscriminator 1 2 的意思是 Y ˆ 真正的 + 1 2 的意思是 1 - Y ˆ 生成的

スコアは損失に反比例しますが、実質的には同じ情報を表しています。

函数[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated)%计算鉴别器网络的损耗。lossDiscriminator = -mean(log(probReal)) -mean(log(1-probGenerated));%计算发电机网络的损耗。lossGenerator =意味着(日志(probGenerated));结束

ミニバッチ前処理関数

関数preprocessMiniBatchは,次の手順でデータを前処理します。

  1. 入力电池配列からイメージデータを抽出して数値配列に連結します。

  2. イメージの範囲が[1]となるように再スケーリングします。

函数X = preprocessMiniBatch(数据)%连接mini-batchX =猫(4、数据{:});%在[-1]范围内重新缩放图像。X =重新调节(X, 1, 1,“InputMin”,0,“InputMax”,255);结束

参考文献

  1. TensorFlow团队。http://download.tensorflow.org/example_images/flower_photos.tgz

  2. 雷德福,亚历克,卢克·梅茨,和苏史密斯·钦塔拉。基于深度卷积生成对抗网络的无监督表示学习arXiv预印本arXiv: 1511.06434(2015).

参考

|||||||

関連するトピック