主要内容

focalLossLayer

使用焦损函数创建焦损层

描述

焦点损失层使用焦点损失预测对象类别。增加焦点损失层,在前景类和背景类不平衡的情况下训练对象检测、语义分割或分类网络。为了补偿类别不平衡,焦点损失函数将交叉熵函数与一个调制因子相乘,该调制因子增加了网络对错误分类观测的敏感性。

创建

描述

例子

= focalLossLayer为深度学习网络创建一个焦点损失层。有关如何在目标检测网络中使用焦点损失层的信息,请参见创建SSD对象检测网络

例子

= focalLossLayer (名称,值通过使用一个或多个名称-值对参数设置焦点损耗层的属性。将每个属性名用引号括起来。

例如,focalLossLayer('名称',' focalloss ')用名称创建焦点损失层“focalloss”以及指定的平衡和聚焦参数。

属性

全部展开

焦点损失函数的平衡参数,指定为正实数。的α值线性缩放损失函数,通常设置为0.25.如果你减少α,增加γ

焦距损失函数的聚焦参数,指定为正实数。增加价值γ提高网络对错误分类观测的敏感性。

训练网络检测的对象类,指定为字符串向量、分类向量、字符向量的单元格数组或“汽车”.当你设置“汽车”,在培训时间自动设置课程。指定字符串向量或字符向量的单元格数组时,的元素都是根据输出的类别函数。

数据类型:字符串|分类|细胞|字符

层名,指定为字符向量或字符串标量。为数组输入时,trainNetwork(深度学习工具箱)assembleNetwork(深度学习工具箱)layerGraph(深度学习工具箱),dlnetwork(深度学习工具箱)函数自动为具有该名称的层分配名称

数据类型:字符|字符串

例子

全部折叠

指定类名。

类= [“汽车”“背景”];

指定焦距损失函数的平衡参数和聚焦参数。为这两个类创建一个名为“focallosslayer”的焦点损失层,显示结果。

图层= focalLossLayer(“类”、类“名字”“focallosslayer”
名称:' FocalLossLayer '超参数Gamma: 2 Alpha: 0.2500类:[2x1 categorical] LossFunction: 'focalLoss'

创建基于ResNet-18的DeepLab v3+网络。

imageSize = [480 640 3];numClasses = 5;网络=“resnet18”;lgraph = deeplabv3plusLayers(imageSize,numClasses,network,“DownsamplingFactor”, 16)
lgraph = LayerGraph with properties: Layers: [100x1 nnet.cnn.layer.Layer] Connections: [113x2 table] InputNames: {'data'} OutputNames: {'classification'}

显示网络的输出层。DeepLab v3+网络的输出层为a像素分类该算法使用交叉熵损失来预测输入二维图像中每个像素的分类标签。

lgraph.Layers(结束)
ans = PixelClassificationLayer with properties: Name: 'classification' Classes: 'auto' ClassWeights: 'none' OutputSize: 'auto'超参数LossFunction: 'crossentropyex'

替换输出像素分类损失处理数据中的不平衡类。

图层= focalLossLayer(“名称”“focalloss”);lgraph =替换层(lgraph,“分类”层);

显示网络。

analyzeNetwork (lgraph);

创建一个用于语义分割的三维U-Net网络unet3dLayers函数。将编码器-解码器深度设置为2,并指定第一个卷积层的输出通道数为16。

imageSize = [128 128 128 3];numClasses = 5;lgraph = unet3dLayers(imageSize,numClasses,“EncoderDepth”2,...“NumFirstEncoderFilters”16);图绘制(lgraph)

图中包含一个轴对象。axis对象包含一个graphplot类型的对象。

创建焦损层并替换Segmentation-Layer在网络中有焦损层。该层预测输入3-D体积中每个体素的分类标签。

图层= focalLossLayer(“名称”“focalloss”);lgraph =替换层(lgraph,“Segmentation-Layer”层)
lgraph = LayerGraph with properties: Layers: [40x1 nnet.cnn.layer.Layer] Connections: [41x2 table] InputNames: {'ImageInputLayer'} OutputNames: {'focalloss'}

显示网络。

analyzeNetwork (lgraph);

更多关于

全部展开

参考文献

[1]林、宗毅、普里亚·戈雅、罗斯·格什克、何凯明和彼得·杜dollar。密集物体检测的焦损2017年IEEE®计算机视觉国际会议,2999 - 3007。威尼斯:IEEE, 2017。https://doi.org/10.1109/ICCV.2017.324。

扩展功能

C/ c++代码生成
使用MATLAB®Coder™生成C和c++代码。

版本历史

R2020a中引入