MultiLabelSoftMarginLoss

1、MultiLabelSoftMarginLoss原理

MultiLabelSoftMarginLoss针对multi-label one-versus-all(多分类,且每个样本只能属于一个类)的情形。

loss的计算公式如下:

MultiLabelSoftMarginLoss

其中,x是模型预测的标签,x的shape是(N,C),N表示batch size,C是分类数;y是真实标签,shape也是(N,C),MultiLabelSoftMarginLoss

MultiLabelSoftMarginLoss的值域是(0,MultiLabelSoftMarginLoss);

MultiLabelSoftMarginLoss的值域是(1,MultiLabelSoftMarginLoss);

MultiLabelSoftMarginLoss的值域是(0,1);

MultiLabelSoftMarginLoss的值域是(-MultiLabelSoftMarginLoss,0),函数曲线如图1所示:

MultiLabelSoftMarginLoss
图1

为了看得更清楚一点,再画一下[-10,10]范围内的曲线,如图2:

MultiLabelSoftMarginLoss
图2

当y[i]=1得时候,x[i]越大==》MultiLabelSoftMarginLoss越大==》loss越小(因为MultiLabelSoftMarginLoss前面有个负号);

MultiLabelSoftMarginLoss的函数曲线如图3所示:

MultiLabelSoftMarginLoss

 [-10,10]范围内的曲线,如图4所示:

MultiLabelSoftMarginLoss

 当y[i]=0得时候,x[i]越小==》MultiLabelSoftMarginLoss越大==》loss越小(因为MultiLabelSoftMarginLoss前面有个负号);

 

2、使用MultiLabelSoftMarginLoss进行图片多分类

2.1 数据源以及如何打标签

以mnist数据源为例,共有10个分类。mnist中每张图片的标签是0-9中的一个数字,我们需要对标签进行转换:对于标签是0的图片,转换成的新标签是[1,0,0,0,0,0,0,0,0,0];对于标签是1的图片,转换成的新标签是[0,1,0,0,0,0,0,0,0,0];依此类推。

2.2 模型训练

2.2.1 模型搭建

MultiLabelSoftMarginLoss层前一层输出的特征图的大小必须是1*10,我们可以使用一个out_features是10的Linear层(全连接层)来实现。