主要内容

训练变分自动编码器(VAE)生成图像

这个例子展示了如何在MATLAB中创建一个变分自编码器(VAE)来生成数字图像。VAE生成MNIST数据集风格的手绘数字。

vae与常规的自编码器不同之处在于,它们不使用编码-解码过程来重建输入。相反,他们在潜在空间上施加一个概率分布,并学习这个分布,以便解码器输出的分布与观测数据的分布相匹配。然后,他们从这个分布中取样以生成新的数据。

在本例中,构建一个VAE网络,在MNIST数据集上训练它,并生成与数据集中的图像非常相似的新图像。

加载数据

下载MNIST文件http://yann.lecun.com/exdb/mnist/并将MNIST数据集加载到工作区[1]中。调用processImagesMNIST而且processLabelsMNIST附加到此示例的helper函数用于将文件中的数据加载到MATLAB数组中。

因为VAE将重构的数字与输入进行比较,而不是与分类标签进行比较,所以您不需要在MNIST数据集中使用训练标签。

trainImagesFile =“train-images-idx3-ubyte.gz”;testImagesFile =“t10k-images-idx3-ubyte.gz”;testLabelsFile =“t10k-labels-idx1-ubyte.gz”;XTrain = processImagesMNIST(trainImagesFile);
读取MNIST图像数据…数据集中的图像数量:60000…
numTrainImages = size(XTrain,4);XTest = processImagesMNIST(testImagesFile);
读取MNIST图像数据…数据集中的图像数量:10000…
YTest = processLabelsMNIST(testLabelsFile);
读取MNIST标签数据…数据集中的标签数量:10000…

构建网络

自动编码器有两部分:编码器和解码器。编码器接受图像输入和输出压缩表示(编码),这是一个大小的向量latentDim,在本例中等于20。解码器将压缩后的图像进行解码,然后重新生成原始图像。

为了使计算在数值上更加稳定,通过使网络从方差的对数中学习,将可能值的范围从[0,1]增加到[-inf, 0]。定义两个大小相同的向量latent_dim一个为手段 μ 1表示方差对数 日志 σ 2 .然后用这两个向量来创建抽样的分布。

使用二维卷积,然后是一个完全连接的层,从28 × 28 × 1 MNIST图像向下采样到潜在空间中的编码。然后,使用转置二维卷积将1 × 1 × 20编码放大为28 × 28 × 1图像。

latentDim = 20;imageSize = [28 28 1];encoderLG = layerGraph([imageInputLayer(imageSize,“名字”“input_encoder”“归一化”“没有”)卷积2dlayer (3, 32,“填充”“相同”“步”2,“名字”“conv1”) reluLayer (“名字”“relu1”)卷积2dlayer (3, 64,“填充”“相同”“步”2,“名字”“conv2”) reluLayer (“名字”“relu2”(2 * latentDim,“名字”“fc_encoder”)));decoderLG = layerGraph([imageInputLayer([1 1 latentDim],“名字”“我”“归一化”“没有”转置conv2dlayer (7, 64,“种植”“相同”“步”7“名字”“transpose1”) reluLayer (“名字”“relu1”转置conv2dlayer (3, 64,“种植”“相同”“步”2,“名字”“transpose2”) reluLayer (“名字”“relu2”转置conv2dlayer (3, 32,“种植”“相同”“步”2,“名字”“transpose3”) reluLayer (“名字”“relu3”转置conv2dlayer (3,1,“种植”“相同”“名字”“transpose4”)));

要使用自定义训练循环训练两个网络并启用自动区分,请将层图转换为dlnetwork对象。

encoderNet = dlnetwork(encoderLG);decoderNet = dlnetwork(decoderLG);

定义模型梯度函数

辅助函数modelGradients接收编码器和解码器dlnetwork对象和一小批输入数据X,并返回损失相对于网络中可学习参数的梯度。这个helper函数在本例的最后定义。

函数分两步执行这个过程:采样和损耗.采样步骤对平均值和方差向量进行采样,以创建传递给解码器网络的最终编码。但是,由于不可能通过随机抽样操作进行反向传播,因此必须使用reparameterization技巧.这个技巧将随机抽样操作转移到一个辅助变量上 ε ,然后按均值平移 μ 然后按标准差缩放 σ .这个想法是从 N μ σ 2 和抽样一样吗 μ + ε σ ,在那里 ε N 0 1 .下图生动地描述了这个想法。

损耗步骤将采样步骤生成的编码通过解码器网络传递,并确定损耗,然后用于计算梯度。vae中的损失,也称为证据下限(ELBO)损失,定义为两个独立损失项的和:

ELBO 损失 重建 损失 + 吉隆坡 损失

重建的损失通过使用均方误差(MSE)来测量解码器输出与原始输入的接近程度:

重建 损失 均方误差 译码器 输出 原始 图像

KL损失,或Kullback-Leibler散度,测量两个概率分布之间的差异。在这种情况下,最小化KL损失意味着确保学习的均值和方差尽可能接近目标(正态)分布的均值和方差。为了一个潜在的尺寸 n ,则KL损失为

吉隆坡 损失 - 0 5 1 n 1 + 日志 σ 2 - μ 2 - σ 2

加入KL损失项的实际效果是将由于重构损失而学习到的聚类紧紧地包裹在潜在空间的中心周围,形成一个连续的采样空间。

指定培训项目

在可用的GPU上训练(需要并行计算工具箱™)。

executionEnvironment =“汽车”

设置网络的培训选项。当使用Adam优化器时,您需要初始化每个网络的平均梯度和空数组的平均梯度平方衰减率

numEpochs = 50;miniBatchSize = 512;Lr = 1e-3;numIterations = floor(numTrainImages/miniBatchSize);迭代= 0;avgGradientsEncoder = [];avgGradientsSquaredEncoder = [];avgGradientsDecoder = [];avgGradientsSquaredDecoder = [];

火车模型

使用自定义训练循环训练模型。

对于一个纪元中的每一次迭代:

  • 从训练集中获取下一个小批量。

  • 将迷你批处理转换为dlarray对象,确保指定了尺寸标签“SSCB”(空间,空间,通道,批次)。

  • 对于GPU训练,转换dlarray到一个gpuArray对象。

  • 方法评估模型梯度dlfeval而且modelGradients功能。

  • 方法更新网络可学习内容和两个网络的平均梯度adamupdate函数。

在每个epoch结束时,将测试集图像通过自编码器,并显示该epoch的损失和训练时间。

epoch = 1:numEpochs tic;i = 1:numIterations迭代=迭代+ 1;idx = (i-1)*miniBatchSize+1:i*miniBatchSize;XBatch = XTrain(:,:,:,idx);XBatch = dlarray(single(XBatch),“SSCB”);如果(executionEnvironment = =“汽车”&& canUseGPU) || executionEnvironment ==“图形”XBatch = gpuArray(XBatch);结束[infGrad, genGrad] = dlfeval(...@modelGradients, encoderNet, decoderNet, XBatch);[decoderNet。可学的,avgGradientsDecoder, avgGradientsSquaredDecoder] =...adamupdate (decoderNet。可学的,...genGrad, avgGradientsDecoder, avgGradientsSquaredDecoder,迭代,lr);[encoderNet。可学的,avgGradientsEncoder, avgGradientsSquaredEncoder] =...adamupdate (encoderNet。可学的,...infGrad, avgGradientsEncoder, avgGradientsSquaredEncoder,迭代,lr);结束elapsedTime = toc;[z, zMean, zLogvar] = sampling(encoderNet, XTest);xPred = sigmoid(forward(decoderNet, z));elbo = ELBOloss(XTest, xPred, zMean, zLogvar);disp (“时代:”+时代+"测试ELBO损失= "+收集(extractdata (elbo)) +...”。epoch所用时间= "elapsedTime +“s”结束
Epoch: 1 Test ELBO loss = 28.0145。epoch: 2测试ELBO损失= 24.8995。epoch: 3测试ELBO损失= 23.2756。epoch: 4测试ELBO损失= 21.151。epoch: 5测试ELBO损失= 20.5335。epoch: 6测试ELBO损失= 20.232。epoch: 7测试ELBO损失= 19.9988。epoch: 8测试ELBO损失= 19.8955。epoch: 9测试ELBO损失= 19.7991。epoch: 10测试ELBO损失= 19.6773。 Time taken for epoch = 8.4269s Epoch : 11 Test ELBO loss = 19.5181. Time taken for epoch = 8.5771s Epoch : 12 Test ELBO loss = 19.4532. Time taken for epoch = 8.4227s Epoch : 13 Test ELBO loss = 19.3771. Time taken for epoch = 8.5807s Epoch : 14 Test ELBO loss = 19.2893. Time taken for epoch = 8.574s Epoch : 15 Test ELBO loss = 19.1641. Time taken for epoch = 8.6434s Epoch : 16 Test ELBO loss = 19.2175. Time taken for epoch = 8.8641s Epoch : 17 Test ELBO loss = 19.158. Time taken for epoch = 9.1083s Epoch : 18 Test ELBO loss = 19.085. Time taken for epoch = 8.6674s Epoch : 19 Test ELBO loss = 19.1169. Time taken for epoch = 8.6357s Epoch : 20 Test ELBO loss = 19.0791. Time taken for epoch = 8.5512s Epoch : 21 Test ELBO loss = 19.0395. Time taken for epoch = 8.4674s Epoch : 22 Test ELBO loss = 18.9556. Time taken for epoch = 8.3943s Epoch : 23 Test ELBO loss = 18.9469. Time taken for epoch = 10.2924s Epoch : 24 Test ELBO loss = 18.924. Time taken for epoch = 9.8302s Epoch : 25 Test ELBO loss = 18.9124. Time taken for epoch = 9.9603s Epoch : 26 Test ELBO loss = 18.9595. Time taken for epoch = 10.9887s Epoch : 27 Test ELBO loss = 18.9256. Time taken for epoch = 10.1402s Epoch : 28 Test ELBO loss = 18.8708. Time taken for epoch = 9.9109s Epoch : 29 Test ELBO loss = 18.8602. Time taken for epoch = 10.3075s Epoch : 30 Test ELBO loss = 18.8563. Time taken for epoch = 10.474s Epoch : 31 Test ELBO loss = 18.8127. Time taken for epoch = 9.8779s Epoch : 32 Test ELBO loss = 18.7989. Time taken for epoch = 9.6963s Epoch : 33 Test ELBO loss = 18.8. Time taken for epoch = 9.8848s Epoch : 34 Test ELBO loss = 18.8095. Time taken for epoch = 10.3168s Epoch : 35 Test ELBO loss = 18.7601. Time taken for epoch = 10.8058s Epoch : 36 Test ELBO loss = 18.7469. Time taken for epoch = 9.9365s Epoch : 37 Test ELBO loss = 18.7049. Time taken for epoch = 10.0343s Epoch : 38 Test ELBO loss = 18.7084. Time taken for epoch = 10.3214s Epoch : 39 Test ELBO loss = 18.6858. Time taken for epoch = 10.3985s Epoch : 40 Test ELBO loss = 18.7284. Time taken for epoch = 10.9685s Epoch : 41 Test ELBO loss = 18.6574. Time taken for epoch = 10.5241s Epoch : 42 Test ELBO loss = 18.6388. Time taken for epoch = 10.2392s Epoch : 43 Test ELBO loss = 18.7133. Time taken for epoch = 9.8177s Epoch : 44 Test ELBO loss = 18.6846. Time taken for epoch = 9.6858s Epoch : 45 Test ELBO loss = 18.6001. Time taken for epoch = 9.5588s Epoch : 46 Test ELBO loss = 18.5897. Time taken for epoch = 10.4554s Epoch : 47 Test ELBO loss = 18.6184. Time taken for epoch = 10.0317s Epoch : 48 Test ELBO loss = 18.6389. Time taken for epoch = 10.311s Epoch : 49 Test ELBO loss = 18.5918. Time taken for epoch = 10.4506s Epoch : 50 Test ELBO loss = 18.5081. Time taken for epoch = 9.9671s

可视化的结果

要可视化和解释结果,请使用helper可视化功能.这些helper函数在本例的最后定义。

VisualizeReconstruction函数显示从每个类中随机选择的数字,并伴随着它经过自编码器后的重建。

VisualizeLatentSpace函数取测试图像通过编码器网络后生成的均值和方差编码(每个维度为20),并对包含每个图像编码的矩阵进行主成分分析(PCA)。然后,您可以可视化由两个第一主成分表征的两个维度的均值和方差定义的潜在空间。

生成函数初始化从正态分布采样的新编码,并输出当这些编码通过解码器网络时生成的图像。

visualizerconstruct (XTest, YTest, encoderNet, decoderNet)

visualizeLatentSpace(XTest, YTest, encoderNet)

生成(decoderNet latentDim)

下一个步骤

变分自动编码器只是用于执行生成任务的许多可用模型之一。它们在图像较小且具有明确定义特征的数据集(如MNIST)上工作良好。对于具有较大图像的更复杂的数据集,生成对抗网络(GANs)往往表现更好,生成的图像噪音更少。有关显示如何实现GANs以生成64 × 64 RGB图像的示例,请参见训练生成对抗网络(GAN)

参考文献

  1. LeCun, Y., C. Cortes,和C. J. C. Burges。“MNIST手写数字数据库。”http://yann.lecun.com/exdb/mnist/

辅助函数

模型梯度函数

modelGradients函数接受编码器和解码器dlnetwork对象和一小批输入数据X,并返回损失相对于网络中可学习参数的梯度。该函数执行三个操作:

  1. 方法获取编码抽样在通过编码器网络的小批图像上执行函数。

  2. 通过将编码通过解码器网络传递并调用ELBOloss函数。

  3. 计算损失的梯度相对于两个网络的可学习参数调用dlgradient函数。

函数[infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x) [z, zMean, zLogvar] = sampling(encoderNet, x);xPred = sigmoid(forward(decoderNet, z));损失= ELBOloss(x, xPred, zMean, zLogvar);[genGrad, infGrad] = dlgradient(loss, decoderNet. net .)可学的,...encoderNet.Learnables);结束

采样和损失函数

抽样函数从输入图像中获取编码。最初,它通过编码器网络传递一小批图像,并分割大小的输出(2 * latentDim) * miniBatchSize变成均值矩阵和方差矩阵,每个矩阵都有大小latentDim * batchSize.然后,它使用这些矩阵来实现重新参数化技巧并计算编码。最后,它将此编码转换为dlarray对象的SSCB格式。

函数[zsampling, zMean, zLogvar] = sampling(encoderNet, x) compressed = forward(encoderNet, x);D = size(压缩后,1)/2;zMean =压缩(1:d,:);zLogvar = compressed(1+d:end,:);sz = size(zMean);= randn(sz);σ = exp(。5* zLogvar); z = epsilon .* sigma + zMean; z = reshape(z, [1,1,sz]); zSampled = dlarray(z,“SSCB”);结束

ELBOloss函数返回的均值和方差的编码抽样函数,并用它们来计算ELBO损失。

函数elbo = ELBOloss(x, xPred, zMean, zLogvar) squares = 0.5*(xPred-x).^2;reconstructionLoss = sum(平方,[1,2,3]);Kl = -。5* sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1); elbo = mean(reconstructionLoss + KL);结束

可视化功能

VisualizeReconstruction函数为MNIST数据集的每个数字随机选择两张图像,将它们通过VAE,并与原始输入并排绘制重建图。属性中包含的信息dlarray对象,则需要首先使用extractdata而且收集功能。

函数visualizerconstruct (XTest,YTest, encoderNet, decoderNet) f = figure;图(f)标题(“地面真实图像与重建图像的示例”I = 1:2c=0:9 idx = iRandomIdxOfClass(YTest,c);X = XTest(:,:,:,idx);[z, ~, ~] = sampling(encoderNet, X);XPred = sigmoid(forward(decoderNet, z));X = gather(extractdata(X));XPred = gather(extractdata(XPred));比较= [X, ones(size(X,1),1), XPred];次要情节(4、5、(张)* 10 + c + 1), imshow(比较,[]),结束结束结束函数idx = iRandomIdxOfClass(T,c) idx = T == categorical(c);Idx = find(Idx);Idx = Idx (randi(numel(Idx),1));结束

VisualizeLatentSpace函数可视化了由构成编码器网络输出的均值和方差矩阵定义的潜在空间,并定位由每个数字的潜在空间表示形成的簇。

函数首先从函数中提取均值和方差矩阵dlarray对象。因为用通道/批处理维度(C和B)转置矩阵是不可能的,所以函数调用stripdims在矩阵转置之前。然后,对两个矩阵进行主成分分析。为了在二维空间中可视化潜在空间,该函数保留了前两个主分量,并将它们相互绘制。最后,该函数将数字类着色,以便您可以观察聚类。

函数visualizeLatentSpace(XTest, YTest, encoderNet) [~, zMean, zLogvar] = sampling(encoderNet, XTest);zMean = stripdims(zMean)';zMean = gather(extractdata(zMean));zLogvar = stripdims(zLogvar)';zLogvar = gather(extractdata(zLogvar));[~,scoreMean] = pca(zMean);[~,scoreLogvar] = pca(zLogvar);C = parula(10);F1 =数字;图(f1)标题(“潜在的空间”) ah = subplot(1,2,1);散射(scoreMean (:, 2), scoreMean (: 1), [], c(双重(欧美):));啊。YDir =“反向”;轴平等的包含(“Z_m_u(2)”) ylabel (“Z_m_u(1)”) cb = colorbar;cb。蜱= 0:(1/9):1;cb。TickLabels = string(0:9);Ah = subplot(1,2,2);散射(scoreLogvar (:, 2), scoreLogvar (: 1), [], c(双重(欧美):));啊。YDir =“反向”;包含(“Z_v_a_r(2)”) ylabel (“Z_v_a_r(1)”) cb = colorbar;cb。蜱= 0:(1/9):1;cb。TickLabels = string(0:9);轴平等的结束

生成功能测试了VAE的生成能力。它初始化一个dlarray对象,其中包含25个随机生成的编码,将它们通过解码器网络传递,并绘制输出。

函数生成(decoderNet, latentDim) randomNoise = dlarray(randn(1,1,latentDim,25),“SSCB”);generatedImage = sigmoid(predict(decoderNet, randomNoise));generatedImage = extractdata(generatedImage);F3 =图;图(f3) imshow (imtile (generatedImage“ThumbnailSize”,[100,100]))标题(“生成的数字样本”) drawnow结束

另请参阅

||||||

相关的话题