从两个例子理解EM算法

从两个例子理解EM算法

本文是作者对EM算法学习的笔记,从EM算法出发介绍EM算法,为了更好理解,用两个应用EM算法求解的例子进一步解释EM的应用。


EM算法

EM算法引入

EM算法,指的是最大期望算法(Expectation Maximization Algorithm,期望最大化算法),是一种迭代算法,在统计学中被用于寻找,依赖于不可观察的隐性变量的概率模型中,参数的最大似然估计。基本思想是首先随机取一个值去初始化待估计的参数值,然后不断迭代寻找更优的参数使得其似然函数比原来的似然函数大。

  • EM算法当做最大似然估计的拓展,解决难以给出解析解(模型中存在隐变量)的最大似然估计(MLE)问题
  • 在算法中加入隐变量的思想可以类比为几何题中加入一条辅助线的做法。

假定有训练集{x(1),x(2),...x(m)},包含m个独立样本,希望从中找到该组数据的模型p(x,z)的参数。
对数似然函数表达如下:
从两个例子理解EM算法

在表达式中因为存在隐变量,直接找到参数估计比较困难,所以我们通过EM算法迭代求解下界的最大值,直到收敛。

我们通过以下的图片来解释这一过程:
从两个例子理解EM算法

图片上的紫色部分是我们的目标模型p(x|θ)曲线,该模型比较复杂,难以直接求解其解析解,为了消除隐变量z带来的影响,我们可以得到一个不包含的z的模型r(x|θ)(该函数是我们自己选定的,因此最大值可求解), 同时满足条件r(x|θ)p(x|θ)

  • 我们先取一个θ1,使得r(x|θ1)=p(x|θ1)(如绿线所示),然后再对此时的r求其最大值,得到极值点θ2,实现参数的更新。
  • 不断重复以上过程,在更新过程中始终满足rp直到收敛。

从以上过程来看,EM算法的核心就是如何找到这个r,即p的下界函数。

这个下界函数有多种方法理解,我们从Jensen不等式的角度来理解。

从两个例子理解EM算法

上述等号成立的条件是
p(x(i),z(i);θ)Qi(z(i))=c, zQi(z(i))=1,所以:
从两个例子理解EM算法

最终框架如下:
从两个例子理解EM算法

EM推导高斯混合模型

高斯混合模型GMM

设有随机变量X, 则高斯混合模型可以用p(x)=KπkN(x|μk,Σk),其中N(x|μk,Σk)表示混合模型中的第k个分量πk表示混合系数,满足
kπk=10πk1

我们知道高斯函数的概率分布为f(x)=1(2π)σexp((xμ)22σ2), 在混合高斯分布中待估计变量就包括了μ,σ,π

对数似然函数为lμ,Σ,π=i=1Nlog(k=1K)πkN(xi|μk,Σk))

EM 推导过程

第一步:估算数据来自于哪个组分,即估计每一个组分生成的概率,对每个样本xi,它由第k个组份生成的概率可以记作:γ(i,k)=πkN(xi|μk,Σk)jπjN(xi|μj,Σj)

第二步:估计每个组份的参数

E-step: 在给定了样本和每个高斯分布的参数以及组份的分布函数的情况下

wj(i)=Qi(z(i)=j)=p(z(i))=j|x(i);ϕ,μ,Σ)

M-step:将多项式分布和高斯分布的参数带入:
i=1mz(i)Qi(z(i))logp(x(i),z(i);ϕ,μ,Σ)Qi(z(i))
=i=1mj=1kQi(z(i)=j)logp(x(i)|z(i)=j;ϕ,μ,Σ)p(z(i)=j;ϕ)Qi(z(i))
i=1mj=1kwj(i)log1(2π)n2|Σj|(12)exp(12(x(i)μj)TΣj1(x(i)μj))ϕjwj(i)

分别对其中的未知参数求偏导数:

  • 对均值求偏导
    uji=1mj=1kwj(i)log1(2π)n2|Σj|(12)exp(12(x(i)μj)TΣj1(x(i)μj))ϕjwj(i)

=uji=1mj=1kwj(i)12(x(i)μj)TΣj1(x(i)μj)
=i=1mwj(i)Σj1(x(i)μj)=0

可得
μj=i=1mwj(i)x(i)i=1mwj(i)

  • 对方差求导:
    Σji=1mj=1kwj(i)log1(2π)n2|Σj|(12)exp(12(x(i)μj)TΣj1(x(i)μj))ϕjwj(i)

=Σji=1mj=1kwj(i)(logΣ1212(x(i)μj)TΣj1(x(i)μj))
=i=1mwj(i)Σ(1)i=1mwj(i)(x(i)μj)(x(i)μj)TΣ(2)=0
可得
Σj=i=1mwj(i)(x(i)μj)(x(i)μj)Ti=1mwj(i)

  • ϕ求偏导, 等式约束,用到拉格朗日乘子法, 删除常数项目得到:
    ϕji=1mj=1kwj(i)log(ϕj)+β(j=1kϕj1)
    =i=1mwj(i)ϕj+β=0
    β=i=1mj=1kwj(i)=m
    可得
    ϕj=1mi=1mwj(i)

EM推导PLSA模型

详细过程可参考作者的另一篇博客plsaEM的详细推导