这个示例演示如何为运行在Intel®cpu上的训练有素的变分自动编码器(VAE)网络生成MEX函数。这个例子说明了:
这个例子使用了一个预先训练的解码器网络训练变分自动编码器(VAE)生成图像示例来自深度学习工具箱™。有关更多信息,请参见训练变分自动编码器(VAE)生成图像(深度学习工具箱).
支持英特尔高级矢量扩展2(英特尔AVX2)金宝app指令的英特尔处理器。
对于非mex构建,如静态、动态库或可执行文件,本例有以下附加要求。
用于深度神经网络的英特尔数学内核库(MKL-DNN)
有关编译器和库支持的版本的信息,请参见金宝app深度学习与MATLAB编码器的先决条件
这个例子不支持在MATLAB®联机。金宝app
自动编码器有两部分:编码器和解码器。编码器接收图像输入并输出压缩表示(编码),这是一个大小向量latent_dim
,在本例中等于20。解码器获取压缩的表示,解码它,并重新生成原始图像。
vie与常规自动编码器的不同之处在于,它们不使用编码-解码过程来重构输入。相反,他们在潜在空间上施加一个概率分布,并学习这个分布,以便解码器输出的分布与观测数据的分布相匹配。然后,他们从这个分布中抽取样本,生成新的数据。
这个例子使用在训练变分自动编码器(VAE)生成图像的例子。要自己训练网络,看训练变分自动编码器(VAE)生成图像(深度学习工具箱).
的generateVAE
入口点函数加载dlnetwork
将trainedDecoderVAENet mat -文件中的对象转换为持久变量,并在后续的预测调用中重用该持久对象。它初始化一个dlarray
对象包含25个随机生成的编码,将它们通过解码器网络,并从深度学习数组对象中提取生成图像的数值数据。
类型(“generateVAE.m”)
function generatedImage = generateVAE(decoderNetFileName,latentDim,Environment) %#codegen % Copyright 2020-2021 The MathWorks, Inc. persistent decoderNet;如果是空的(decoderNet) decoderNet = coder.loadDeepLearningNetwork(decoderNetFileName);end %生成随机噪声randomNoise = dlarray(randn(1,1,latentDim,25,'single'),'SSCB');如果code .target('MATLAB') && strcmp(Environment,'gpu') randomNoise = gpuArray(randomNoise);end %从噪声生成新图像generatedImage = sigmoid(predict(decoderNet,randomNoise));%从dlarray中提取数值数据generatedImage = extractdata(generatedImage);结束
评估generateVAE
入口点函数生成数字图像并绘制结果。
latentDim = 20;matfile =“trainedDecoderVAENet.mat”;Env ='';图()标题(数字生成样本- MATLAB) generatedImageML = generateVAE(matfile, latentDim, Env);imshow (imtile (generatedImageML,“ThumbnailSize”, [100100]))
的MEX函数generateVAE
入口点函数,为MEX目标创建代码配置对象,并将目标语言设置为c++。使用编码器。DeepLearningConfig
函数创建一个MKL-DNN深度学习配置对象,并将其分配给DeepLearningConfig
属性。
cfg = coder.config (墨西哥人的);cfg。TargetLang =“c++”;cfg。DeepLearningConfig =编码器。DeepLearningConfig (“mkldnn”);args = {coder.Constant(matfile), coder.Constant(latentDim), coder.Constant(Env)};codegen配置cfgarg游戏arg游戏generateVAE报告
代码生成成功:查看报告
调用生成的MEX并显示结果。
图()标题(“使用MKL-DNN生成数字样本”) generatedImage = generateVAE_mex(matfile, latentDim, Env);imshow (imtile (generatedImage,“ThumbnailSize”, [100100]))
codegen
|coder.config
|编码器。DeepLearningConfig
|dlarray
(深度学习工具箱)|dlnetwork
(深度学习工具箱)