代码中第5行,可能稍微有点难以理解,它不过是为了找出标签对应的输出值。比如第2个样本的标签值为3,那它分类器的输出应当选择第2行,第3列的值。
3.2 Focal Loss实现下面的代码的10~12行:依据输出,计算概率,再将其转为focal_weight;15~16行,将类权重和focal_weight添加到交叉熵损失,得到最终的focal_loss;18~21行,实现mean和sum两种reduction方法,注意求平均不是简单的直接平均,而是加权平均。
class FocalLoss(nn.Module): def __init__(self, gamma=2, weight=None, reduction='mean'): super(FocalLoss, self).__init__() self.gamma = gamma self.weight = weight self.reduction = reduction def forward(self, output, target): # convert output to presudo probability out_target = torch.stack([output[i, t] for i, t in enumerate(target)]) probs = torch.sigmoid(out_target) focal_weight = torch.pow(1-probs, self.gamma) # add focal weight to cross entropy ce_loss = F.cross_entropy(output, target, weight=self.weight, reduction='none') focal_loss = focal_weight * ce_loss if self.reduction == 'mean': focal_loss = (focal_loss/focal_weight.sum()).sum() elif self.reduction == 'sum': focal_loss = focal_loss.sum() return focal_loss注:上面实现中,output的维度应当满足output.dim==2,并且其形状为(batch_size, C),且target.max()<C。
总结Focal Loss从2017年提出至今,该论文已有2000多引用,足以说明其有效性。其实从本质上讲,它也只不过是给样本重新分配权重,它相对类别权重的分配方法,只不过是将样本空间进行更为细致的划分,从图2-6很容易理角,类别权重的方法,只是将样本空间划分为蓝色线上下两个部分,而加入难易样本的划分,又可以将空间划分为左右两个部分,如此,样本空间便被划分4个部分,这样更加细致。其实借助于这个思想,我们是否可以根据不同任务的需求,更加细致划分我们的样本空间,然后再相应的分配不同的权重呢?
参考文献[1] [Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. Paper presented at the Proceedings of the IEEE international conference on computer vision.]()