Tuan Anh Le

Reweighted Wake-Sleep

21 February 2018

Consider a generative network and an inference network on latents and observations . Given a set of observations sampled iid from the true generative model , we want to learn to maximize and so that is close to for the learned .

Evidence Lower Bound

Define the importance sampling based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left(\prod_{k = 1}^K q_\phi(z^k \given x)\right) \log \left(\frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}, \end{align} where is the number of particles.

Gradient Estimator for Wake Update of

Estimate the gradient \begin{align} \nabla_\theta \left(\frac{1}{N} \sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)}) \right) \end{align} using samples from . Call this gradient estimator .

Gradient Estimator for Sleep Update of

Estimate the gradient \begin{align} \nabla_\phi \int p_\theta(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p_\theta(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &= \int p_\theta(z, x) \left( -\nabla_\phi \log q_\phi(z \given x) \right) \,\mathrm dz \,\mathrm dx, \end{align} using samples from . Call this gradient estimator .

Gradient Estimator for Wake Update of

Estimate the gradient \begin{align} \nabla_\phi \int p(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &= \int p(x) \int p_\theta(z \given x) ( -\nabla_\phi \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &\approx \int p(x) \left( \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)) \right) \,\mathrm dx, \end{align} where and . This quantity is estimated using one sample : \begin{align} &\approx \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)). \end{align} Call this gradient estimator .

It is usually advantageous to use the wake update of (rather than the sleep update of ) because the target is minimizing the expected KL divergence under the true rather than the current data distribution.

Full Algorithm

Repeat until convergence:

[back]