# Amortized Inference

19 December 2017

Consider a probabilistic model $p(x, y)$ of $\mathcal X$-valued latents $x$, $\mathcal Y$-valued observes $y$. Amortized inference is finding a mapping $g_\phi: \mathcal Y \to \mathcal P(\mathcal X)$ from an observation $y$ to a distribution $q_{\phi}(x \given y)$ that is close to $p(x \given y)$. Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}

Here,

• $\mathcal L$ is the objective function,
• $\pi(y)$ is some distribution over $y$ values we are interested in performing good inference during test-time,
• $w(y)$ is a weighting for each $y$,
• $f(y) := w(y)\pi(y)$ is just grouping the two together, and
• $\mathrm{divergence}$ measure a distance between two probability distributions (see Wikipedia).

## Variational Inference

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the $qp$ loss. This objective is also suitable for simultaneous model learning.

## Inference Compilation

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the $pq$ loss.

## Examples

Let’s compare these two losses on a sequence of increasingly difficult generative models.

### Gaussian Unknown Mean

The generative model: \begin{align} p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set $\mu_0 = 0$, $\sigma_0 = 1$, $\sigma = 1$.

The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where $\phi = (a, b, c)$ consists of a multiplier $a$, offset $b$ and standard deviation $c$.

We can obtain the true values for $\phi$: $a^* = \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2}$, $b^* = \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}$, $c^* = \frac{1}{1/\sigma_0^2 + 1/\sigma^2}$.

Amortizing inference using both the $pq$ and the $qp$ loss gets it spot on.

### Gaussian Mixtures

The generative model: \begin{align} p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set $K = 2$ and $\mu_1 = -5$, $\mu_2 = 5$, $\pi_k = 1 / K$, $\sigma_k = 1$, $\sigma = 10$.

The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where $\eta_{\phi_1}^1$ and $\eta_{\phi_2}^2$ are neural networks parameterized by $\phi = (\phi_1, \phi_2)$.

Amortizing inference using the $pq$ loss results in the mass-covering/mean-seeking behavior whereas using the $qp$ loss results in zero-forcing/mode-seeking behavior (for various test observations $y$). The zero-forcing/mode-seeking behavior is very clear in the second plot below: $\eta_{\phi_1}^1$ always maps to either $\mu_1$ or $\mu_2$, depending on which peak is larger; $\eta_{\phi_2}^2$ always maps to more or less a constant. It is also interesting to look at $\eta_{\phi_1}^1$ and $\eta_{\phi_2}^2$ when the $pq$ loss is used to amortize inference. It would actually make more sense if $\eta_{\phi_2}^2$ dropped to the same value as $\eta_{\phi_2}^2$ in the $qp$ case.

### Gaussian Mixtures (Non-Marginalized)

The generative model: \begin{align} p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\ p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set $K = 2$ and $\mu_1 = -5$, $\mu_2 = 5$, $\pi_k = 1 / K$, $\sigma_k = 1$, $\sigma = 10$.

The inference network: \begin{align} q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\ q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)), \end{align} where $\eta_{\phi_z}^{z \given y}$, $\eta_{\phi_\mu}^{\mu \given y, z}(y, z)$, and $\eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}$ are neural networks parameterized by $\phi = (\phi_z, \phi_\mu, \phi_{\sigma^2})$.

In the first plot below, we show the marginal posterior $p(z \given y)$ and $p(x \given y)$ and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.

### Gaussian Mixtures (Open Universe)

Generative model:

• $K \sim \mathrm{Discrete}(\alpha_1, \alpha_2) + 1$.
• if $K = 1$:
• $x \sim \mathrm{Normal}(\mu_{1, 1}, \sigma_{1, 1}^2)$.
• else:
• $z \sim \mathrm{Discrete}(\pi_1, \pi_2)$.
• if $z = 0$:
• $x \sim \mathrm{Normal}(\mu_{2, 1}, \sigma_{2, 1}^2)$.
• else:
• $x \sim \mathrm{Normal}(\mu_{2, 2}, \sigma_{2, 2}^2)$.
• observe $y$ under $\mathrm{Normal}(x, \sigma^2)$.

The inference network:

• $K \sim \mathrm{Discrete}(\eta_{\phi_1}^1(y)) + 1$.
• if $K = 1$:
• $x \sim \mathrm{Normal}(\eta_{\phi_2}^2(y, K), \eta_{\phi_3}^3(y, K))$.
• else:
• $z \sim \mathrm{Discrete}(\eta_{\phi_3}^3(y, K))$.
• if $z = 0$:
• $x \sim \mathrm{Normal}(\eta_{\phi_4}^4(y, K, z), \eta_{\phi_5}^5(y, K, z))$.
• else:
• $x \sim \mathrm{Normal}(\eta_{\phi_4}^4(y, K, z), \eta_{\phi_5}^5(y, K, z))$.

In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).