19 December 2017
Consider a probabilistic model \(p(x, y)\) of \(\mathcal X\)-valued latents \(x\), \(\mathcal Y\)-valued observes \(y\). Amortized inference is finding a mapping \(g_\phi: \mathcal Y \to \mathcal P(\mathcal X)\) from an observation \(y\) to a distribution \(q_{\phi}(x \given y)\) that is close to \(p(x \given y)\). Let’s go with the following objective that is to be minimized: \begin{align} \mathcal L(\phi) = \int \overbrace{w(y)\pi(y)}^{f(y)} \, \mathrm{divergence}\left(p(\cdot \given y), g_\phi(y)\right) \,\mathrm dy. \label{eqn:amortization/objective} \end{align}
Here,
The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \underbrace{\left[\KL{q_{\phi}(\cdot \given y)}{p(\cdot \given y)} - \log p(y)\right]}_{-\mathrm{ELBO}(\phi, \theta, y)} \,\mathrm dy. \label{eqn:amortization/vae-objective} \end{align} Call this the \(qp\) loss. This objective is also suitable for simultaneous model learning.
The objective we are minimizing is \begin{align} \mathcal L(\phi) = \int p(y) \, \KL{p(\cdot \given y)}{q_\phi(\cdot \given y)} \,\mathrm dy. \end{align} Call this the \(pq\) loss.
Let’s compare these two losses on a sequence of increasingly difficult generative models.
The generative model:
\begin{align}
p(x) &= \mathrm{Normal}(x; \mu_0, \sigma_0^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(\mu_0 = 0\), \(\sigma_0 = 1\), \(\sigma = 1\).
The inference network: \begin{align} q_{\phi}(x \given y) = \mathrm{Normal}(x; ay + b, c^2), \end{align} where \(\phi = (a, b, c)\) consists of a multiplier \(a\), offset \(b\) and standard deviation \(c\).
We can obtain the true values for \(\phi\): \(a^* = \frac{1/\sigma^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(b^* = \frac{\mu_0/\sigma_0^2}{1/\sigma_0^2 + 1/\sigma^2}\), \(c^* = \frac{1}{1/\sigma_0^2 + 1/\sigma^2}\).
Amortizing inference using both the \(pq\) and the \(qp\) loss gets it spot on.
The generative model:
\begin{align}
p(x) &= \sum_{k = 1}^K \pi_k \mathrm{Normal}(x; \mu_k, \sigma_k^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).
The inference network: \begin{align} q_{\phi}(x \given y) &= \mathrm{Normal}(x; \eta_{\phi_1}^1(y), \eta_{\phi_2}^2(y)), \end{align} where \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) are neural networks parameterized by \(\phi = (\phi_1, \phi_2)\).
Amortizing inference using the \(pq\) loss results in the mass-covering/mean-seeking behavior whereas using the \(qp\) loss results in zero-forcing/mode-seeking behavior (for various test observations \(y\)). The zero-forcing/mode-seeking behavior is very clear in the second plot below: \(\eta_{\phi_1}^1\) always maps to either \(\mu_1\) or \(\mu_2\), depending on which peak is larger; \(\eta_{\phi_2}^2\) always maps to more or less a constant. It is also interesting to look at \(\eta_{\phi_1}^1\) and \(\eta_{\phi_2}^2\) when the \(pq\) loss is used to amortize inference. It would actually make more sense if \(\eta_{\phi_2}^2\) dropped to the same value as \(\eta_{\phi_2}^2\) in the \(qp\) case.
The generative model:
\begin{align}
p(z) &= \mathrm{Discrete}(z; \pi_1, \dotsc, \pi_K) \\
p(x \given z) &= \mathrm{Normal}(x; \mu_z, \sigma_z^2) \\
p(y \given x) &= \mathrm{Normal}(y; x, \sigma^2).
\end{align}
In the experiments, we set \(K = 2\) and \(\mu_1 = -5\), \(\mu_2 = 5\), \(\pi_k = 1 / K\), \(\sigma_k = 1\), \(\sigma = 10\).
The inference network:
\begin{align}
q_{\phi}(z \given y) &= \mathrm{Discrete}(z; \eta_{\phi_z}^{z \given y}(y)) \\
q_{\phi}(x \given y, z) &= \mathrm{Normal}(x; \eta_{\phi_\mu}^{\mu \given y, z}(y, z), \eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}(y, z)),
\end{align}
where \(\eta_{\phi_z}^{z \given y}\), \(\eta_{\phi_\mu}^{\mu \given y, z}(y, z)\), and \(\eta_{\phi_{\sigma^2}}^{\sigma^2 \given y, z}\) are neural networks parameterized by \(\phi = (\phi_z, \phi_\mu, \phi_{\sigma^2})\).
In the first plot below, we show the marginal posterior \(p(z \given y)\) and \(p(x \given y)\) and the corresponding marginals of the inference network. The posterior density is approximated as a kernel density estimation of resampled importance samples. In the second plot below, we show the outputs of the neural networks for different inputs.
Generative model:
The inference network:
In the plot below, we can see that the inference network learns a good mapping to the posterior (tested on various test observations).
[back]