主要内容

使用最大和最小激活图像可视化图像分类

这个例子展示了如何使用数据集来找出是什么激活了深层神经网络的通道。这使您能够理解神经网络是如何工作的,并使用训练数据集诊断潜在问题。

这个例子介绍了许多简单的可视化技术,使用GoogLeNet在食物数据集上学习的转换。通过查看最大限度或最小限度地激活分类器的图像,您可以发现神经网络分类错误的原因。

加载和预处理数据

将图像加载为图像数据存储。这个小型数据集总共包含978个,其中9级食物。

将这些数据分解为训练、验证和测试集,为使用GoogLeNet进行迁移学习做准备。显示从数据集中选择的图像。

rng默认的datadir = fullfile(tempdir,“食品数据集”);URL =.“//www.tatmou.com/金宝appsupportfiles/nnet/data/ExampleFoodImageDataset.zip”如果~存在(dataDir“dir”mkdir (dataDir);结尾downloadExampleFoodImagesData (url, dataDir);
下载MathWorks示例食品图像数据集......这可能需要几分钟的时间下载......下载完成...解压缩文件...解压缩完成......完成。
imd = imageDatastore (dataDir,...“IncludeSubfolders”,真的,“LabelSource”“foldernames”);[imdsTrain, imdsValidation imdsTest] = splitEachLabel (imd, 0.6, 0.2);rnd = randperm(元素个数(imds.Files), 9);为了i = 1:numel(rnd) subplot(3,3,i) imshow(imread(imds.Files{rnd(i)})) label = imds.Labels(rnd(i));标题(标签,“翻译”“没有”结尾

食品图像分类训练网络

使用预先训练过的GoogLeNet网络,并再次训练它来分类9种食物。如果你没有深度学习工具箱™模型对于Googlenet网络金宝app安装支持包,软件提供了一个下载链接。

为了尝试一个不同的预训练网络,在MATLAB®中打开这个例子,选择一个不同的网络,例如squeezenet,一个甚至更快的网络googlenet.有关所有可用网络的列表,请参阅预先训练的深度神经网络

网= googlenet;

第一个元素网络的属性是图像输入层。该层需要输入尺寸224-×224-3的输入图像,其中3是颜色信道的数量。

InputSize = Net.Layers(1).InputSize;

网络架构

网络提取图像的卷积层的图像特征是最后一种学习层和最终分类层用于对输入图像进行分类。这两层,“loss3-classifier”'输出'在GoogLeNet中,包含关于如何将网络提取的特征组合成类别概率、损失值和预测标签的信息。为了训练一个预先训练的网络来分类新的图像,将这两层替换为适应新数据集的新层。

从训练的网络中提取层图。

lgraph = layerGraph(净);

在大多数网络中,具有可学习权值的最后一层是完全连接层。将这个完全连接层替换为一个新的完全连接层,其输出数量等于新数据集中的类数量(在本例中为9)。

numClasses =元素个数(类别(imdsTrain.Labels));newfclayer = fullyConnectedLayer (numClasses,...'姓名''new_fc'...“WeightLearnRateFactor”10...“BiasLearnRateFactor”10);lgraph = replaceLayer (lgraph net.Layers (end-2) . name, newfclayer);

分类层指定网络的输出类别。将分类层替换为一个没有类标签的新层。trainNetwork在训练时自动设置层的输出类。

newclasslayer = classificationLayer ('姓名'“new_classoutput”);lgraph = replaceLayer (lgraph net.Layers(结束). name, newclasslayer);

火车网络

网络需要大小为224 × 224 × 3的输入图像,但图像数据存储中的图像大小不同。使用扩充图像数据存储来自动调整训练图像的大小。指定要对训练图像执行的附加增强操作:沿着垂直轴随机翻转训练图像,随机将其平移至30像素,并将其水平和垂直缩放至10%。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。

PIXELRANGE = [-30 30];scalerange = [0.9 1.1];imageaugmenter = imagedataAugmenter(...“RandXReflection”,真的,...“RandXTranslation”pixelRange,...“RandYTranslation”pixelRange,...'randxscale'scaleRange,...“RandYScale”, scaleRange);augimdsTrain = augmentedImageDatastore (inputSize (1:2), imdsTrain,...'dataaugmentation',imageaugmender);

若要自动调整验证图像的大小而不执行进一步的数据扩展,请使用扩展后的图像数据存储,而不指定任何额外的预处理操作。

augimdsValidation = augmentedImageDatastore (inputSize (1:2), imdsValidation);

指定培训选项。集InitialLearnRate降低到一个很小的值,以减缓尚未冻结的转移层的学习。在前面的步骤中,您增加了最后一个可学习层的学习率因子,以加快新的最后一层的学习率。这种学习速率设置的组合导致在新层中学习快,在中间层中学习慢,而在较早的冻结层中没有学习。

指定训练的时代数量。在进行转移学习时,您不需要为尽可能多的时期训练。epoch是整个培训数据集的完整培训周期。指定迷你批量大小和验证数据。每个时代计算一次验证精度。

miniBatchSize = 10;valFrequency =地板(元素个数(augimdsTrain.Files) / miniBatchSize);选择= trainingOptions (“个”...“MiniBatchSize”miniBatchSize,...“MaxEpochs”,4,...“InitialLearnRate”3的军医,...“洗牌”“every-epoch”...“ValidationData”augimdsValidation,...“ValidationFrequency”valFrequency,...“详细”假的,...'plots'“训练进步”);

使用训练数据对网络进行训练。默认情况下,trainNetwork如果GPU可用,则使用GPU。这需要并行计算工具箱™和支持的GPU设备。金宝app有关支持的设备的信息,请参见金宝appGPU支金宝app持情况(并行计算工具箱).否则,trainNetwork使用一个CPU。属性也可以指定执行环境'executionenvironment'的名称-值对参数trainingOptions.由于这个数据集很小,训练速度很快。如果您运行这个示例并亲自训练网络,您将得到不同的结果和由于训练过程中涉及的随机性造成的错误分类。

网= trainNetwork (augimdsTrain、lgraph选项);

测试图像进行分类

利用微调网络对测试图像进行分类,并计算分类精度。

augimdsTest = augmentedImageDatastore (inputSize (1:2), imdsTest);[predictedClasses, predictedScores] =(网络,augimdsTest)进行分类;(predictedClasses == imdsTest.Labels)
精度= 0.8418

测试集的混淆矩阵

绘制测试集预测的混淆矩阵。这突出显示了哪些特定的类对网络造成了大多数问题。

图;confusionchart (imdsTest。标签,predictedClasses,“归一化”“row-normalized”);

混淆矩阵表明,该网络对某些类的性能较差,如希腊沙拉、生鱼片、热狗和寿司。这些类在数据集中表示不足,可能会影响网络性能。调查其中的一个类,以更好地理解为什么网络正在挣扎。

图();直方图(imdsValidation.Labels);甘氨胆酸ax = ();ax.XAxis.TickLabelInterpreter =“没有”

调查分类

调查寿司类的网络分类。

寿司最像寿司

首先,找出哪些寿司图片最强烈地激活了寿司课的网络。这就回答了“该网站认为哪些图片最像寿司?”

绘制最大激活的图像,这些是强烈激活全连接层的“寿司”神经元的输入图像。这张图显示了排名靠前的4张图片,以降序排列。

chosenClass =“寿司”;classIdx =找到(net.Layers(结束)。类= = chosenClass);numImgsToShow = 4;[sortedScores, imgIdx] = findMaxActivatingImages (imdsTest、chosenClass predictedScores, numImgsToShow);图plotImages (imdsTest imgIdx、sortedScores predictedClasses, numImgsToShow)

想象寿司课的线索

电视台对寿司的定位正确吗?网络上最活跃的寿司类图像看起来都很相似——许多圆形的图形紧密地聚集在一起。

网络在分类这些种类的寿司时表现得很好。但是,要验证这是真的并且更好地理解为什么网络使其决定,请使用像Grad-Cam等可视化技术。有关使用Grad-Cam的更多信息,请参阅Grad-Cam揭示为什么深入学习决策

从增强的图像数据存储中读取第一个调整大小的图像,然后使用gradCAM

Imagenumber = 1;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);sgtitle(字符串(标签)+”(分数:+马克斯(分数)+“)”

Grad-CAM地图确认了网络聚焦在图像中的寿司上。不过,你也可以看到,网络是看着的部分板和表。

第二幅图的左边是一簇寿司,右边是一个单独的寿司。要查看网络的焦点,请阅读第二幅图像并绘制grado - cam。

imageNumber = 2;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图plotGradCAM (img, gradcamMapα);sgtitle(字符串(标签)+”(分数:+马克斯(分数)+“)”

该网站将这张图片归类为寿司,因为它看到了一组寿司。然而,它是否能够单独对一种寿司进行分类呢?看一张寿司的图片来测试这一点。

img = imread (strcat (tempdir,“食品数据集/寿司/ sushi_18.jpg”));img = imresize(img,net.layers(1).inputsize(1:2),“方法”“双线性”“抗锯齿”,真正的);(标签,分数)=(净,img)进行分类;gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);sgtitle(字符串(标签)+”(分数:+马克斯(分数)+“)”

该网络能够正确地对这一种寿司进行分类。然而,GradCAM显示出网络集中在寿司顶部和黄瓜簇上,而不是整块。

在一个没有堆叠任何小块配料的寿司上运行grado - cam可视化技术。

img = imread (“crop__sushi34-copy.jpg”);img = imresize(img,net.layers(1).inputsize(1:2),“方法”“双线性”“抗锯齿”,真正的);(标签,分数)=(净,img)进行分类;gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(字符串(标签)+”(分数:+马克斯(分数)+“)”

在本例中,可视化技术强调了网络性能差的原因。它错误地将寿司的图像分类为汉堡包。

为了解决这个问题,您必须在训练过程中向网络提供更多的单个寿司的图像。

最不像寿司的寿司

现在发现寿司的哪些图像激活了寿司级别的网络。这回答了“网络思维较少寿司的图像”的问题。

这是有用的,因为它可以找到网络性能较差的图像,并提供一些对其决策的了解。

chosenClass =“寿司”;numImgsToShow = 9;[sortedScores, imgIdx] = findMinActivatingImages (imdsTest、chosenClass predictedScores, numImgsToShow);图plotImages (imdsTest imgIdx、sortedScores predictedClasses, numImgsToShow)

调查寿司被错误分类为生鱼片

为什么这个网络把寿司归类为生鱼片?该网站将9张图片中的3张分类为生鱼片。其中一些图像,例如图像4和9,实际上包含了生鱼片,这意味着网络实际上并没有错误地对它们进行分类。这些图片被标错了。

要查看网络的焦点是什么,在其中一张图像上运行grado - cam技术。

imageNumber = 4;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(字符串(标签)+(寿司评分:)+马克斯(分数)+“)”

不出所料,该网站关注的是生鱼片而不是寿司。

调查寿司被错误分类为披萨

为什么这个网络把寿司归类为披萨?该网站将其中四张图片分类为披萨而不是寿司。以图像1为例,该图像有一个彩色的顶点,这可能会混淆网络。

要查看网络正在查看图像的哪一部分,可以在其中一张图像上运行grado - cam技术。

Imagenumber = 1;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(字符串(标签)+(寿司评分:)+马克斯(分数)+“)”

这家电视台非常关注浇头。为了帮助网络区分披萨和带配料的寿司,添加更多带配料的寿司训练图像。网络也聚焦于板块。这可能是因为网络已经学会了将特定的食物与特定类型的盘子联系起来,在看寿司图片时也突出显示了这一点。为了提高网络的表现,训练时使用更多不同类型盘子上的食物例子。

调查寿司被误分类为汉堡

为什么这个网站把寿司归类为汉堡?要查看网络的焦点是什么,对分类错误的图像运行grado - cam技术。

imageNumber = 2;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(字符串(标签)+(寿司评分:)+马克斯(分数)+“)”

网络主要集中在图像中的花。五颜六色的紫色花和棕色茎迷住了网络将此图像识别为汉堡包。

调查寿司被误归类为炸薯条

为什么这个网站把寿司归类为炸薯条?该电视台将第三张图片分类为炸薯条而不是寿司。这种寿司的顶部是黄色的,网络可能会把这种颜色和薯条联系起来。

在这个图像上运行grado - cam。

imageNumber = 3;观察= augimdsTest.readByIndex (imgIdx (imageNumber));img = observation.input {1};标签= predictedClasses (imgIdx (imageNumber));分数= sortedScores (imageNumber);gradcamMap = gradCAM(净、img标签);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(字符串(标签)+(寿司评分:)+马克斯(分数)+“)”“翻译”“没有”

网络把黄色寿司归类为炸薯条。和汉堡一样,这种不同寻常的颜色也导致该网站对寿司进行了错误的分类。

为了帮助网络在这个特定的情况下,用更多黄色食物的图像训练它,而不是炸薯条。

结论

调查产生大或小类分数的数据点,以及网络自信地分类但错误的数据点,是一项简单的技术,可以提供有用的见解,了解训练有素的网络是如何运作的。在食物数据集的例子中,这个例子强调了:

  • 测试数据包含几个具有错误真实标签的图像,例如实际上是“Sushi”的“生鱼片”。数据还包含不完整的标签,例如包含Sushi和Sashimi的图像。

  • 该网站认为“寿司”是“多个、集群、圆形的东西”。然而,它也必须能够区分一个单独的寿司。

  • 任何带有浇头或不寻常颜色的寿司或生鱼片都会让网络感到困惑。为了解决这个问题,数据必须有更多种类的寿司和生鱼片。

  • 为了提高性能,网络需要看到更多来自未充分表示的类的图像。

辅助函数

函数dataDir downloadExampleFoodImagesData (url)%下载Example Food Image数据集,包含978张图片不同种类的食物分成9类。版权所有2019 The MathWorks, Inc.文件名=“ExampleFoodImageDataset.zip”;fileFullPath = fullfile (dataDir,文件名);%下载。zip文件到临时目录。如果~存在(fileFullPath“文件”)流("下载MathWorks示例食物图像数据集…\n");fprintf(“这可能需要几分钟下载…\n”);websave (fileFullPath、url);fprintf(“下载完成…\ n”);其他的fprintf("跳过下载,文件已经存在…\n");结尾%解压缩文件。通过检查文件是否存在来检查文件是否已经解压缩其中一个类目录的百分比。examplefolderfullpath = fullfile(datadir,“比萨”);如果~存在(exampleFolderFullPath“dir”)流(“将文件解压缩…\ n”);解压缩(fileFullPath dataDir);fprintf(“解完成…\ n”);其他的fprintf("跳过解压缩,文件已经解压缩…\n");结尾fprintf(“完成。\ n”);结尾函数[SortedScores,IMGIDX] = FindMaxActivationImages(IMDS,ClassName,PrediceScores,NumimgstOshow)%在所选班级的所有图像上找到所选班级的预测分数%(例如,寿司在所有寿司图片上的预测得分)[scoresforchosenclass,imgsofclassidxs] = findscoresforchosenclass(IMDS,ClassName,PredicteScores);%按降序排列分数[SortedScores,IDX] = sort(scoresforchosenclass,“下”);%返回只有前几个的指数imgIdx = imgsOfClassIdxs (idx (1: numImgsToShow));结尾函数[SortedScores,IMGIDX] = FindMinActivationImages(IMDS,ClassName,PregeteScores,NumimGstoshow)%在所选班级的所有图像上找到所选班级的预测分数%(例如,寿司在所有寿司图片上的预测得分)[scoresforchosenclass,imgsofclassidxs] = findscoresforchosenclass(IMDS,ClassName,PredicteScores);%将分数按升序排序[SortedScores,IDX] = sort(scoresforchosenclass,“提升”);%返回只有前几个的指数imgIdx = imgsOfClassIdxs (idx (1: numImgsToShow));结尾函数[scoresForChosenClass, imgsOfClassIdxs] = findScoresForChosenClass (imd,类名,predictedScores)查找className的索引(例如:“寿司”是第9课)uniqueclasses =唯一(IMDS.Labels);chosenclassidx = string(uniqueclasses)== classname;%在imageDatastore中查找标签为"className"的图像的索引%(例如找到所有寿司类图片)imgsOfClassIdxs =找到(imd)。标签= =类名);在所有的图像上找到所选班级的预测分数%选择类%(例如,寿司在所有寿司图片上的预测得分)scoresForChosenClass = predictedScores (imgsOfClassIdxs chosenClassIdx);结尾函数plotImages (imd, imgIdx sortedScores、predictedClasses numImgsToShow)为了i=1:numImgsToShow score = sortedScores(i);sortedImgIdx = imgIdx(我);predClass = predictedClasses (sortedImgIdx);correctClass = imds.Labels (sortedImgIdx);imgPath = imds.Files {sortedImgIdx};如果predClass == correctClass color =“\ color {green}”其他的颜色={红}\颜色”结尾predclasstitle = strrep(string(predclass),“_”' ');correctClassTitle = strrep (string (correctClass),“_”' ');子图(3,CEIL(Numimgstoshow./3),i)imshow(imread(ImgPath));标题(预测:“+ color + predclasstitle +”{黑}\换行符\颜色得分:“+ + num2str(得分)“\ newlineGround真理:“+ correctClassTitle);结尾结尾函数Plotgradcam(IMG,Gradcammap,Alpha)子图(1,2,1)Imshow(IMG);h =子图(1,2,2);imshow(img)持有;显示亮度图像(gradcamMap“AlphaData”、α);originalSize2 =得到(h,“位置”);colormap飞机colorbar集(h,“位置”,Originalsize2);抓住结尾

也可以看看

||||||||

相关例子

更多关于