Sequential VAE

Sequential VAEs model observed sequence \(\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_T\) using latent sequence \(\mathbf{z}_1, \mathbf{z}_2, \dots, \mathbf{z}_T\).

In this article, we use \(\mathbf{x}_{1:T}\) and \(\mathbf{z}_{1:T}\) to denote \(\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_T\) and \(\mathbf{z}_1, \mathbf{z}_2, \dots, \mathbf{z}_T\), respectively. Also, we use \(\mathbf{x}_{\neg t}\) to denote \(\mathbf{x}_1,\dots,\mathbf{x}_{t-1},\mathbf{x}_{t+1},\dots,\mathbf{x}_T\), and \(\mathbf{z}_{\neg t}\) likewise.

The Future Dependency Problem

In this section, we shall discuss the future dependency problem of \(\mathbf{z}_t\) on \(\mathbf{x}_{t:T}\) in the variational distribution \(q_{\phi}(\mathbf{z}_{1:T}|\mathbf{x}_{1:T})\), via a simple hidden state markov chain model.

Suppose \(\mathbf{z}_{1:T}\) is a Markov chain, and serves as a sequence of hidden states that determines the \(\mathbf{x}_{1:T}\) sequence. Formally, the probabilistic formulation of such a model can be written as: \[ \begin{align} p_{\lambda}(\mathbf{z}_{1:T}) &= \prod_{t=1}^T p_{\lambda}(\mathbf{z}_t|\mathbf{z}_{t-1}) \\ p_{\theta}(\mathbf{x}_{1:T}|\mathbf{z}_{1:T}) &= \prod_{t=1}^T p_{\theta}(\mathbf{x}_t|\mathbf{z}_t) \end{align} \] This formulation can also be visualized as the following Bayesian network diagram:

Figure 1: Hidden State Markov Chain Model

Surprisingly, the posterior distribution, i.e., \(p_{\theta}(\mathbf{z}_{1:T}|\mathbf{x}_{1:T})\), exhibits future dependency, as is noticed by Fraccaro et al. (2016). Using d-separation, one can easily figure out that: \[ p_{\theta}(\mathbf{z}_{1:T}|\mathbf{x}_{1:T}) = \prod_{t=1}^T p_{\theta}(\mathbf{z}_t|\mathbf{z}_{t-1}, \mathbf{x}_{t:T}) \] Such independence relationship is also illustrated in the following diagram. Clearly, the observation of \(\mathbf{z}_{t-1}\) blocks the dependency of \(\mathbf{z}_t\) on \(\mathbf{x}_{1:(t-1)}\) and \(\mathbf{z}_{1:(t-1)}\), however, does not block its dependence on the future \(\mathbf{x}_{(t+1):T}\).

Figure 2: d-separation illustration for \(\mathbf{z}_t\)

The future dependency seems counter-intuitive at the first glance. However, this dependency can be naturally explained in the information theoretic perspective. Since the sequence \(\mathbf{x}_{t:T}\) is generated with the influence of \(\mathbf{z}_t\), i.e., the information of \(\mathbf{z}_t\) flows into \(\mathbf{x}_{t:T}\), then it should be not suprising that knowing the information of \(\mathbf{x}_{t:T}\) is helpful to infer \(\mathbf{z}_t\).

VRNN

Chung et al. (2015) proposed to embed a variational autoencoder into each step of an LSTM or GRU recurrent network, formalized as: \[ \begin{align} p_{\lambda}(\mathbf{z}_t|\mathbf{h}_{t-1}) &= \mathcal{N}(\mathbf{z}_t|\boldsymbol{\mu}_{\lambda,t}, \boldsymbol{\sigma}_{\lambda,t}^2) \\ p_{\theta}(\mathbf{x}_t|\mathbf{z}_t,\mathbf{h}_{t-1}) &= \mathcal{N}(\mathbf{x}_t| \boldsymbol{\mu}_{\theta,t}, \boldsymbol{\sigma}_{\theta,t}^2) \\ q_{\phi}(\mathbf{z}_t|\mathbf{x}_t,\mathbf{h}_{t-1}) &= \mathcal{N}(\mathbf{z}_t| \boldsymbol{\mu}_{\phi,t}, \boldsymbol{\sigma}_{\phi,t}^2) \\ \mathbf{h}_t &= f_{\text{rnn}}(f_{\mathbf{x}}(\mathbf{x}_t),f_{\mathbf{z}}(\mathbf{z}_t),\mathbf{h}_{t-1}) \\ [\boldsymbol{\mu}_{\lambda,t}, \boldsymbol{\sigma}_{\lambda,t}] &= f_{\text{prior}}(\mathbf{h}_{t-1}) \\ [\boldsymbol{\mu}_{\theta,t}, \boldsymbol{\sigma}_{\theta,t}] &= f_{\text{dec}}(f_{\mathbf{z}}(\mathbf{z}_t), \mathbf{h}_{t-1}) \\ [\boldsymbol{\mu}_{\phi,t}, \boldsymbol{\sigma}_{\phi,t}] &= f_{\text{enc}}(f_{\mathbf{x}}(\mathbf{x}_t), \mathbf{h}_{t-1}) \end{align} \] where \(f_{\mathbf{z}}\) and \(f_{\mathbf{x}}\) are feature networks shared among the prior, the encoder and the decoder, which are “crucial for learning complex sequences” according to Chung et al. (2015). \(f_{\text{rnn}}\) is the transition kernel of the recurrent network, and \(\mathbf{h}_{1:T}\) are the deterministic hidden states of the recurrent network. \(f_{\text{prior}}\), \(f_{\text{enc}}\) and \(f_{\text{dec}}\) are the neural networks in the prior, the encoder and the decoder, respectively.

The overall architecture of VRNN can be illustrated as the following figure, given by Chung et al. (2015):

Figure 3: Overall architecture of VRNN

The following figure may provide a more clear illustration of the dependency among \(\mathbf{h}_{1:T}\), \(\mathbf{x}_{1:T}\) and \(\mathbf{z}_{1:T}\) in the generative part:

Figure 4: Dependency graph of VRNN in the generative part

The authors did not provide a theoretical analysis of the dependency relationship in their variational posterior \(q_{\phi}(\mathbf{z}_{1:T}|\mathbf{x}_{1:T})\), but according to d-separation, we can easily figure out the correct dependency for the true posterior should be: \[ p_{\theta}(\mathbf{z}_t|\mathbf{z}_{1:(t-1)},\mathbf{x}_{1:T},\mathbf{h}_{1:T}) = p_{\theta}(\mathbf{z}_t|\mathbf{h}_{t-1},\mathbf{x}_t,\mathbf{h}_t) \] The dependency of \(\mathbf{z}_t\) on future state \(\mathbf{h}_t\) brings trouble for posterior inference. Chung et al. (2015) simply neglected this dependency. On the other hand, Fraccaro et al. (2016) considered such dependency and proposed SRNN, which brought us to a theoretically more reasonable factorization.

SRNN

Fraccaro et al. (2016) proposed to factorize \(\mathbf{z}_{1:T}\) as a state-space machine, depending on deterministic hidden state \(\mathbf{h}_{1:T}\) of a recurrent network, and potentially the input \(\mathbf{u}_{1:T}\) of each time step. The observation \(\mathbf{x}_t\) of each time step is then assumed to depend only on \(\mathbf{d}_t\) and \(\mathbf{z}_t\). The overall architecture of SRNN is illustrated in the following figure (Fraccaro et al. 2016):

Figure 5: Overall architecture of SRNN

Generative Part

The initial state is chosen to be \(\mathbf{z}_0 = \mathbf{0}\) and \(\mathbf{d}_0 = \mathbf{0}\). According to d-separation, the generative part is formulated as: \[ \begin{align} p_{\lambda}(\mathbf{z}_{1:T},\mathbf{d}_{1:T}|\mathbf{u}_{1:T},\mathbf{z}_0,\mathbf{d}_0) &= \prod_{t=1}^T p_{\lambda}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_t) \, p_{\lambda}(\mathbf{d}_t|\mathbf{d}_{t-1},\mathbf{u}_t) \\ p_{\theta}(\mathbf{x}_{1:T}|\mathbf{z}_{1:T},\mathbf{d}_{1:T},\mathbf{u}_{1:T},\mathbf{z}_0,\mathbf{d}_0) &= \prod_{t=1}^T p_{\theta}(\mathbf{x}_t|\mathbf{z}_t,\mathbf{d}_t) \end{align} \]

\(p_{\lambda}(\mathbf{d}_t|\mathbf{d}_{t-1},\mathbf{u}_t)\) is a dirac distribution, derived by \(\text{RNN}^{(p)}\), a recurrent network: \[ \begin{align} p_{\lambda}(\mathbf{d}_t|\mathbf{d}_{d-1},\mathbf{u}_t) &= \delta(\mathbf{d}_t-\widetilde{\mathbf{d}}_t) \\ \widetilde{\mathbf{d}}_t &= \text{RNN}^{(p)}(\mathbf{d}_{t-1}, \mathbf{u}_t) \end{align} \] \(p_{\lambda}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_t)\) is a state-space machine, given by: \[ \begin{align} p_{\lambda}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_t) &= \mathcal{N}(\mathbf{z}_t| \boldsymbol{\mu}_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t), \boldsymbol{\sigma}^2_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t)) \\ \boldsymbol{\mu}_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t) &= \text{NN}^{(p)}_1(\mathbf{z}_{t-1},\mathbf{d}_t) \\ \log \boldsymbol{\sigma}_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t) &= \text{NN}^{(p)}_2(\mathbf{z}_{t-1},\mathbf{d}_t) \end{align} \]

\(p_{\theta}(\mathbf{x}_t|\mathbf{z}_t, \mathbf{d}_t)\) is derived by: \[ p_{\theta}(\mathbf{x}_t|\mathbf{z}_t,\mathbf{d}_t) = \mathcal{N}(\mathbf{x}_t| \boldsymbol{\mu}_{\theta}(\mathbf{z}_t,\mathbf{d}_t), \boldsymbol{\sigma}^2_{\theta}(\mathbf{z}_t,\mathbf{d}_t)) \]

where \(\boldsymbol{\mu}_{\theta}(\mathbf{z}_t,\mathbf{d}_t)\) and \(\boldsymbol{\sigma}_{\theta}(\mathbf{z}_t,\mathbf{d}_t)\) use similar parameterization as in \(p_{\lambda}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_t)\).

Variational Part

The variational approximated posterior can be factorized as: \[ \begin{align} q_{\phi}(\mathbf{z}_{1:T},\mathbf{d}_{1:T}|\mathbf{x}_{1:T},\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0) &= q_{\phi}(\mathbf{z}_{1:T}|\mathbf{d}_{1:T},\mathbf{x}_{1:T},\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0)\, q_{\phi}(\mathbf{d}_{1:T}|\mathbf{x}_{1:T},\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0) \end{align} \]

Since in the generative part, \(p_{\lambda}(\mathbf{d}_t|\mathbf{d}_{t-1},\mathbf{u}_t)\) is a dirac distribution, Fraccaro et al. (2016) decided to assume the second term in the above equation to be: \[ q_{\phi}(\mathbf{d}_{1:T}|\mathbf{x}_{1:T},\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0) \equiv p_{\lambda}(\mathbf{d}_{1:T}|\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0) = \prod_{t=1}^T p_{\lambda}(\mathbf{d}_t|\mathbf{d}_{t-1},\mathbf{u}_t) \] That is, the same recurrent network \(\text{RNN}^{(p)}\) is used to produce the deterministic states \(\mathbf{d}_{1:T}\) in both the generative part and variational part.

The first term is factorized according to d-separation, as: \[ q_{\phi}(\mathbf{z}_{1:T}|\mathbf{d}_{1:T},\mathbf{x}_{1:T},\mathbf{u}_{1:T},\mathbf{d}_0,\mathbf{u}_0) = \prod_{t=1}^T q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_{t:T},\mathbf{x}_{t:T}) \] where \(q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_{t:T},\mathbf{x}_{t:T})\) is derived by a reverse recurrent network \(\text{RNN}^{(q)}\), whose hidden state was denoted as \(\mathbf{a}_t\), as illustrated in Fig. 5. The formalization of \(q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_{t:T},\mathbf{x}_{t:T})\) is: \[ \begin{align} q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{d}_{t:T},\mathbf{x}_{t:T}) &= q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{a}_t) \\ q_{\phi}(\mathbf{z}_t|\mathbf{z}_{t-1},\mathbf{a}_t) &= \mathcal{N}(\mathbf{z}_t| \boldsymbol{\mu}_{\phi}(\mathbf{z}_{t-1},\mathbf{a}_t), \boldsymbol{\sigma}_{\phi}^2(\mathbf{z}_{t-1},\mathbf{a}_t)) \\ \boldsymbol{\mu}_{\phi}(\mathbf{z}_{t-1},\mathbf{a}_t) &= \boldsymbol{\mu}_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t) + \text{NN}^{(q)}_1(\mathbf{z}_{t-1},\mathbf{a}_t) \\ \log \boldsymbol{\sigma}_{\phi}(\mathbf{z}_{t-1},\mathbf{a}_t) &= \text{NN}^{(q)}_2(\mathbf{z}_{t-1},\mathbf{a}_t) \\ \mathbf{a}_t &= \text{RNN}^{(q)}(\mathbf{a}_{t+1},[\mathbf{d}_t,\mathbf{x}_t]) \end{align} \] Notice \(\text{NN}^{(q)}_1(\mathbf{z}_{t-1},\mathbf{a}_t)\) is adopted to learn the residual \(\boldsymbol{\mu}_{\phi}(\mathbf{z}_{t-1},\mathbf{a}_t) - \boldsymbol{\mu}_{\lambda}(\mathbf{z}_{t-1},\mathbf{d}_t)\), instead of \(\boldsymbol{\mu}_{\phi}(\mathbf{z}_{t-1},\mathbf{a}_t)\) directly. Fraccaro et al. (2016) found that this residual parameterization can lead to better performance.

References

Chung, Junyoung, Kyle Kastner, Laurent Dinh, Kratarth Goel, Aaron C. Courville, and Yoshua Bengio. 2015. “A Recurrent Latent Variable Model for Sequential Data.” In Advances in Neural Information Processing Systems, 2980–8.

Fraccaro, Marco, Søren Kaae Sønderby, Ulrich Paquet, and Ole Winther. 2016. “Sequential Neural Models with Stochastic Layers.” In Advances in Neural Information Processing Systems, 2199–2207.