4 minute read

Introduction

While many works have focused on learning representations with continuous features, discrete representations might be a more natural fit for several modalities, including language, speech, and images. Additionally, discrete representations align well with complex reasoning, planning and predictive scenarios (e.g., deciding to use an umbrella if it rains). Despite the challenges associated with using discrete latent variables in deep learning, vector quantizing VAE (VQ-VAE) proposed by Van Den Oord et al. 2017 succesfully combines the VAE framework with discrete latent representations. The model is built on vector quantization (VQ) method, which maps $K$-dimensional vectors into a finite set of code vectors maintained in codebooks.

image
$\mathbf{Fig\ 1.}$ VQ-VAE framework (source: Van Den Oord et al. NeurIPS 2017)


VQ-VAE

Forward computation pipeline

Define a latent embedding space $e \in \mathbb{R}^{K \times D}$ where $K$ is the size of the discrete latent space (i.e., a $K$-way categorical), and $D$ is the dimensionality of each latent embedding vector $e_i$.

  1. The model takes an input $x$ and pass it through an encoder $z_e$, producing output $z_e (x)$.
  2. The discrete latent variables $z$ are then calculated by a nearest neighbor look-up using the shared embedding space $e$, i.e. $k = \text{argmin} _ j \lVert z_e (x) - e_j \rVert_2$. Hence the posterior categorical distribution $q(z \vert x)$ probabilities are defined as one-hot as follows: $$ \begin{aligned} q(z=k \vert x) = \begin{cases} 1 & \quad \text{for } k = \text{argmin}_j \lVert z_e (x) - e_j \rVert_2 \\ 0 & \quad \text{ otherwise } \end{cases} \end{aligned} $$
  3. The input to the decoder is the corresponding embedding vector $e_k$, i.e. $$ \begin{aligned} z_q (x) = e_k \quad \text{ where } \quad k = \text{argmin}_j \lVert z_e (x) - e_j \rVert_2 \end{aligned} $$

Although the above discussion is using a single random variable $z$ to represent the discrete latent variables, but it can be generalized to 1D vector, 2D matrix, and tensors.

Training Loss

Since the argmin operator is not differentiable, the gradients $\nabla_z L$ from decoder input $z_q$ is copied to the encoder output $z_e$. (straight-through gradient estimation) And the loss $L$ has 3 components: reconstruction loss, VQ loss, and commitment loss.

  • reconstruction loss
    • optimizes the decoder and the encoder
  • VQ loss
    • Note that the gradient bypasses the embedding space due to argmin operator.
    • $L_2$ error between the embedding space and the encoder outputs to learn embedding space.
    • And the embedding vectors are updated through EMA (exponential moving average). Given a code vector $e_i$, let \(\{z_{i, j}\}_{j = 1}^{n_i}\) be the set of encoder output vectors that are closest to $e_i$, i.e. that are quantized to $e_i$. Then the loss term can be written as
    \[\begin{aligned} \sum_{j=1}^{n_i} \lVert z_{i, j} - e_i \rVert_2^2 \end{aligned}\]

    which has a closed form solution (the average of elements in the set). But this update cannot be performed directly when working with minibatches. Instead, with $\gamma \in [0, 1]$ (the authors found $\gamma = 0.99$ works well in practice)

    \[\begin{aligned} N_i^{(t)} & := \gamma N_{i}^{(t-1)} + (1-\gamma)n_i^{(t)} \\ m_i^{(t)} & := \gamma m_{i}^{(t-1)} + (1-\gamma) \sum_{j} z_{i, j}^{(t)} \\ e_i^{(t)} & := \frac{m_i^{(t)}}{N_i^{(t)}} \end{aligned}\]

    where $N_i$ and $m_i$ are accumulated vector count and volume, respectively.

  • Commitment loss
    • Since there are no constraints on the embedding space, it can grow arbitrarily. Suppose we have $z = 0$ for cats and $z = 1$ for dogs. For simplicity, assume that the embedding has just one dimension, and initialize $e_1=[−1]$ and $e_2=[1]$. Then the decoder will receive as an input one of $e_1$ and $e_2$ and naturally it will want to push two latent variable of cats and dogs away from each other, e.g. push to $-\infty$ for cats and $\infty$ for dogs. And this phenomenon is spread to embeddings $e_1$ and $e_2$ due to the VQ-VAE loss.
    • commitment loss encourages the encoder output to stay close to the embedding space and to prevent it from fluctuating too frequently from one code vector to another.

The following equation specifies the overall loss function.

\[\begin{aligned} L = \underbrace{\log p(x \vert z_q (x))}_{\textrm{reconstruction loss}} + \underbrace{\lVert \text{sg}[z_e (x)] - e \rVert_2^2}_{\textrm{VQ loss}} + \underbrace{\beta \lVert z_e (x) - \text{sg}[e]\rVert_2^2}_{\textrm{commitment loss}} \end{aligned}\]

where $\text{sg}[.]$ is the stop-gradient operator.

Prior distribution $p(z)$

While training, the prior distribution over the discrete latents $p(z)$ is a uniform distribution, and can be made autoregressive by depending on other $z$ in the feature map.

At first, train usual VQ-VAE with uniform prior. The figure below shows that the discrete codes are able to capture some regularities from the MNIST dataset. But since the distribution of the codebook $p(z)$ is simply uniform distribution, this is not enough to generate novel and meaningful images. Not only the form of embedding vectors, but also the relationships (dependencies) between them should be learned for the inference stage.

image
$\mathbf{Fig\ 2.}$ Trained embedding codes on MNIST dataset (source: Sayak Paul)


To solve certain problems of uniform initialization of codebook, we freeze VQ-VAE and train autoregressive generative model (e.g. PixelCNN for image tasks, WaveNet for audio tasks) to predict the embedding vector, represented as embedding space indices. Then the trained VQ-VAE decoder is used to map the indices generated by the autoregressive model back into the data space. By coupling these representations with an autoregressive prior, VQ-VAE models can generate high quality samples. Moreover, this enables not only a substantial acceleration of training and sampling but also the utilization of the autoregressive model’s capability to capture the global structure.

image
$\mathbf{Fig\ 3.}$ Samples (128x128) from a VQ-VAE with a PixelCNN prior trained on ImageNet images. (source: [1])





Reference

[1] Van Den Oord et al. “Neural discrete representation learning.” NeurIPS 2017..
[2] Sayak Paul, Vector-Quantized Variational Autoencoders, Keras

Leave a comment