1. EM算法

假设有数据\(Z = (Z^{(o)}, Z^{(m)})\),其中\(Z^{(o)}\)表示数据是可见的,而\(Z^{(m)}\)表示这个数据不可见(因为种种原因,比如数据丢失,或者是人为故意造出来的hidden variable),完整数据有分布\(P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta})\),写成log-likelihood:

\[l(\boldsymbol{\theta}; Z) = \sum_{i=1}^N \log P(z_i^{(o)}, z_i^{(m)} \vert \boldsymbol{\theta})\]

由于\(z_i^{(m)}\)的值我们不知道,最简单的办法就是把\(z_i^{(m)}\)积掉:

\[l(\boldsymbol{\theta}; Z^{(o)}) = \sum_{i=1}^N \log \int P(z_i^{(o)}, z_i^{(m)} \vert \boldsymbol{\theta}) dz_i^{(m)}\]

这样会带来的问题是显而易见的:对数中存在积分,当你对\(\boldsymbol{\theta}\)中任何一维求导的时候,分母上总是会出现一个积分项,这使得即便对于很简单的分布函数,你都很难将特定的参数分离出来,得到闭解几乎是不可能的。

但是,为什么要将隐变量积掉呢?EM的想法是把\(Z^{(m)}\)用它的期望\(\mathbb{E}_\tilde{P}[Z^{(m)}]\)代替,其中\(\tilde{P}(Z^{(m)}) = P(Z^{(m)} \vert Z^{(o)}, \hat{\boldsymbol{\theta}}^{(j)})\)

这就是EM算法的E-step,这里把\(\boldsymbol{\theta}\)写作\(\hat{\boldsymbol{\theta}}^{(j)}\),是因为这是上一次迭代的结果(初始给定\(\hat{\boldsymbol{\theta}}^{(0)}\)一个随机值),于是对于第\(j\)步iteration,有

E-step:计算\(Q(\boldsymbol{\theta}, \hat{\boldsymbol{\theta}}^{(j)}) = \mathbb{E}_\tilde{P}[l(\boldsymbol{\theta}; Z^{(o)}, Z^{(m)})]\)

M-step:最大化\(Q(\boldsymbol{\theta}, \hat{\boldsymbol{\theta}}^{(j)})\),得到新的\(\hat{\boldsymbol{\theta}}^{(j+1)}\),进入下一轮迭代。

注意\(P(Z^{(m)} \vert Z^{(o)}, \hat{\boldsymbol{\theta}}^{(j)})\)容易计算是因为\(P(Z^{(m)} Z^{(o)} \vert \hat{\boldsymbol{\theta}}^{(j)})\)是已知的模型,除以一个归一化项就是\(P(Z^{(m)} \vert Z^{(o)}, \hat{\boldsymbol{\theta}}^{(j)})\)了。

对于书上举例的高斯混合模型而言,引入了一个missing变量\(\Delta_k\),表示第\(k\)个高斯模型是否active,而\(P(\Delta_k = 1) = \pi_k\),书上Algorithm-8.1的\(\hat{\gamma}_i\)就是\(\mathbb{E}(\Delta_{1i}\vert y_i) = P(\Delta_{1i}=1 \vert y_i)\)(因为是二项分布)

2. EM算法原理

首先,对于observed变量的log-likelihood,

\[l(\boldsymbol{\theta}; Z^{(o)}) = \log P(Z^{(o)} \vert \boldsymbol{\theta}) = \sum_{i=1}^N \log P(z_i^{(o)} \vert \boldsymbol{\theta})\]

由于它和missing变量\(Z^{(m)}\)无关,因此令\(\tilde{P}(Z^{(m)})\)为任意概率分布,总是有

\[\begin{equation} \label{eq:E-likelihood} l(\boldsymbol{\theta}; Z^{(o)}) = \mathbb{E}_\tilde{P}[l(\boldsymbol{\theta}; Z^{(o)})] = \sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log P(Z^{(o)} \vert \boldsymbol{\theta}) \end{equation}\]

这构成了后面讨论的基础。

给定\(\tilde{P}(Z^{(m)})\)的情况下最大化\(Q\)

由于

\[P(Z^{(o)} \vert \boldsymbol{\theta}) = \frac{P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta})}{P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta})}\]

(\(\ref{eq:E-likelihood}\))式就变成了:

\[\begin{equation} \begin{split} \label{eq:likelihood-decom1} l(\boldsymbol{\theta}; Z^{(o)}) &= \mathbb{E}_\tilde{P}[l(\boldsymbol{\theta}; Z^{(o)}, Z^{(m)})] - \mathbb{E}_\tilde{P}[\log P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta})] \\ &= \sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta}) - \sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}) \\ &\triangleq Q_{\tilde{P}}(\boldsymbol{\theta}) - R_{\tilde{P}}(\boldsymbol{\theta}) \end{split} \end{equation}\]

假设已经给定了\(\tilde{P}(Z^{(m)}) = P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(old)})\),令\(\hat{\boldsymbol{\theta}}^{(j)}\)表示上一轮迭代的参数估计值(注意根据EM,它会出现在本轮的\(\tilde{P}(Z^{(m)})\)中),令\(\hat{\boldsymbol{\theta}}^{(j+1)}\)表示本轮最大化\(Q_{\tilde{P}}(\boldsymbol{\theta})\)得到的估计值:

\[\begin{equation} \begin{split} \label{eq:likelihood-diff} l(\hat{\boldsymbol{\theta}}^{(j+1)}; Z^{(o)}) - l(\hat{\boldsymbol{\theta}}^{(j)}; Z^{(o)}) &= \sum\limits_{Z^{(m)}} P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(j)}) \log\left\{ \frac{P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta}^{(j+1)})}{P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta}^{(j)})} \right\} \\ &\quad - \sum\limits_{Z^{(m)}} P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(j)}) \log \left\{ \frac{P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(j+1)})}{P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(j)})} \right\} \end{split} \end{equation}\]

上式的第一项显然是大于等于0的,因为\(\boldsymbol{\theta}^{(j+1)}\)是使其最大化的估计值;而第二项(含负号)是KL距离,当且仅当分子分母相同的时候达到最小值0,因此第二项也是大于等于0的,于是我们证明了,用EM算法得到的observed变量的log-likelihood,其每轮迭代总是上升的。

这里我们分析了EM中的M-step,有时候M-step的最大化也不好求,这里启示我们,并不一定要找绝对的最优解,只要M-step的解比原来好一点就可以了,这个就是GEM(generalized EM)的思路。

对下界的两轮Maximization

在上面的分析中,我们还漏了一个重要问题:为什么E-step要选择\(\tilde{P}(Z^{(m)}) = P(Z^{(m)} \vert Z^{(o)}, \hat{\boldsymbol{\theta}}^{(j)})\),选择其他分布行不行?下面来说这个问题。

对(\(\ref{eq:likelihood-decom1}\))式的\(R\),\(Q\)两项分别加减\(\sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log \tilde{P}(Z^{(m)})\),等式依然成立:

\[\begin{equation} \begin{split} \label{eq:likelihood-decom2} l(\boldsymbol{\theta}; Z^{(o)}) &= \sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log\left\{ \frac{P(Z^{(o)}, Z^{(m)} \vert \boldsymbol{\theta})}{\tilde{P}(Z^{(m)})} \right\} \\ &\quad - \sum\limits_{Z^{(m)}} \tilde{P}(Z^{(m)}) \log\left\{ \frac{P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta})}{\tilde{P}(Z^{(m)})} \right\} \\ &\triangleq \mathcal{L}\left[\tilde{P}(Z^{(m)}), \boldsymbol{\theta}\right] + \operatorname{KL}\left[\tilde{P}(Z^{(m)}) \| P(Z^{(m)}, \vert Z^{(o)}, \boldsymbol{\theta})\right] \end{split} \end{equation}\]

由于KL距离总是大于等于0的,因此\(\mathcal{L}\)构成了observed变量log-likelihood的一个下界,EM的目标就是优化这个下界。

两步maximization的第一步是给定\(\boldsymbol{\theta}^{(j)}\),寻找\(\tilde{P}\)最大化\(\mathcal{L}\left[\tilde{P}(Z^{(m)}), \boldsymbol{\theta}^{(j)}\right]\)。由于等式左侧\(l(\boldsymbol{\theta}^{(j)}; Z^{(o)})\)和\(\tilde{P}\)没有关系(在\(\ref{eq:E-likelihood}\)式处已经讨论过这个问题了),那么\(\mathcal{L}\)就应该在\(\operatorname{KL} = 0\),也就是\(\tilde{P}(Z^{(m)}) = P(Z^{(m)} \vert Z^{(o)}, \boldsymbol{\theta}^{(j)})\)的时候取到最大值。

第二步是给定\(\tilde{P}\)的情况下寻找新的\(\boldsymbol{\theta}^{(j+1)}\)最大化\(\mathcal{L}\),这就是M-step. 这个在上一节已经分析过了,并且我们发现,由于\(\boldsymbol{\theta}^{(j+1)}\)在未收敛的情况下不等于\(\boldsymbol{\theta}^{(j)}\),因此\(l(\boldsymbol{\theta}; Z^{(o)})\)的增长要比其下界更多一些(参见\(\ref{eq:likelihood-diff}\)式)

(窃以为书上对maximization-maximization的描述不够清楚,PRML这部分内容写得比较好)

3. EM和Gibbs-sampling的联系

EM和Gibbs-sampling联系起来的关键是将missing变量\(Z^{(m)}\)看成是需要sample的变量/参数,这一点相信敏感的同学在看到\(\mathbb{E}_\tilde{P}\)的时候已经感受到了。我们来对比一下两种方法应用在GMM的流程(为简便起见,书上假设gibbs-sampling中只有\(\Delta\)和\(\mu\)是需要求解的参数):

可见Algorithm-8.1的E-step对应Algorithm-8.4的(a),区别在于前者是求expectation,后者是直接sample;而在M-step,前者是求maximization,后者是继续sample。注意到Algorithm-8.4中的sample,\(\Delta\)依赖于\(Y\),\(\mu\)依赖于\(\Delta\)和\(Y\),是典型的gibbs-sampling。

另外需要注意的是gibbs-sampling为了能够sample \(\mu\),为它增加了一个共轭(Gaussian)分布,共轭分布的均值\(\hat{\mu}_1\),\(\hat{\mu}_2\)是(8.40)式对\(\mu_1\),\(\mu_2\)求导得到的MLE。