*19 December 2017*

Consider a probabilistic model of -valued latents , -valued observes . Amortized inference is finding a mapping from an observation to a distribution that is close to . Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}

Here,

- is the objective function,
- is some distribution over values we are interested in performing good inference during test-time,
- is a weighting for each ,
- is just grouping the two together, and
- measure a distance between two probability distributions (see Wikipedia).

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the loss. This objective is also suitable for simultaneous model learning.

The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the loss.

Let’s compare these two losses on a sequence of increasingly difficult generative models.

The generative model: \begin{align} p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set , , .

The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where consists of a multiplier , offset and standard deviation .

We can obtain the true values for : , , .

Amortizing inference using both the and the loss gets it spot on.

The generative model: \begin{align} p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set and , , , , .

The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where and are neural networks parameterized by .

Amortizing inference using the loss results in the mass-covering/mean-seeking behavior whereas using the loss results in zero-forcing/mode-seeking behavior (for various test observations ). The zero-forcing/mode-seeking behavior is very clear in the second plot below: always maps to either or , depending on which peak is larger; always maps to more or less a constant. It is also interesting to look at and when the loss is used to amortize inference. It would actually make more sense if dropped to the same value as in the case.

The generative model: \begin{align} p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\ p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\ p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2). \end{align} In the experiments, we set and , , , , .

The inference network: \begin{align} q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\ q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)), \end{align} where , , and are neural networks parameterized by .

In the first plot below, we show the marginal posterior and and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.

Generative model:

- .
- if :
- .

- else:
- .
- if :
- .

- else:
- .

- observe under .

The inference network:

- .
- if :
- .

- else:
- .
- if :
- .

- else:
- .

In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).

[back]