25 February 2018
These are notes on the VIMCO paper.
The goal is to reduce variance for gradient estimators for importance weighted autoencoders, especially when there are discrete latent variables and a REINFORCE gradient estimator must usually be used.
Let \(p_\theta(z, x)\) be a generative network of latent variables \(z\) and observations \(x\). Let \(q_\phi(z \given x)\) be the inference network. Given a dataset \((x^{(n)})_{n = 1}^N\), we want to maximize \(\sum_{n = 1}^N \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x^{(n)})\) where: \begin{align} \mathrm{ELBO}_{\text{IS}}(\theta, \phi, x) = \int \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right) \log \left( \frac{1}{K} \sum_{k = 1}^K \frac{p_\theta(z^k, x)}{q_\phi(z^k \given x)} \right) \,\mathrm dz^{1:K}. \label{eq:elbo} \end{align}
If \(z^{1:K}\) is not reparameterizable, we must use the REINFORCE gradient estimator to estimate gradients of \eqref{eq:elbo} with respect to \(\phi\): \begin{align} \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right) + \nabla_\phi \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right), \label{eq:reinforce} \end{align} where \(z^{1:K} \sim \prod_{k = 1}^K q_\phi(\mathrm dz^k \given x)\) and \(f_{\theta, \phi}(z, x) := \frac{p_\theta(z, x)}{q_\phi(z \given x)}\). Although this estimator is unbiased, it’s high variance due to the first term.
First, let’s rewrite the first term of the estimator in \eqref{eq:reinforce} as:
\begin{align}
\log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log \left( \prod_{k = 1}^K q_\phi(z^k \given x) \right)
&= \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \left( \sum_{\ell = 1}^K \nabla_\phi \log q_\phi(z^\ell \given x) \right) \\
&= \sum_{\ell = 1}^K \left( \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right). \label{eq:reinforce-2}
\end{align}
Given a function \(\hat f(z^{-\ell}, x)\) (where \(z^{-\ell} := (z^1, \dotsc, z^{\ell - 1}, z^{\ell + 1}, \dotsc, z^K)\)) which is independent of \(z^\ell\), we continue from \eqref{eq:reinforce-2}: \begin{align} \sum_{\ell = 1}^K \left( \left(\log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) - \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right). \label{eq:reinforce-3} \end{align}
The authors experiment with \(\hat f(z^{-\ell}, x) = \frac{1}{K - 1} \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x)\) but found \begin{align} \hat f(z^{-\ell}, x) := \exp\left( \frac{1}{K - 1} \sum_{k \neq \ell} \log f_{\theta, \phi}(z^k, x) \right) \end{align} to work better.
Since the term
\begin{align}
g(z^{-\ell}, x) := \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right)
\end{align}
is independent of \(z^\ell\), we can verify that:
\begin{align}
\E\left[\sum_{\ell = 1}^K g(z^{-\ell}, x) \nabla_\phi \log q_\phi(z^\ell \given x) \right]
&= \sum_{\ell = 1}^K \E\left[ g(z^{-\ell}, x) \nabla_\phi \log q_\phi(z^\ell \given x) \right] \\
&= \sum_{\ell = 1}^K \E\left[ g(z^{-\ell}, x) \right] \E\left[ \nabla_\phi \log q_\phi(z^\ell \given x) \right] && \text{(since } z^{-\ell} \text{ and } z^\ell \text{ are independent)} \\
&= 0,
\end{align}
where we use the fact that \(\E\left[ \nabla_\phi \log q_\phi(z^\ell \given x) \right] = \int \nabla_\phi q_\phi(z^\ell \given x) \,\mathrm dz^\ell = \nabla_\phi \int q_\phi(z^\ell \given x) \,\mathrm dz^\ell = \nabla_\phi 1 = 0\).
Hence \eqref{eq:reinforce-2} and \eqref{eq:reinforce-3} have the same expectation.
Putting everything together, the VIMCO estimator is:
\begin{align}
&\sum_{\ell = 1}^K \left( \left(\log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right) - \log \left( \frac{1}{K} \left( \hat f(z^{-\ell}, x) + \sum_{k \neq \ell} f_{\theta, \phi}(z^k, x) \right) \right) \right) \nabla_\phi \log q_\phi(z^\ell \given x) \right) \nonumber\\
&+ \nabla_\phi \log \left( \frac{1}{K} \sum_{k = 1}^K f_{\theta, \phi}(z^k, x) \right).
\end{align}
[back]