具体实现是将pspnet的输出经过softmax,然后进行两次筛选。第一次筛选基于label的有效区域(非255),predict上255对应的区域将不纳入loss的计算。经第一次筛选,将label中对应predict概率大于0.7的区域也置为255。最后只有剩余区域将参与loss的计算。
import torch import torch.nn.functional as F import torch.nn as nn from torch.autograd import Variable import numpy as np import scipy.ndimage as nd class OhemCrossEntropy2d(nn.Module): def __init__(self, ignore_label=255, thresh=0.7, min_kept=100000, factor=8): super(OhemCrossEntropy2d, self).__init__() self.ignore_label = ignore_label #忽略类别255 self.thresh = float(thresh) #阈值0.7 # self.min_kept_ratio = float(min_kept_ratio) self.min_kept = int(min_kept) # self.factor = factor self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) #寻找阈值 #np_predict.shape(1, 19, 769, 769)、np_target.shape(1, 769, 769) """ 阈值的选取主要基于min_kept,用第min_kept个的概率来确定。 且返回的阈值只能 ≥ thresh。 """ def find_threshold(self, np_predict, np_target): # downsample 1/8 factor = self.factor #8 predict = nd.zoom(np_predict, (1.0, 1.0, 1.0/factor, 1.0/factor), order=1) #双线性插值 shape(1, 19, 96, 96) target = nd.zoom(np_target, (1.0, 1.0/factor, 1.0/factor), order=0) #最近临插值 shape(1, 96, 96) n, c, h, w = predict.shape #1, 19, 96, 96 min_kept = self.min_kept // (factor*factor) #int(self.min_kept_ratio * n * h * w) #100000/64 = 1562 input_label = target.ravel().astype(np.int32) #将多维数组转化为一维 shape(9216, ) input_prob = np.rollaxis(predict, 1).reshape((c, -1)) #轴1滚动到轴0、shape(19, 9216) valid_flag = input_label != self.ignore_label #label中有效位置(9216, ) valid_inds = np.where(valid_flag)[0] #(9013, ) label = input_label[valid_flag] #有效label(9013, ) num_valid = valid_flag.sum() #9013 if min_kept >= num_valid: #1562 >= 9013 threshold = 1.0 elif num_valid > 0: #9013 > 0 prob = input_prob[:,valid_flag] #(19, 9013) #找出有效区域对应的prob pred = prob[label, np.arange(len(label), dtype=np.int32)] #??? shape(9013, ) threshold = self.thresh #0.7 if min_kept > 0: #1562>0 k_th = min(len(pred), min_kept)-1 #min(9013, 1562)-1 = 1561 new_array = np.partition(pred, k_th) #排序并分成两个区,小于第1561个及大于第1561个 new_threshold = new_array[k_th] #第1561对应的pred 0.03323581 if new_threshold > self.thresh: #返回的阈值只能≥0.7 threshold = new_threshold return threshold #生成新的labels #predict.shape(1, 19, 97, 97)、target.shape(1, 97, 97) """ 主要思路 1先通过find_threshold找到一个合适的阈值如0.7 2一次筛选出不为255的区域 3再从中二次筛选找出对应预测值小于0.7的区域 4重新生成一个label,label把预测值大于0.7和原本为255的位置 都置为255 """ def generate_new_target(self, predict, target): np_predict = predict.data.cpu().numpy() #shape(1, 19, 769, 769) np_target = target.data.cpu().numpy() #shape(1, 769, 769) n, c, h, w = np_predict.shape #1, 19, 769, 769 threshold = self.find_threshold(np_predict, np_target) #寻找阈值0.7 input_label = np_target.ravel().astype(np.int32) #shape(591361, ) input_prob = np.rollaxis(np_predict, 1).reshape((c, -1)) #(19, 591361) valid_flag = input_label != self.ignore_label #label中有效位置(591361, ) valid_inds = np.where(valid_flag)[0] #(579029, ) label = input_label[valid_flag] #一次筛选:不为255的label(579029, ) num_valid = valid_flag.sum() #579029 if num_valid > 0: prob = input_prob[:,valid_flag] #(19, 579029) pred = prob[label, np.arange(len(label), dtype=np.int32)] #不明白这一步的操作??? (579029, ) kept_flag = pred <= threshold #二次筛选:在255中找出pred≤0.7的位置 valid_inds = valid_inds[kept_flag] #shape(579029, ) print('Labels: {} {}'.format(len(valid_inds), threshold)) label = input_label[valid_inds].copy() #从原label上扣下来shape(579029, ) input_label.fill(self.ignore_label) #shape(591361, )每个值都为255 input_label[valid_inds] = label #把二次筛选后有效区域的对应位置为label,其余为255 new_target = torch.from_numpy(input_label.reshape(target.size())).long().cuda(target.get_device()) #shape(1, 769, 769) return new_target #shape(1, 769, 769) def forward(self, predict, target, weight=None): """ Args: predict:(n, c, h, w) (1, 19, 97, 97) target:(n, h, w) (1, 97, 97) weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses" """ assert not target.requires_grad input_prob = F.softmax(predict, 1) #在channel上进行一次softmax,得到概率 target = self.generate_new_target(input_prob, target) #生成新labels return self.criterion(predict, target) 参考Zhao H, Shi J, Qi X, et al. Pyramid scene parsing network[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 2881-2890.
Yuan Y, Wang J. Ocnet: Object context network for scene parsing[J]. arXiv preprint arXiv:1809.00916, 2018.