那么什么叫Hard Sample,什么叫Easy Sample呢?看下面的图就知道了。
图2-1 Hard Sample 图2-2 Easy Sample1 图2-3 Easy Sample2 图2-4 Sample Space
假设,我们的任务是训练一个分类器,分类出人和马,对于上面的三张图,图2-2和图2-3应该是非常容易判断出来的,但是图2-1就是不那么容易了,它即有人的特征,又有马的特征,非常容易混淆。这种样本虽然在数据集中出现的频率可能并不高,但是想要提高分类器的性能,需要着力解决这种样本分类问题。
提出Hard Sample和Easy Sample后,可以将样本空间划分为如图2-4所示的样本空间。其中纵轴为多数类样本(Majority Class)和少数类样本(Minority Class),上面的带权重的交叉熵损失只能解决Majority Class和Minority Class的样本不平衡问题,并没有考虑Hard Sample和Easy Sample的问题,Focal Loss的提出就是为解决这个难易样本的分类问题。
2.2 Focal Loss解决方案要解决难易样本的分类问题,首先就需要找出Hard Sample和Easy Sample。这对于神经网络来说,应该是一件比较容易的事情。如图2-6所示,这是一个5分类的网络,神经网络的最后一层输出时,加上一个Softmax或者Sigmoid就会得到输出的伪概率值,代表着模型预测的每个类别的概率,
图2-6 Easy Sample Classifier Output 图2-7 Hard Sample Classifier Output
图2-6中,样本标签为1,分类器输出值最大的为第1个神经元(以0开始计数),这刚好预测准确,而且其输出值2也比其它神经元的输出值要大不少,因此可以认为这是一个易分类样本(Easy Sample);图2-7的样本标签是3,分类器输出值最大的为第4个神经元,并且这几个神经元的输出值都相差不大,神经网络无法准确判断这个样本的类别,所以可以认为这是一个难分类样本(Hard Sample)。其实说白了,判断Easy/Hard sample的方法就是看分类网络的最后的输出值。如果网络预测准确,且其概率较大,那么这是一个Easy Sample,如果网络输出的概率较小,这是一个Hard Sample。下面用数学公式严谨地表达来Focal Loss的表达式。
令一个\(C\)类分类器的输出为\(\boldsymbol{y}\in \mathcal{R}^{C\times 1}\),定义函数\(f\)将输出\(\boldsymbol{y}\)转为伪概率值\(\boldsymbol{p}=f(\boldsymbol{y})\),当前样本的类标签为\(t\),记\(p_t=\boldsymbol{p}[t]\),它表示分类器预测为\(t\)类的概率值,再结合上面的交叉熵损失,定义Focal Loss为:
\[
\text{FL} = -(1-p_t)\log(p_t)
\tag{2-1}
\]
这实质就是交叉熵损失前加了一个权重,只不过这个权重有点不一样的来头。为了更好地控制前面权重的大小,可以给前面的权重系数添加一个指数\(\gamma\),那么更改式(2-1):
\[ \text{FL} = -(1-p_t)^\gamma\log(p_t) \tag{2-2} \]
其中\(\gamma\)一值取值为2就好,\(\gamma\)取值为0时与交叉熵损失等价,\(\gamma\)越大,就越抑制Easy Sample的损失,相对就会越放大Hard Sample的损失。同时为解决样本类别不平衡的问题,可以再给式(2-2)添加一个类别的权重\(\alpha_t\)(这个类别权重上面的交叉熵损失已经实现):
\[ \text{FL} = -\alpha_t(1-p_t)^\gamma\log(p_t) \tag{2-3} \]
到这里,Focal Loss理论就结束了,非常简单,但是有效。
3 Focal Loss实现(Pytorch) 3.1 交叉熵损失实现(numpy)为了更好的理解Focal Loss的实现,先理解交叉熵损失的实现,我这里用numpy简单地实现了一下交叉熵损失。
import numpy as np def cross_entropy(output, target): out_exp = np.exp(output) out_cls = np.array([out_exp[i, t] for i, t in enumerate(target)]) ce = -np.log(out_cls / out_exp.sum(1)) return ce