预测
利用神经网络分类器对观测数据进行分类
语法
描述
例子
利用神经网络对测试集观测值进行分类
使用神经网络分类器预测测试集观察的标签。
加载病人
数据集。从数据集中创建一个表。每一行对应一个病人,每列对应一个诊断变量。使用吸烟者
变量作为响应变量,其余变量作为预测变量。
负载病人tbl = table(舒张压,收缩压,性别,身高,体重,年龄,吸烟);
将数据分成一个训练集tblTrain
和一个测试集tblTest
通过使用分层坚持分区。该软件为测试数据集保留大约30%的观测值,并将其余观测值用于训练数据集。
rng (“默认”)用于分区的再现性C = cvpartition(tbl.)抽烟,“坚持”, 0.30);trainingIndices = training(c);testIndices =测试(c);tblTrain = tbl(trainingIndices,:);tblTest = tbl(testIndices,:);
使用训练集训练神经网络分类器。指定吸烟者
列的tblTrain
作为响应变量。指定以标准化数值预测器。
Mdl = fitcnet(tblTrain,“抽烟”,...“标准化”,真正的);
对测试集观测值进行分类。使用混淆矩阵将结果可视化。
label = predict(Mdl,tblTest);confusionchart (tblTest.Smoker、标签)
神经网络模型正确地分类了除两个测试集观测值之外的所有测试集观测值。
选择要包含在神经网络分类器中的特征
通过比较测试集分类边缘、边缘、误差和预测来执行特征选择。将使用所有预测器训练的模型的测试集指标与仅使用预测器子集训练的模型的测试集指标进行比较。
加载示例文件fisheriris.csv
,其中包含虹膜数据,包括萼片长度、萼片宽度、花瓣长度、花瓣宽度和物种类型。将文件读入表。
渔场=可读表(“fisheriris.csv”);
将数据分成一个训练集trainTbl
和一个测试集testTbl
通过使用分层坚持分区。该软件为测试数据集保留大约30%的观测值,并将其余观测值用于训练数据集。
rng (“默认”c = cvpartition(渔场表。物种,“坚持”, 0.3);trainTbl =渔场(训练(c),:);testTbl = fishtable (test(c),:);
使用训练集中的所有预测器训练一个神经网络分类器,并使用训练集中的所有预测器训练另一个分类器PetalWidth
.对于这两个模型,请指定物种
作为响应变量,并对预测函数进行标准化。
allMdl = fitcnet(trainTbl,“物种”,“标准化”,真正的);subsetMdl = fitcnet(trainTbl,“物种~ SepalLength + SepalWidth + PetalLength”,...“标准化”,真正的);
计算两个模型的测试集分类边界。因为测试集只包含45个观察值,所以使用条形图显示边缘。
对于每个观察,分类裕度是真实类别的分类分数与虚假类别的最大分数之间的差值。因为神经网络分类器返回的分类分数是后验概率,边际值接近1表示正确分类,边际值为负表示错误分类。
tiledlayout (2, 1)%顶轴Ax1 = nexttile;alledges = margin(allMdl,testTbl);栏(ax₁,allMargins)包含(ax₁“观察”) ylabel (ax₁,“保证金”)标题(ax₁,“预测”)%底轴Ax2 = nexttile;subsetmargin = margin(subsetMdl,testTbl);栏(ax2 subsetMargins)包含(ax2,“观察”) ylabel (ax2,“保证金”)标题(ax2,“预测因子子集”)
比较两个模型的测试集分类边缘,或分类边缘的平均值。
allEdge = edge(allMdl,testTbl)
allEdge = 0.8198
subsetEdge = edge(subsetMdl,testTbl)
subsetEdge = 0.9556
基于测试集分类边界和边,在一个预测器子集上训练的模型似乎优于在所有预测器上训练的模型。
比较两种模型的测试集分类误差。
allError = loss(allMdl,testTbl);allAccuracy = 1-allError
allAccuracy = 0.9111
subsetError = loss(subsetMdl,testTbl);subsetAccuracy = 1-subsetError
subsetAccuracy = 0.9778
同样,只使用一个预测器子集训练的模型似乎比使用所有预测器训练的模型表现得更好。
使用混淆矩阵可视化测试集分类结果。
allLabels = predict(allMdl,testTbl);图(testtable . species,allLabels)“预测”)
subsetLabels = predict(subsetMdl,testTbl);figure figure (testtable . species,subsetLabels)“预测因子子集”)
使用所有预测器训练的模型错误地分类了四个测试集观测值。使用预测器子集训练的模型只错分类了一个测试集观测值。
考虑到这两个模型的测试集性能,考虑使用使用所有预测器训练的模型PetalWidth
.
利用神经网络分类器的层结构进行预测
了解神经网络分类器的各层如何协同工作,以预测单个观察的标签和分类分数。
加载示例文件fisheriris.csv
,其中包含虹膜数据,包括萼片长度、萼片宽度、花瓣长度、花瓣宽度和物种类型。将文件读入表。
渔场=可读表(“fisheriris.csv”);
使用数据集训练神经网络分类器。指定物种
列的fishertable
作为响应变量。
Mdl = fitcnet(渔场表,“物种”);
从数据集中选择第15个观测值。看看神经网络分类器的各层如何获取观察结果并返回预测的类标签newPointLabel
分类分数newPointScores
.
newPoint = Mdl。X{15日:}
newPoint =1×45.8000 4.0000 1.2000 0.2000
firstFCStep = (Mdl.LayerWeights{1})*newPoint' + Mdl.LayerBiases{1};reluStep = max(firstFCStep,0);finalFCStep = (Mdl.LayerWeights{end})*reluStep + Mdl.LayerBiases{end};finalSoftmaxStep = softmax(finalFCStep);[~, classsidx] = max(finalSoftmaxStep);newPointLabel = Mdl。一会{classIdx}
newPointLabel = 'setosa'
newPointScores = finalSoftmaxStep'
newPointScores =1×31.0000 0.0000 0.0000
方法返回的预测是否匹配预测
对象的功能。
[predictedLabel,predictedScores] = predict(Mdl,newPoint)
predictedLabel =1x1单元阵列{' setosa '}
predictedScores =1×31.0000 0.0000 0.0000
输入参数
Mdl
- - - - - -训练神经网络分类器
ClassificationNeuralNetwork
模型对象|CompactClassificationNeuralNetwork
模型对象
训练过的神经网络分类器,指定为ClassificationNeuralNetwork
模型对象或CompactClassificationNeuralNetwork
返回的模型对象fitcnet
或紧凑的
,分别。
X
- - - - - -预测数据要分类
数字矩阵|表格
要分类的预测器数据,指定为数字矩阵或表格。
默认情况下,每一行X
对应一个观察结果,每一列对应一个变量。
对于数值矩阵:
列中的变量
X
必须与训练的预测变量有相同的顺序Mdl
.如果你训练
Mdl
使用表格(例如,资源描述
),资源描述
那么,只包含数值预测变量X
可以是数值矩阵。处理中的数值预测器资源描述
作为分类的训练,识别分类预测因子使用CategoricalPredictors
的名称-值参数fitcnet
.如果资源描述
包含异构预测变量(例如,数字和分类数据类型)和X
是数字矩阵吗预测
抛出错误。
对于表格:
预测
不支持多列变量或除金宝app字符向量的单元格数组外的单元格数组。如果你训练
Mdl
使用表格(例如,资源描述
),然后输入所有预测变量X
必须具有与训练的变量相同的变量名和数据类型Mdl
(存储在Mdl。PredictorNames
).的列序X
是否需要对应的列顺序资源描述
.同时,资源描述
而且X
可以包含额外的变量(响应变量、观察权重等),但是预测
忽略了它们。如果你训练
Mdl
使用一个数字矩阵,然后预测器名称在Mdl。PredictorNames
必须与中对应的预测变量名称相同X
.若要在训练期间指定预测器名称,请使用PredictorNames
的名称-值参数fitcnet
.所有预测变量X
必须是数值向量。X
可以包含额外的变量(响应变量、观察权重等),但是预测
忽略了它们。
如果你设置“标准化”,真的
在fitcnet
当训练Mdl
,然后软件用相应的均值和标准差对预测器数据的数值列进行标准化。
请注意
如果你定位你的预测矩阵,使观察结果与列相对应,并指定“ObservationsIn”、“列”
,那么您可能会经历计算时间的显著减少。你不能指定“ObservationsIn”、“列”
用于表中的预测器数据。
数据类型:单
|双
|表格
维
- - - - - -预测器数据观测维数
“行”
(默认)|“列”
预测器数据观测维数,指定为“行”
或“列”
.
请注意
如果你定位你的预测矩阵,使观察结果与列相对应,并指定“ObservationsIn”、“列”
,那么您可能会经历计算时间的显著减少。你不能指定“ObservationsIn”、“列”
用于表中的预测器数据。
数据类型:字符
|字符串
输出参数
更多关于
分类的分数
的分类的分数对于神经网络分类器使用softmax激活函数计算,该激活函数跟随网络中最后的完全连接层。分数与后验概率相对应。
后验概率表示观测结果x是一流的k是
在哪里
P(x|k)条件概率是x给定类k.
P(k)是类的先验概率吗k.
K响应变量中的类数。
一个k(x)是k从最终的全连接层输出观察x.
选择功能
金宝app仿真软件块
将神经网络分类模型的预测集成到Simulink中金宝app®,你可以使用ClassificationNeuralNetwork预测块的统计和机器学习工具箱™库或MATLAB®函数块。预测
函数。有关示例,请参见使用ClassificationNeuralNetwork预测块预测类标签而且使用MATLAB函数块预测类标签.
在决定使用哪种方法时,请考虑以下因素:
如果使用“统计和机器学习工具箱”库块,则可以使用定点的工具(定点设计师)将浮点模型转换为定点模型。
金宝app控件的MATLAB函数块必须启用对可变大小数组的支持
预测
函数。如果使用MATLAB函数块,则可以在同一MATLAB函数块中使用MATLAB函数进行预测前后的预处理或后处理。
扩展功能
C/ c++代码生成
使用MATLAB®Coder™生成C和c++代码。
使用注意事项和限制:
使用
saveLearnerForCoder
,loadLearnerForCoder
,codegen
(MATLAB编码器)方法生成代码预测
函数。通过使用保存一个训练好的模型saveLearnerForCoder
.定义一个入口点函数,该函数通过loadLearnerForCoder
并调用预测
函数。然后使用codegen
为入口点函数生成代码。生成单精度C/ c++代码
预测
,指定名称-值参数“数据类型”、“单身”
当你打电话给loadLearnerForCoder
函数。该表包含关于的参数的注释
预测
.完全支持不包括在本表中的参数。金宝app论点 注意事项和限制 Mdl
有关模型对象的使用说明和限制,请参见代码生成的
CompactClassificationNeuralNetwork
对象。X
X
必须是单精度或双精度矩阵或包含数值变量、类别变量或两者的表。的行数,或观测值
X
可以是大小可变的,但列的数量呢X
必须修复。如果你想指定
X
作为一个表,那么你的模型必须使用一个表来训练,你的预测入口点函数必须做到以下几点:接受数据作为数组。
根据数据输入参数创建一个表,并在表中指定变量名。
把桌子递给
预测
.
有关此表工作流的示例,请参见生成代码对表中的数据进行分类.有关在代码生成中使用表的详细信息,请参见表的代码生成(MATLAB编码器)而且代码生成的表限制(MATLAB编码器).
ObservationsIn
的
维
的值。ObservationsIn
名称-值参数必须是编译时常量。例如,使用“ObservationsIn”、“列”
在生成的代码中,包含{coder.Constant(“ObservationsIn”),coder.Constant(“列”)}
在arg游戏
的价值codegen
(MATLAB编码器).
有关更多信息,请参见代码生成简介.
版本历史
R2021a中引入
Abrir比如
Tiene una versión modificada de este ejemplo。¿Desea abrir este ejemplo con sus modificaciones?
MATLAB突击队
Ha hecho clic en unenlace que对应一个este commando de MATLAB:
弹射突击队introduciéndolo en la ventana de commandos de MATLAB。Los navegadores web no permission comandos de MATLAB。
您也可以从以下列表中选择一个网站:
如何获得最佳的网站性能
选择中国站点(中文或英文)以获得最佳站点性能。其他MathWorks国家站点没有针对您所在位置的访问进行优化。