# Semi-supervised model learning and amortized inference

27 September 2019

This is a note on Kingma et al.’s paper and Siddharth et al.’s paper.

Say we have unsupervised data $$x_u \sim p(x)$$ and some supervised data $$y_s, x_s \sim p(y, x)$$ (where $$p(x)$$ is the marginal of $$p(x, y)$$ which is not super important here).

Then let’s say we want to learn parameters $$\theta$$ of a generative model $$p_\theta(z, y, x)$$ with a latent variable $$z$$ and sometimes-latent variable $$y$$. We also want to learn an inference network $$q_\phi(z, y \given x)$$.

To learn $$\theta, \phi$$, we should maximize the following objective: \begin{align} \mathcal L(\theta, \phi) := \E_{p(x_u)}\left[\mathrm{ELBO}(x_u, \theta, \phi)\right] + \gamma \E_{p(x_s, y_s)}\left[\mathrm{ELBO}(x_s, y_s, \theta, \phi) + \alpha \log q_\phi(y_s \given x_s)\right], \label{eq:obj} \end{align} where \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &:= \E_{q_\phi(z, y \given x_u)} \left[\log \frac{p_\theta(z, y, x_u)}{q_\phi(z, y \given x_u)}\right], \text{and} \label{eq:elbo1}\\
\mathrm{ELBO}(x_s, y_s, \theta, \phi) &:= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z \given x_s, y_s)}\right]. \label{eq:elbo2} \end{align}

To see why maximizing \eqref{eq:obj} is a good thing to do, rewrite the ELBOs into the logp - KL form and rewrite the logq term as an expected KL: \begin{align} \mathrm{ELBO}(x_u, \theta, \phi) &= \log p_\theta(x_u) - \KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}, \\
\mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \log p_\theta(x_s, y_s) - \KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}, \\
\E_{p(x_s, y_s)}\left[\log q_\phi(y_s \given x_s)\right] &= -\E_{p(x_s)p(y_s \given x_s)}\left[\log p(y_s \given x_s) - \log q_\phi(y_s \given x_s) - \log p(y_s \given x_s)\right] \nonumber\\
&= -\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right] - H(p(y_s \given x_s)), \end{align} where $$H(p(y_s \given x_s))$$ is the conditional entropy of $$p(y_s \given x_s)$$.

This allows us to rewrite \eqref{eq:obj} as \begin{align} \mathcal L(\theta, \phi) = \color{blue}{\E_{p(x_u)}\left[\log p_\theta(x_u)\right]} \color{red}{-\E_{p(x_u)}\left[\KL{q_\phi(z, y \given x_u)}{p_\theta(z, y \given x_u)}\right]} + \color{blue}{\gamma\E_{p(x_s, y_s)}\left[\log p_\theta(x_s, y_s)\right]} \color{red}{-\gamma\E_{p(x_s, y_s)}\left[\KL{q_\phi(z \given x_s, y_s)}{p_\theta(z \given x_s, y_s)}\right]} - \color{red}{\gamma\alpha\E_{p(x_s)}\left[\KL{p(y_s \given x_s)}{q_\phi(y_s \given x_s)}\right]} - \gamma H(p(y_s \given x_s)). \end{align} Maximizing the blue terms leads to model learning and minimizing the red terms leads to amortized inference. The H term is not dependent on either $$\theta, \phi$$.

To estimate gradients of \eqref{eq:obj}, we can sample from $$p(x_u)$$ and $$p(x_s, y_s)$$ (which are our datasets) and “move the gradients inside the expectations.” How do we estimate the ELBOs and the logq term? If the factorization of q is nice, it is easy (Kingma). Otherwise, we need to use self-normalized importance sampling (Siddharth).

## Nice factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(y \given x) q_\phi(z \given y, x). \end{align} Gradients of both ELBOs in \eqref{eq:elbo1} and \eqref{eq:elbo2} are straightforward to estimate, as long as both $$z$$ and $$y$$ are reparameterizable. The logq term is also easy to evaluate. Kingma et al. use a model where $$y$$ is discrete, however the support is just $$10$$ elements so we can replace the expectation with a sum over ten terms.

## Unfavourable factorization of the inference network

Let’s say the inference network is factorized as \begin{align} q_\phi(z, y \given x) = q_\phi(z \given x) q_\phi(y \given z, x). \label{eq:factorization2} \end{align} There are three problems:

1. the denominator of \eqref{eq:elbo2}, $$q_\phi(z \given x_s, y_s)$$, is difficult to evaluate,
2. the expectation in \eqref{eq:elbo2} under $$q_\phi(z \given x_s, y_s)$$ is difficult to sample from, and
3. the term $$\log q_\phi(y_s \given x_s)$$ is difficult to evaluate.

To solve problem 1, use the identity $$\log q_\phi(z, y \given x) = \log q_\phi(y \given x) + \log q_\phi(z \given y, x)$$—where the terms in RHS are only implicitly defined through the factorization \eqref{eq:factorization2}—and rewrite the ELBO in \eqref{eq:elbo2} as \begin{align} \mathrm{ELBO}(x_s, y_s, \theta, \phi) &= \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] + \log q_\phi(y_s \given x_s). \label{eq:elbo3} \end{align} The extra logq term can be lumped together with the logq term in \eqref{eq:obj} so that we have $$(\alpha + 1) \log q_\phi(y_s \given x_s)$$ instead of $$\log q_\phi(y_s \given x_s)$$.

To solve problem 2, we use self-normalized importance sampling where the proposal is $$q_\phi(z \given x)$$ and the unnormalized target distribution over $$z$$ is $$q_\phi(z, y \given x) = q_\phi(z \given y, x) q_\phi(y \given x) \propto q_\phi(z \given y, x)$$. This now allows us to estimate the expectation in \eqref{eq:elbo3} as \begin{align} \E_{q_\phi(z\given x_s, y_s)} \left[\log \frac{p_\theta(z, y_s, x_s)}{q_\phi(z, y_s \given x_s)}\right] \approx \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)}, \end{align} where $$z_k \sim q_\phi(z \given x_s)$$, $$w_k = q_\phi(z_k, y_s \given x_s) / q_\phi(z_k \given x_s)$$, and $$\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell$$.

To solve problem 3, we use evaluate an IWAE-like lower bound on $$\log q_\phi(y_s \given x_s)$$ where we also use $$q_\phi(z \given x)$$ as the proposal and treat $$q_\phi(z, y \given x)$$ as the unnormalized target distribution corresponding to the normalized $$q_\phi(y \given x)$$. This allows us to use previously sampled $$z_k$$ and weights $$w_k$$ in evaluating the stochastic lower bound \begin{align} \widehat{logq} := \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right) \end{align} whose expectation is a lower bound to $$\log q_\phi(y \given x)$$.

This allows us to estimate the gradient of \eqref{eq:obj} as \begin{align} \hat g = \nabla_{\theta, \phi} \left(\mathrm{ELBO}(x_u, \theta, \phi) + \gamma \sum_{k = 1}^K \bar w_k \log \frac{p_\theta(z_k, y_s, x_s)}{q_\phi(z_k, y_s \given x_s)} + \gamma(\alpha + 1) \log\left(\frac{1}{K}\sum_{k = 1}^K w_k\right)\right). \end{align} All sampling is reparameterized.

All of this can be generalized to other bad factorizations of the inference network.

[back]