1. 引言
实际应用中,某些领域数据有采集困难、采集成本高、标注样本难度大、时间成本高等问题,如医疗诊断、机器故障检测、汽车喷漆等领域。因此,小样本学习就是研究在样本量小的情况下,通过少量的样本训练泛化能力强的预测模型。小样本学习大多基于数据增强、基于迁移、元学习等思想来解决。
基于数据增强的方法主要研究通过辅助数据对原本的数据集进行扩充,或者是进行特征增强。但该方法有可能会在数据集中引入噪声数据,对特征提取进行干扰,使得分类精度降低。基于迁移的方法研究寻找多个任务间的相似性,将训练好的模型应用到新的任务上。Oquab等 [1] 在ImageNet数据集上对模型进行预训练,对分类层进行改进,将新的网络用于数据集分类。但以上方法通常需要大量的计算,且如果任务间差距很大或者没有相似性,那么效果并不理想。
基于元学习的方法研究如何学习,它旨在让模型能在不同的任务中获取学习的能力来指导完成新任务。常见的元学习方法是将样本投影到度量空间中,利用度量函数计算样本间的相似度,之后通过最邻算法预测类别分类 [2] 。原型网络 [3] 正是基于这个思想构建的,该网络提取每类样本的均值作为该类的原型,在分类过程中,计算待分类的样本与各类原型的距离,将距离最小的作为类别预测结果。然而对样本特征不明显或分布散的情况,用样本均值代替类别特征并不准确。
本文研究小样本图像分类任务,基于元学习中原型网络的不足,提出加权聚合原型的概念,并构建WPNet模型。考虑样本特征不明显、分散、受噪声影响大等情况,根据距离对每个类别的样本赋予权值,利用加权求和的方式得到该类的原型。实验结果显示,本文的方法提高了图像分类准确率。
2. WPNet
本节提出一个用于解决小样本图像分类的网络模型WPNet (Weighted aggregation prototype Networks),模型示意图如图1。
在小样本学习中,设我们当前的数据集为D,其内部的样本形式为:
其中
代表样本,
代表对应标签。在原型网络中提取类别特征是对类别样本求均值,本文提出加权聚合原型代替求均值的策略。
2.1. 加权聚合原型
加权聚合原型是对同一类别的每一个样本特征赋予权值,最后再利用加权求和的方式得到该类的原型,见图2。
Figure 2. Weighted aggregation prototype diagram
图2. 加权聚合原型示意图
权值的学习方法如下
(1)
其中
(2)
(3)
(4)
h为权值计算函数。通过特征提取网络得到样本的特征向量表示为
,计算加权聚合原型:
(5)
其中
是输入样本,
是对应的标签,
是样本特征向量,
是各样本的权重,n表示第n类别的数据集。
2.2. 损失函数
本文的损失函数采用三元损失函数 [4] ,它可以使得网络拉近同种类之间的距离,增加不同种类之间的距离。假设
是数据集中要分类的一个样本,则三元损失函数为:
(6)
其中
表示三元损失函数,
表示和
是同类别的样本,
表示和
是不同类别的样本。则
构成了一个三元组,
的目的是减少
和
的距离,增大
和
的距离,
表示一个常数,最后的损失函数L为softmax与
的结合:
(7)
2.3. WPNet流程
下面描述下求解小样本图像分类问题的WPNet流程。对于N-way,K-shot任务,N代表训练集中的类别数量,K代表样本数量,
是每个任务批次中的类别数目,
是每个任务中的支持集样本数量,
是每个任务中的查询集样本数量。
首先从数据集D中随机选
个样本,形成支持集S。同理,随机选取
个样本,形成查询集Q。
对于支持集S利用公式(1)~(5)求出每类的加权聚合原型。同理求出查询集Q中的加权聚合原型。
对于需要计算的是待分类样本属于哪一类,使用softmax计算分类结果。
最终通过公式(6)和(7)计算损失函数L并更新参数。
以上就是整个算法的流程。
3. 实验结果与分析
本节为实验结果与分析,通过在小样本数据集MiniImageNet和CUB-200上进行实验,之后对比不同网络在数据集上的效果,并对实验的结果做进一步的分析。
3.1. 数据集介绍
本节第一次实验采用的数据集为MiniImageNet,该数据集包含60,000张图片,共100类,每类下有600张图片,训练集、验证集、测试集类别比例是16:4:5。第二次实验采用的数据集是CUB-200,该数据集是一个鸟类数据集,有11,788张图片,共200类,训练集、验证集、测试集类别比例是2:1:1。实验前统一对数据集进行预处理,保证样本的标准化。首先调整输入大小,将图像Resize统一设置为84 × 84。接下来进行归一化处理,图像像素值为(0, 255),因此将图像中各位置的像素点值除255,得到了归一化后的数据集。
3.2. 实验设置
本文实验均在同一环境下使用同样的设置。操作系统为CentOS Linux 7,使用学习率为0.001的Adam优化器。在两个数据集上进行5-way 1-shot和5-way 5-shot的实验,模型准确率为所有分类任务的平均准确率。由于本文是基于元学习提出的模型,因此也是与基于元学习的模型进行对比实验。
3.3. MiniImageNet实验结果与分析
本小节在MiniImageNet数据集上进行,该数据集常用于检验小样本分类模型的性能。同时在相同的实验设置下,对8个广泛应用的小样本模型进行实验对比,结果如表1所示。
Table 1. Classification results on MiniImageNet
表1. MiniImageNet数据集上的分类结果
对表1的实验结果分析可知,本文的方法在5-way 1-shot和5-way 5-shot上效果都很好,尤其是5-way 5-shot中,表现优秀,精确度最高为81.75。对比其余网络,MAML [5] 、Matching nets [6] 、PN、Meta-learn LSTM [7] 等都是基于元学习的网络模型。本文的WPNet比原型网络PN的精确度提高了约5%和1.22%,这是由于原型网络中用均值代替原型的方法不适用类别样本分散或类别特征不明显的情况,WPNet提出加权聚合原型,有效的改善了此情况,使得网络性能得到了提升。
3.4. CUB-200实验结果与分析
本节实验设置与3.2中的一样,对4个广泛应用的小样本模型进行实验对比,结果如表2所示。
Table 2. Classification results on CUB-200
表2. CUB-200数据集上的分类结果
分析表2可知,由于CUB-200都是鸟类图片,虽然不同种类的鸟不一样,但与在MiniImageNet数据集相比,分类任务相似,因此相同网络的模型准确率比表1中的高。本文的WPNet比PN在5-way 1-shot和5-way 5-shot任务上准确率分别提高了7.67%和6.85%。在5-way 1-shot任务和5-way 5-shot任务中,网络表现最好的都是TriNet [8] ,本文的WPNet比TriNet分别提高了7.15%和5.25%。分析表明,本文提出的小样本学习方法能更好的提取类别特征,提高模型分类准确率,验证了WPNet能更好的处理小样本图像分类任务。
3.5. 与PN的浮点数对比实验
本文是基于原型网络提出的模型,是对原型提取策略进行改进,对网络结构没有修改。本节实验对WPNet和PN两个网络的浮点数进行估算对比,对比浮点数相当于对比两个网络的计算量,实验结果见表3。
从表3中可以知道,PN的浮点数为1.804 billion,本文提出的WPNet浮点数为1.827 billion,大约增加了0.022 billion,在计算机中增加的计算时间和难度约等于零。结合之前的两个实验结果,说明本文提出的WPNet不仅效果好,分类准确率高,且几乎没有增加计算难度与时间成本,有较高的实际应用能力。
4. 总结
本文考虑了小样本图像分类问题,基于元学习中的原型网络,提出加权聚合原型的概念,优化了提取原型的方法。通过在MiniImageNet和CUB-200小样本数据集上进行实验,与经典网络进行对比,实验结果显示本文的模型表现优秀,不论是5-way 1-shot任务或5-way 5-shot任务中准确率都最高。且本文通过与原型网络的浮点数对比,表明本文模型复杂度并没有增加,性能和实用性都很高。
基金项目
吉林省自然科学基金(20220101040JC)。
NOTES
*通讯作者。