EM (Expectation–Maximization) Algorithm 思路分析及推导

往期文章链接目录

Jensen’s inequality

EM (Expectation–Maximization) Algorithm 思路分析及推导
  • Theorem: Let ff be a convex function, and let XX be a random variable. Then:

E[f(X)]f(E[X])E[f(X)] \geq f(E[X])

\quad Moreover, if ff is strictly convex, then E[f(X)]=f(E[X])E[f(X)] = f(E[X]) holds true if and
only if XX is a constant.

  • Later in the post we are going to use the following fact from the Jensen’s inequality:
    Suppose λj0\lambda_j \geq 0 for all jj and jλj=1\sum_j \lambda_j = 1, then

logjλjyjjλjlogyj \log \sum_j \lambda_j y_j \geq \sum_j \lambda_j \, log \, y_j

\quad where the log\log function is concave.

Overview of Expectation–Maximization (EM) algorithm

In this post, let YY be a set of observed data, ZZ a set of unobserved latent data, and θ\theta the unknown parameters.

(After this post, you will be comfortable with the following description about the EM algorithm.)

Expectation–Maximization (EM) algorithm is an iterative method to find (local) maximum likelihood estimation (MLE) of L(θ)=p(Yθ)L(\theta) = p(Y|\theta), where the model depends on unobserved latent variables ZZ.

Algorithm:

  1. Initialize peremeters θ0\theta_0.

Iterate between steps 2 and 3 until convergence:

  1. an expectation (E) step, which creates a function Q(θ,θi)Q(\theta, \theta_i) for the expectation of the log-likelihood logp(Y,Zθ)\log p(Y,Z|\theta) evaluated using the current conditional distribution of
    ZZ given YY and the current estimate of the parameters θi\theta_i, where

Q(θ,θi)=ZP(ZY,θi)logp(Y,Zθ)=EZP(ZY,θi)[logp(Y,Zθ)] \begin{aligned} Q(\theta, \theta_i) &= \sum_Z P(Z|Y,\theta_i) \cdot \log p(Y,Z|\theta) \\ &= E_{Z \sim P(Z|Y,\theta_i)}[\log p(Y,Z|\theta)] \end{aligned}

  1. A maximization (M) step, which computes parameters maximizing the expected log-likelihood Q(θ,θi)Q(\theta, \theta_i) found on the EE step and then update parameters to θi+1\theta_{i+1}.

These parameter-estimates are then used to determine the distribution of the latent variables in the next EE step. We say it converges if the increase in successive iterations is smaller than some tolerance parameter.

In general, multiple maxima may occur, with no guarantee that the global maximum will be found.

Intuition: Why we need EM algorithm

Sometimes maximizing the likelihood (θ)\ell(\theta) explicitly might be difficult since there are some unknown latent variables. In such a setting, the EM algorithm gives an efficient method for maximum likelihood estimation.

Complete Case v.s. Incomplete Case

Complete case

(Y,Z)(Y, Z) is observable, and the log likelihood can be written as

(θ)=logp(Y,Zθ)=logp(Zθ)p(YZ,θ)=logp(Zθ)+logp(YZ,θ) \begin{aligned} \ell(\theta) &= \log p(Y, Z | \theta) \\ &= \log p(Z|\theta) \cdot p(Y|Z, \theta) \\ &= \log p(Z|\theta) + \log p(Y|Z, \theta) \\ \end{aligned}

We subdivide our task of maximizing (θ)\ell(\theta) into two sub-tasks. Note that in both logp(Zθ)\log p(Z|\theta) and logp(YZ,θ)\log p(Y|Z, \theta), the only unknown parameter is θ\theta. They are just two standard MLE problems which could be easily solved by methods such as gradient descent.

Incomplete case

(Y)(Y) is observable, but (Z)(Z) is unknown. We need to introduce the marginal distribution of variable ZZ:

(θ)=logp(Yθ)=logZp(Y,Zθ)=logZp(Zθ)p(YZ,θ) \begin{aligned} \ell(\theta) &= \log p(Y | \theta) \\ &= \log \sum_Z p(Y, Z | \theta) \\ &= \log \sum_Z p(Z|\theta) \cdot p(Y|Z, \theta) \\ \end{aligned}

Here we have a summation inside the log, so it’s hard to use optimization methods or take derivatives. This is the case where EM algorithm comes into aid.

EM Algorithm Derivation (Using MLE)

Given the observed data YY, we want to maximize the likelihood (θ)=p(Yθ)\ell(\theta) = p(Y|\theta), and it’s the same as maximizing the log-likelihood logp(Yθ)\log p(Y|\theta). Therefore, from now on we will try to maximize the likelihood

(θ)=logp(Yθ)\ell(\theta) = \log p(Y|\theta)

by taking the unknown variable ZZ into account, we rewrite the objective function as

(θ)=logp(Yθ)=logZp(Y,Zθ)=logZp(YZ,θ)p(Zθ) \begin{aligned} \ell(\theta) &= \log p(Y|\theta) \\ &= \log \sum_Z p(Y, Z | \theta) \\ &= \log \sum_Z p(Y|Z,\theta) \cdot p(Z|\theta) \\ \end{aligned}

Note that in the last step, the log\log is outside of the \sum, which is hard to compute and optimize. Check out my previous post to know why we prefer to have log\log inside of \sum, instead of outside. So later we would find a way to approximate it (Jensen’s inequality).

Suppose we follow the iteration step 2 (E) and 3 (M) repeatedly, and have updated parameters to θi\theta_i, then the difference between (θ)\ell(\theta) and our estimate (θi)\ell(\theta_i) is (θ)(θi)\ell(\theta) - \ell(\theta_i). You can think of this difference as the improvement that later estimate of θ\theta tries to achieve. So our next step is to find θi+1\theta_{i+1} such that it improves the difference the most. That is, to make the difference (θ)(θi)\ell(\theta) - \ell(\theta_i) as large as possible. So we want our next estimate θi+1\theta_{i+1} to be

θi+1=argmaxθ  (θ)(θi)\theta_{i+1} = \mathop{\rm arg\,max}\limits_{\theta} \,\, \ell(\theta) - \ell(\theta_i)

Note that we know the value of θi\theta_i, so as (θi)\ell(\theta_i). And

(θ)(θi)=logp(Yθ)=logZp(YZ,θ)p(Zθ)logp(Yθi)=logZP(ZY,θi)p(YZ,θ)p(Zθ)P(ZY,θi)logp(Yθi)(1) \begin{aligned} \ell(\theta) - \ell(\theta_i)&= \log p(Y|\theta) \\ &= \log \sum_Z p(Y|Z,\theta) \cdot p(Z|\theta) - \log p(Y|\theta_i) \\ &= \log \sum_Z P(Z|Y,\theta_i) \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i)} - \log p(Y|\theta_i) & & &(1)\\ \end{aligned}

Since P(ZY,θi)0P(Z|Y,\theta_i) \geq 0 for all zZz\in Z and ZP(ZY,θi)=1\sum_Z P(Z|Y,\theta_i) = 1, we can use the Jensen’s inequality and then re-write (1)(1) as

(θ)(θi)ZP(ZY,θi)logp(YZ,θ)p(Zθ)P(ZY,θi)ZP(ZY,θi))logp(Yθi)=ZP(ZY,θi)(logp(YZ,θ)p(Zθ)P(ZY,θi)logp(Yθi))(θ)(θi)+ZP(ZY,θi)logp(YZ,θ)p(Zθ)P(ZY,θi)p(Yθi) \begin{aligned} \ell(\theta) - \ell(\theta_i) &\geq \sum_Z P(Z|Y,\theta_i) \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i)} - \sum_Z P(Z|Y,\theta_i)) \cdot \log p(Y|\theta_i)\\ &= \sum_Z P(Z|Y,\theta_i) \left( \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i)} - \log p(Y|\theta_i) \right) \\ \ell(\theta) &\geq \ell(\theta_i) + \sum_Z P(Z|Y,\theta_i) \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i) \cdot p(Y|\theta_i)} \end{aligned}

Now we define

B(θ,θi)(θi)+ZP(ZY,θi)logp(YZ,θ)p(Zθ)P(ZY,θi)p(Yθi) B(\theta, \theta_i) \triangleq \ell(\theta_i) + \sum_Z P(Z|Y,\theta_i) \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i) \cdot p(Y|\theta_i)}

So we see that

(θ)B(θ,θi) \ell(\theta) \geq B(\theta, \theta_i)

which implies that B(θ,θi)B(\theta, \theta_i) is a lower bound of (θ)\ell(\theta) for all ii. Therefore our next step is to maximize the lower bound B(θ,θi)B(\theta, \theta_i) and make it as tight as possible. In the MM step, we define

θi+1=argmaxθB(θ,θi)=argmaxθ((θi)+ZP(ZY,θi)logp(YZ,θ)p(Zθ)P(ZY,θi)p(Yθi))(2)=argmaxθ(ZP(ZY,θi)logp(YZ,θ)p(Zθ)P(ZY,θi)p(Yθi))(3)=argmaxθ(ZP(ZY,θi)[logp(YZ,θ)p(Zθ)logP(ZY,θi)p(Yθi)])(4)=argmaxθ(ZP(ZY,θi)logp(YZ,θ)p(Zθ))(5)=argmaxθ(ZP(ZY,θi)logp(Y,Zθ))(6)=argmaxθQ(θ,θi)(7) \begin{aligned} \theta_{i+1} &= \mathop{\rm arg\,max}\limits_{\theta} B(\theta, \theta_i)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, \left(\ell(\theta_i) + \sum_Z P(Z|Y,\theta_i) \cdot \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i) \cdot p(Y|\theta_i)} \right) && (2)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, \left(\sum_Z P(Z|Y,\theta_i) \cdot \log \frac{p(Y|Z,\theta) \cdot p(Z|\theta)}{P(Z|Y,\theta_i) \cdot p(Y|\theta_i)} \right) && (3)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, \left(\sum_Z P(Z|Y,\theta_i) \cdot [\log p(Y|Z,\theta) \cdot p(Z|\theta) - \log P(Z|Y,\theta_i) \cdot p(Y|\theta_i)] \right) && (4)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, \left(\sum_Z P(Z|Y,\theta_i) \cdot \log p(Y|Z,\theta) \cdot p(Z|\theta) \right) && (5)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, \left(\sum_Z P(Z|Y,\theta_i) \cdot \log p(Y,Z|\theta) \right) &&(6)\\ &= \mathop{\rm arg\,max}\limits_{\theta} \, Q(\theta, \theta_i) &&(7)\\ \end{aligned}

Remark:

  • (2)(3)(2) \to (3) since (θi)\ell(\theta_i) does not contain θ\theta.

  • (3)(4)(3) \to (4) we used logAB=logAlogB\log \frac{A}{B} = \log A - \log B.

  • (4)(5)(4) \to (5) since logP(ZY,θi)p(Yθi)\log P(Z|Y,\theta_i) \cdot p(Y|\theta_i) does not contain θ\theta.

  • (6)(7)(6) \to (7) we define Q(θ,θi)=ZP(ZY,θi)logp(Y,Zθ)Q(\theta, \theta_i) = \sum_Z P(Z|Y,\theta_i) \cdot \log p(Y,Z|\theta).

  • since both YY and θi\theta_i are known, we have the distribution of Zp(ZY,θi)Z \sim p(Z|Y,\theta_i). Therefore, the only unknown parameter is θ\theta, which means this is now a complete case I mentioned early in the post. So this is now a MLE problem.

summary

  1. θi+1=argmaxθ(θ)(θi)=argmaxθQ(θ,θi)\theta_{i+1} = \mathop{\rm arg\,max}\limits_{\theta} \ell(\theta) - \ell(\theta_i) = \mathop{\rm arg\,max}\limits_{\theta} \, Q(\theta, \theta_i). This implies that maximizing (θ)(θi)\ell(\theta) - \ell(\theta_i) is the same as maximizing Q(θ,θi)Q(\theta, \theta_i).

  2. Note that Q(θ,θi)=ZP(ZY,θi)logp(Y,Zθ)Q(\theta, \theta_i) = \sum_Z P(Z|Y,\theta_i) \cdot \log p(Y,Z|\theta) is just the expectation of logp(Y,Zθ)\log p(Y,Z|\theta), where ZZ is drawn from the current conditional distribution P(ZY,θi)P(Z|Y,\theta_i). Therefore we have
    Q(θ,θi)=EZp(ZY,θi)[logp(Y,Zθ)] Q(\theta, \theta_i) = E_{Z \sim p(Z|Y,\theta_i)}[\log p(Y,Z|\theta)]
    That’s why it’s called the Expectation step. In E step, we are actually trying to calculate the expectation of the term logp(Y,Zθ)\log p(Y,Z|\theta), where unknown variable ZZ follows the current conditional distribution given by YY and θi\theta_i. Then in the Maximazation step, we are trying to maximize this expectation. That’s why it’s called the M step.

Coordinate Ascent/ descent - view EM from a different prospect

EM (Expectation–Maximization) Algorithm 思路分析及推导

In the Expectation Step, we actually fixed θi\theta_i, and tried to optimize Q(θ,θi)Q(\theta, \theta_i).

In the Maximization step, we actually fixed Q(θ,θi)Q(\theta, \theta_i), and tried to optimize θ\theta to get θi+1\theta_{i+1}.

Every time we only optimize one variable and fix the rest. Therefore, from the graph above we see that in every iteration the gradient changes either vertically or horizontally.

To see how to perform the coordinate descent, check out my previous post.

Convergence of EM algorithm

By following the algorithm, we keep updating parameter θi\theta_i and calculating approximated log-likelihood (θi)\ell (\theta_i). But do we actually keep (θi)\ell (\theta_i) getting closer to l(θ)l(\theta) as we do more iterations? Keep in mind that in MLE our final goal is to maximize (θ)\ell (\theta).

Suppose θi\theta_i and θi+1\theta_{i+1} are the parameters from two successive iterations of EM. We will now prove that (θi)(θi+1)\ell(\theta_i) \leq \ell(\theta_{i+1}), which shows EM always monotonically improves the log-likelihood.

(θ)=logp(Yθ)=logp(Y,Zθ)p(ZY,θ)=logp(Y,Zθ)logp(ZY,θ)=Zp(ZY,θi)logp(Y,Zθ)Zp(ZY,θi)logp(ZY,θ)=Q(θ,θi)+H(θ,θi)(8) \begin{aligned} \ell(\theta) &= \log p(Y|\theta) \\ &= \log \frac {p(Y,Z|\theta)}{p(Z|Y, \theta)} \\ &= \log p(Y,Z|\theta) - \log p(Z|Y, \theta) \\ &= \sum_Z p(Z|Y, \theta_i)\cdot \log p(Y,Z|\theta) - \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta) \\ &= Q(\theta, \theta_i) + H(\theta, \theta_i) && (8)\\ \end{aligned}

where H(θ,θi)=Zp(ZY,θi)logp(ZY,θ)H(\theta, \theta_i) = - \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta). This last equation (8)(8) holds for every value of θ\theta, including θ=θi\theta=\theta_i, which means

(θi)=Q(θi,θi)+H(θi,θi) \ell(\theta_i) = Q(\theta_i, \theta_i) + H(\theta_i, \theta_i)

Therefore subtracting (θi)\ell(\theta_i) from (θi+1)\ell(\theta_{i+1}) gives

(θi+1)(θi)=[Q(θi+1,θi)+H(θi+1,θi)][Q(θi,θi)+H(θi,θi)]=[Q(θi+1,θi)Q(θi,θi)]+[H(θi+1,θi)H(θi,θi)] \begin{aligned} \ell(\theta_{i+1}) - \ell(\theta_i) &= [Q(\theta_{i+1}, \theta_{i}) + H(\theta_{i+1}, \theta_{i})] - [Q(\theta_i, \theta_i) + H(\theta_i, \theta_i)]\\ &= [Q(\theta_{i+1}, \theta_{i}) - Q(\theta_i, \theta_{i})] + [H(\theta_{i+1}, \theta_{i}) - H(\theta_i, \theta_{i})] \\ \end{aligned}

Since θi+1=argmaxθQ(θ,θi)\theta_{i+1} = \mathop{\rm arg\,max}\limits_{\theta} \, Q(\theta, \theta_i), we have Q(θi+1,θi)Q(θi,θi)Q(\theta_{i+1}, \theta_i) \geq Q(\theta_i, \theta_i). The second parenthesis gives

H(θi+1,θi)H(θi,θi)=Zp(ZY,θi)logp(ZY,θi+1)+Zp(ZY,θi)logp(ZY,θi)=(Zp(ZY,θi)logp(ZY,θi+1)Zp(ZY,θi)logp(ZY,θi))=(Zp(ZY,θi)logp(ZY,θi+1)p(ZY,θi))   (9)log(Zp(ZY,θi)p(ZY,θi+1)p(ZY,θi))(10)=logZp(ZY,θi+1)=log1=0 \begin{aligned} H(\theta_{i+1}, \theta_{i}) - H(\theta_i, \theta_{i}) &= - \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta_{i+1}) + \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta_i) \\ &= - \left( \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta_{i+1}) - \sum_Z p(Z|Y, \theta_i)\cdot \log p(Z|Y, \theta_i)\right) \\ &= - \left( \sum_Z p(Z|Y, \theta_i) \cdot \log \frac{p(Z|Y, \theta_{i+1})}{p(Z|Y, \theta_{i})} \right) && \,\,\,(9)\\ &\geq - \log \left( \sum_Z p(Z|Y, \theta_i) \cdot \frac{p(Z|Y, \theta_{i+1})}{p(Z|Y, \theta_{i})} \right) && (10)\\ &= - \log \sum_Z p(Z|Y, \theta_{i+1}) = - \log 1 = 0 \end{aligned}

From (9)(9) to (10)(10) we use the Jensen’s inequality (note that there is a negative sign in the front so we reverse the inequality).

Since Q(θi+1,θi)Q(θi,θi)Q(\theta_{i+1}, \theta_{i}) \geq Q(\theta_i, \theta_{i}) and H(θi+1,θi)H(θi,θi)H(\theta_{i+1}, \theta_{i}) \geq H(\theta_i, \theta_{i}), we have

(θi+1)(θi)for alli \ell(\theta_{i+1}) \geq \ell(\theta_i) \qquad\qquad \text{for all} \, i

Since it’s monotonically increasing and bounded above, we say (θi)\ell (\theta_i) converges.

What’s next?

I’m going to write a post to discuss how EM algorithm is applied in K-means and GMM in the future. Stay tuned!


Reference:

  • https://en.wikipedia.org/wiki/Expectation–maximization_algorithm
  • Part IX: The EM algorithm from CS229 Lecture notes by Andrew Ng http://cs229.stanford.edu/notes/cs229-notes8.pdf
  • http://guillefix.me/cosmos/static/Jensen%27s%2520inequality.html
  • https://en.wikipedia.org/wiki/Jensen%27s_inequality
  • http://www.adeveloperdiary.com/data-science/machine-learning/introduction-to-coordinate-descent-using-least-squares-regression/

往期文章链接目录