# Reweighted Wake-Sleep

21 February 2018

Consider a generative network $p_\theta(z, x)$ and an inference network $q_\phi(z \given x)$ on latents $z$ and observations $x$. Given a set of observations $(x^{(n)})_{n = 1}^N$ sampled iid from the true generative model $p(x)$, we want to learn $\theta$ to maximize $\frac{1}{N} \sum_{n = 1}^N \log p_\theta(x)$ and $\phi$ so that $q_\phi(z \given x)$ is close to $p_\theta(z \given x)$ for the learned $\theta$.

## Evidence Lower Bound

Define the importance sampling based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left(\prod_{k = 1}^K q_\phi(z^k \given x)\right) \log \left(\frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}, \end{align} where $K$ is the number of particles.

## Gradient Estimator for Wake Update of $\theta$

Estimate the gradient \begin{align} \nabla_\theta \left(\frac{1}{N} \sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)}) \right) \end{align} using samples from $q_\phi(z \given x^{(n)})$. Call this gradient estimator $\hat g_\theta^{\text{wake}}$.

## Gradient Estimator for Sleep Update of $\phi$

Estimate the gradient \begin{align} \nabla_\phi \int p_\theta(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p_\theta(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &= \int p_\theta(z, x) \left( -\nabla_\phi \log q_\phi(z \given x) \right) \,\mathrm dz \,\mathrm dx, \end{align} using samples from $p_\theta(z, x)$. Call this gradient estimator $\hat g_\phi^{\text{sleep}}$.

## Gradient Estimator for Wake Update of $\phi$

Estimate the gradient \begin{align} \nabla_\phi \int p(x) \KL{p_\theta(z \given x)}{q_\phi(z \given x)} \,\mathrm dx &= \nabla_\phi \int p(x) \int p_\theta(z \given x) (\log p_\theta(z \given x) - \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &= \int p(x) \int p_\theta(z \given x) ( -\nabla_\phi \log q_\phi(z \given x)) \,\mathrm dz \,\mathrm dx \\ &\approx \int p(x) \left( \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)) \right) \,\mathrm dx, \end{align} where $z^k \sim q_\phi(\cdot \given x)$ and $w_k = \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)}$. This quantity is estimated using one sample $x \sim p(x)$: \begin{align} &\approx \sum_{k = 1}^K \frac{w_k}{\sum_{\ell = 1}^K w_\ell} (-\nabla_\phi \log q_\phi(z^k \given x)). \end{align} Call this gradient estimator $\hat g_\phi^{\text{wake}}$.

It is usually advantageous to use the wake update of $\phi$ (rather than the sleep update of $\phi$) because the target is minimizing the expected KL divergence under the true rather than the current data distribution.

## Full Algorithm

Repeat until convergence:

• Update $\theta$ using $\hat g_\theta^{\text{wake}}$
• Do at least one of:
• Update $\phi$ using $-\hat g_\phi^{\text{sleep}}$
• Update $\phi$ using $-\hat g_\phi^{\text{wake}}$

[back]