21 February 2018
Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). Given a set of observations \((x^{(n)})_{n = 1}^N\) sampled iid from the true generative model \(p(x)\), we want to learn \(\theta\) to maximize \(\frac{1}{N} \sum_{n = 1}^N \log p_\theta(x)\) and \(\phi\) so that \(q_\phi(z \given x)\) is close to \(p_\theta(z \given x)\) for the learned \(\theta\).
Define the importance sampling based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left(\prod_{k = 1}^K q_\phi(z^k \given x)\right) \log \left(\frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}, \end{align} where \(K\) is the number of particles.
Estimate the gradient \begin{align} \nabla_\theta \left(\frac{1}{N} \sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)}) \right) \end{align} using samples from \(q_\phi(z \given x^{(n)})\). Call this gradient estimator \(\hat g_\theta^{\text{wake}}\).
Estimate the gradient
\begin{align}
\nabla_\phi \int p_\theta(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx
&= \nabla_\phi \int p_\theta(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&= \int p_\theta(z, x) \left( -\nabla_\phi \log q_\phi(z \given x) \right) \,\mathrm dz \,\mathrm dx,
\end{align}
using samples from \(p_\theta(z, x)\).
Call this gradient estimator \(\hat g_\phi^{\text{sleep}}\).
Estimate the gradient
\begin{align}
\nabla_\phi \int p(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx
&= \nabla_\phi \int p(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&= \int p(x) \int p_\theta(z \given x) ( -\nabla_\phi \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&\approx \int p(x) \left( \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)) \right) \,\mathrm dx,
\end{align}
where \(z^k \sim q_\phi(\cdot \given x)\) and \(w_k = \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)}\).
This quantity is estimated using one sample \(x \sim p(x)\):
\begin{align}
&\approx \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)).
\end{align}
Call this gradient estimator \(\hat g_\phi^{\text{wake}}\).
It is usually advantageous to use the wake update of \(\phi\) (rather than the sleep update of \(\phi\)) because the target is minimizing the expected KL divergence under the true rather than the current data distribution.
Repeat until convergence:
[back]