19 February 2018
These are notes on Shane Gu et al.’s paper.
Goal: Given a state space model \(p(x_{1:T} \given y_{1:T}) = p(x_1) \prod_{t = 2}^T p(x_t \given x_{1:t - 1}) \prod_{t = 1}^T p(y_t \given x_{1:t})\), we want to learn proposal distributions \(q_\phi(x_{1:T} \given y_{1:T}) = q_\phi(x_1 \given y_1) \prod_{t = 2}^T q_\phi(x_t \given x_{1:t - 1}, y_{1:t})\).
We minimize \(\KL{p(x_{1:T} \given y_{1:T})}{q_\phi(x_{1:T} \given y_{1:T})}\) using stochastic gradient descent where the gradient is:
\begin{align}
\frac{\partial}{\partial \phi} \KL{p(x_{1:T} \given y_{1:T})}{q_\phi(x_{1:T} \given y_{1:T})}
&= \frac{\partial}{\partial \phi} \int p(x_{1:T} \given y_{1:T}) \log \frac{p(x_{1:T} \given y_{1:T})}{q_\phi(x_{1:T} \given y_{1:T})} \,\mathrm dx_{1:T} \\
&= -\int p(x_{1:T} \given y_{1:T}) \frac{\partial}{\partial \phi} \log q_\phi(x_{1:T} \given y_{1:T}) \,\mathrm dx_{1:T}. \label{eq:gradient}
\end{align}
This gradient can be estimated using importance sampling (or sequential Monte Carlo): \begin{align} -\sum_{k = 1}^K \bar w_T^k \frac{\partial}{\partial \phi} \log q_\phi(x_{1:T}^k \given y_{1:T}), \label{eq:gradient-estimator-1} \end{align} where \(\bar w_T^k\) are the normalized weights and \((x_{1:T}^k)_{k = 1}^K\) are the particle values of an importance sampler (or final normalized weights and particle values of SMC). This gradient estimator is biased but consistent.
In the case of SMC, one could also estimate \eqref{eq:gradient} as: \begin{align} -\sum_{k = 1}^K \left(\bar w_1^k \frac{\partial}{\partial \phi} \log q_\phi(x_1^k \given y_1) + \sum_{t = 2}^T \bar w_t^k \frac{\partial}{\partial \phi} \log q_\phi(x_t^k \given x_{1:t - 1}^{a_{t - 1}^k}, y_{1:t}) \right), \label{eq:gradient-estimator-2} \end{align} where \(\bar w_t^k\) and \(x_{1:t - 1}^{a_{t - 1}^k}\) are intermediate particle (normalized) weights and values. This is what is done by the paper. This gradient estimator is biased and it is not clear whether it is consistent. It is not clear whether the bias and/or variance of \eqref{eq:gradient-estimator-2} is smaller than \eqref{eq:gradient-estimator-1}.
[back]