EM算法与GMM
EM算法与GMM
Hongliang He 2014年4月 [email protected]
注:本文主要参考Andrew Ng的Lecture notes 8,并结合自己的理解和扩展完成。
GMM简介
GMM(Gaussian mixture model) 混合高斯模型在机器学习、计算机视觉等领域有着广泛的应用。其典型的应用有概率密度估计、背景建模、聚类等。
图1 GMM用于聚类 图2 GMM用于概率密度估计 图3 GMM用于背景建模
我们以GMM聚类为例子进行讨论。如图1所示,假设我们有m个点,其坐标数据为{,…}。假设m个数据分别属于k个类别,且不知道每个点属于哪一个类。倘若假设每个类的分布函数都是高斯分布,那我们该如何求得每个点所属的类别?以及每个类别的概率分布函数(概率密度估计)?我们先尝试最大似然估计。
上式中是当前m个数据出现的概率,我们要将它最大化;是出现的概率;是指第z个类;u和分别指第z个类的均值和方差;为其他的参数。为计算方便,对上式两边取对数,得到似然函数。
上说道,GMM的表达式为k个高斯分布的叠加,所以有
为类出现的先验概率。令j=,所以此时的似然函数可以写为
上式中x和z为自变量;为需要估计的参数。为高斯分布,我们可以写出解析式,但是的形式是未知的。所以如果我们不能直接对求偏导取极值。考虑到z是不能直接观测到的,我们称为隐藏变量(latent variable)。为了求解
我们引入EM算法(Expectation-Maximization)。我们从Jensen不等式开始讨论EM算法。
Jensen不等式
若实函数存在二阶导且有,则为凸函数(convex function)。的值域为,则对于
有以下不等式成立:
此不等式的几何解释如下
需要说明的是,若则不等式的方向取反。对上式进行推广,便可得到Jensen不等式(Jensen's Inequality)。倘若有为凸函数,且
则有
此结果可由数学归纳法得到,在这里不做详细的描述。值得注意的是,如果Jensen不等式中的,而且把看做概率密度,则有
上式成立的依据是,,为概率密度时,f(E(x))=且。在后续的EM算法推导中,会连续多次应用到Jensen不等式的性质。
EM算法
现在重新考虑之前的似然函数
直接对上式进行最大化求解会比较困难,所以我们考虑进行一定的变通。假设是某种概率密度函数,有且。现在对的表达式进行一定得处理,先乘以一个再除以一个,有
我们把看做是的函数; 为概率密度,则有
考虑到log函数为凹函数,利用Jensen不等式有
此时我们找到了的一个下界。而且这个下界的选取随着的不同而不同。即我们得到了一组下界。用下图来简单描述
图3 选择不同的得到不同的下界
我们的目的是最大化,如果我们不断的取的最优下界,再优化最优下界,等到算法收敛就得到了局部最大值。所以我们先取得的最优下界。上式在等号成立时取得最优下界。根据Jensen不等式的性质,取得等号时的条件有
c是不依赖于的常数。此时如果选取就可使得上式成立。又考虑到=1,所以我们可以取
所以取后验概率的时候是最优下界。如果此时在下界的基础上优化参数使其最大化,则可进一步抬高。如此循环往复的进行:取最优化下界;优化下界,便是EM算法的做法。接下来正式给出EM算法的步骤:
算法开始
E-step:取似然函数的最优下界,对于每个训练样本计算。
M-step:优化下界,即求取。
判断是否成立,若成立则算法结束。是设定的算法收敛时的增量。
这就是一个不断取最优下界,抬高下界的过程。用下图简单的表示一个迭代过程:
图4 EM算法的几何解释
我们可以这样解释:E-step就是取的最优下界,此处是。在M-step,我们优化下界,通过调整使得取得局部最优值。由于Jensen不等式始终成立,始终大于等于下界,所以的值从变为实现上升。那么这样的迭代是否是收敛的呢?
假设在t时刻的参数为此时的似然函数值为。接下来进行EM算法迭代,在E-step
第二步利用了Jensen不等式。在M-step
所以有
上式第二步中再次用到Jensen不等式。所以似然函数会一直单调递增,直到到达局部最优值。利用图4来解释的话我们可以这样看:在E-step我们选取了最优下界,此时=;在M-step我们优化得到;最后Jensen不等式一直都成立,所以有=,即。
GMM的训练
对于GMM,其表达式为
是每个gauss分量的权重。在E-step有
对于M-step
其中需要优化的参数为均值分别对其求偏导。
令
解出
这便是第l个高斯分量均值在M-step的更新公式。
对于协方差矩阵
考虑到
且有
所以有
等价于
为对称阵,,所以有
解出协方差矩阵的更新公式为
以上便是协方差矩阵
对于每个gauss分量的权重(或者说是先验概率),考虑到有等式约束
应用Lagrange乘子法
所以有
考虑到
联立方程可解得
这便是的更新公式。
总结启发
- EM算法适用于似然函数中具有隐藏变量的估计问题。
- 创造下界的想法非常精妙,应该有广泛的应用前景。
- Jensen不等式在不等式证明方面有着广泛应用。
GMM的简单应用
接下来简单讨论GMM在图像分割中的应用。以图像中每个像素的颜色信息作为特征进行聚类进而达到图像分割的目的。我们同时拿k-means算法作为对比。
- K-means和GMM用于图像分割由于只考虑了像素的颜色信息,没有考虑空间信息导致其对于复杂背景的效果很差。对于简单背景和前景的颜色分布都比较柔和的情况有较好的效果。
- K-means初始值的选择非常重要。不好的初始值经常会造成较差的聚类效果。
- 应用GMM时,先将3通道彩色图像转换为了灰度图。原因是原始的3个通道数据存在很强的相关性,导致协方差矩阵不可逆。
-
聚类(分割)时需要手动确定类别的数量。类的数量对于聚类效果也有很大的影响。
Matlab实现
根据以上推导,可以很容易实现EM算法估计GMM参数。现以1维数据2个高斯混合概率密度估计作为实例,详细代码如下所示。
% fitting_a_gmm.m
% EM算法简单实现
% Hongliang He 2014/03
clear
close all
clc
% generate data
len1 = 1000;
len2 = fix(len1 * 1.5);
data = [normrnd(0, 1, [1 len1]) normrnd(4, 2, [1 len2])] + 0.1*rand([1 len1+len2]);
data_len = length(data);
% use EM algroithm to estimate the parameters
ite_cnt = 100000; % maximum iterations
max_err = 1e-5; % 迭代停止条件
% soft boundary EM algorithm
z0 = 0.5; % prior probability
z1 = 1 - z0;
u = mean(data);
u0 = 1.2 * u;
u1 = 0.8 * u;
sigma0 = 1;
sigma1 = 1;
itetation = 0;
while( itetation < ite_cnt )
% init papameters
w0 = zeros(1, data_len); % Qi, postprior
w1 = zeros(1, data_len);
% E-step, update Qi/w to get a tight lower bound
for k1=1:data_len
p0 = z0 * gauss(data(k1), u0, sigma0);
p1 = z1 * gauss(data(k1), u1, sigma1);
p = p0 / (p0 + p1);
if p0 == 0 && p1 == 0
%p = w0(k1);
dist0 = (data(k1)-u0).^2;
dist1 = (data(k1)-u1).^2;
if dist0 > dist1
p = w0(k1) + 0.01;
elseif dist0 == dist1
else
p = w0(k1) - 0.01;
end
end
if p > 1
p = 1;
elseif p < 0
p = 0;
end
w0(k1) = p; % postprior
w1(k1) = 1 - w0(k1);
end
% record the pre-value
old_u0 = u0;
old_u1 = u1;
old_sigma0 = sigma0;
old_sigma1 = sigma1;
% M-step, maximize the lower bound
u0 = sum(w0 .* data) / sum(w0);
u1 = sum(w1 .* data) / sum(w1);
sigma0 = sqrt( sum(w0 .* (data - u0).^2) / sum(w0));
sigma1 = sqrt( sum(w1 .* (data - u1).^2) / sum(w1));
z0 = sum(w0) / data_len;
z1 = sum(w1) / data_len;
% is convergance
if mod(itetation, 10) == 0
sprintf('%d: u0=%f,d0=%f u1=%f,d1=%f\n',itetation, …
u0,sigma0,u1,sigma1)
end
d_u0 = abs(u0 - old_u0);
d_u1 = abs(u1 - old_u1);
d_sigma0 = abs(sigma0 - old_sigma0);
d_sigma1 = abs(sigma1 - old_sigma1);
% 迭代停止判断
if d_u0 < max_err && d_u1 < max_err && …
d_sigma0 < max_err && d_sigma1 < max_err
clc
sprintf('ite = %d, final value is', itetation)
sprintf('u0=%f,d0=%f u1=%f,d1=%f\n', u0,sigma0,u1,sigma1)
break;
end
itetation = itetation + 1;
end
% compare
my_hist(data, 20);
hold on;
mi = min(data);
mx = max(data);
t = linspace(mi, mx, 100);
y = z0*gauss(t, u0, sigma0) + z1*gauss(t, u1, sigma1);
plot(t, y, 'r', 'linewidth', 5);
% gauss.m
% 1维高斯函数
% Hongliang He 2014/03
function y = gauss(x, u, sigma)
y = exp( -0.5*(x-u).^2/sigma.^2 ) ./ (sqrt(2*pi)*sigma);
end
% my_hist.m
% 用直方图估计概率密度
% Hongliang He 2013/03
function my_hist(data, cnt)
dat_len = length(data);
if dat_len < cnt*5
error('There are not enough data!\n')
end
mi = min(data);
ma = max(data);
if ma <= mi
error('sorry, there is only one type of data\n')
end
dt = (ma - mi) / cnt;
t = linspace(mi, ma, cnt);
for k1=1:cnt-1
y(k1) = sum( data >= t(k1) & data < t(k1+1) );
end
y = y ./ dat_len / dt;
t = t + 0.5*dt;
bar(t(1:cnt-1), y);
%stem(t(1:cnt-1), y)
end
最终运行结果:
EM算法最终结果