Attend, Infer, Repeat
19 February 2018
Generative network
Here is a pseudocode for the generative network \(p_\theta(x \given z)\) where the observation \(x \in \mathbb R^{D \times D}\) is an image and \(z\) are all the latent variables in the execution trace.
\(\theta\) contains parameters of various neural nets in the generative network.
- Initialize the image mean \(\mu = 0\),
- While \(\mathrm{sample}\left(\mathrm{Bernoulli}(\rho)\right)\):
- \(z^{\text{where}} = \mathrm{sample}\left(\mathrm{Normal}(0, I)\right)\),
- \(z^{\text{what}} = \mathrm{sample}\left(\mathrm{Normal}(0, I)\right)\),
- \(\hat g = D_\theta(z^{\text{what}})\),
- \(\mu = \mu + \mathrm{STN}^{-1}(\hat g, z^{\text{where}})\),
- \(\mathrm{observe}\left(x, \prod_{\text{pixel } i} \mathrm{Normal}(\mu_i, \sigma_x^2)\right)\).
where \(\mathrm{STN}^{-1}\) is an inverse Spatial Transformer Network, and the \(D_\theta\) a parametric function.
Inference Network
Here is a pseudocode for the inference network \(q_\phi(z \given x)\).
\(\phi\) contains parameters of various neural nets in the inference network.
- Initialize the hidden state \(h = 0\) for the LSTM cell \(R_\phi\),
- \(w, h = R_\phi(\mathrm{concat}(x, 0, 0), h)\),
- While \(\mathrm{sample}\left(\mathrm{Bernoulli}(f_\phi(w))\right)\):
- \(w, h = R_\phi(\mathrm{concat}(x, z^{\text{where}}, z^{\text{what}}), h)\),
- \(z^{\text{where}} = \mathrm{sample}\left(\mathrm{Normal}(\mu_\phi^{\text{where}}(w), \sigma_\phi^{\text{where}}(w)^2)\right)\),
- \(g = \mathrm{STN}(x, z^{\text{where}})\),
- \(z^{\text{what}} = \mathrm{sample}\left(\mathrm{Normal}(\mu_\phi^{\text{what}}(g), \sigma_\phi^{\text{what}}(g)^2)\right)\).
where \(\mathrm{STN}\) is a Spatial Transformer Network, and the LSTM cell \(R_\phi\) takes in an (input, hidden state) pair and outputs an (output, next hidden state) pair.
[back]