主要内容

使用最大和最小激活图像可视化图像分类

这个例子展示了如何使用数据集来找出是什么激活了深度神经网络的通道。这可以让你理解神经网络是如何工作的,并诊断训练数据集的潜在问题。

这个例子涵盖了一些简单的可视化技术,使用GoogLeNet转移学习食品数据集。通过查看最大限度或最小限度激活分类器的图像,您可以发现神经网络分类错误的原因。

加载和预处理数据

将映像加载为映像数据存储。这个小数据集总共包含9类食物的978个观察结果。

将这些数据分成训练集、验证集和测试集,为使用GoogLeNet进行迁移学习做准备。显示数据集中选定的图像。

rng默认的dataDir = fullfile(tempdir,“食品数据集”);url =“//www.tatmou.com/金宝appsupportfiles/nnet/data/ExampleFoodImageDataset.zip”如果~存在(dataDir“dir”mkdir (dataDir);结束downloadExampleFoodImagesData (url, dataDir);
下载MathWorks示例食物图像数据集…这可能需要几分钟的时间来下载…下载完成了…将文件解压缩……解压缩完成……完成了。
imds = imageDatastore(dataDir,...“IncludeSubfolders”,真的,“LabelSource”“foldernames”);[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2);rnd = randperm(numel(imds.Files),9);i = 1: nummel (rnd) subplot(3,3,i) imshow(imread(imds.Files{rnd(i)}) label = imds.Labels(rnd(i));标题(标签,“翻译”“没有”结束

训练网络分类食品图像

使用预先训练的GoogLeNet网络,再次训练它对9种食物进行分类。如果您没有深度学习工具箱™模型为GoogLeNet网络金宝app支持包安装后,再提供软件下载链接。

要尝试不同的预训练网络,请在MATLAB®中打开此示例并选择不同的网络,例如squeezenet这个网络甚至比googlenet.有关所有可用网络的列表,请参见预训练的深度神经网络

Net = googlenet;

元素的第一个元素网络的属性是图像输入层。该层需要输入大小为224 × 224 × 3的图像,其中3是彩色通道的数量。

inputSize = net.Layers(1).InputSize;

网络体系结构

网络的卷积层提取图像特征,最后的可学习层和最终的分类层使用这些特征对输入图像进行分类。这两层,“loss3-classifier”而且“输出”在GoogLeNet中,包含关于如何将网络提取的特征组合成类概率、损失值和预测标签的信息。为了训练一个预先训练好的网络来分类新图像,用适应新数据集的新层替换这两个层。

从训练好的网络中提取层图。

lgraph = layerGraph(net);

在大多数网络中,具有可学习权重的最后一层是全连接层。将这个全连接层替换为一个新的全连接层,输出的数量等于新数据集中的类的数量(在本例中为9)。

numClasses = numel(categories(imdsTrain.Labels));newfclayer = fullyConnectedLayer(numClasses,...“名字”“new_fc”...“WeightLearnRateFactor”10...“BiasLearnRateFactor”10);lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);

分类层指定网络的输出类。用一个没有类标签的新层替换分类层。trainNetwork在训练时自动设置层的输出类。

newclasslayer = classificationLayer(“名字”“new_classoutput”);lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);

列车网络的

网络需要大小为224 × 224 × 3的输入图像,但是图像数据存储中的图像大小不同。使用增强图像数据存储来自动调整训练图像的大小。指定要对训练图像执行的附加增强操作:沿着垂直轴随机翻转训练图像,随机将其平移到30像素,并将其缩放到水平和垂直的10%。数据增强有助于防止网络过度拟合和记忆训练图像的确切细节。

pixelRange = [-30 30];scaleRange = [0.9 1.1];imageAugmenter = imageDataAugmenter(...“RandXReflection”,真的,...“RandXTranslation”pixelRange,...“RandYTranslation”pixelRange,...“RandXScale”scaleRange,...“RandYScale”, scaleRange);augimdsTrain = augmentedimagedastore (inputSize(1:2)),imdsTrain,...“DataAugmentation”, imageAugmenter);

若要自动调整验证图像的大小,而不执行进一步的数据增强,请使用增强图像数据存储,而不指定任何额外的预处理操作。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

指定培训选项。集InitialLearnRate到一个小的值来减慢学习在转移层,而不是已经冻结。在前面的步骤中,您增加了最后一个可学习层的学习率因子,以加快新的最终层的学习速度。这种学习率设置的组合导致在新层中快速学习,在中间层中学习较慢,而在较早的冻结层中没有学习。

指定要训练的epoch数。在执行迁移学习时,您不需要训练许多epoch。epoch是整个训练数据集上的一个完整的训练周期。指定小批大小和验证数据。每个epoch计算一次验证精度。

miniBatchSize = 10;valFrequency = floor(nummel (augimdsTrain.Files)/miniBatchSize);选项= trainingOptions(“个”...“MiniBatchSize”miniBatchSize,...“MaxEpochs”4...“InitialLearnRate”3的军医,...“洗牌”“every-epoch”...“ValidationData”augimdsValidation,...“ValidationFrequency”valFrequency,...“详细”假的,...“阴谋”“训练进步”);

使用训练数据训练网络。默认情况下,trainNetwork使用GPU(如果有的话)。这需要并行计算工具箱™和受支持的GPU设备。金宝app有关受支持设备的信息,请参见金宝appGPU计算要求(并行计算工具箱).否则,trainNetwork使用CPU。属性指定执行环境“ExecutionEnvironment”的名称-值对参数trainingOptions.因为这个数据集很小,所以训练速度很快。如果你运行这个例子并自己训练网络,你会得到不同的结果和错误分类,这是由训练过程中涉及的随机性引起的。

net = trainNetwork(augimdsTrain,lgraph,options);

分类测试图像

利用微调网络对测试图像进行分类,并计算分类精度。

augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);[predictedClasses,predictedScores] = category (net,augimdsTest);accuracy = mean(predictedClasses == imdsTest.Labels)
准确度= 0.8418

测试集的混淆矩阵

绘制测试集预测的混淆矩阵。这突出了哪些特定的类会给网络带来最多的问题。

图;confusionchart (imdsTest。标签,predictedClasses,“归一化”“row-normalized”);

混淆矩阵表明,该网络对于某些类别的性能较差,如希腊沙拉、生鱼片、热狗和寿司。这些类在数据集中表现不足,这可能会影响网络性能。研究这些类中的一个,以更好地理解为什么网络正在挣扎。

图();直方图(imdsValidation.Labels);Ax = gca();ax.XAxis.TickLabelInterpreter =“没有”

调查的分类

调查寿司类的网络分类。

最喜欢寿司的寿司

首先,找出哪些寿司图片能最强烈地激活寿司课的网络。这回答了“网络认为哪些图像最像寿司?”的问题。

绘制最大激活的图像,这些是强烈激活全连接层的“寿司”神经元的输入图像。该图显示了排名靠前的4张图片,按降序排列。

chosenClass =“寿司”;classsidx = find(net.Layers(end)。Classes == chosenClass);numImgsToShow = 4;[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);图plotImages (imdsTest imgIdx、sortedScores predictedClasses, numImgsToShow)

为寿司课设想线索

电视台对寿司的定位正确吗?网络上最活跃的寿司类图片看起来都很相似——许多圆形的物体紧密地聚集在一起。

该网络在分类这些寿司方面做得很好。然而,为了验证这是真的,并更好地理解为什么网络做出它的决定,使用可视化技术,如Grad-CAM。有关使用Grad-CAM的更多信息,请参见Grad-CAM揭示深度学习决策背后的原因

从增强图像数据存储中读取第一个调整大小的图像,然后使用绘制Grad-CAM可视化gradCAM

imageNumber = 1;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);(标签)+ sgtitle(字符串(分数:"+马克斯(分数)+“)”

Grad-CAM地图证实了网络聚焦在图像中的寿司上。但是你也可以看到网络正在观察盘子和桌子的一部分。

第二张图左边是一堆寿司,右边是一个寿司。要了解网络的重点,请阅读第二张图像并绘制Grad-CAM。

imageNumber = 2;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图plotGradCAM (img, gradcamMapα);(标签)+ sgtitle(字符串(分数:"+马克斯(分数)+“)”

该网络将这张图片归类为寿司,因为它看到了一组寿司。然而,它能自己对寿司进行分类吗?通过看一张寿司的图片来验证这一点。

Img = imread(strcat(tempdir,“食品数据集/寿司/ sushi_18.jpg”));img = imresize(img,net.Layers(1).InputSize(1:2),“方法”“双线性”“抗锯齿”,真正的);[label,score] = category (net,img);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);(标签)+ sgtitle(字符串(分数:"+马克斯(分数)+“)”

该网络能够正确地对这个寿司进行分类。然而,Grad-CAM显示,神经网络集中在寿司的顶部和黄瓜簇上,而不是整个寿司。

在一个单独的寿司上运行Grad-CAM可视化技术,不包含任何堆叠的小块食材。

Img = imread(“crop__sushi34-copy.jpg”);img = imresize(img,net.Layers(1).InputSize(1:2),“方法”“双线性”“抗锯齿”,真正的);[label,score] = category (net,img);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(分数:"+马克斯(分数)+“)”

在这种情况下,可视化技术突出了为什么网络性能不佳。它错误地将寿司的图片归类为汉堡。

为了解决这个问题,你必须在训练过程中向网络提供更多的孤独寿司的图像。

寿司最不像寿司

现在找出哪些寿司图片对寿司班的网络激活程度最低。这回答了“网络认为哪些图像不那么像寿司?”的问题。

这很有用,因为它可以找到网络表现不佳的图像,并为其决策提供一些见解。

chosenClass =“寿司”;numImgsToShow = 9;[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);图plotImages (imdsTest imgIdx、sortedScores predictedClasses, numImgsToShow)

调查误归为生鱼片的寿司

为什么电视网把寿司归为生鱼片?该网络将9张图片中的3张归类为刺身。其中一些图像,例如图像4和9,实际上包含生鱼片,这意味着网络实际上并没有对它们进行错误分类。这些图片的标签是错误的。

要查看网络关注的是什么,请在其中一个图像上运行Grad-CAM技术。

imageNumber = 4;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:"+马克斯(分数)+“)”

不出所料,网络关注的是生鱼片而不是寿司。

调查寿司被错误归类为披萨

为什么电视网把寿司归为披萨?该网络将其中四张图片分类为披萨而不是寿司。考虑图1,这张图有一个彩色的顶部,这可能会混淆网络。

要查看网络正在查看图像的哪个部分,请在其中一个图像上运行Grad-CAM技术。

imageNumber = 1;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:"+马克斯(分数)+“)”

该网络强烈关注配料。为了帮助网络区分披萨和带配料的寿司,添加更多带配料的寿司训练图像。该网络也关注板块。这可能是因为神经网络已经学会了将某些食物与某些类型的盘子联系起来,就像在看寿司图片时强调的那样。为了提高网络的性能,可以使用更多不同类型盘子上的食物示例进行训练。

调查寿司被错误归类为汉堡

为什么电视台把寿司归为汉堡?要查看网络关注的是什么,请对错误分类的图像运行Grad-CAM技术。

imageNumber = 2;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:"+马克斯(分数)+“)”

网络聚焦在图像中的花上。五颜六色的紫色花朵和棕色茎使网络误认为这是一个汉堡。

调查被错误归类为薯条的寿司

为什么电视台把寿司归为薯条?该网络将第三张图片归类为薯条而不是寿司。这种特殊的寿司有黄色的顶部,网络可能会将这种颜色与炸薯条联系起来。

对这张照片运行Grad-CAM。

imageNumber = 3;观察= augimdsTest.readByIndex(imgIdx(imageNumber));Img = observation.input{1};label = predictedClasses(imgIdx(imageNumber));score = sortedScores(imageNumber);gradcamMap = gradCAM(net,img,label);图alpha = 0.5;plotGradCAM (img, gradcamMapα);标题(string(标签)+(寿司评分:"+马克斯(分数)+“)”“翻译”“没有”

该电视台将黄色寿司归类为炸薯条。和汉堡一样,这种不寻常的颜色也导致该网站将寿司错误归类。

为了在这种特定情况下帮助网络,用更多不是炸薯条的黄色食物的图像来训练它。

结论

调查产生大或小班级分数的数据点,以及网络自信但错误分类的数据点,是一种简单的技术,可以为训练有素的网络如何运作提供有用的见解。在食品数据集的例子中,这个例子强调了:

  • 测试数据包含了一些带有错误真实标签的图像,例如“生鱼片”实际上是“寿司”。数据还包含不完整的标签,例如同时包含寿司和生鱼片的图像。

  • 该网络认为“寿司”是“多个聚集的圆形物体”。然而,它也必须能够区分一个单独的寿司。

  • 任何带有配料或不寻常颜色的寿司或生鱼片都会让电视台感到困惑。要解决这个问题,数据必须有更广泛的寿司和生鱼片种类。

  • 为了提高性能,网络需要从未充分表示的类中看到更多的图像。

辅助函数

函数dataDir downloadExampleFoodImagesData (url)下载示例食品图像数据集,包含978张图片%不同种类的食物分为9类。The MathWorks, Inc.版权所有文件名=“ExampleFoodImageDataset.zip”;fileFullPath = fullfile(dataDir,fileName);下载.zip文件到一个临时目录。如果~存在(fileFullPath“文件”)流(“下载MathWorks示例食物图像数据集…\n”);流(“这可能需要几分钟才能下载……\n”);websave (fileFullPath、url);流(“下载完成…\ n”);其他的流(“跳过下载,文件已经存在……\n”);结束解压缩文件。通过检查文件是否已经解压缩一个类目录的%。exampleFolderFullPath = fullfile(dataDir,“披萨”);如果~存在(exampleFolderFullPath“dir”)流(“将文件解压缩…\ n”);解压缩(fileFullPath dataDir);流(“解完成…\ n”);其他的流("跳过解压缩,文件已解压缩…\n");结束流(“完成。\ n”);结束函数[sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow)在所选类别的所有图像上找到所选类别的预测分数%(例如,寿司在所有寿司图片上的预测得分)[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);按降序排列分数[sortedScores,idx] = sort(scoresForChosenClass,“下”);只返回前几个索引imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));结束函数[sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow)在所选类别的所有图像上找到所选类别的预测分数%(例如,寿司在所有寿司图片上的预测得分)[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);按升序排列分数[sortedScores,idx] = sort(scoresForChosenClass,“提升”);只返回前几个索引imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));结束函数[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores)查找className的索引(例如:“寿司”是第九节课)uniqueClasses = unique(imds.Labels);chosenClassIdx = string(uniqueClasses) == className;找到imageDatastore中标签为“className”的图像的索引%(例如找到所有寿司类的图片)imgsOfClassIdxs = find(imds。标签== className);找到所选班级在所有图片上的预测分数所选班级百分比%(例如,寿司在所有寿司图片上的预测得分)scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx);结束函数plotImages (imd, imgIdx sortedScores、predictedClasses numImgsToShow)i=1:numImgsToShow分数= sortedScores(i);sortedImgIdx = imgIdx(i);predClass = predictedClasses(sortedImgIdx);correctClass = imds.Labels(sortedImgIdx);imgPath = imds.Files{sortedImgIdx};如果predClass == correctClass color =“{绿}\颜色”其他的颜色={红}\颜色”结束predClassTitle = strrep(string(predClass),“_”' ');correctClassTitle = strrep(string(correctClass),“_”' ');次要情节(装天花板(numImgsToShow. / 3), i) imshow (imread (imgPath));标题(预测:“+ color + predClassTitle +”{黑}\换行符\颜色得分:“+ num2str(score) +"\newlineGround真相:"+ correctClassTitle);结束结束函数plotGradCAM(img,gradcamMap,alpha) subplot(1,2,1) imshow(img);H = subplot(1,2,2);imshow (img);显示亮度图像(gradcamMap“AlphaData”、α);originalSize2 = get(h,“位置”);colormap飞机colorbar集(h,“位置”, originalSize2);持有结束

另请参阅

||||||||

相关的例子

更多关于