(CVPR 2019)Universal Domain adaptation笔记

(CVPR 2019)Universal Domain Adaptation

文章链接

本文主要是针对Universal Domain Adaptation问题提出的方法
Universal Domain Adaptation是指目标域的标签空间未知的无监督领域自适应(Unsupervised Domain Adaptation)问题,如下图所示(CVPR 2019)Universal Domain adaptation笔记

网络结构

(CVPR 2019)Universal Domain adaptation笔记

训练部分,图像 x x x输入进入特征提取器 F F F,得到特征向量 z z z

D ′ D' D是一个非对抗领域判别器, D D D是一个对抗领域判别器,用于判断输入 z z z是源域数据的可能性。 D D D和传统的对抗领域判别器类似,用于混淆源域和目标域的特征数据。不同点在于,损失函数中加入了样布属于公共标签的概率作为权值

损失函数

E G = E ( x , y ) ∼ p L ( y , G ( F ( x ) ) ) E_G=E_{(x,y) \sim p}L(y,G(F(x))) EG=E(x,y)pL(y,G(F(x)))

E D ′ = − E x ∼ p l o g ( D ′ ( F ( x ) ) − E x ∼ q l o g ( 1 − D ′ ( F ( x ) ) ) E_{D'}=-E_{x \sim p}log(D'(F(x))-E_{x \sim q}log(1-D'(F(x))) ED=Explog(D(F(x))Exqlog(1D(F(x)))

E D = − E x ∼ p ω s ( x ) l o g ( D ( F ( x ) ) ) − E x ∼ q ω t ( x ) l o g ( 1 − D ( F ( x ) ) ) E_D=-E_{x \sim p}\omega^s(x)log(D(F(x)))-E_{x \sim q}\omega^t(x)log(1-D(F(x))) ED=Expωs(x)log(D(F(x)))Exqωt(x)log(1D(F(x)))

优化目标

m a x D ( m i n F , G ( E G − λ E D ) ) max_D(min_{F,G}(E_G-\lambda E_D)) maxD(minF,G(EGλED))

m i n D ′ ( E D ′ ) min_{D'}(E_{D'}) minD(ED)

D D D反传的时候包含一个梯度翻转层,,这样在优化 F F F G G G的时候就与D的优化形成了对抗

ω s ( x ) 和 ω t ( x ) \omega^s(x)和\omega^t(x) ωs(x)ωt(x)分别代表源域样本属于公共样本空间的概率和目标域样本属于公共样本空间的概率


novel内容

文中提出了两个假设

1. E x ∼ p C s ‾ d ′ ^ > E x ∼ p C d ′ ^ > E x ∼ q C d ′ ^ > E x ∼ q C t ‾ d ′ ^ E_{x \sim p_{\overline{C_s}}} \hat{d'}>E_{x \sim p_C} \hat{d'}>E_{x \sim q_C} \hat{d'}>E_{x \sim q_{\overline{C_t}}} \hat{d'} ExpCsd^>ExpCd^>ExqCd^>ExqCtd^

2. E x ∼ q C t ‾ H ( y ^ ) > E x ∼ q C H ( y ^ ) > E x ∼ p C H ( y ^ ) > E x ∼ p C s ‾ H ( y ^ ) E_{x \sim q_{\overline{C_t}}} H(\hat{y})>E_{x \sim q_C} H(\hat{y})>E_{x \sim p_C} H(\hat{y})>E_{x \sim p_{\overline{C_s}}} H(\hat{y}) ExqCtH(y^)>ExqCH(y^)>ExpCH(y^)>ExpCsH(y^)

其中
p C s ‾ p_{\overline{C_s}} pCs代表源域中与目标域中类别重叠部分的数据的分布
p C p_C pC代表源域中与目标域中类别重叠部分的数据的分布,
q C t ‾ q_{\overline{C_t}} qCt代表目标域中与源域中类别重叠部分的数据的分布,
q C q_C qC代表目标域中不与源域中类别重叠部分的数据的分布,

以上的假设基于直觉得出,首先观察第一个假设, d ^ ′ \hat{d}' d^代表样本属于源域的概率,文章认为只属于源域的类别的数据更容易被划分成源域的类别,那么 d ^ ′ \hat{d}' d^就会越高。而在公共类别中,源域和目标域数据相互影响, d ^ ′ \hat{d}' d^会相对较低,而只属于目标域的类别的数据 d ^ ′ \hat{d}' d^会最低

H ( y ^ ) H(\hat{y}) H(y^)代表伪标签 y ^ \hat{y} y^的熵,熵越大代表分类结果越确定,越小代表越不确定,思考方式类似于假设1

可以根据以上两个假设构造出 ω s ( x ) 和 ω t ( x ) \omega^s(x)和\omega^t(x) ωs(x)ωt(x)的计算方法

ω s ( x ) = H ( y ^ ) l o g ∣ C s ∣ − d ^ ′ ( x ) \omega^s(x)=\frac{H(\hat{y})}{log|C_s|}-\hat{d}'(x) ωs(x)=logCsH(y^)d^(x)

ω t ( x ) = d ^ ′ ( x ) − H ( y ^ ) l o g ∣ C s ∣ \omega^t(x)=\hat{d}'(x)-\frac{H(\hat{y})}{log|C_s|} ωt(x)=d^(x)logCsH(y^)

构造方法的思路是,对于源域数据,我们希望 ω s ( x ) \omega^s(x) ωs(x)更大,而对于目标域数据,我们希望 ω t ( x ) \omega^t(x) ωt(x)越大


测试阶段

根据 G G G给出的分类结果判定类别,根据 D ′ D' D给出的 ω t ( x ) \omega^t(x) ωt(x)判断其类别是不是unknown,即是不是源域中出现过的类别