论文信息
论文标题:Unsupervised Domain Adaptation for COVID-19 Information Service with Contrastive Adversarial Domain Mixup
论文作者:Huimin Zeng, Zhenrui Yue, Ziyi Kou, Lanyu Shang, Yang Zhang, Dong Wang
论文来源:aRxiv 2022
论文地址:download
论文代码:download
1 Introduction
2 Problem Statement
Regarding misinformation detection, we aim at training a model $f$ , which takes an input text $boldsymbol{x}$ (a COVID-19 claim or a piece of news) to predict whether the information contained in $boldsymbol{x}$ is valid or not (i.e., a binary classification task). Moreover, in our domain adaptation problem, we use $mathcal{P}$ to denote source domain data distribution and $mathcal{Q}$ for the target domain data distribution. Each data point ($boldsymbol{x}$, $y$) contains an input segment of COVID-19 claim or news ($boldsymbol{x}$) and a label $y in{0,1}$ ( $y=1$ for true information and $y=0$ for false information). To differentiate the notations of the data sampled from the source distribution $mathcal{P}$ and the target distribution $mathcal{Q}$ , we further introduce two definitions of the domain data:
-
- Source domain: The subscript $s$ is used to denote the source domain data: $mathcal{X}_{s}=left{left(boldsymbol{x}_{s}, y_{s}right) midleft(boldsymbol{x}_{s}, y_{s}right) sim mathcal{P}right}$ .
- Target domain: Similarly, the subscript t is used to denote the target domain data: $mathcal{X}_{t}=left{boldsymbol{x}_{t} mid boldsymbol{x}_{t} sim mathcal{P}right}$ . Note that in our unsupervised setting, the ground truth labels of target domain data $y_{t}$ are not used during training.
Our goal is to adapt a classifier $f$ trained on $mathcal{P}$ to $mathcal{Q}$ . For a given target domain input $boldsymbol{x}_{t}$ , a well-adapted model aims at making predictions as correctly as possible.
3 Method
整体框架:
3.1 Domain Discriminator
第一步是训练一个域鉴别器 $f_{D}$ 来分类输入数据是属于源域还是属于目标域。该域鉴别器与 COVID 模型共享相同的 BERT Encoder $f_{e}$,并具有不同的二进制分类模块 $f_{D}$。域鉴别器以 BERT Encoder 中的标记 [CLS] 表示作为输入,以预测输入数据的域,如所示:
$hat{y}=f_{D}(boldsymbol{z}) quadquad(1)$
其中,$z$ 是 token [CLS] 的表示。
对于 $f_{D}$ 的训练,明确地将源域数据的域标签 $y_{D}$ 定义为 $y_{D}=0$,将目标域数据的域标签定义为 $y_{D}=1$。因此,对域鉴别器的训练可以表述为:
$underset{f_{D}}{text{min}} ;; mathbb{E}_{left(boldsymbol{x}, y_{D}right) sim mathcal{X}^{prime}}left[lleft(f_{D}left(f_{e}(boldsymbol{x})right), y_{D}right)right] quadquad(2)$
其中,$mathcal{X}^{prime}$ 表示带有域标签的源域和目标域训练数据的合并数据集。
3.2 Adversarial Domain Mixup
在训练了域鉴别器后,我们提出直接干扰来自源域和目标域的输入数据的潜在表示到域鉴别器的决策边界,如 Figure 1b 所示。为此,来自两个域的扰动表示(即域对抗表示)可以变得更接近,表明域间隙减小。在此,从两个域生成的域对抗性表示在模型的潜在特征空间中形成了一个平滑的中间域混合。在数学上,通过求解一个优化问题,可以找到干扰训练样本 $ boldsymbol{x}$ 的潜在表示 $ boldsymbol{z}$ 的最优扰动 $delta^{*}$:
$begin{array}{r}mathcal{A}left(f_{e}, f_{D}, boldsymbol{x}, y_{D}, epsilonright)=underset{boldsymbol{delta}}{text{max}} left[lleft(f_{D}(boldsymbol{z}+boldsymbol{delta}), y_{D}right)right] \text { s.t. } quad|boldsymbol{delta}| leq epsilon, quad boldsymbol{z}=f_{e}(boldsymbol{x})end{array}quadquad(3)$
注意,在上面的方程中,我们引入了一个超参数 $epsilon$ 来约束扰动 $delta$ 的范数,从而避免了无穷大解。最后,将 $text{Eq.3}$ 应用于合并训练集 $mathcal{X}^{prime}$ 中的所有训练样本,得到对抗域混合 $mathcal{Z}^{prime}$:
$begin{aligned}mathcal{Z}^{prime} & =left{boldsymbol{z}^{prime} mid boldsymbol{z}^{prime}=boldsymbol{z}+mathcal{A}left(f_{e}, f_{D}, boldsymbol{x}, y_{D}, epsilonright),left(boldsymbol{x}, y_{D}right) in mathcal{X}^{prime}right} \& :=mathcal{Z}_{s}^{prime} cup mathcal{Z}_{t}^{prime}end{aligned}quadquad(4)$
其中,$mathcal{Z}_{s}^{prime}$ 是扰动的源特性,$mathcal{Z}_{t}^{prime}$ 是受干扰的目标特征。我们使用投影梯度下降(PGD)来近似 $text{Eq.3}$ 的解,如在[7],[8]。
3.3 Contrastive Domain Adaptation
接下来,受[6]的启发,我们提出了 $mathcal{Z}_{a d v}$ 的双重对比自适应损失,以进一步将源数据域的知识适应到目标数据域。首先,我们减少了类内表示之间的域差异。也就是说,如果一个表示从源数据域的标签是真(或假)和一个表示从目标数据域的伪标签是真(或假),那么这两个表示被视为类内表示,我们减少域之间的差异。其次,如 Figure 1c 所示,真实信息和虚假信息的表示之间的差异将被扩大。
为了计算我们提出的对比自适应损失,我们建议使用径向基函数(RBF)来度量标记类之间的差异。在[11]中,RBF 被证明是量化深度神经网络中不确定性的有效工具。由于我们的伪标记过程是为了自动过滤出目标域数据的低置信度标签,因此使用RBF来衡量标记类之间的差异可以有效地提高伪标签的质量,最终有助于模型的域适应。
在形式上,使用 RBF 内核的定义:$kleft(z_{1}, z_{2}right)=exp left[-frac{left|boldsymbol{z}_{1}-boldsymbol{z}_{2}right|^{2}}{2 sigma^{2}}right]$
我们定义了错误信息检测任务的类感知损失如下:
$begin{aligned}mathcal{L}_{text {con }}left(mathcal{Z}^{prime}right) =&-sum_{i=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{j=1}^{left|mathcal{Z}_{t}^{prime}right|} frac{mathbb{1}left(y_{s}^{(i)}=0, hat{y}_{t}^{(j)}=0right) kleft(boldsymbol{z}_{s}^{(i)}, boldsymbol{z}_{t}^{(j)}right)}{sum_{l=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{m=1}^{left|mathcal{Z}_{t}^{prime}right|} mathbb{1}left(y_{s}^{(l)}=0, hat{y}_{t}^{(m)}=0right)} \& -sum_{i=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{j=1}^{left|mathcal{Z}_{t}^{prime}right|} frac{mathbb{1}left(y_{s}^{(i)}=1, hat{y}_{t}^{(j)}=1right) kleft(boldsymbol{z}_{s}^{(i)}, boldsymbol{z}_{t}^{(j)}right)}{sum_{l=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{m=1}^{left|mathcal{Z}_{t}^{prime}right|} mathbb{1}left(y_{s}^{(l)}=1, hat{y}_{t}^{(m)}=1right)} \& +sum_{i=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{j=1}^{left|mathcal{Z}_{s}^{prime}right|} frac{mathbb{1}left(y_{s}^{(i)}=1, y_{s}^{(j)}=0right) kleft(boldsymbol{z}_{s}^{(i)}, boldsymbol{z}_{s}^{(j)}right)}{sum_{l=1}^{left|mathcal{Z}_{s}^{prime}right|} sum_{m=1}^{left|mathcal{Z}_{s}^{prime}right|} mathbb{1}left(y_{s}^{(l)}=1, y_{s}^{(m)}=0right)} \& +sum_{i=1}^{left|mathcal{Z}_{t}^{prime}right|} sum_{j=1}^{left|mathcal{Z}_{t}^{prime}right|} frac{mathbb{1}left(hat{y}_{t}^{(i)}=1, hat{y}_{t}^{(j)}=0right) kleft(boldsymbol{z}_{t}^{(i)}, boldsymbol{z}_{t}^{(j)}right)}{sum_{l=1}^{left|mathcal{Z}_{t}^{prime}right|} sum_{m=1}^{left|mathcal{Z}_{t}^{prime}right|} mathbb{1}left(hat{y}_{t}^{(l)}=1, hat{y}_{t}^{(m)}=0right)}end{aligned}quadquad(5)$
其中,$hat{y}_{t}$ 为目标域样本的伪标签,$z$ 表示标记 CLS 的表示。
3.4 Overall Contrastive Adaptation Loss
现在,我们将任务分类问题的交叉熵损失和上述对比自适应损失合并为 COVID 模型的单一优化目标:
$mathcal{L}_{text {all }}=mathcal{L}_{c e}(boldsymbol{mathcal { X }})+lambda mathcal{L}_{text {con }}left(mathcal{Z}^{prime}right) quadquad(6)$
其中,$mathcal{L}_{c e}$ 代表交叉熵损失函数。
4 Experiment
在我们的实验中,我们使用了三个 source misinformation datasets :GossipCop , LIAR and PHEME,两个 COVID misinformation datasets:Constraint and ANTiVax。
Results: