主要内容

用LIME解释表格数据的深度网络预测

这个例子展示了如何使用局部可解释的模型不可知解释(LIME)技术来理解深度神经网络分类表格数据的预测。您可以使用LIME技术来了解哪些预测因子对网络的分类决策最重要。

在本例中,您使用LIME来解释特征数据分类网络。对于指定的查询观测,LIME生成一个合成数据集,其每个特征的统计数据与真实数据集相匹配。该合成数据集通过深度神经网络获得分类,并拟合一个简单的、可解释的模型。这个简单的模型可以用来理解前几个特征对网络分类决策的重要性。在训练这个可解释模型时,合成观测值由它们与查询观测值之间的距离加权,因此解释是“局部的”。

这个示例使用石灰(统计学和机器学习工具箱)适合(统计学和机器学习工具箱)生成一个综合数据集,并使一个简单的可解释模型适合于该综合数据集。要理解训练有素的图像分类神经网络的预测,使用imageLIME.有关更多信息,请参见理解使用LIME进行网络预测

加载数据

加载Fisher虹膜数据集。该数据包含150个观测数据,其中4个输入特征代表植物的参数,1个类别响应代表植物物种。每一个观察结果都被归类为三个物种之一:刚毛犀、花斑犀或维京犀。每次观察有四个测量值:萼片宽度,萼片长度,花瓣宽度和花瓣长度。

文件名= fullfile (toolboxdir (“统计数据”),“statsdemos”“fisheriris.mat”);加载(文件名)

将数值数据转换为表。

特点= [“花萼长度”“花萼宽”“花瓣长度”“花瓣宽度”];预测= array2table(量,“VariableNames”、功能);trueLabels = array2table(分类(物种),“VariableNames”“响应”);

创建一个培训数据表,其最后一列是响应。

data =[预测因子];

计算观察、特征和类的数量。

numObservations =大小(预测,1);numFeatures =大小(预测,2);numClasses =长度(类别(数据{:5}));

将数据分割为训练、验证和测试集

将数据集划分为训练集、验证集和测试集。留出15%的数据用于验证,15%用于测试。

确定每个分区的观察数。设置随机种子,使数据分割和CPU训练具有可重现性。

rng (“默认”);numObservationsTrain =地板(0.7 * numObservations);numObservationsValidation =地板(0.15 * numObservations);

创建一个与观察值对应的随机索引数组,并使用分区大小对其进行分区。

idx = randperm (numObservations);idxTrain = idx (1: numObservationsTrain);idxValidation = idx(numObservationsTrain + 1:numObservationsTrain + numObservationsValidation);idxTest = idx(numObservationsTrain + numObservationsValidation + 1:end);

使用索引将数据表划分为训练、验证和测试分区。

dataTrain =数据(idxTrain:);dataVal =数据(idxValidation:);人数(=数据(idxTest:);

定义网络体系结构

创建一个简单的多层感知器,有一个包含五个神经元和ReLU激活的隐藏层。特征输入层接受包含表示特征的数字标量的数据,例如Fisher虹膜数据集。

numHiddenUnits = 5;layers = [featureInputLayer(numFeatures) fulllyconnectedlayer (numHiddenUnits) reluLayer fulllyconnectedlayer (numClasses) softmaxLayer classificationLayer];

确定培训方案和培训网络

使用随机动量梯度下降(SGDM)训练网络。设置最大纪元数为30,并使用15的小批量大小,因为训练数据不包含许多观察。

选择= trainingOptions (“个”...“MaxEpochs”30岁的...“MiniBatchSize”15岁的...“洗牌”“every-epoch”...“ValidationData”dataVal,...“ExecutionEnvironment”“cpu”);

培训网络。

网= trainNetwork (dataTrain层,选择);
|======================================================================================================================| | 时代| |迭代时间| Mini-batch | |验证Mini-batch | |验证基地学习  | | | | ( hh: mm: ss) | | |精度精度损失| | |率损失|======================================================================================================================| | 1 | 1 |就是| | 40.00% 31.82% | 1.3060 | 1.2897 | 0.0100 | | 8 50 | |就是| | 86.67% 90.91% | 0.4223 | 0.3656 | 0.0100 | | 100 | |就是| | 93.33% 86.36% | 0.2947 | 0.2927 | 0.0100 | | 22 | 150 |就是|86.67% | 81.82% | 0.2804 | 0.3707 | 0.0100 | | 200 | | 29日00:00:01 | | 86.67% 90.91% | 0.2268 | 0.2129 | 0.0100 | | 210 | | 00:00:01 93.33% | | | 0.2782 | 0.1666 | 0.0100 95.45%  | |======================================================================================================================|

评估网络性能

使用训练过的网络对测试集的观测结果进行分类。

predictedLabels = net.classify(人数();trueLabels =人数({,,}结束;

使用混淆矩阵可视化结果。

图confusionchart (trueLabels predictedLabels)

该网络成功地利用四种植物特征来预测试验观测的物种。

了解不同的预测对不同的类有多重要

使用LIME来理解每个预测器对网络分类决策的重要性。

调查每个观察的两个最重要的预测因素。

numImportantPredictors = 2;

使用石灰创建一个合成数据集,其每个特征的统计数据与真实数据集相匹配。创建一个石灰对象使用深度学习模型黑箱其中包含了预测数据预测.使用一个低“KernelWidth”所以价值石灰使用关注查询点附近的示例的权重。

黑箱= @ (x)分类(净,x);讲解员=石灰(黑箱预测,“类型”“分类”“KernelWidth”, 0.1);

你可以使用LIME解释器来理解深层神经网络最重要的特征。该函数通过使用一个简单的线性模型估计特征的重要性,该模型在查询观测附近近似神经网络。

找出测试数据中与setosa类相对应的前两个观察值的指数。

trueLabelsTest =人数({,,}结束;标签=“setosa”;idxSetosa = find(trueLabelsTest == label,2);

使用适合函数使一个简单的线性模型符合指定类的前两个观察值。

explainerObs1 =适合(讲解员人数((idxSetosa (1), 1:4), numImportantPredictors);explainerObs2 =适合(讲解员人数((idxSetosa (2), 1:4), numImportantPredictors);

策划的结果。

图次要情节(2,1,1)情节(explainerObs1);次要情节(2,1,2)情节(explainerObs2);

对于萼片类来说,最重要的预测因子是低花瓣长度值和高花瓣宽度值。

对类versicolor执行相同的分析。

标签=“多色的”;idxVersicolor = find(trueLabelsTest == label,2);explainerObs1 =适合(讲解员人数((idxVersicolor (1), 1:4), numImportantPredictors);explainerObs2 =适合(讲解员人数((idxVersicolor (2), 1:4), numImportantPredictors);图次要情节(2,1,1)情节(explainerObs1);次要情节(2,1,2)情节(explainerObs2);

对于花斑类,高花瓣长度值是重要的。

最后,再考虑一下维琪卡课程。

标签=“virginica”;idxVirginica = find(trueLabelsTest == label,2);explainerObs1 =适合(讲解员人数((idxVirginica (1), 1:4), numImportantPredictors);explainerObs2 =适合(讲解员人数((idxVirginica (2), 1:4), numImportantPredictors);图次要情节(2,1,1)情节(explainerObs1);次要情节(2,1,2)情节(explainerObs2);

对于弗吉尼亚类来说,高花瓣长度值和低萼片宽度值是重要的。

验证石灰假说

LIME图显示,高花瓣长度值与花斑类和维珍类有关,低花瓣长度值与刚毛类有关。您可以通过研究数据进一步研究结果。

在数据集中绘制每个图像的花瓣长度。

setosaIdx = ismember(数据{,,},“setosa”);versicolorIdx = ismember(数据{,,},“多色的”);virginicaIdx = ismember(数据{,,},“virginica”);图保存情节(数据{setosaIdx,“花瓣长度”},“。”)情节(数据{versicolorIdx,“花瓣长度”},“。”)情节(数据{virginicaIdx,“花瓣长度”},“。”)举行包含(“观察”) ylabel (“花瓣长度”)传说([“setosa”“多色的”“virginica”])

setosa类的花瓣长度值比其他类低得多,与从石灰模型。

另请参阅

(统计学和机器学习工具箱)|(统计学和机器学习工具箱)||||

相关的话题