1、 年第 期(第 卷总第 期)基于因果关系和特征对齐的图像分类域泛化模型明水根 张 洪(中国科学技术大学 大数据学院 安徽 合肥 中国科学技术大学 管理学院 安徽 合肥)摘 要:针对现有域泛化方法性能较差或缺乏理论可解释性的缺点 提出了一种基于因果关系和特征对齐的图像分类域泛化模型 并证明了该模型的可识别性 该模型利用域泛化中的因果关系来学习含有不同信息的特征 将域泛化问题转化为特征相关分布的偏移 再利用特征对齐消除偏移 为提高模型的性能 采用对抗训练进一步优化学到的特征 在公共数据集上的实验结果表明 新提出的模型与目前最优的方法性能相当 表明该模型具有理论可解释性的同时 还有不俗的实际性能表现
2、关键词:域泛化 变分自编码器 因果关系 特征对齐 对抗训练中图分类号:文献标识码:/引用格式:明水根 张洪.基于因果关系和特征对齐的图像分类域泛化模型.网络安全与数据治理 ():.():.:引言深度学习在计算机视觉和自然语言处理等领域都取得了惊人的成就 传统的深度学习方法基于一个基本假设:训练数据和测试数据是独立且同分布的()但是 在现实任务中 如医学成像和自主驾驶等领域 这种 假设通常不成立 当训练数据的分布(也称为域)与测试数据的分布不同时 由于存在分布差距 训练出的模型通常表现不佳 训练数据的域与测试数据的域不同的现象也被称为域偏移 上述观点促进了域适应()和域泛化()的研究 域泛化的目
3、标是从多个相似分布(也称为源域)中学习一种通用表征 一般数据都存在某些与输出(即标签)相关且在不同域间都保持不变的特征 那么就可以将这种特征迁移到具有未见过分布(也称为目标域)的测试数据上 域泛化任务的示例如图 所示大多数域泛化研究集中在学习不受域干扰的表征从而得到所谓的域不变特征 例如 等采用了域适应研究中的思想 使用对抗训练来学习域不变特征以解决域泛化问题 尽管这些基于学习策略的方法在真实世界的任务中表现良好 但缺乏理论可解释性 等和人工智能 年第 期(第 卷总第 期)图 域泛化任务实例 等采用特征解耦方法来学习标签的特定特征 并希望这些特征是关于域不变的 这种基于特征解耦的方法在理论上是
4、可解释的 但在实际的域泛化任务中表现不佳 因此 研究既具有理论可解释性又在真实的域泛化任务中表现良好的方法非常重要本文提出了一种称为对抗域不变变分自动编码器()的模型来解决域泛化问题 该模型先使用变分自动编码器()框架将输入数据解耦成三个潜在因子:域信息因子、标签信息因子和包含任何残留信息的因子 然后 将因果关系引入到域泛化任务中 将域偏移问题拆分为两个相关分布的偏移问题 为了修正这些偏移 本文采用特征对齐方法来学习域不变特征 此外 为了解决 中存在的解耦不完全问题 本文采用对抗训练来消除潜在因子中的混淆信息 以进一步提高模型的实际表现 本文在两个域泛化公开数据集 和 上 对 进行了大量实验
5、实验结果表明 在域泛化表现方面具有与目前最优方法相当的竞争力 本文的主要贡献有:其一 本文基于域泛化中的因果关系使用特征对齐来解决域偏移问题 其二 本文采用对抗训练来解决 特征解耦不完全的问题 提升 在实际任务中的性能 其三 本文证明了 的可识别性理论 相关工作 学习不变表征由于在域泛化中目标域是未知的 因此许多方法尝试去学习一个关于域不变的数据表征(即给定标签后学习到的表征是有条件或无条件独立于域的)等提出了域对抗神经网络去解决域适应问题 该网络使用对抗训练来减少不同域数据之间的特征差异 这也启发了许多域泛化方面的研究 本文也采用对抗训练来减少特征的域差异 提升模型的实际表现 解耦学习与因果
6、关系很多文献表明 因果关系和域泛化有着深刻的联系 域泛化问题的关键是域偏移 可以看作是一种干预 因此将因果关系引入域泛化任务中是自然且有效的但由于因果关系很难学习 一些域泛化工作便采用特征解耦方法来引入因果关系 特征解耦方法旨在将输入数据映射到具有域特定或标签特定的特征上 域不变变分自编码器()就是一个经典的基于特征解耦方法的模型该模型使用 框架将输入数据分解为标签因子、域因子和包含其他信息的因子 再进行优化 然而 缺乏有效的训练策略 导致了特征解耦不完全以及实际表现一般的问题 本文借鉴了 的思想 也使用了基于 框架的特征解耦方法 学习策略其他域泛化算法大多是基于学习策略的方法 包括数据操作、
7、基于集成学习的方法、基于元学习的方法等 这些方法通常在真实的域泛化任务中表现良好但在某种程度上缺乏理论可解释性 因此 开发既在真实域泛化任务上表现良好又具有理论可解释性的方法至关重要 模型与方法 任务定义与记号首先使用数学符号来定义域泛化问题 记、和 分别为输入样本、标签和域的空间 其相应的随机变量用、和 表示 其中 表示训练集数据中所有 个源域的集合 在图像分类任务中 一般指图像数据 为图像对应的标签 为图像对应的域 用 表示目标域 即测试集数据所在域 其数据仅在测试时可用 又记域 的数据联合分布为()测试域的数据联合分布记为()域泛化中的深度学习模型通常都含有特征提取器(编码器)()和标签
8、分类器 ()域泛化问题的目标就是从 个源域中抽取数据样本用来训练深度学习模型 从而学习出一个广义的预测函数:使其在未知测试域 上达到最小的预测误差:()()()()其中 为损失函数 模型结构许多研究表明 将因果理论引入域泛化问题中可提高模型泛化能力 图 显示了本文在此采用的域泛化问题对应的因果图 该图代表了实际域泛化任务中最常见的设置 由图 可得 不同域的标签边缘分布()不同(也被称为目标偏移)条件分布 ()也不同(也被称为条件偏移)投稿网址:年第 期(第 卷总第 期)图 域泛化任务的因果图应用全概率公式 如果所有源域的条件分布()和边缘分布()都能够对齐(即修正偏移)那么模型将学习到一个关于
9、域不变的特征 即对于任何 都有()()此外 在深度学习模型中 对()和()进行建模是自然和简单的 更确切地说 通过源域 中的各标签比例可以简单地表示其标签分布()而对于分布()则可以通过特征提取器近似表示 其准确的表达几乎不可能获得首先 本文采用变分贝叶斯推断 使用高斯分布族来近似真实分布()自然地 就构建了基于 框架的模型 受 的启发 本文将输入数据的特征解耦为三个部分:域特定信息、标签特定信息和其他剩余信息 并用 中的隐变量(也称潜在因子)表示它们 分别记为 和 隐变量 和 分别用于进一步预测标签 和域 接下来 对 中的相关分布进行对齐 以学习到关于域不变的特征 此外 为了进一步解决 框架
10、中的不完全分解问题本文应用对抗训练来减少潜在因子中的混淆因素 基于以上思路 本文提出了适用于域泛化任务的对抗域不变变分自编码器 模型 的详细模型结构如图 所示 包括编码器、解码器、域分类器和标签分类器图 的模型结构 损失函数 变分推断损失记 ()为隐变量()的后验分布 这也视为 ()的一个近似 仿照大多数基于 框架模型的做法 可假设 ()是一个完全因子化的高斯分布 其中包含可学习的模型参数 通过先验分布和后验分布的 ()散度来优化模型参数 此 散度可以进一步转换为置信下界()它是一个可使用训练数据和模型参数进行具体计算的目标函数 仿照 等的方法 可以将此变分推断的 表示为:()()()()()
11、()()()()其中第一项中的 ()代表了 中解码器的生成过程 ()()和 ()分别代表了 和 的先验分布 而 ()()和 ()分别代表了 和 的后验分布 这里后验分布 ()已通过特征解耦被分解为()()()在实际应用中 中的第一项一般用原始输入 和解码器输出的 损失或均方误差损失代替 而后三项中的分布均为带可学习参数的高斯分布 由此可保证 能很容易地计算 特征对齐损失为了解决条件偏移问题 可以将真实的后验分布()进行对齐来学习关于域不变的特征 由于()计算困难 无法准确对齐 因此本文对()的估计 是对 ()进 行 对 齐从而近似对齐 ()由于域泛化的最终目的是预测输入数据 的标签 而隐变量
12、是唯一用于预测标签的特征 因此只需确保隐变量 在给定 和 条件下的条件分布 ()能够对齐即可 同样地 本文对由变分贝叶斯推断得到的 ()的估计 是对()进行对齐 从而近似对齐 ()在 节中 由于后验分布 ()比较容易计算 因此本文通过对齐具有相同标签的样本对应的()来对齐 ()且在此过程中 完全忽略样本的域信息 为此 本文采用对比损失来确保模型提取的样本标签特征 在具有相同标签的样本之间尽可能接近 不同标签的样本之间则尽可能远离 且同样忽略样本的域信息 对比损失是无监督学习中提出的一种特殊损失 在深度学习中被广泛使用本文将具有相同标签但不同域的两个输入视为一对人工智能 年第 期(第 卷总第 期
13、)正匹配 将来自不同标签的两个输入视为一对负匹配故自然地希望正匹配的特征之间的差异远小于负匹配的特征之间的差异 假设对于一个输入图像 存在一个与 具有相同标签但不同域的样本 以及 个与具有不同标签的样本 即 和 构成一对正匹配 和 构成一对负匹配 其中 则本文的对比损失可定义为:()(/)(/)(/)()其中 是一个衡量 和 的特征相似度的度量 是一个针对困难样本的超参数如前所述 本文旨在通过对齐 ()来学习到关于域不变的特征 为此 这里使用 距离来度量不同分布间的差距 进而度量不同样本特征之间的相似性 假设 和 的后验分布 ()分别为 ()和 ()其中 和 分别对应样本 和 此时 相应的 距
14、离为:()/()将式()中的 用 ()代替 其中 至此 就得到了 模型中使用的分布对比损失 在实际实验中 分布对比损失由如下的损失平均值代替:()()其中 是实验中训练集一个批次中的样本数量 需要注意的是 最小化分布对比损失时完全忽略了输入样本 的域信息 这样可以将特征 中的域信息排除 从而得到一个关于域不变的特征 分类损失与对抗训练域泛化本质上还是一个分类任务 因此还需要一个针对标签分类器 ()的标签分类损失 标准的标签分类损失是通过计算真实标签边际分布 ()和标签分类器输出的标签后验分布 ()之间的交叉熵损失得到的 然而 采用标准标签分类损失是无法解决目标偏移问题的 为了解决目标偏移问题
15、本文通过自适应加权来对齐标签后验分布 且在文献的附录 中推导出了如下的自适应权重:()()()()其中 代表域 ()为域 数据中的标签分布 而()为训练集数据的标签分布 训练集由所有源域数据组成 于是就得到了最终的加权标签分类损失为:()()其中 为不同标签的总数 ()为样本 的真实标签分布 ()为由标签分类器 输出的标签后验分布正如图 所示 还包含一个域分类器 用于特征解耦中学习域因子 故相对应地有一个关于域的标准分类损失 其公式为:()()其中 ()代表交叉熵损失通过最小化加权标签损失和域分类损失 标签信息和域信息分别在 和 中会得到相应的增加 然而 经过对齐后的标签因子 仍然有域信息残留
16、 这将导致域偏移问题不能得到完全解决 为了消除潜在因子中冗余的信息(即标签因子 中残留的域信息)本文借鉴了 等提出的对抗式训练方法 最终 的对抗训练模块如图 底部所示 在该对抗训练模块中 标签因子 在经过一个梯度反转层()后 又被输入到域分类器 中 此时 自然地产生了一个由 得到的域分类损失:()()其中 表示 层 层在神经网络模型的前向传播阶段没有任何影响 但在反向传播阶段 通过该模块的梯度方向会被反转 即添加一个负号 因此 当最小化时 编码器 的参数将向最大化域分类损失 的方向学习 换句话说 标签因子 会使域分类器 越来越无法正确预测其输入所在的域 从而可以清除 中含有的域信息 另一方面
17、域分类器 中的参数仍然向最小化域分类损失 的方向学习 以期望正确地预测其输入所在的域 这就形成了对抗式训练注意到上述对抗训练产生的损失本质上还是分类损失 故可将标签分类损失、域分类损失和对抗训练损失线性组合为一个最终的分类损失 具体公式如下:()其中参数 用来平衡分类损失和对抗训练产生的损失最后 结合变分推断、特征分布对齐以及分类产生的损失 就得到了 模型的总损失 可以表示为:()其中 为权重系数 理论结果在本节中 本文将建立 的可识别性理论 许多文献专注于深层潜变量模型 包括 的可识别性研究 其中大多数模型被证明是不可识别的 因此基于 框架的模型是无法保证模型可识别性的 需要投稿网址:年第
18、期(第 卷总第 期)额外的约束条件 等证明了具有非条件先验隐变量的 是不可识别的 但在一些特殊条件下 具有条件因子先验的 模型是可识别的 下面的定理证明了基于 框架的 模型具有可识别性定理 假设 模型中 隐变量()的先验分布属于高斯分布族 且具有如下的形式:()()()()()则 模型是可识别的 至多相差一个线性变换定理的证明可见文献 的附录 可识别性定理的证明 保证了 模型的理论可行性 也为模型可解释性提供了支撑 实验 实验设置 数据集选择公开推荐数据集 和 作为实验评估数据集 数据集是 的“旋转”变体 随机从 数据集中抽取了 类数字灰度图作为基础图像 然后以角度 间隔 分别旋转基础图像中的
19、数字 具有相同旋转角度的图像被视为具有相同的域 因此 旋转 数据集包含了 个域 分别记为 每个域包含 个数字类别 每个类别的图像数量均为 张共 张图像 数据集则更复杂 是从网络上收集了包含着 个共享类别(狗、大象、长颈鹿、吉他、马、房子和人类)的 张 图像 涵盖着 个不同的域(照片、艺术画、漫画和素描)具体信息请参见文献 附录 中的数据集说明 需要注意的是 数据集不存在标签漂移问题 而 数据集中存在严重的标签漂移问题 评估指标与大多数域泛化研究一样 本文采用留一法()来评估模型的域泛化性能 留一法从数据集中单独取一个域数据作为目标域(测试集)数据而剩下其他域数据则作为源域(训练集)数据 数据集
20、中的每个域都会被选为目标域进行实验 记录模型在每个目标域上的标签分类准确率 所有准确率的平均值会被用来度量模型的域泛化能力 准确率越高 则表明模型域泛化能力越强 参数设置本文在 环境中使用 深度学习框架进行实验 在所有实验中 框架中的隐变量维度均设置为 分布对比损失中 正负匹配比为:(即式()中 )对于 数据集 本文使用预训练好的 来提取图像特征 而对于 数据集 由于灰度图的简单性和模型对比的公平性 本文使用和 相同的自建 来提取图像特征 实验中使用 作为优化器 初始学习率为 基线本文分别在 和 数据集上将本文模型 与各种基线模型作比较 对于 数据集 本文选取了、和 作为对比模型、和 都是基于
21、自编码器的早期模型 是采用对抗训练来消除域信息的模型 是基于 来学习域不变特征的模型 对于 数据集 本文选取了、和 作为对比模型 模型完全忽略域变量 只是简单使用 进行图像分类 是受拼图游戏启发的域泛化算法 模型将聚类算法和对抗训练策略结合 和 都是基于学习策略的模型 是基于数据操纵方法的模型 这些对比模型的实验结果均来自于模型的原始论文 比较实验表 中记录了在 数据集做留一实验时 个域中各模型的标签分类准确率 最后一列是所有域的平均标签准确率 从表中数据可以看出 在大多数域泛化任务中 的分类准确率最高 明显优于其他方法 尤其是在目标域 和 这两个任务上表现更加突出 要注意到的是 大部分模型在
22、 和 这两个域上的分类准确率都不如其他域 说明这两个域作为目标域时的域泛化任务难度更高 的表现优于 这表明通过对抗训练和特征对齐可以改善不充分的特征解耦 提升模型的域泛化性能 此外 基线(和)的域泛化表现远不如专门针对域偏移优化的方法()这说明解决域偏移问题是解决域泛化任务的关键表 的实验结果()模型平均值 人工智能 年第 期(第 卷总第 期)表 记录了在 数据集做留一实验时 个域中的标签分类准确率以及它们的平均值 从表中数据可以看出 虽然本文模型 没有取得最高的平均值表现但是仍然优于大多数基线模型 需要注意的是 除了素描域外 在其他 个域上 和最优表现模型(和)的域泛化表现相当 在素描域表现
23、相对差的原因可能是素描域中的图像太简单(仅由线条和空白组成)难以使用像 这样的自编码器网络进行学习 这也可以解释为什么大多数方法都在素描域上表现最差 但是 优于 的两个模型(和)完全基于一些特定的学习策略 相当缺乏理论可解释性而 采用了 框架和基于因果理论的方法 因此具有很好的可解释性 这也说明 能够在保证较好域泛化表现的同时 仍然具有优秀的理论可解释性 使两者达成一个平衡表 的实验结果()模型艺术画卡通画素描图片平均值 消融实验为验证对抗训练和特征对齐在 中的有效性本节对这两个模块进行了消融实验()消融对抗训练此时从 模型中去除对抗训练模块 即去掉 层和对抗训练损失 记此时的消融模型为()消
24、融特征对齐 此时从 模型的总损失中去除分布对比损失项 和自适应权重 记此时的消融模型为()同时消融对抗训练和特征对齐 此时的消融模型类似于 模型 记为 表 显示了在 数据集上进行消融实验的结果表 在 数据集上的消融实验结果()模型艺术画卡通画素描图片平均值 从表 可以看出 随着对抗训练和特征对齐模块的去除 模型的域泛化评估指标都在不同程度上有所降低当两者都去除后 模型的评估指标达到最低 这些结果都验证了 模型中对抗性训练模块和特征对齐操作的有效性 结论本文提出了一种新型的深度学习模型 用于解决域泛化问题 其主要思想是通过训练具有三个潜在因子的 来学习域不变特征 基于域泛化任务背后的因果理论 本
25、文使用特征对齐方法 使用新的分布对比损失和加权分类损失来解决域偏移问题 为了解决解耦不完全问题 本文使用了对抗训练来优化相关特征 同时 本文还证明了 的可识别性 为 提供了理论依据 在两个公开数据集上的实验结果表明 在域泛化任务上的表现与目前最优算法相当 总之 不仅具有良好的理论可解释性 还能保证较高的域泛化性能参考文献 ./():.:/:./:.:.():.:.:()():./:./:.投稿网址:年第 期(第 卷总第 期).:./:./:./:./:.“”/.().:/././:.:/:./():./():.:.():./:./():./:.(收稿日期:)作者简介:明水根()男 硕士研究生 主要研究方向:深度学习、因果推断张洪()男 博士 教授 博士生导师 主要研究方向:机器学习、因果推断、遗传统计人工智能