这个例子展示了如何使用无监督的图像到图像转换网络(UNIT)在白天和黄昏光照条件下转换图像。
域翻译是将风格和特征从一个图像域转移到另一个图像域的任务。该技术可以扩展到其他图像到图像的学习操作,如图像增强、图像着色、缺陷生成和医学图像分析。
单位[1]是一种生成式对抗网络(GAN),由一个生成器网络和两个同时训练的判别器网络组成,以最大化整体性能。有关UNIT的更多信息,请参见开始了解用于图像到图像转换的GANs.
本例使用CamVid数据集[2]来自剑桥大学的培训。该数据集是701张图像的集合,包含在驾驶时获得的街道视图。
下载CamVid数据集。下载时间取决于你的网络连接。
imageURL =“http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip”;dataDir = fullfile(tempdir,“CamVid”);downloadCamVidImageData (dataDir imageURL);imgDir = fullfile(dataDir,“图片”,“701 _stillsraw_full”);
CamVid图像数据集包括497张白天采集的图像和124张黄昏采集的图像。训练后的UNIT网络性能有限,因为CamVid训练图像的数量相对较少,限制了训练后网络的性能。此外,一些图像属于一个图像序列,因此与数据集中的其他图像相关。为了最大限度地减少这些限制的影响,本示例以最大限度地提高训练数据的可变性的方式手动将数据划分为训练数据集和测试数据集。
通过加载文件获得用于训练和测试的白天和黄昏图像的文件名camvidDayDuskDatasetFileNames.mat
.训练数据集由263张白天图像和107张黄昏图像组成。测试数据集由234张白天图像和17张黄昏图像组成。
负载(“camvidDayDuskDatasetFileNames.mat”);
创建imageDatastore
对象,这些对象管理用于训练和测试的白天和黄昏图像。
imdsDayTrain = imageDatastore(fullfile(imgDir,trainDayNames));imdsDuskTrain = imageDatastore(fullfile(imgDir,trainDuskNames));imdsDayTest = imageDatastore(fullfile(imgDir,testDayNames));imdsDuskTest = imageDatastore(fullfile(imgDir,testDuskNames));
从白天和黄昏训练数据集预览一个训练图像。
day = preview(imdsDayTrain);黄昏=预览(imdsDuskTrain);蒙太奇({天,黄昏})
指定源和目标图像的图像输入大小。
inputSize = [256,256,3];
对训练数据进行增强和预处理变换
函数使用helper函数指定的自定义预处理操作augmentDataForDayToDusk
.该函数作为支持文件附加到示例中。金宝app
的augmentDataForDayToDusk
函数执行以下操作:
使用双三次插值将图像调整为指定的输入大小。
在水平方向上随机翻转图像。
将图像缩放到范围[- 1,1]。这个范围与韵母的范围相匹配tanhLayer
(深度学习工具箱)用于发电机。
imdsDayTrain = transform(imdsDayTrain, @(x)augmentDataForDayToDusk(x,inputSize));imdsDuskTrain = transform(imdsDuskTrain, @(x)augmentDataForDayToDusk(x,inputSize));
创建一个UNIT生成器网络unitGenerator
函数。发生器的源编码器和目标编码器部分各由两个下采样块和五个剩余块组成。编码器部分共享五个剩余块中的两个。类似地,生成器的源和目标解码器部分各由两个下采样块和五个剩余块组成,解码器部分共享五个剩余块中的两个。
gen = unitGenerator(inputSize,“NumResidualBlocks”5,“NumSharedBlocks”2);
可视化发电机网络。
analyzeNetwork(创)
属性创建两个鉴别器网络,分别用于源域和目标域patchGANDiscriminator
函数。日是源域,黄昏是目标域。
discDay = patchGANDiscriminator(inputSize,“NumDownsamplingBlocks”4“FilterSize”3,...“ConvolutionWeightsInitializer”,“narrow-normal”,“NormalizationLayer”,“没有”);disdissk = patchGANDiscriminator(inputSize,“NumDownsamplingBlocks”4“FilterSize”3,...“ConvolutionWeightsInitializer”,“narrow-normal”,“NormalizationLayer”,“没有”);
可视化鉴别器网络。
analyzeNetwork (discDay);analyzeNetwork (discDusk);
的modelGradientsDisc
而且modelGradientGen
辅助函数分别计算鉴别器和生成器的梯度和损失。函数中定义了这些函数金宝app支持功能部分的示例。
每个鉴别器的目标是正确区分其域中图像的真实图像(1)和翻译图像(0)。每个鉴别器都有一个损失函数。
生成器的目标是生成经过翻译的图像,鉴别器将其分类为真实图像.发电机损耗是五种类型损耗的加权和:自重构损耗、循环一致性损耗、隐藏KL损耗、循环隐藏KL损耗和对抗损耗。
指定各种损失的权重因子。
lossWeights。自我ReconLossWeight = 10; lossWeights.hiddenKLLossWeight = 0.01; lossWeights.cycleConsisLossWeight = 10; lossWeights.cycleHiddenKLLossWeight = 0.01; lossWeights.advLossWeight = 1; lossWeights.discLossWeight = 0.5;
指定Adam优化的选项。训练网络35个epoch。为生成器和鉴别器网络指定相同的选项。
指定0.0001的相等学习率。
初始化尾随平均梯度和尾随平均梯度平方衰减率[]
.
使用梯度衰减因子为0.5和平方梯度衰减因子为0.999。
使用系数为0.0001的权重衰减正则化。
使用1个小批量进行训练。
learnRate = 0.0001;gradDecay = 0.5;sqGradDecay = 0.999;weightDecay = 0.0001;genAvgGradient = [];genAvgGradientSq = [];discDayAvgGradient = [];discDayAvgGradientSq = [];disduskavggradient = [];discDuskAvgGradientSq = []; miniBatchSize = 1; numEpochs = 35;
创建一个minibatchqueue
(深度学习工具箱)对象,该对象在自定义训练循环中管理观察结果的迷你批处理。的minibatchqueue
对象也将数据强制转换为dlarray
(深度学习工具箱)对象,用于在深度学习应用程序中实现自动区分。
指定小批数据提取格式为"SSCB”
(空间,空间,通道,批次)。设置DispatchInBackground”
返回的布尔值参数canUseGPU
.如果支持的金宝appGPU可用于计算,则minibatchqueue
对象在训练期间在并行池的后台预处理小批。
mbqDayTrain = minibatchqueue(imdsDayTrain,“MiniBatchSize”miniBatchSize,...“MiniBatchFormat”,“SSCB”,“DispatchInBackground”, canUseGPU);mbqDuskTrain = minibatchqueue(imdsDuskTrain,“MiniBatchSize”miniBatchSize,...“MiniBatchFormat”,“SSCB”,“DispatchInBackground”, canUseGPU);
默认情况下,本例使用helper函数下载CamVid数据集的UNIT生成器的预训练版本downloadTrainedDayDuskGeneratorNet
.helper函数作为支持文件附加到示例中。金宝app预训练的网络使您可以运行整个示例,而无需等待训练完成。
为了训练网络,设置doTraining
变量转换为真正的
.在自定义训练循环中训练模型。对于每个迭代:
方法读取当前小批处理的数据下一个
(深度学习工具箱)函数。
方法评估模型梯度dlfeval
(深度学习工具箱)功能和modelGradientsDisc
而且modelGradientGen
辅助功能。
方法更新网络参数adamupdate
(深度学习工具箱)函数。
在每个纪元之后显示源域和目标域的输入和翻译图像。
如果有GPU,可以在GPU上进行训练。使用GPU需要并行计算工具箱™和支持CUDA®的NVIDIA®GPU。有关更多信息,请参见GPU支金宝app持版本(并行计算工具箱).在NVIDIA Titan RTX上训练大约需要88个小时。
doTraining = false;如果doTraining创建一个显示结果的图形图(“单位”,“归一化”);为iPlot = 1:4 ax(iPlot) = subplot(2,2,iPlot);结束迭代= 0;%遍历epoch为epoch = 1:numEpochs每个纪元洗牌数据重置(mbqDayTrain);洗牌(mbqDayTrain);重置(mbqDuskTrain);洗牌(mbqDuskTrain);运行循环,直到迷你批处理队列mbqDayTrain中的所有图像都被处理完毕而hasdata(mbqDayTrain)迭代=迭代+ 1;从日域读取数据imDay = next(mbqDayTrain);从黄昏域读取数据如果hasdata(mbqDuskTrain) == 0 reset(mbqDuskTrain);洗牌(mbqDuskTrain);结束imDusk = next(mbqDuskTrain);计算鉴别器梯度和损失[discdaylosses,disDuskLoss,disDuskLoss] = dlfeval(@modelGradientDisc,...创,discDay、discDusk imDay、imDusk lossWeights.discLossWeight);在日鉴别器梯度上应用权重衰减正则化discday= dlupdate(@(g,w) g+weightDecay*w, discday, discDay.Learnables);更新日鉴别器参数[discDay,discDayAvgGradient,discDayAvgGradientSq] = adamupdate(discDay, discdaygradientsq,discDayAvgGradientSq)...discDayAvgGradient discDayAvgGradientSq,迭代,learnRate、gradDecay sqGradDecay);在黄昏鉴别器梯度上应用权重衰减正则化discdusk= dlupdate(@(g,w) g+weightDecay*w, discdusk, discDusk.Learnables);%更新黄昏鉴别器参数[discDuskAvgGradient,discDuskAvgGradientSq] = adamupdate(discDusk, discduskgradientsq,discDuskAvgGradientSq,...discDuskAvgGradient discDuskAvgGradientSq,迭代,learnRate、gradDecay sqGradDecay);计算发电机梯度和损耗[genGrad,genLoss,images] = dlfeval(@modelGradientGen,gen,discDay, disdusk,imDay,imDusk,lossWeights);对发电机梯度应用权重衰减正则化genGrad = dlupdate(@(g,w) g+weightDecay*w,genGrad,gen.Learnables);更新发电机参数[gen,genAvgGradient,genAvgGradientSq] = adamupdate(gen,genGrad,genAvgGradient,...迭代,genAvgGradientSq learnRate、gradDecay sqGradDecay);结束显示结果updateTrainingPlotDayToDusk (ax,图片{:});结束保存已训练的网络modelDateTime = string(datetime(“现在”,“格式”,“yyyy-MM-dd-HH-mm-ss”));保存(strcat (“trainedDayDuskUNITGeneratorNet——”modelDateTime,“时代——”num2str (numEpochs),“.mat”),“创”);其他的net_url =“https://ssd.mathworks.com/金宝appsupportfiles/vision/data/trainedDayDuskUNITGeneratorNet.zip”;downloadTrainedDayDuskGeneratorNet (net_url dataDir);负载(fullfile (dataDir“trainedDayDuskUNITGeneratorNet.mat”));结束
源到目标图像转换使用UNIT生成器从源域(天)中的图像生成目标域(黄昏)中的图像。
从日测试映像的数据存储中读取映像。
idxToTest = 1;dayTestImage = readimage(imdsDayTest,idxToTest);
将图像转换为数据类型单
并将图像归一化到范围[- 1,1]。
dayTestImage = im2single(dayTestImage);dayTestImage = (dayTestImage-0.5)/0.5;
创建一个dlarray
对象,该对象向生成器输入数据。如果支持的金宝appGPU可用于计算,则通过将数据转换为a在GPU上执行推理gpuArray
对象。
dlDayImage = dlarray(dayTestImage,“SSCB”);如果canUseGPU dlDayImage = gpuArray(dlDayImage);结束
方法将输入的日图像转换为黄昏域unitPredict
函数。
dlDayToDuskImage = unitPredict(gen,dlDayImage);dayToDuskImage = extractdata(gather(dlDayToDuskImage));
发电机网络的最后一层产生范围为[- 1,1]的激活。为了显示,将激活缩放到范围[0,1]。另外,在显示之前重新调整输入的日期图像。
dayToDuskImage = rescale(dayToDuskImage);dayTestImage = rescale(dayTestImage);
以蒙太奇的方式显示输入的日图像及其翻译的黄昏版本。
图蒙太奇({dayTestImage dayToDuskImage})标题([“日间测试图像”num2str (idxToTest),“带有黄昏图像翻译”])
目标到源图像转换使用UNIT生成器从目标域(黄昏)中的图像生成源域(天)中的图像。
从黄昏测试图像的数据存储中读取图像。
idxToTest = 1;duskTestImage = readimage(imdsDuskTest,idxToTest);
将图像转换为数据类型单
并将图像归一化到范围[- 1,1]。
duskTestImage = im2single(duskTestImage);duskTestImage = (duskTestImage-0.5)/0.5;
创建一个dlarray
对象,该对象向生成器输入数据。如果支持的金宝appGPU可用于计算,则通过将数据转换为a在GPU上执行推理gpuArray
对象。
dlDuskImage = darray (duskTestImage,“SSCB”);如果canUseGPU dlDuskImage = gpuArray(dlDuskImage);结束
控件将输入的黄昏图像转换为白天域unitPredict
函数。
dlDuskToDayImage = unitPredict(gen,dlDuskImage,“OutputType”,“TargetToSource”);duskToDayImage =提取数据(收集(dlDuskToDayImage));
为了显示,将激活缩放到范围[0,1]。此外,在显示之前重新调整输入黄昏图像的大小。
duskToDayImage = rescale(duskToDayImage);duskTestImage = rescale(duskTestImage);
在蒙太奇中显示输入的黄昏图像及其翻译的白天版本。
蒙太奇({duskTestImage duskToDayImage}) title([“测试黄昏图像”num2str (idxToTest),“带翻译的白天图像”])
的modelGradientDisc
Helper函数计算两个鉴别器的梯度和损失。
函数[discgrads, discgrds, discloss, discloss] = modelGradientDisc(gen,...discA,discB,ImageA,ImageB,discLossWeight) [~,fakeA,fakeB,~] = forward(gen,ImageA,ImageB);计算X_A的鉴别器损失outA = forward(discA,ImageA);outfA = forward(discA,fakeA);discalss = disclosight *computeDiscLoss(outA,outfA);更新X的鉴别器参数discAGrads = dlgradient(discalss,discA.Learnables);计算X_B的鉴别器损失outB = forward(discB,ImageB);outfB = forward(discB,fakeB);discloloss = discloight *computeDiscLoss(outB,outfB);更新Y的鉴别器参数discblosses = dlgradient(discBLoss,discB.Learnables);将数据类型从单数组转换为单数组discALoss = extractdata(discALoss);diskloss = extractdata(diskloss);结束
的modelGradientGen
Helper函数计算生成器的梯度和损失。
函数[genGrad,genLoss,images] = modelGradientGen(gen,discA,discB,ImageA,ImageB,lossWeights) [ImageAA,ImageBA,ImageAB,ImageBB] = forward(gen,ImageA,ImageB);隐藏= forward(gen,ImageA,ImageB,“输出”,“encoderSharedBlock”);[~,ImageABA,ImageBAB,~] = forward(gen,ImageBA,ImageAB);cycle_hidden = forward(gen,ImageBA,ImageAB,“输出”,“encoderSharedBlock”);计算不同损失selfReconLoss = computeReconLoss(ImageA,ImageAA) + computeReconLoss(ImageB,ImageBB);hiddenKLLoss = computeKLLoss(隐藏);cycleReconLoss = computeReconLoss(ImageA,ImageABA) + computeReconLoss(ImageB,ImageBAB);cycleHiddenKLLoss = computeKLLoss(cycle_hidden);outA = forward(discA,ImageBA);outB = forward(discB,ImageAB);advLoss = computeAdvLoss(outA) + computeAdvLoss(outB);将发电机的总损耗计算为5的加权和%的损失genTotalLoss =...selfReconLoss * lossWeights。selfReconLossWeight +...hiddenKLLoss * lossWeights。hiddenKLLossWeight +...cycleReconLoss * lossWeights。cycleConsisLossWeight +...cycleHiddenKLLoss * lossWeights。cycleHiddenKLLossWeight +...advLoss * lossWeights.advLossWeight;更新发电机参数genGrad = dlgradient(genTotalLoss,gen.Learnables);将数据类型从单数组转换为单数组genLoss = extractdata(genTotalLoss);images = {ImageA,ImageAB,ImageB,ImageBA};结束
的computeDiscLoss
Helper函数计算鉴别器损失。每个鉴别器损失是两个分量的和:
1向量与鉴别器在真实图像上的预测值之间的平方差,
零向量与鉴别器对生成图像的预测之间的平方差,
函数discLoss = computeDiscLoss(Yreal,Ytranslated) discLoss = mean(((1-Yreal).^2),“所有”) +...意思是(((0-Ytranslated) ^ 2),“所有”);结束
的computeAdvLoss
辅助函数计算发电机的对抗损失。对抗损失是1的向量与翻译图像上的鉴别器预测之间的平方差。
函数advLoss = computeAdvLoss(Ytranslated) advLoss = mean(((Ytranslated-1).^2),“所有”);结束
的computeReconLoss
辅助函数计算发电机的自重构损失和周期一致性损失。自我重建的损失是
输入图像与自重建图像之间的距离。循环一致性损失是
输入图像与其循环重建版本之间的距离。
函数reconLoss = computeReconLoss(Yreal,Yrecon) = mean(abs(Yreal-Yrecon),“所有”);结束
的computeKLLoss
helper函数计算发电机的隐藏KL损失和循环隐藏KL损失。隐藏KL损失是零向量和encoderSharedBlock
自我重建流的激活。周期隐藏KL损失是零向量和encoderSharedBlock
循环重建流的激活。
函数klLoss = computeKLLoss(hidden) klLoss = mean(abs(hidden.^2),“所有”);结束
[1]刘明宇,Thomas Breuel, Jan Kautz,“无监督图像到图像的翻译网络”。在神经信息处理系统研究进展,2017.https://arxiv.org/abs/1703.00848.
[2] Brostow, Gabriel J., Julien Fauqueur和Roberto Cipolla。视频中的语义对象类:一个高清地面真相数据库。模式识别信.第30卷,第2期,2009年,pp 88-97。
变换
|unitGenerator
|unitPredict
|dlarray
(深度学习工具箱)|dlfeval
(深度学习工具箱)|adamupdate
(深度学习工具箱)|minibatchqueue
(深度学习工具箱)|patchGANDiscriminator