【论文笔记】2019-WWW-Multiple Treatment Effect Estimation using Deep Generative Model with Task Embedding
背景
这篇文章考虑了一个新的causal inference设定:treatment不是简单的二元变量 { 0 , 1 } \left\{0,1\right\} {0,1},而是二元变量的组合 { 0 , 1 } k \left\{0,1\right\}^k {0,1}k。这个设定也比较好理解,还用医生治病的例子来说,通常医生使用的是多种药的组合。如果总共涉及到三种药物,而病人使用了第一种和第三种,则对应的 k = 3 k=3 k=3,treatment就是 [ 1 , 0 , 1 ] [1,0,1] [1,0,1]。
挑战
这个设定的挑战在于如何设计针对多个treatment的网络结构。在经典的TARnet和Dragonnet中,作者针对 p ( y ∣ t = 0 , x ) p(y|t=0,x) p(y∣t=0,x)和 p ( y ∣ t = 1 , x ) p(y|t=1,x) p(y∣t=1,x)都设计了不同的网络,如果本文也沿用这个方法,就会出现网络结构冗余的问题。比如例子中涉及到3个treatment的组合,那就要相应设计 2 3 = 8 2^3=8 23=8个网络,非常不高效,还会出现因为数据分布不均匀网络训练不准确的问题。
方法
整体的框架还是套用的CEVAE(可以参见笔者写的上一篇文章),创新之处在于引入了一个可学习的embedding matrix。
Encoder
网络结构如下图所示:
前向传播:首先输入
x
x
x会经过网络
g
1
g_1
g1得到
q
(
t
∣
x
)
=
∏
i
=
1
k
B
e
r
n
(
q
t
,
i
)
q(t|x)=\prod_{i=1}^k Bern(q_{t,i})
q(t∣x)=∏i=1kBern(qt,i),然后从
q
(
t
∣
x
)
q(t|x)
q(t∣x)中采样得到
t
′
t'
t′(这里有个问题就是怎么反向传播?采样得到
t
′
t'
t′没法反向传播吧),接下来
t
′
t'
t′会和一个embedding matrix
W
W
W相乘得到新的表示
τ
=
W
⋅
t
′
\tau=W\cdot t'
τ=W⋅t′。新表示
τ
\tau
τ经过网络
g
2
g_2
g2得到
q
(
y
∣
t
,
x
)
=
N
(
g
2
,
1
)
q(y|t,x)=N(g_2,1)
q(y∣t,x)=N(g2,1),这里方差设为1也是为了简单防止过拟合吧,避免网络中要学习太多变量。之后,作者把
τ
,
x
,
g
2
\tau, x, g_2
τ,x,g2concatenate到一起得到
g
3
g_3
g3和
g
4
g_4
g4的输入,
g
3
g_3
g3和
g
4
g_4
g4的输出恰好是
q
(
z
∣
x
,
t
,
y
)
q(z|x,t,y)
q(z∣x,t,y)的均值和方差。
Decoder
网络结构如下图所示:前向传播:这里作者没写清楚decoder的输入
z
z
z怎么来的(吐槽一句,作者有很多细节都没写清楚),我猜测就是从encoder的输出采样得到。接下来先看下面四个网络
f
1
,
f
2
,
f
3
,
f
4
f_1,f_2,f_3,f_4
f1,f2,f3,f4,其实是针对
x
x
x的三种可能情形:二元变量、目录变量、连续变量,这里只以连续变量为例进行说明。
f
1
f_1
f1和
f
2
f_2
f2的输出分别是
p
(
x
∣
z
)
p(x|z)
p(x∣z)的均值和方差。
f
5
f_5
f5的设计和
g
1
g_1
g1基本一致,输出就是
p
(
t
∣
z
)
=
∏
i
=
1
k
B
e
r
n
(
p
t
,
i
)
p(t|z)=\prod_{i=1}^k Bern(p_{t,i})
p(t∣z)=∏i=1kBern(pt,i),然后继续采样得到
t
~
\widetilde{t}
t
,
t
~
\widetilde{t}
t
再与embedding matrix相乘得到
τ
~
=
W
⋅
t
~
\widetilde{\tau}=W \cdot \widetilde{t}
τ
=W⋅t
。之后作者在文章里说把
τ
~
,
x
,
z
\widetilde{\tau},x,z
τ
,x,z concatanate到一起作为
f
6
f_6
f6的输入,但根据流程图似乎没有
x
x
x?(这个作者写作有点不认真啊,文章居然和图对不上)
作者没具体写出训练的目标函数(很迷,这么重要的东西居然文章里没有明确写出来),只是说利用和VAE类似的变分推断的方法,估计是和CEVAE差不多,先验分布也是标准正态分布。
总结
文章的亮点在于提出了multiple treatment的范式和embedding的解决思路,缺点在于作者写作实在太不严谨了,很多细节没交代清楚(当然也可能是我读的还不够细),类似于采样 t t t怎么反向传播、目标函数之类的都没有具体写出来。