17 December 2019
This is an alternative proof that the asymptotic signal-to-noise ratio of the importance weighted autoencoder (IWAE)-based inference gradient estimator is \(O(1 / \sqrt{K})\) for number of particles \(K\) as given in Theorem 1 of our paper. This alternative proof is due to Finke and Thiery’s Remark 1 (second bullet-point).
Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). Given a set of observations \((x^{(n)})_{n = 1}^N\) sampled iid from the true generative model \(p(x)\), we want to learn \(\theta\) to maximize \(\frac{1}{N} \sum_{n = 1}^N \log p_\theta(x)\) and \(\phi\) so that \(q_\phi(z \given x)\) is close to \(p_\theta(z \given x)\) for the learned \(\theta\).
Consider a generative network \(p_\theta(z, x)\) and an inference network \(q_\phi(z \given x)\) on latents \(z\) and observations \(x\). The \(K\)-particle IWAE-based evidence lower bound (ELBO) as: \begin{align} \mathrm{ELBO}_{\text{IWAE}}^K(\theta, \phi, x) = \E_{\prod_{k = 1}^K q_\phi(z_k \given x)}\left[\log \left(\frac{1}{K} \sum_{k = 1}^K w_k \right)\right], \end{align} where \begin{align} w_k = \frac{p_\theta(z_k, x)}{q_\phi(z_k \given x)} \end{align} is the unnormalized importance weight.
Given a differentiable reparameterization function \(r_\phi\) and a noise distribution \(s(\epsilon)\) such that \(r_\phi(\epsilon, x)\) for \(\epsilon \sim s(\epsilon)\) has the same distribution as \(z \sim q_\phi(z \given x)\), the gradient estimator for \(\nabla_\phi \mathrm{ELBO}_{\text{IWAE}}^K(\theta, \phi, x)\) can be written as \begin{align} \hat g_K = \sum_{k = 1}^K \bar w_k \nabla_\phi \log w_k, \end{align} where \(w_k = {p_\theta(r_\phi(\epsilon_k, x), x)} / {q_\phi(r_\phi(\epsilon_k, x) \given x)}\), \(\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell\) and \(\epsilon_k \sim s(\epsilon)\) for \(k = 1, \dotsc, K\).
To prove that the signal-to-noise ratio (SNR) of \(\hat g_K\), defined as \(\mathrm{SNR}(\hat g_K) = \E[\hat g_K] / \mathrm{std}(\hat g_K)\) is \(O(1 / \sqrt{K})\), we prove that it is a \(K\)-particle self-normalized importance sampling (SNIS) estimate of a zero vector. It follows from the SNIS literature that \(\E[\hat g_K]\) is \(O(1 / K)\) and \(\mathrm{std}(\hat g_K)\) is \(O(1 / \sqrt{K})\) and hence \(\mathrm{SNR}(\hat g_K)\) is \(O(1 / \sqrt{K})\).
We prove this in two steps:
detach
ed or stop_gradient
ed);In this step, we use the fact that any reparameterization function \(r_\phi(\cdot, x): \mathcal E \to \mathcal Z\) must be a bijection. This means that for a fixed \(\phi\) and \(x\), there exists an inverse \(r_\phi^{-1}(\cdot, x): \mathcal Z \to \mathcal E\) such that \(r_\phi^{-1}(r_\phi(\epsilon, x), x) = \epsilon\) for all \(\epsilon\).
Using this, we can rewrite \(\hat g_K\) as \begin{align} \hat g_K = \sum_{k = 1}^K \bar w_k \nabla_\phi \log w_k, \end{align} where \(w_k = {p_\theta(r_\phi(\boxed{r_\phi^{-1}(z_k, x)}, x), x)} / {q_\phi(r_\phi(\boxed{r_\phi^{-1}(z_k, x)}, x) \given x)}\), \(\bar w_k = w_k / \sum_{\ell = 1}^K w_\ell\) and \(z_k \sim q_\phi(z \given x)\) for \(k = 1, \dotsc, K\).
This means that \(\hat g_K\) is an SNIS estimate of \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right]\).
To prove \(\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right] = 0\), we will use the equivalence of REINFORCE and reparameterization tricks used in (Tucker et al. 2018, equation 5) in which for any \(f(z)\) which can potentially depend on \(\phi\): \begin{align} \E_{q_\phi(z \given x)}\left[f(z) \frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right] = \E_{s(\epsilon)}\left[\frac{\partial f(z)}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right]. \end{align}
bla
\begin{align}
\E_{p_\theta(z \given x)}\left[\nabla_\phi \log w\right]
&= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi} + \frac{\partial \log w}{\partial \phi}\right] \\
&= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi} - \frac{\partial \log q}{\partial \phi}\right]
\end{align}
need to prove \begin{align} \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z}\frac{\partial z}{\partial \phi}\right] = \E_{p_\theta(z \given x)}\left[\frac{\partial \log q}{\partial \phi}\right] \end{align}
subst \(f = w\)
\begin{align}
\E_{q_\phi(z \given x)}\left[w \frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right]
&= \E_{s(\epsilon)}\left[\frac{\partial w}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right] \\
&= \E_{s(\epsilon)}\left[w \frac{\partial \log w}{\partial z} \frac{\partial z(\epsilon, \phi)}{\partial \phi}\right]
\end{align}
RHS is
\begin{align}
RHS
&= \E_{q_\phi(z \given x)}\left[w \frac{\partial \log w}{\partial z} \frac{\partial r_\phi(\boxed{r_\phi^{-1}(z, x)}, x)}{\partial \phi}\right] \\
&= \E_{p_\theta(z \given x)}\left[\frac{\partial \log w}{\partial z} \frac{\partial r_\phi(\boxed{r_\phi^{-1}(z, x)}, x)}{\partial \phi}\right] p_\theta(x)
\end{align}
LHS is \begin{align} LHS &= \E_{p_\theta(z \given x)}\left[\frac{\partial}{\partial \phi}\log q_\phi(z \given x)\right] p_\theta(x) \end{align}
therefore (8) is true.
[back]