# Semi-supervised model learning and amortized inference

27 September 2019

This is a note on Kingma et al.’s paper and Siddharth et al.’s paper.

Say we have unsupervised data $x_u \sim p(x)$ and some supervised data $y_s, x_s \sim p(y, x)$ (where $p(x)$ is the marginal of $p(x, y)$ which is not super important here).

Then let’s say we want to learn parameters $\theta$ of a generative model $p_\theta(z, y, x)$ with a latent variable $z$ and sometimes-latent variable $y$. We also want to learn an inference network $q_\phi(z, y \given x)$.

To learn $\theta, \phi$, we should maximize the following objective: \begin{align} \mathcal L(\theta, \phi) := \E_{p(x_u)}\left[\mathrm{ELBO}(x_u, \theta, \phi)\right] + \gamma \E_{p(x_s, y_s)}\left[\mathrm{ELBO}(x_s, y_s, \theta, \phi) + \alpha \log q_\phi(y_s \given x_s)\right], \label{eq:obj} \end{align} where \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &:= \E_{q_\phi(z, y \given x_u)} \left[\log \frac{p_\theta(z, y, x_u)}{q_\phi(z, y \given x_u)}\right], \text{and} \label{eq:elbo1}\\ \mathrm{ELBO}(x_s, y_s, \theta, \phi) &:= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z \given x_s, y_s)}\right]. \label{eq:elbo2} \end{align}

To see why maximizing \eqref{eq:obj} is a good thing to do, rewrite the ELBOs into the logp - KL form and rewrite the logq term as an expected KL: \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &= \log p_\theta(x_u) - \KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}, \\ \mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \log p_\theta(x_s, y_s) - \KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}, \\ \E_{p(x_s, y_s)}\left[\log q_\phi(y_s \given x_s)\right] &= -\E_{p(x_s)p(y_s \given x_s)}\left[\log p(y_s \given x_s) - \log q_\phi(y_s \given x_s) - \log p(y_s \given x_s)\right] \nonumber\\ &= -\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right] - H(p(y_s \given x_s)), \end{align} where $H(p(y_s \given x_s))$ is the conditional entropy of $p(y_s \given x_s)$.

This allows us to rewrite \eqref{eq:obj} as \begin{align} \mathcal L(\theta, \phi) = \color{blue}{\E_{p(x_u)}\left[\log p_\theta(x_u)\right]} \color{red}{-\E_{p(x_u)}\left[\KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}\right]} + \color{blue}{\gamma\E_{p(x_s, y_s)}\left[\log p_\theta(x_s, y_s)\right]} \color{red}{-\gamma\E_{p(x_s, y_s)}\left[\KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}\right]} - \color{red}{\gamma\alpha\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right]} - \gamma H(p(y_s \given x_s)). \end{align} Maximizing the blue terms leads to model learning and minimizing the red terms leads to amortized inference. The H term is not dependent on either $\theta, \phi$.

To estimate gradients of \eqref{eq:obj}, we can sample from $p(x_u)$ and $p(x_s, y_s)$ (which are our datasets) and “move the gradients inside the expectations.” How do we estimate the ELBOs and the logq term? If the factorization of q is nice, it is easy (Kingma). Otherwise, we need to use self-normalized importance sampling (Siddharth).

## Nice factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(y \given x) q_\phi(z \given y, x). \end{align} Gradients of both ELBOs in \eqref{eq:elbo1} and \eqref{eq:elbo2} are straightforward to estimate, as long as both $z$ and $y$ are reparameterizable. The logq term is also easy to evaluate. Kingma et al. use a model where $y$ is discrete, however the support is just $10$ elements so we can replace the expectation with a sum over ten terms.

## Unfavourable factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(z \given x) q_\phi(y \given z, x). \label{eq:factorization2} \end{align} There are three problems:

1. the denominator of \eqref{eq:elbo2}, $q_\phi(z \given x_s, y_s)$, is difficult to evaluate,
2. the expectation in \eqref{eq:elbo2} under $q_\phi(z \given x_s, y_s)$ is difficult to sample from, and
3. the term $\log q_\phi(y_s \given x_s)$ is difficult to evaluate.

To solve problem 1, use the identity $\log q_\phi(z, y \given x) = \log q_\phi(y \given x) + \log q_\phi(z \given y, x)$—where the terms in RHS are only implicitly defined through the factorization \eqref{eq:factorization2}—and rewrite the ELBO in \eqref{eq:elbo2} as \begin{align} \mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] + \log q_\phi(y_s \given x_s). \label{eq:elbo3} \end{align} The extra logq term can be lumped together with the logq term in \eqref{eq:obj} so that we have $(\alpha + 1) \log q_\phi(y_s \given x_s)$ instead of $\log q_\phi(y_s \given x_s)$.

To solve problem 2, we use self-normalized importance sampling where the proposal is $q_\phi(z \given x)$ and the unnormalized target distribution over $z$ is $q_\phi(z, y \given x) = q_\phi(z \given y, x) q_\phi(y \given x) \propto q_\phi(z \given y, x)$. This now allows us to estimate the expectation in \eqref{eq:elbo3} as \begin{align} \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] \approx \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)}, \end{align} where $z_k \sim q_\phi(z \given x_s)$, $w_k = q_\phi(z_k, y_s \given x_s) / q_\phi(z_k \given x_s)$, and $\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell$.

To solve problem 3, we use evaluate an IWAE-like lower bound on $\log q_\phi(y_s \given x_s)$ where we also use $q_\phi(z \given x)$ as the proposal and treat $q_\phi(z, y \given x)$ as the unnormalized target distribution corresponding to the normalized $q_\phi(y \given x)$. This allows us to use previously sampled $z_k$ and weights $w_k$ in evaluating the stochastic lower bound \begin{align} \widehat{logq} := \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right) \end{align} whose expectation is a lower bound to $\log q_\phi(y \given x)$.

This allows us to estimate the gradient of \eqref{eq:obj} as \begin{align} \hat g = \nabla_{\theta, \phi} \left(\mathrm{ELBO}(x_u, \theta, \phi) + \gamma \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)} + \gamma(\alpha + 1) \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right)\right). \end{align} All sampling is reparameterized.

All of this can be generalized to other bad factorizations of the inference network.

[back]