Tuan Anh Le

Reweighted Wake-Sleep

21 February 2018

Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). Given a set of observations \((x^{(n)})_{n = 1}^N\) sampled iid from the true generative model \(p(x)\), we want to learn \(\theta\) to maximize \(\frac{1}{N} \sum_{n = 1}^N \log p_\theta(x)\) and \(\phi\) so that \(q_\phi(z \given x)\) is close to \(p_\theta(z \given x)\) for the learned \(\theta\).

Evidence Lower Bound

Define the importance sampling based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left(\prod_{k = 1}^K q_\phi(z^k \given x)\right) \log \left(\frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}, \end{align} where \(K\) is the number of particles.

Gradient Estimator for Wake Update of \(\theta\)

Estimate the gradient \begin{align} \nabla_\theta \left(\frac{1}{N} \sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)}) \right) \end{align} using samples from \(q_\phi(z \given x^{(n)})\). Call this gradient estimator \(\hat g_\theta^{\text{wake}}\).

Gradient Estimator for Sleep Update of \(\phi\)

Estimate the gradient \begin{align} \nabla_\phi \int p_\theta(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p_\theta(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&= \int p_\theta(z, x) \left( -\nabla_\phi \log q_\phi(z \given x) \right) \,\mathrm dz \,\mathrm dx, \end{align} using samples from \(p_\theta(z, x)\). Call this gradient estimator \(\hat g_\phi^{\text{sleep}}\).

Gradient Estimator for Wake Update of \(\phi\)

Estimate the gradient \begin{align} \nabla_\phi \int p(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&= \int p(x) \int p_\theta(z \given x) ( -\nabla_\phi \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\
&\approx \int p(x) \left( \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)) \right) \,\mathrm dx, \end{align} where \(z^k \sim q_\phi(\cdot \given x)\) and \(w_k = \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)}\). This quantity is estimated using one sample \(x \sim p(x)\): \begin{align} &\approx \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)). \end{align} Call this gradient estimator \(\hat g_\phi^{\text{wake}}\).

It is usually advantageous to use the wake update of \(\phi\) (rather than the sleep update of \(\phi\)) because the target is minimizing the expected KL divergence under the true rather than the current data distribution.

Full Algorithm

Repeat until convergence:

[back]