The Expectation-Maximization algorithm

28 September 2016

Given a probabilistic model $p(x, y \given \theta)$ with latents $x$, observed variables $y$ and parameters $\theta$, we want to find the maximum likelihood estimate (MLE) or the maximum a posteriori (MAP). The Expectation-Maximization (EM) algorithm does that. These notes follow chapter 9.4 of the Bishop book.

EM for MLE

Goal: Find $\mathrm{argmax}_{\theta} p(y \given \theta)$.

Consider an auxiliary probability distribution $q(x)$ over the latents. Consider a function(al) $\mathcal L(q, \theta)$ such that: \begin{align} \log p(y \given \theta) &= \mathcal L(q, \theta) + \KL{q(x)}{p(x \given y; \theta)}. \label{eq:notes/em/p} \end{align} Hence: \begin{align} \mathcal L(q, \theta) &= \log p(y \given \theta) - \int q(x) (\log q(x) - \log p(x \given y; \theta)) \,\mathrm dx \\ &= \int q(x) \log p(y \given \theta) \,\mathrm dx - \int q(x) (\log q(x) - \log p(x \given y; \theta)) \,\mathrm dx \\ &= \int q(x) (\log p(y \given \theta) - \log q(x) + \log p(x \given y; \theta)) \,\mathrm dx \\ &= \int q(x) (\log p(x, y \given \theta) - \log q(x)) \,\mathrm dx. \end{align}

Since $\KL{q(x)}{p(x \given y; \theta)} \geq 0$, \begin{align} \log p(y \given \theta) \geq \mathcal L(q, \theta). \end{align}

The algorithm proceeds by improving $\theta$ iteratively by updating $q$ in an E-step and $\theta$ in an M-step:

1. Initialize $\theta^{(0)}$.
2. For $t = 1, 2, 3, \dotsc$:
• E-step:
• $\mathrm{Maximize}_q \mathcal L(q, \theta^{(t - 1)})$ by \begin{align} q^{(t)} \leftarrow p(x \given y; \theta^{(t - 1)}). \label{eq:notes/em/E} \end{align}
• M-step:
• $\mathrm{Maximize}_{\theta} \mathcal L(q^{(t)}, \theta)$. \begin{align} \mathcal L(q^{(t)}, \theta) &= \int q^{(t)}(x) (\log p(x, y \given \theta) - \log q^{(t)}(x)) \,\mathrm dx \\ &= \int p(x \given y; \theta^{(t - 1)}) (\log p(x, y \given \theta) - \log p(x \given y; \theta^{(t - 1)})) \,\mathrm dx \\ &= \int p(x \given y; \theta^{(t - 1)}) \log p(x, y \given \theta) \,\mathrm dx - \int p(x \given y; \theta^{(t - 1)}) \log p(x \given y; \theta^{(t - 1)}) \,\mathrm dx \\ &= \int p(x \given y; \theta^{(t - 1)}) \log p(x, y \given \theta) \,\mathrm dx + \text{const} \\ \implies \theta^{(t)} &\leftarrow \mathrm{argmax}_{\theta} \int p(x \given y; \theta^{(t - 1)}) \log p(x, y \given \theta) \,\mathrm dx. \label{eq:notes/em/M} \end{align}

After the E-step in \eqref{eq:notes/em/E}, $\mathcal L(q^{(t)}, \theta^{(t - 1)}) = \log p(y \given \theta^{(t - 1)})$ because the KL divergence in \eqref{eq:notes/em/p} becomes zero.

After the M-step in \eqref{eq:notes/em/M}, we obtain $\mathcal L(q^{(t)}, \theta^{(t)}) \geq \mathcal L(q^{(t)}, \theta^{(t - 1)})$. Since $q^{(t)} = p(x \given y; \theta^{(t - 1)})$, the KL divergence in \eqref{eq:notes/em/p} is no longer necessarily zero and $\log p(y \given \theta^{(t)}) \geq \mathcal L(q^{(t)}, \theta^{(t)})$. Hence \begin{align} \log p(y \given \theta^{(t)}) &\geq \mathcal L(q^{(t)}, \theta^{(t)}) \\ &\geq \mathcal L(q^{(t)}, \theta^{(t - 1)}) \\ &= \log p(y \given \theta^{(t - 1)}). \end{align}

EM for MAP

Goal: Find $\mathrm{argmax}_{\theta} p(\theta \given y)$.

Consider the identity \begin{align} \log p(\theta \given y) &= \log p(y \given \theta) + \log p(\theta) - \log p(y) \\ &= \mathcal L(q, \theta) + \KL{q(x)}{p(x \given y; \theta)} + \log p(\theta) - \log p(y). \label{eq:notes/em/p2} \end{align}

Since the KL divergence is nonnegative, \begin{align} \log p(\theta \given y) \geq \mathcal L(q, \theta) + \log p(\theta) - \log p(y). \end{align}

The algorithm proceeds as follows:

1. Initialize $\theta^{(0)}$.
2. For $t = 1, 2, 3, \dotsc$:
• E-step:
• $\mathrm{Maximize}_q \left\{\mathcal L(q, \theta^{(t - 1)}) + \log p(\theta^{(t - 1)}) - \log p(y)\right\}$ by \begin{align} q^{(t)} \leftarrow p(x \given y; \theta^{(t - 1)}). \end{align}
• M-step:
• $\mathrm{Maximize}_\theta \left\{\mathcal L(q^{(t)}, \theta) + \log p(\theta) - \log p(y)\right\}$ by \begin{align} \theta^{(t)} \leftarrow \mathrm{argmax}_{\theta} \left\{\log p(\theta) + \int p(x \given y; \theta^{(t - 1)}) \log p(x, y \given \theta) \,\mathrm dx\right\}. \label{eq:notes/em/M2} \end{align}

We observe, with a line of reasoning similar to the MLE part, \begin{align} \log p(\theta^{(t)} \given y) &\geq \mathcal L(q^{(t)}, \theta^{(t)}) + \log p(\theta^{(t)}) - \log p(y) & \text{(because the KL divergence in \eqref{eq:notes/em/p2} is nonnegative)} \\ &\geq \mathcal L(q^{(t)}, \theta^{(t - 1)}) + \log p(\theta^{(t - 1)}) - \log p(y) & \text{(because of the M-step in \eqref{eq:notes/em/M2})} \\ &= \log p(\theta^{(t - 1)} \given y). & \text{(because the KL divergence in \eqref{eq:notes/em/p2} is zero)} \end{align}

[back]