如何解决回归任务数据不均衡的问题?

摘要:现有的处理不平衡数据/长尾分布的方法绝大多数都是针对分类问题,而回归问题中出现的数据不均衡问题确极少被研究。

本文分享自华为云社区《如何解决回归任务数据不均衡的问题?》,原文作者:PG13。

现有的处理不平衡数据/长尾分布的方法绝大多数都是针对分类问题,而回归问题中出现的数据不均衡问题确极少被研究。但是,现实很多的工业预测场景都是需要解决回归的问题,也就是涉及到连续的,甚至是无限多的目标值,如何解决回归问题中出现的数据不均衡问题呢?ICML2021一篇被接收为Long oral presentation的论文:Delving into Deep Imbalanced Regression,推广了传统不均衡分类问题的范式,将数据不平衡问题从离散值域推广到了连续值域,并提出了两种解决深度不均衡回归问题的方法。

主要的贡献是三个方面:1)提出了一个深度不均衡回归(Deep Imbalanced Regression, DIR)任务,定义为从具有连续目标的不平衡数据中学习,并能泛化到整个目标范围;2)提出了两种解决DIR的新方法,标签分布平滑(label distribution smoothing, LDS)和特征分布平滑(feature distribution smoothing, FDS),来解决具有连续目标的不平衡数据的学习问题;3)建立了5个新的DIR数据集,包括了CV、NLP、healthcare上的不平衡回归任务,致力于帮助未来在不平衡数据上的研究。

数据不平衡问题背景

现实世界的数据通常不会每个类别都具有理想的均匀分布,而是呈现出长尾的偏斜分布,其中某些目标值的观测值明显较少,这对于深度学习模型有较大的挑战。传统的解决办法可以分为基于数据基于模型两种:基于数据的解决方案无非对少数群体进行过采样和对多数群体进行下采样,比如SMOTE算法;基于模型的解决方案包括对损失函数的重加权(re-weighting)或利用相关的学习技巧,如迁移学习、元学习、两阶段训练等。

但是现有的数据不平衡解决方案,主要是针对具有categorical index的目标值,也就是离散的类别标签数据。其目标值属于不同的类别,并且具有严格的硬边界,不同类别之间没有重叠。现实世界很多的预测场景可能涉及到连续目标值的标签数据。比如,根据人脸视觉图片预测年龄,年龄便是一个连续的目标值,并且在目标范围内可能会高度失衡。在工业领域中,也会发生类似的问题,比如在水泥领域,水泥熟料的质量,一般都是连续的目标值;在配煤领域,焦炭的热强指标也是连续的目标值。这些应用中需要预测的目标变量往往存在许多稀有和极端值。在连续域的不平衡问题在线性模型和深度模型中都是存在的,在深度模型中甚至更为严重,这是因为深度学习模型的预测往往都是over-confident的,会导致这种不平衡问题被严重的放大。

因此,这篇文章定义了深度不平衡回归问题(DIR),即从具有连续目标值的不平衡数据中学习,同时需要处理某些目标区域的潜在确实数据,并使最终模型能够泛化到整个支持所有目标值的范围上。

https://bbs-img.huaweicloud.com/blogs/img/images_162328840109677.png

不平衡回归问题的挑战

解决DIR问题的三个挑战如下:

 对于连续的目标值(标签),不同目标值之间的硬边界不再存在,无法直接采用不平衡分类的处理方法。

 连续标签本质上说明在不同的目标值之间的距离是有意义的。这些目标值直接告诉了哪些数据之间相隔更近,指导我们该如何理解这个连续区间上的数据不均衡的程度。

 对于DIR,某些目标值可能根本没有数据,这为对目标值做extrapolation和interpolation提供了需求。

解决方法一:标签分布平滑(LDS)

首先通过一个例子展示一下当数据出现不均衡的时候,分类和回归问题之间的区别。作者在两个不同的数据集:(1)CIFAR-100,一个100类的图像分类数据集;(2)IMDB-WIKI,一个用于根据人像估算年龄(回归)的图像数据集,进行了比较。通过采样处理来模拟数据不平衡,保证两个数据集具有完全相同的标签密度分布,如下图所示:

https://bbs-img.huaweicloud.com/blogs/img/images_162328846042796.png

然后,分别在两个数据集上训练一个ResNet-50模型,并画出它们的测试误差的分布。从图中可以看出,在不平衡的分类数据集CIFAR-100上,测试误差的分布与标签密度的分布是高度负相关的,这很好理解,因为拥有更多样本的类别更容易学好。但是,连续标签空间的IMDB-WIKI的测试误差分布更加平滑,且不再与标签密度分布很好地相关。这说明了对于连续标签,其经验标签密度并不能准确地反映模型所看到的不均衡。这是因为相临标签的数据样本之间是相关的,相互依赖的。

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zzpsdw.html