Tuan Anh Le

Amortized Inference

19 December 2017

Consider a probabilistic model of -valued latents , -valued observes . Amortized inference is finding a mapping from an observation to a distribution that is close to . Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}

Here,

Variational Inference

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the loss. This objective is also suitable for simultaneous model learning.

Inference Compilation

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the loss.

Examples

Let’s compare these two losses on a sequence of increasingly difficult generative models.

Gaussian Unknown Mean

The generative model: \begin{align} p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set , , .

The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where consists of a multiplier , offset and standard deviation .

We can obtain the true values for : , , .

Amortizing inference using both the and the loss gets it spot on.

Python script for generating these figures.

Gaussian Mixtures

The generative model: \begin{align} p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set and , , , , .

The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where and are neural networks parameterized by .

Amortizing inference using the loss results in the mass-covering/mean-seeking behavior whereas using the loss results in zero-forcing/mode-seeking behavior (for various test observations ). The zero-forcing/mode-seeking behavior is very clear in the second plot below: always maps to either or , depending on which peak is larger; always maps to more or less a constant. It is also interesting to look at and when the loss is used to amortize inference. It would actually make more sense if dropped to the same value as in the case.

Python script for generating these figures.

Gaussian Mixtures (Non-Marginalized)

The generative model: \begin{align} p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\ p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set and , , , , .

The inference network: \begin{align} q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\ q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)), \end{align} where , , and are neural networks parameterized by .

In the first plot below, we show the marginal posterior and and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.

Python script for generating these figures.

Gaussian Mixtures (Open Universe)

Generative model:

The inference network:

In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).

Python script for generating these figures.

References

    [back]