主要内容

利用可解释的一类分类神经网络检测图像异常

这个例子展示了如何训练一个异常检测器来视觉检查药丸图像。

在一类异常检测方法中,训练是半监督的,这意味着网络只在没有异常的正常图像组成的数据上训练[1].尽管只对正常场景的样本进行训练,但模型学会了如何区分正常和异常场景。单类学习为异常检测问题提供了许多优势:

  • 异常的表现可能是稀缺的。

  • 异常可能代表代价高昂或灾难性的结果。

  • 可能会有很多种异常,而这些异常的类型会在模型的生命周期中发生变化。描述“好”通常比提供代表现实世界环境中所有可能异常的数据更可行。

异常检测的一个关键目标是让人类观察者能够理解为什么经过训练的网络将图像分类为异常。可辩解的分类用证明神经网络如何做出分类决策的信息补充类预测。

这个例子探讨了如何使用一类深度学习来创建有效的异常检测器。该示例还使用一个网络实现了可解释的分类,该网络返回一个热图,其中每个像素都有可能是异常的。分类器根据异常评分热图的平均值将图像标记为正常或异常。

下载药丸图像分类数据集

本例使用pillQC数据集。该数据集包含三个类别的图像:正常的没有缺陷的图像,芯片药丸中有芯片缺陷的图像,还有污垢有污垢污染的图像。数据集提供了149正常的图片,43芯片图片和138污垢图像。数据集的大小为3.57 MB。

dataDir作为数据集的期望位置。方法下载数据集downloadPillQCDatahelper函数。该函数作为支持文件附加到示例中。金宝app该函数下载一个ZIP文件并将数据提取到子目录中芯片污垢,正常的

dataDir = fullfile(tempdir,“PillDefects”);downloadPillQCData (dataDir)

此图像显示了每个类的示例图像。左边是没有缺陷的正常药片,中间是被污垢污染的药片,右边是有芯片缺陷的药片。虽然此数据集中的图像包含阴影、焦点模糊和背景颜色变化的实例,但本示例中使用的方法对这些图像采集工件具有鲁棒性。

加载和预处理数据

创建一个imageDatastore读取和管理图像数据。将每个图像标记为芯片污垢,或正常的根据其目录的名称。

imagadir = fullfile(dataDir,“pillQC-main”“图片”);imds = imageDatastore(imageDir, inclesubfolders =true,LabelSource=“foldernames”);

将数据划分为训练集、校准集和测试集

为了模拟一个更典型的半监督工作流,创建一个包含70个图像的训练集正常的类。包括两个异常训练图像从每个芯片而且污垢类,以获得更好的分类结果。从每个异常类中分配30张正常图像和15张图像到校准集。本例使用校准集为分类器选择阈值。分类器将异常分数高于阈值的图像标记为异常。使用单独的校准和测试集避免了信息从测试集泄漏到分类器的设计中。将剩余的图像分配给测试集。

numTrainNormal = 70;numTrainAnomaly = 2;numCalNormal = 30;numCalAnomaly = 15;[imdsNormalTrain,imdsNormalCal,imdsNormalTest] = splitEachLabel(imds,numTrainNormal,numCalNormal,“随机”,包括=“正常”);[imdsAnomalyTrain,imdsAnomalyCal,imdsAnomalyTest] = splitEachLabel(imds,numTrainAnomaly,numCalAnomaly,“随机”,包括= (“芯片”“土”]);imdsTrain = imageDatastore(vertcat(imdsNormalTrain.Files,imdsAnomalyTrain.Files),LabelSource=“foldernames”, IncludeSubfolders = true);imdsCal = imageDatastore(vertcat(imdsNormalCal.Files,imdsAnomalyCal.Files),LabelSource=“foldernames”, IncludeSubfolders = true);imdsTest = imageDatastore(vertcat(imdsNormalTest.Files,imdsAnomalyTest.Files),LabelSource=“foldernames”, IncludeSubfolders = true);trainLabels = countlabels(imdsTrain.Labels)
trainLabels =3×3表标签计数百分比______ _____ _______芯片2 2.7027污垢2 2.7027正常70 94.595

增强训练数据

扩充训练数据变换函数使用helper函数指定的自定义预处理操作augmentDataForPillAnomalyDetector而且addConfettiNoiseForPillAnomalyDetector.helper函数作为支持文件附加到示例中。金宝app

augmentDataForPillAnomalyDetector函数对每个输入图像随机应用90度旋转和水平和垂直反射。的addConfettiNoiseForPillAnomalyDetector函数添加五彩纸屑噪声来模拟正常图像中的局部异常。每张正常的图像都有50%的可能性添加了五彩纸屑的噪音。这个增强步骤在正常和异常之间平衡训练数据,这有助于在训练过程中稳定损失函数。利用模拟异常图像平衡训练数据在异常图像稀缺的应用中非常有用。

dsTrain = transform(imdsTrain,@augmentDataForPillAnomalyDetector);dsTrain = transform(dsTrain,@addConfettiNoiseForPillAnomalyDetector,IncludeInfo=true);

方法将二进制标签添加到校准和测试数据集变换属性指定的操作addLabelDatahelper函数。方法中定义的辅助函数在本例的末尾,并在正常的a类二进制标签0和图片芯片污垢类一个二进制标签1

dsCal = transform(imdsCal,@addLabelData,IncludeInfo=true);dsTest = transform(imdsTest,@addLabelData,IncludeInfo=true);

可视化9个增强训练图像的样本。大约有一半的训练图像有五彩纸屑噪声异常。

exampleData = readall(子集(dsTrain,1:9));蒙太奇(exampleData (: 1));

创建FCDD模型

这个例子使用了一个全卷积数据描述(FCDD)模型[1].FCDD的基本思想是训练一个网络来生成一个异常分值图,描述输入图像中每个区域包含异常内容的概率。

本例以VGG-16网络为例[3.]在ImageNet上训练[4]作为全卷积网络架构的基础。该示例冻结了模型的大部分,并随机初始化和训练最后的卷积阶段。这种方法可以用少量的输入训练数据进行快速训练。

vgg16函数返回预先训练好的VGG-16网络。该功能需要VGG-16网络的深度学习工具箱™模型支持包。金宝app如果没有安装此支金宝app持包,则该函数将提供下载链接。

Net = vgg16;

用一个新的输入层替换编码器中的图像输入层,该输入层使用计算的平均值执行零中心归一化。将网络的输入大小设置为数据集中图像的大小。冻结网络的前24层freezeLayershelper函数。helper函数在本例的最后定义。

inputSize = [225 225 3];pretrainedVGG = [imageInputLayer(inputSize,Name= .“输入”归一化=“zerocenter”) net.Layers (24)];pretrainedVGG = freezeLayers(pretrainedVGG);

添加最后的卷积阶段。该阶段类似于vg -16的下一个卷积阶段,但具有随机初始化和可训练的卷积层,并具有批量归一化。1乘1卷积将网络输出压缩成一个单通道异常评分热图。下一层是伪huber损失函数,用于用FCDD损失稳定训练,并将输出热图限制在范围[0,] [1] [2].resize图层用于将输出热图调整为与输入图像相同的大小。全局平均池化层计算标量异常分数作为网络返回的输出热图的平均值。最后一个自定义损耗层fcddLossLayerForPillAnomalyDetector用于实现损失函数[1].

additionalFCLayers = [convolution2dLayer(3,512,Padding=“相同”)卷积2dlayer(3,512,填充=“相同”) batchNormalizationLayer reluLayer convolution2dLayer(1,1) functionLayer(@(x)√(x.^2+1)-1) resize2dLayer(EnableReferenceInput=true,Method=“双线性”、名称=“upsampleHeatmap”) globalAveragePooling2dLayer fcddLossLayer];

组装完整的网络。

lgraph = layerGraph([pretrainedVGG;additionalFCLayers]);lgraph = connectLayers(“输入”“upsampleHeatmap / ref”);

训练网络或下载预训练网络

默认情况下,本例使用helper函数下载预先训练好的VGG-16网络版本downloadTrainedNetwork.helper函数作为支持文件附加到这个示例中。金宝app您可以使用预训练的网络来运行整个示例,而无需等待训练完成。

为了训练网络,设置doTraining变量转换为真正的.指定用于训练的epoch数numEpochs通过在字段中输入一个值。训练模型使用trainNetwork函数。

如果可用,在一个或多个gpu上训练。使用GPU需要并行计算工具箱™和支持CUDA®的NVIDIA®GPU。有关更多信息,请参见GPU支金宝app持版本(并行计算工具箱).在NVIDIA Titan上训练大约需要6分钟RTX™。

doTraining =虚假的;numEpochs =One hundred.如果doTraining options = trainingOptions(“亚当”...洗牌=“every-epoch”...MaxEpochs = numEpochs InitialLearnRate = 1的军医,...MiniBatchSize = numpartitions (dsTrain));net = trainNetwork(dsTrain,lgraph,options);modelDateTime = string(datetime(“现在”格式=“yyyy-MM-dd-HH-mm-ss”));保存(fullfile (dataDir“trainedPillAnomalyDetector——”+ modelDateTime +“.mat”),“净”);其他的trainedPillAnomalyDetectorNet_url =“https://ssd.mathworks.com/金宝appsupportfiles/vision/data/trainedFCDDPillAnomalyDetector.zip”;downloadTrainedNetwork (trainedPillAnomalyDetectorNet_url dataDir);net = load(fullfile(dataDir,“trainedAnomalyDetector”“trainedPillFCDDNet.mat”));Net = net.net;结束

创建分类模型

根据图像的平均异常评分是否大于或小于阈值,将图像分为正常或异常。平均异常评分是异常评分热图的平均值。本示例计算最准确地分类校准图像集的阈值。

计算校准集中每个图像的平均异常分数和已知的地面真实标签(正常或异常)。

分数= predict(net,dsCal);labels = imdsCal。标签~ =“正常”

绘制正常和异常类别的平均异常得分的直方图。这些分布被模型预测的异常得分很好地分开。

numBins = 20;[~,edges] = histcounts(scores,numBins);figure normal =直方图(scores(labels==0),edges);持有hAnomaly =直方图(分数(标签==1),边);持有传奇([hNormal, hAnomaly],“正常”“异常”)包含(“平均异常分”);ylabel (“计数”);

创建接收者工作特征(ROC)曲线来计算异常阈值。ROC曲线上的每个点代表假阳性率(x-坐标)和真阳性率(y-coordinate)当使用不同的阈值对校准集图像进行分类时。最佳阈值可以使真阳性率最大化,假阳性率最小化。使用ROC曲线和相关指标允许您根据假阳性和假阴性之间的权衡选择阈值。这些权衡取决于将图像错误分类为假阳性和假阴性对特定应用程序的影响。

创建ROC曲线perfcurve(统计和机器学习工具箱)函数。蓝色实线代表ROC曲线。红色虚线表示一个随机分类器,对应于50%的成功率。显示图标题中校准集的曲线下面积(AUC)度量。一个完美的分类器具有最大AUC为1的ROC曲线。

[xroc,yroc,troc,auc] = perfcurve(标签,分数,true);图lroc = plot(xroc,yroc);持有Lchance = plot([0 1],[0 1],“r——”);持有包含(“假阳性率”) ylabel (“真阳性率”)标题(ROC曲线AUC:+ auc);传奇([lroc lchance),“ROC曲线”“随机的机会”

本例使用最大约登指数度量从ROC曲线中选择异常评分阈值。该值对应于使蓝色模型ROC曲线与红色随机概率ROC曲线之间的距离最大化的阈值。

[~,ind] = max(yroc-xroc);异常阈值= troc(ind)
anomalyThreshold =0.3696

评估分类模型

预测测试集中每个图像的平均异常分数。并得到每个测试图像的ground truth标签。

分数= predict(net,dsTest);标签= imdsTest。标签~ =“正常”

通过将平均异常分数与阈值进行比较,为测试图像分配一个类标签。

testSetOutputLabels = scores >异常阈值;

计算测试集的混淆矩阵和分类精度。本例中的分类模型非常准确,并预测了一小部分假阳性和假阴性。

testSetTargetLabels = logical(labels);M = confusimat (testSetTargetLabels,testSetOutputLabels);confusionchart (M,“正常”“异常”) = sum(diag(M)) / sum(M,“所有”);标题(准确性:“+ acc);

解释分类决定

您可以使用网络预测的异常热图来帮助解释为什么图像被分类为正常或异常。这种方法对于识别假阴性和假阳性的模式很有用。您可以使用这些模式来确定增加训练数据的类平衡或提高网络性能的策略。

查看异常热图

选择一个正确分类的异常的图像。这个结果是一个真正的阳性分类。显示图像。

idxTruePositive = find(testSetTargetLabels & testSetOutputLabels);dsExample =子集(dsTest,idxTruePositive);data = read(dsExample);Img =数据{1};图imshow (img)

的激活,获得异常图像的热图resize2dLayerupsampleHeatmap网络的。调整大小层返回与输入图像相同大小的异常评分热图。

Map =激活(净,单(img),“upsampleHeatmap”);

方法在输入图像上显示由网络预测的热图的覆盖heatmapOverlayhelper函数。这个函数在示例的最后定义。计算反映整个测试集中观察到的热图值范围的显示范围。在本例中,为所有热图应用显示范围。的最小值displayRange为0。将最大值设置为具有最大平均异常得分的测试集图像的热图的第80个百分位值。方法计算百分位值prctile函数。

[~,sampleIdx] = max(scores);sampleMaxScore = read(子集(dsTest,sampleIdx));热mapmaxscore =激活(net,sampleMaxScore{1}“upsampleHeatmap”);displayRange = [0, prtile (heatmapMaxScore,80,“所有”));imshow (heatmapOverlay (img,地图,displayRange))

为了定量确认结果,显示网络预测的真阳性测试图像的平均异常分。大于异常评分阈值。

disp (测试图像平均热图异常分:+的分数(idxTruePositive (1)));
试验图像平均热图异常分:1.1949

查看法向图像热图

选择并显示分类正确的正常图像。这个结果是一个真正的否定分类。

idxtrunegative = find(~(testSetTargetLabels | testSetOutputLabels));dsExample =子集(dsTest, idxtrunegative);data = read(dsExample);Img =数据{1};imshow (img)

方法的激活提取法向图像的热图resize2dLayerupsampleHeatmap网络的。方法在输入图像上显示由网络预测的热图的覆盖heatmapOverlayhelper函数。这个函数在示例的最后定义。许多真正的阴性测试图像,例如这张测试图像,要么没有可见的异常区域,要么在图像的局部部分具有较低的异常分数。

Map =激活(净,单(img),“upsampleHeatmap”);imshow (heatmapOverlay (img,地图,displayRange))

显示网络预测的真阴性测试图像的平均异常分。该值小于异常评分阈值。

disp (测试图像平均热图异常分:+的分数(idxTrueNegative (1)));
测试图像平均热图异常分:0.12476

查看假阴性图像热图

假阴性是带有药丸缺陷异常的图像,网络将其归类为正常。使用来自网络的解释来深入了解错误分类。

从测试集中找出任何假阴性图像。获取假阴性图像的热图叠加变换函数。属性的匿名函数指定了转换的操作heatmapOverlay的辅助函数的激活resize2dLayerupsampleHeatmap网络的。的heatmapOverlayHelper函数在示例的末尾定义。显示假阴性图像作为蒙太奇。如果没有假阴性,则该图为空。

falseNegativeIdx = find(testSetTargetLabels & ~testSetOutputLabels);如果~isempty(falseNegativeIdx) fnExamples =子集(dsTest,falseNegativeIdx);fnexampleswithheatmapoverlay =变换(fnExamples,@(x) {heatmapOverlay(x{1},激活(net,x{1},“upsampleHeatmap”), displayRange)});fnExamples = readall(fnExamples);fnExamples = fnExamples(:,1);fnexampleswithheatmapoverlay = readall(fnexampleswithheatmapoverlay);其他的[fnExamples, fnexampleswithheatmapoverlay] = deal([]);结束蒙太奇(fnExamples)

显示热图叠加为蒙太奇。正如预期的那样,该网络预测了芯片缺陷和污垢斑点周围可见的异常分数。

蒙太奇(fnExamplesWithHeatmapOverlays)

显示网络预测的假阴性检测图像的平均异常分。平均得分低于异常得分阈值,导致误分类。

disp (平均热图异常得分:);分数(falseNegativeIdx)
平均热图异常得分:
ans =2×1单列向量0.2603 - 0.3277

查看假阳性图像热图

假阳性图像没有丸缺陷异常,网络分类为异常。使用来自网络的解释来深入了解错误分类。

从测试集中找出任何假阳性图像。获取假阳性图像的热图叠加变换函数。属性的匿名函数指定了转换的操作heatmapOverlay的辅助函数的激活resize2dLayerupsampleHeatmap网络的。的heatmapOverlayHelper函数在示例的末尾定义。显示假阳性图像作为蒙太奇。如果没有假阳性,则该图为空。

falsePositiveIdx = find(~testSetTargetLabels & testSetOutputLabels);如果~isempty(falsePositiveIdx) fpExamples =子集(dsTest,falsePositiveIdx);fpexampleswithheatmapoverlay = transform(fpExamples,@(x) {heatmapOverlay(x{1},激活(net,x{1},“upsampleHeatmap”), displayRange)});fpExamples = readall(fpExamples);fpExamples = fpExamples(:,1);fpexampleswithheatmapoverlay = readall(fpexampleswithheatmapoverlay);其他的[fpExamples, fpexampleswithheatmapoverlay] = deal([]);结束蒙太奇(fpExamples)

显示热图叠加为蒙太奇。假阳性图像显示了网络标记为异常的区域。您可以使用网络行为的这种解释来深入了解分类问题。例如,如果异常分数被定位到图像背景中,您可以在预处理期间探索抑制背景。

蒙太奇(fpExamplesWithHeatmapOverlays)

显示网络预测的假阳性测试图像的平均异常得分。平均得分大于异常得分阈值,导致误分类。

disp (平均热图异常得分:);分数(falsePositiveIdx)
平均热图异常得分:
ans =0.4467

金宝app支持功能

freezeLayersHelper函数冻结由层数组指定的网络层

函数图层= freezeLayers(图层)Idx = 1:长度(层数)如果isprop(层(idx),“重量”) layers(idx) = setlearnnratefactor (layers(idx),Weights=0);layers(idx) = setlearnnratefactor (layers(idx),Bias=0);结束结束结束

heatmapOverlay辅助功能覆盖彩色热图hmap指定的显示范围displayRange在图像上img

函数out = heatmapOverlay(img,hmap,displayRange)%归一化到范围[0,1]Img = mat2gray(Img);hmap = rescale(hmap,InputMin=displayRange(1),InputMax=displayRange(2));使用颜色图将热图转换为RGB图像Map = jet(256);hmapRGB = ind2rgb(gray2ind(hmap,size(map,1)),map);混合结果%hmapWeight = hmap;imgWeight = 1-hmapWeight;out = im2uint8(imgWeight. out = im2uint8)*img + hmapWeight.*hmapRGB);结束

addLabelData中创建标签信息的单热编码表示数据

函数[data,info] = addLabelData(data,info)如果信息。标签== category (“正常”) onehotencoding = 0;其他的Onehotencoding = 1;结束Data = {Data,onehotencoding};结束

参考文献

[1]利兹涅斯基、菲利普、卢卡斯·拉夫、罗伯特·a·范德穆伦、比利·乔·弗兰克斯、马吕斯·克洛夫特和克劳斯·罗伯特Müller。“可解释的深层单类分类。”预印本,2021年3月18日提交。https://arxiv.org/abs/2007.01760

[2]拉夫,卢卡斯,罗伯特·a·范德穆伦,比利·乔·弗兰克斯,克劳斯-罗伯特Müller和马吕斯·克洛夫特。“重新思考深度异常检测中的假设。”预印本,2020年5月30日提交。https://arxiv.org/abs/2006.00339

Simonyan, Karen, Andrew Zisserman。“用于大规模图像识别的深度卷积网络。”预印本,2015年4月10日提交。https://arxiv.org/abs/1409.1556

[4]ImageNethttps://www.image-net.org

另请参阅

|||||(统计和机器学习工具箱)|(统计和机器学习工具箱)|(统计和机器学习工具箱)

相关的例子

更多关于