用GAIN来补充缺失数据之论文篇(一)

原文章

如下图是算法的图示:用GAIN来补充缺失数据之论文篇(一)

一、定义变量

在这个算法中,我们定义如下几个变量:
通过原始数据保留未缺失的数据得到:
 X~={Xiif Mi =1otherwise \ \tilde X = \begin{cases} X_i&\text{if $M_i$ =1}\\ *&\text otherwise \end{cases}
接下来通过训练生成器通过随机变量Z来填补:
Xˉ=G(X~,M,(1M)Z) \\ \bar X = G(\tilde X,M,(1-M)\odot Z)
当m=1 时,用x原始值;当m=0时,用生成器训练出的值得出最后结果。
 X^=MX~+(1M)Xˉ \ \hat X = M \odot \tilde X + (1-M) \odot \bar X
结束定义关于数据集的变量,开始定义关于hint_matrix的变量。
首先定义一个辅助变量B。
 B=(B1,...,Bd){0,1}d \ B = (B_1,...,B_d) \in {\left\{0,1\right\}}^d
B中的具体值为随机均匀选取1到d中的一个数字然后设置
 B={1 if jk 0 if j=k  \ B = \begin{cases} 1&\text { if $j \neq k $ }\\ 0&\text { if $j = k $ } \end{cases}
由此可见,在d个B向量的元素中,只有一个随机的元素为0,其他均为1。
紧接着,我们定义
 H=BM+0.5(1B) \ H = B \odot M+0.5(1-B)
通过一张表格来显示B和H的关系:

M B H
1 0 0.5
0 0 0.5
1 1 1
0 1 0

在上面的表格中,我们可以发现一个规律:
当H等于0.5时,我们并不能从H的值中得到正确的M的值。因为M可能为1也可能为0。此时B等于0。
当H等于0或1时,H的值与M的数值相等。即H=1,M=1;H=0,M=0。因此我们可以通过H的数值推断出M的确切值。此时B等于1。

由此可见,我们可以根据H推断出准确的M值,除了唯一一个B=0的位置。而那个我们唯一不确定的M值,就是我们要通过模型训练的对象(如果我们训练确定的M值,很有可能会出现过度拟合问题)。

二、解读算法

下面我们正式进入到对GAIN算法的解读:

对于每一个batch,生成两个随机矩阵Z和B,即对于每一个在一个batch中的样本((x(j), m(j)))都有对应的 z(j) 和 b(j)来计算我们上面定义的其他变量。这样,每一个batch都有自己相应的一套变量。

生成器(generator)

1. 输入(input)

原数据 + mask_matrix + Z

2. 输出(output)

生成器输出的是对于缺失值的估算量。

3. 损失函数(loss function)

因为我们只训练B(j)=0的位置,所以只取B(j)=0时的损失和。
生成器的目的是干扰判别器的判别结果,尽量减小判别器得出正确答案的概率。在这里,我们最小化

  1. 在真实数据缺失时,判别结果(生成器生成的数据在原数据中的缺失几率)判定为缺失数据的情况发生;
     LG(m(j),D(x^(j),h(j)),b(j))=i(1mi)log(m^i) \ L_G(m(j), D(\hat x(j),h(j)),b(j))=- \sum_i (1-m_i) log(\hat m_i)\\
  2. 真实数据未缺失时,生成结果与原数据的差异。

 LM(x(j),x~(j))=imiLM(xi,xi) LM(xi,xi)={(xixi)2mi continuous,xilog(xi)mi binary.  \ L_M(x(j), \tilde x(j))= \sum_i m_iL_M(x_i, x_i\prime )\\ \ L_M(x_i, x_i\prime )= \begin{cases} {(x_i\prime-x_i)^2}&\text{$m_i$ continuous},\\ -x_ilog(x_i\prime)&\text{$m_i$ binary}. \end{cases}\
综上所述,我们最终的损失函数是
 LG(m(j),D(x^(j),h(j)),b(j))+αLM(x(j),x~(j)) \ \sum L_G(m(j), D(\hat x(j),h(j)),b(j)) + \alpha * L_M(x(j), \tilde x(j))

判别器(discriminator)

1. 输入(input)

生成器的输出(估计数据未缺失的概率)+ mask_matrix(计算损失函数)

2. 输出(output)

在判别器中,需要判别的是每个从生成器传输过来的完整数据是否是生成器自己补充的数据。因此判别器需要输出估计的mask_matrix 中元素的值即数据未缺失的概率(因为1代表未缺失,0代表缺失)。这个判别器的估计值我们用

 m^(j)=D(x^(j),m(j)) \ \hat m(j) = D(\hat x(j), m(j)) 表示

3. 损失函数(loss function)

在这里,判别器应该计算的是:
真实数据缺失时,判别结果(生成器生成的数据在原数据中的缺失几率)与1的差异;
真实数据未缺失时,判别结果与0的差异。

 LD(m(j),D(x^(j),h(j)),b(j))=i:bi=0[milog(m^i)+(1mi)log(1m^i)] \ L_D(m(j), D(\hat x(j),h(j)),b(j))= \sum_{i: b_i =0} [m_i log(\hat m_i)+(1-m_i) log(1-\hat m_i)]

因为
 m^(j)=D(x^(j),m(j))[0,1] \ \hat m(j) = D(\hat x(j), m(j)) \in [0,1]
其中的log项均为负数,真正需要最小化的差异是
 LD(m(j),D(x^(j),h(j)),b(j)) \ -L_D(m(j), D(\hat x(j),h(j)),b(j))

三、伪代码

用GAIN来补充缺失数据之论文篇(一)