主要内容

训练变分自动编码器(VAE)生成图像

这个例子展示了如何训练一个深度学习变分自编码器(VAE)来生成图像。

要生成强烈表示数据集合中的观察值的数据,可以使用变分自编码器。自动编码器是一种模型,通过将输入转换为低维空间(编码步骤)并从低维表示重构输入(解码步骤)来训练复制其输入。

该图说明了重建数字图像的自动编码器的基本结构。

要使用变分自编码器生成新图像,请向解码器输入随机向量。

一个变分自编码器与常规自编码器的不同之处在于,它在潜在空间上施加了一个概率分布,并学习该分布,以便解码器输出的分布与观测数据的分布相匹配。特别地,潜在输出是从编码器学习的分布中随机采样的。

本例使用MNIST数据集[1],其中包含60,000张用于训练的手写数字灰度图像和10,000张用于测试的图像。

加载数据

下载训练和测试MNIST文件http://yann.lecun.com/exdb/mnist/并使用processImagesMNIST函数作为支持文件附加到本示例。金宝app要访问此函数,请将此示例作为活动脚本打开。vae不需要标记数据。

trainImagesFile =“train-images-idx3-ubyte.gz”;testImagesFile =“t10k-images-idx3-ubyte.gz”;XTrain = processImagesMNIST(trainImagesFile);
读取MNIST图像数据…数据集中的图像数量:60000…
XTest = processImagesMNIST(testImagesFile);
读取MNIST图像数据…数据集中的图像数量:10000…

定义网络架构

自动编码器有两部分:编码器和解码器。编码器接受图像输入,并使用一系列下采样操作(如卷积)输出潜在向量表示(编码)。类似地,解码器将潜在向量表示作为输入,并使用一系列上采样操作(如转置卷积)重建输入。

为了对输入进行采样,该示例使用了自定义层samplingLayer.要访问此层,请将此示例作为活动脚本打开。该层以平均向量作为输入 μ 与对数方差向量连接 日志 σ 2 并从 N μ σ 2 .该层使用对数方差使训练过程在数值上更加稳定。

定义编码器网络架构

定义以下编码器网络,将28 × 28 × 1的图像采样到16 × 1的潜在向量。

  • 对于图像输入,指定输入大小与训练数据匹配的图像输入层。不要规范化数据。

  • 要对输入进行低采样,请指定两个2-D卷积和ReLU层块。

  • 要输出一个均值和对数方差的级联向量,指定一个完全连接的层,其输出通道数量是潜在通道数量的两倍。

  • 要对统计数据指定的编码进行采样,请使用自定义层包含一个采样层samplingLayer.要访问此层,请将此示例作为活动脚本打开。

numLatentChannels = 16;imageSize = [28 28 1];layersE = [imageInputLayer(imageSize,归一化=“没有”32岁的)convolution2dLayer(3填充=“相同”,Stride=2) relullayer卷积2dlayer(3,64,填充=“相同”,Stride=2) reluLayer fullyConnectedLayer(2*numLatentChannels) samplingLayer];

定义解码器网络体系结构

定义以下编码器网络,从16 × 1潜在向量重建28 × 28 × 1图像。

  • 对于特征向量输入,指定输入大小与潜在通道数量匹配的特征输入层。

  • 使用自定义层将潜在输入投影并重塑为7 × 7 × 64数组projectAndReshapeLayer,作为支持文件附在本例中。金宝app要访问此层,请将示例作为活动脚本打开。指定投影大小为[7,7,64]

  • 要对输入进行上采样,请指定转置卷积和ReLU层的两个块。

  • 要输出一张大小为28 × 28 × 1的图像,需要包含一个带有3 × 3滤波器的转置卷积层。

  • 要将输出映射到范围[0,1]中的值,需要包含一个sigmoid激活层。

projectionSize = [7 7 64];numInputChannels = size(imageSize,1);layersD = [featureInputLayer(numLatentChannels) projectAndReshapeLayer(projectionSize,numLatentChannels)转置conv2dlayer(3,64,裁剪=“相同”,Stride=2) reluLayer转置conv2dlayer(3,32,裁剪=“相同”,Stride=2) reluLayer转置conv2dlayer (3,numInputChannels,裁剪=“相同”) sigmoidLayer);

要使用自定义训练循环训练两个网络并启用自动区分,请将层数组转换为dlnetwork对象。

netE = dlnetwork(layersE);netD = dlnetwork(layersD);

定义模型损失函数

定义一个函数,该函数返回模型损失以及相对于可学习参数的损失梯度。

modelLoss函数中定义的模型损失函数部分,将编码器和解码器网络和一小批输入数据作为输入,并返回关于网络中可学习参数的损失和损失的梯度。为了计算损失,函数使用ELBOloss函数中定义的ELBO损失函数部分,将编码器输出的均值和对数方差作为输入,并用它们来计算证据下限损失。

指定培训项目

训练30个epoch,迷你批大小为128,学习率为0.001。

numEpochs = 30;miniBatchSize = 128;learnRate = 1e-3;

火车模型

使用自定义训练循环训练模型。

创建一个minibatchqueue对象,该对象在训练期间处理和管理小批量图像。对于每个小批量:

  • 将训练数据转换为数组数据存储。指定在第4维上迭代。

  • 使用自定义小批量预处理功能preprocessMiniBatch(在本例结束时定义)将多个观察结果连接到单个小批处理中。

  • 用尺寸标签格式化图像数据“SSCB”(空间,空间,通道,批次)。默认情况下,minibatchqueue对象将数据转换为dlarray具有基础类型的对象

  • 如果有GPU,可以在GPU上进行训练。默认情况下,minibatchqueue对象将每个输出转换为gpuArray如果GPU可用。使用GPU需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU支金宝app持版本(并行计算工具箱)

  • 为了确保所有小批量都是相同的大小,丢弃任何部分小批量。

dsTrain = arrayDatastore(XTrain,IterationDimension=4);numOutputs = 1;mbq = minibatchqueue(dsTrain,numOutputs,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatch,...MiniBatchFormat =“SSCB”...PartialMiniBatch =“丢弃”);

初始化培训进度图。

图C = colororder;lineLossTrain = animatedline(Color=C(2,:));Ylim ([0 inf]) xlabel(“迭代”) ylabel (“损失”网格)

初始化亚当解算器的参数。

trailingAvgE = [];trailingAvgSqE = [];trailingAvgD = [];trailingAvgSqD = [];

使用自定义训练循环训练网络。对于每个纪元,洗牌数据并在小批量数据上循环。对于每个小批量:

  • 评估模型损失和梯度使用dlfeval而且modelLoss功能。

  • 方法更新编码器和解码器网络参数adamupdate函数。

  • 显示培训进度。

迭代= 0;开始= tic;%遍历epoch。epoch = 1:numEpochs% Shuffle数据。洗牌(兆贝可);在小批上循环。Hasdata (mbq)迭代=迭代+ 1;读取小批数据。X = next(mbq);评估损失和梯度。[loss,gradientsE,gradientsD] = dlfeval(@modelLoss,netE,netD,X);更新可学习参数。[netE,trailingAvgE,trailingAvgSqE] = adamupdate(netE, trailingAvgSqE)...gradientsE、trailingAvgE trailingAvgSqE,迭代,learnRate);[netD, trailingAvgD, trailingAvgSqD] = adamupdate(netD,...gradientsD、trailingAvgD trailingAvgSqD,迭代,learnRate);%显示培训进度。D = duration(0,0,toc(start),Format=“hh: mm: ss”);Loss = double(extractdata(Loss));addpoints (lineLossTrain、迭代、失去)标题(”时代:“+ epoch +,消失:"+字符串(D))现在绘制结束结束

测试网络

用测试集测试训练好的自动编码器。使用与训练数据相同的步骤创建数据的迷你批队列,但不要丢弃任何部分的迷你批数据。

dsTest = arrayDatastore(XTest,IterationDimension=4);numOutputs = 1;mbqTest = minibatchqueue(dsTest,numOutputs,...MiniBatchSize = MiniBatchSize,...MiniBatchFcn = @preprocessMiniBatch,...MiniBatchFormat =“SSCB”);

使用训练过的自动编码器进行预测modelPredictions函数。

YTest = modelforecasts (netE,netD,mbqTest);

通过取测试图像和重建图像的均方误差来可视化重建误差,并在直方图中可视化它们。

err = mean((XTest-YTest)。^2,[1 2 3]);图直方图(err) xlabel(“错误”) ylabel (“频率”)标题(“测试数据”

生成新图像

通过将随机采样的图像编码通过解码器生成一批新的图像。

numImages = 64;ZNew = randn(numLatentChannels,numImages);ZNew = dlarray(ZNew,“CB”);YNew = predict(netD,ZNew);YNew = extractdata(YNew);

在图形中显示生成的图像。

图I = imtile(YNew);imshow (I)标题(“生成的图像”

在这里,VAE学习了一个强大的特征表示,允许它生成与训练数据相似的图像。

辅助函数

模型损失函数

modelLoss函数将编码器和解码器网络和一小批输入数据作为输入,并返回损失和损失相对于网络中可学习参数的梯度。该函数将训练图像通过编码器传递,并将结果图像编码通过解码器传递。为了计算损失,函数使用elboLoss函数使用编码器的采样层输出的均值和对数-方差统计信息。

函数[loss,gradientsE,gradientsD] = modelLoss(netE,netD,X)%通过编码器转发。[Z,mu,logSigmaSq] = forward(netE,X);%通过解码器转发。Y = forward(netD,Z);计算损失和梯度。loss = elboLoss(Y,X,mu,logSigmaSq);[gradientsE,gradientsD] = dlgradient(loss,net . learnables,net . learnables);结束

ELBO损失函数

ELBOloss函数取编码器输出的均值和对数方差,并用它们来计算证据下限(ELBO)损失。ELBO损失由两个独立损失项的和给出:

ELBO 损失 重建 损失 + 吉隆坡 损失

重建的损失通过使用均方误差(MSE)来测量解码器输出与原始输入的接近程度:

重建 损失 均方误差 重建 图像 输入 图像

KL损失,或Kullback-Leibler散度,测量两个概率分布之间的差异。在这种情况下,最小化KL损失意味着确保学习的均值和方差尽可能接近目标(正态)分布的均值和方差。为了一个潜在的尺寸 K ,则KL损失为

吉隆坡 损失 - 0 5 1 K 1 + 日志 σ 2 - μ 2 - σ 2

加入KL损失项的实际效果是将由于重构损失而学习到的聚类紧紧地包裹在潜在空间的中心周围,形成一个连续的采样空间。

函数损失= elboLoss(Y,T,mu,logSigmaSq)%重建损失。reconstructionLoss = mse(Y,T);% KL散度。KL = -0.5 * sum(1 + logSigmaSq - mu。^2 - exp(logSigmaSq),1);KL = mean(KL);%综合损失。损失=重建损失+ KL;结束

模型预测函数

modelPredictions函数将编码器和解码器的网络对象和作为输入minibatchqueue输入数据的兆贝可并通过迭代所有数据来计算模型的预测minibatchqueue对象。

函数Y = modelforecasts (netE,netD,mbq) Y = [];在小批上循环。hasdata(mbq) X = next(mbq);%通过编码器转发。Z = predict(netE,X);%通过解码器转发。XGenerated = predict(netD,Z);提取并连接预测。Y = cat(4,Y,extractdata(XGenerated));结束结束

迷你批量预处理功能

preprocessMiniBatch函数通过沿着第四维连接输入来预处理一小批预测器。

函数X = preprocessMiniBatch(dataX)%连接。X = cat(4,dataX{:});结束

参考书目

  1. LeCun, Y., C. Cortes,和C. J. C. Burges。“MNIST手写数字数据库。”http://yann.lecun.com/exdb/mnist/

另请参阅

||||||

相关的话题