Background on VAEs

The Variaional Autoencoders (VAEs) are a method of modeling data distribution p(x)p(\mathbf{x}) by introducing latent random variables. Intuitively, VAE encodes the input into some compressed representation form in the latent space, and by forcing a correct reconstruction, hopefully, the model captures some insights in the data distribution.

To be formal, we propose
pθ(x)=zpθ(xz)pθ(z)dzp_\theta(\mathbf{x})=\int_{\mathbf{z}} p_\theta(\mathbf{x}|\mathbf{z})p_\theta(\mathbf{z})\,d\mathbf{z}

where pθ(z)p_\theta(\mathbf{z}) is the prior distribution, usually assumed to be standard normal distribution, pθ(xz)p_\theta(\mathbf{x}|\mathbf{z}) is the likelihood (probabilistic encoder), and pθ(zx)p_\theta(\mathbf{z}|\mathbf{x}) is the posterior (probablistic decoder).

This posterior, however, need to calculate this integral, which is intractable if z\mathbf{z} is high dimensional. So, we introduce another approximator qϕ(zx)q_\phi(\mathbf{z}|\mathbf{x}).

The architecture can be summarized as follows:

VAE architecture
Image Credits to Wikipedia on Variational Autoencoder__;Kg!!DZ3fjg!t6Ws-NJYmBcOfbpLUXAjo8DEPtCG30oxgsJHir59ycjXJtqwfs7MNQ-7N8ZuWMjx9w$

Naturally, in order to train the model, we want to maximize the probability on the dataset, pθ(x)p_\theta(\mathbf{x}). We achieve this through maximizing a lower bound of it.

Evidence Lower Bound (ELBO)

logpθ(x)logpθ(x)DKL(qϕ(zx)pθ(zx))undefined0=Ezqϕ(zx)[logpθ(xz)]undefinedReconstructionDKL(qϕ(zx)pθ(z))undefinedRegularizationundefinedELBO\begin{align*} \log p_\theta(\mathbf{x}) &\geq \log p_\theta(\mathbf{x})-\overbrace{D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z}|\mathbf{x})\right)}^{\geq 0} \\ &=\underbrace{\underbrace{\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]}_\text{Reconstruction} - \underbrace{D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)}_\text{Regularization}}_\text{ELBO}\end{align*}

In practice, we use gradient descent to minimize the negative ELBO, called VAE Loss.

Lθ,ϕ(x)=Ezqϕ(zx)[logpθ(xz)]+DKL(qϕ(zx)pθ(z))\mathcal{L}_{\theta,\phi}(\mathbf{x})=-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]+D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)

What is Desired? Disentanglement!

Disentanglement = Independence + Semantics

We are hoping that unsupervised learning could produce some results that have special meanings to human begins. One specific factor is whether each dimension in the latent space has a atomic meaning capturing some concept from the dataset.

  • Unsupervised learning of a disentangled posterior distribution over the underlying generative factors of sensory data is a major challenge in AI research 1 2.
  • Motivations include discovering independent components, controllable sample generation, and generalization/robustness.
  • Facilitates interpretable decision making and controlled transfer.

The following graph from Ricky Chen’s talk demontrates what we want clearly. For different sample points in the latent space, they are of different genders, ages, and etcs. However, along the axis pointed by the arrow, it means whether the generated images wear sunglasses. Such disentanglement in the space means we can reliably predict how the generated images would change.

enter image description here
Axis-aligned traversal in the representation space and Global interpretability in data space. Image Credits to Ricky Chen’s talk at NIPS 2018

On the other hand, vanilla VAE’s objective is focusing only on reconstruction, if we look at the right examples, traversing along an axis does not produce a smooth changing trend.

enter image description here
Traversal of the rotationallatent dimension 3.

Datasets for Disentanglement

Most of these datasets are specifically constructed so that the intended disentanglement factors are clear. Take dSprites as an example, factors include posXposX, posYposY, rotationrotation, shapeshape, and scalescale.

Common datasets for disentanglement task
Common datasets used in the disentanglement task4.

Related Works

DC-IQN

An obvious attempt is to attach meanings to latent spaces by designers. Deep Convolutional Inverse Graphics Network (DC-IGN) 5 is a model similar to a VAE with special designed training procedure to enforce a designed latent space.

DC-IQN architecture
DC-IQN architecture 5.
DC-IQN latent structure
DC-IQN latent structure 5

In short, to enforce the structure, they use a modified training proceduring. First select a dimension corresponding to a factor, then form a minibatch where only that factor changes. They masked the output of other dimensions by averaging them so that the gradient signal is mixed and mingled, which is going to force the network to capture the changes in the specifie dimension.

DC-IQN training
DC-IQN training 5

InfoGAN

The GAN formulation uses a simple factored continuous input noise vector z\mathbf{z}, but imposing no restrictions on how the generator may use it. So the generator may use it in a highly entangled way.

However, in InfoGAN6,

  • Uses a set of structured latent variables c=(c1,,cL)\mathbf{c}=(c_1,\dots,c_L), and assuming p(c)=i=1Lp(ci)p(\mathbf{c})=\prod_{i=1}^L p(c_i).
  • The generator becomes G(z,c)G(\mathbf{z}, \mathbf{c}).
  • With no constraints, the generator could ignore c\mathbf{c}, pG(xc)=pG(x)p_G(\mathbf{x}|\mathbf{c})=p_G(\mathbf{x}).
  • There should be high mutual information between latent code c\mathbf{c} and the generator distribution, meaning I(c;G(z,c))I(\mathbf{c};G(\mathbf{z},\mathbf{c})) should be high.

An Attempt: β\beta-VAE

ELBO from Another Perspective

Quick Mention on Karush-Kuhn-Tucker (KKT) Conditions

If we have a non-linear programming problem.
Optimize f(x)subject to gi(x)0,i=1,,mhj(x)=0,j=1,,r \begin{align*} &\text{Optimize } &&f(\mathbf{x}) \\ &\text{subject to } &&g_i(\mathbf{x})\leq 0, i=1,\dots,m \\ & &&h_j(\mathbf{x})=0, j=1,\dots,r \end{align*}

Then, we can form the Lagrangian function:
L(x,μ,λ)=f(x)+μT[g1(x),,gm(x)]T+λT[h1(x,,hl(x)]TL(\mathbf{x},\mathbf{\mu},\mathbf{\lambda})=f(\mathbf{x})+\mathbf{\mu}^T\left[g_1(\mathbf{x}),\dots,g_m(\mathbf{x})\right]^T+\mathbf{\lambda}^T\left[h_1(\mathbf{x},\dots,h_l(\mathbf{x})\right]^T

If (x,μ,λ)(\mathbf{x}^*,\mathbf{\mu}^*,\mathbf{\lambda}^*) solves the problem, then Karush-Kuhn-Tucker Conditions holds:

  • Stationarity: f(x)+i=1mμigi(x)+j=1rλjhj(x)=0\nabla f(\mathbf{x}^*)+\sum_{i=1}^m \mu_i\nabla g_i(\mathbf{x}^*)+\sum_{j=1}^r\lambda_j\nabla h_j(\mathbf{x}^*)=0 for minimization.
  • Primal Feasibility: gi(x)0,i=1,,mg_i(\mathbf{x}^*)\geq 0,i=1,\dots,m and hj(x)=0,j=1,,rh_j(\mathbf{x}^*)=0,j=1,\dots,r.
  • Dual Feasibility: μi0,i=1,,m\mu_i\geq 0, i=1,\dots,m.
  • Complementary Slackness: i=1mμigi(x)=0\sum_{i=1}^m \mu_i g_i(\mathbf{x}^*)=0.

If we take a look at the VAE loss again
θ,ϕ=argminθ,ϕ{Ezqϕ(zx)[logpθ(xz)]+DKL(qϕ(zx)pθ(z))]}\theta,\phi=\underset{\theta,\phi}{\arg\min}\left\{-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]+D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)]\right\}

We can formulate it as a constrained optimization problem:

Optimization Problem from ELBO

minθ,ϕEzqϕ(zx)[logpθ(xz)] subject to DKL(qϕ(zx)pθ(z))]<ϵ\min_{\theta,\phi}-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]\text{ subject to }D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)]<\epsilon

Rewriting it as a Lagrangian under KKT conditions, we have
F(θ,ϕ,β;x,z)=Ezqϕ(zx)[logpθ(xz)]+β(DKL(qϕ(zx)pθ(z))]ϵ)\mathcal{F}(\theta, \phi, \beta;\mathbf{x},\mathbf{z})=-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]+\beta\left(D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)]-\epsilon\right)

Since β,ϵ0\beta,\epsilon\geq 0 according to the complementary slackness, we have the β\beta-VAE Loss:
F(θ,ϕ,β;x,z)L(θ,ϕ,β;x,z)=Ezqϕ(zx)[logpθ(xz)]+βDKL(qϕ(zx)pθ(z))]\mathcal{F}(\theta, \phi, \beta;\mathbf{x},\mathbf{z})\geq \mathcal{L}(\theta, \phi, \beta;\mathbf{x},\mathbf{z})=-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]+\beta D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)]

β\beta-VAE Loss

L(θ,ϕ,β;x,z)=Ezqϕ(zx)[logpθ(xz)]+βDKL(qϕ(zx)pθ(z))]\mathcal{L}(\theta, \phi, \beta;\mathbf{x},\mathbf{z})=-\mathbb{E}_{\mathbf{z}\sim q_\phi(\mathbf{z}|\mathbf{x})}\left[\log p_\theta(\mathbf{x}|\mathbf{z})\right]+\beta D_\text{KL}\left(q_\phi(\mathbf{z}|\mathbf{x})\|p_\theta(\mathbf{z})\right)]

Observations

  • Setting β=1\beta=1 corresponds to the original VAE formulation.

  • Setting β>1\beta>1 puts a stronger constraint on the latent bottleneck

    • Limiting the capacity of z\mathbf{z} while trying to maximize the log-likelihood should encourage the model to learn a more efficient representation.
    • Higher value of β\beta should encourage the conditional independence in qϕ(zx)q_\phi(\mathbf{z}|\mathbf{x}) because more weights are put on the DKLD_\text{KL} term.
  • Disentangled representation emerge when the right balance is found between reconstruction and latent capacity restriction.

    • Create a trade-off between reconstruction fidelity and the quality of the disentanglement.
  • Note: In real implementations, β\beta is usually a training-step dependent variable, from 0 to the set value. The intuition behind this warm-up is to first get the network to be able to learn reconstruction.

Measuring Disentanglement - Higgins’ Metric

The basic idea to measure the quality of disentanglement is to have a pair of data points where one factor is fixed while others are sampled randomly. Then, we could let a classifier acts on the difference between their latent representations and see whether the fixed factor could be singled out, and report the classifier accuracy as the disentanglement score.

disentanglement score 1
Image from 3.

Results

betaVAE table
Results from β\beta-VAE3.
betaVAE results
Results from β\beta-VAE3.
betaVAE results
Results from β\beta-VAE3.
betaVAE results
Results from β\beta-VAE3.
betaVAE results
Results from β\beta-VAE3.

The Effect of Tuning β\beta

  • β\beta is a mixing coefficient that weighs the gradients magnitudes between reconstruction and the prior-matching. So it is natural to consider normalized β\beta in analysis by the latent space dimension MM and input data dimension NN, βnorm=βMN\beta_\text{norm}=\frac{\beta M}{N}.
  • β\beta being too low or too high, the model would learn a entangled representation due to either too much or too little capacity in the latent z\mathbf{z} bottleneck.
  • Good disentanglement representations often lead to blurry reconstructions. However, in general, β>1\beta>1 is necessary to achieve good disentanglement.
beta effect.
Positive correlation is present between the size of z\mathbf{z} and the optimal normalised values of β\beta for disentangled factor learning for a fixed β\beta-VAE architecture. Orange approximately corresponds to unnormalized β=1\beta=13.

How Does It Work?

Shortly, it is not clear from β\beta-VAE. Thus, we need to investigate why have a large β\beta penalizing the KL term has such effect.

Decomposing the ELBO More

Quick Mention on Mutual Information (MI)

Let (X,Y)(X,Y) be a pair of r.v.s over the space X×Y\mathcal{X}\times\mathcal{Y}. Then their mutual information is

  1. I(X;Y)=DKL(p(X,Y)p(X)p(Y))I(X;Y)=D_\text{KL}(p(X,Y)\|p(X)p(Y))
  2. KaTeX parse error: Undefined control sequence: \E at position 8: I(X;Y)=\̲E̲_X\left[D_\text…

I(X;Y)I(X;Y) intuitively measures how much could you infer about the other random variable if you are given knowledge about one of them. I(X;Y)=0I(X;Y)=0 means independence because nothing can be inferred (not related at all).

TC-Decomposition

Define a uniform random variable on {1,2,,N}\{1,2,\dots,N\} with which each data point relates. Denote q(zn)=q(zxn)q(\mathbf{z}|n)=q(\mathbf{z}|x_n) and q(z,n)=q(zn)p(n)=q(zn)1Nq(\mathbf{z}, n)=q(\mathbf{z}|n)p(n)=q(\mathbf{z}|n)\frac{1}{N}. q(z)=n=1Nq(zn)p(n)q(\mathbf{z})=\sum_{n=1}^N q(\mathbf{z}|n)p(n) is the \emph{aggregated posterior}. Then, we can decompose the regularization term in the ELBO as

1Nn=1NDKL(q(zxn)p(z))=Ep(n)[DKL(q(zn)p(z))]=DKL(q(zn)p(z))undefinedIndex-Code MI+DKL(q(z)jq(zj))undefinedTotal Correlation+jDKL(q(zjp(zj)))undefinedDimension-wise KL \begin{align*} & \dfrac{1}{N}\sum_{n=1}^N D_\text{KL}\left(q(\mathbf{z}|x_n)\|p(\mathbf{z})\right) = \mathbb{E}_{p(n)}\left[D_\text{KL}\left(q(\mathbf{z}|n)p(\mathbf{z})\right)\right] \\ &=\underbrace{D_\text{KL}(q(\mathbf{z}|n)\|p(\mathbf{z}))}_\text{Index-Code MI} + \underbrace{D_\text{KL}(q(\mathbf{z})\|\prod_j q(z_j))}_\text{Total Correlation} + \underbrace{\sum_j D_\text{KL}\left(q(z_j\|p(z_j)\right))}_\text{Dimension-wise KL} \end{align*}

  • The index-code MI is the mutual information Iq(z;n)I_q(\mathbf{z};n). It is argued that higher mutual information can lead to better disentanglement, but recent investigations also claim that a penalized one encourages compact and disentangled representations.
  • The total correlation is one of many generalization of mutual information. It is a measure of dependency between the variables. This is claimed to be the main source of disentanglement.
  • The dimension-wise KL divergence mainly prevents individual latent dimensions from deviating too far from priors. It acts like a complexity penalty.

β\beta-TCVAE Loss

L=Eq(zn)p(n)[logp(nz)]+αIq(z;n)+βDKL(q(z)jq(zj))+γjDKL(q(zjp(zj)))\mathcal{L}=-\mathbb{E}_{q(\mathbf{z}|n)p(n)}\left[\log p(n|\mathbf{z})\right]+\alpha I_q(\mathbf{z};n) + \beta D_\text{KL}(q(\mathbf{z})\|\prod_j q(z_j)) +\gamma \sum_j D_\text{KL}\left(q(z_j\|p(z_j)\right))

  • With ablation studies, tuning β\beta leads to the best results. The proposed model uses α=γ=1\alpha=\gamma=1, which is the same object as in FactorVAE7.
  • Provides better trade-off between density estimation and disentanglement. Different from β\beta-VAE, higher value of β\beta would not penalize the mutual information term too much.
alpha term
Ablation study shows that setting α\alpha to zero gives no clear improvement4.

Estimate Density from Minibatch

Decomposition expression requires the evaluation of the density q(z)=Ep(n)[q(zn)]q(\mathbf{z})=\mathbb{E}_{p(n)}\left[q(\mathbf{z}|n)\right], which depends on the entire dataset. Simple Monte Carlo approximation is not likely to work, so we need weighted sampling. Given a minibach of samples {n1,,nm}\{n_1,\dots,n_m\}, we use the estimator

Eq(z)[logq(z)]1Mi=1M[log1NMj=1Mq(z(ni)nj)]\mathbb{E}_{q(\mathbf{z})}\left[\log q(\mathbf{z})\right]\approx \dfrac{1}{M}\sum_{i=1}^M \left[\log \dfrac{1}{NM}\sum_{j=1}^M q(\mathbf{z}(n_i)|n_j)\right]

where z(ni)q(zni)\mathbf{z}(n_i)\sim q(\mathbf{z}|n_i).

Measuring Disentanglement - Mutual Information Gap (MIG)

Higgins’ metric uses an extra classifier, which introduced hyperparameters and more training time. In addition, it cannot meausre axis alignment. Is there a metric based only on the distribution of factors and latent variables?

Mutual Information Gap (MIG) is introduced to solve these problems. Estimate the mutual information between a latent variable ziz_i and a ground truth factor vkv_k by q(zj,vk)=n=1Np(vk)p(nvk)q(zjn)q(z_j,v_k)=\sum_{n=1}^N p(v_k)p(n|v_k)q(z_j|n), and use it in some way. A higher mutual information implies that zjz_j contains a lot of information about vkv_k. MI is maximal if there exists a deterministic, invertible relationship between zjz_j and vkv_k.

  1. For each vkv_k, take zj,zlz_j,z_l that has the highest and the second highest mutual information with vkv_k.
  2. MIG=1Kk=1K1H(vk)(I(zj;vk)I(zl;vk))\text{MIG}=\frac{1}{K}\sum_{k=1}^K \frac{1}{H(v_k)}\left(I(z_j;v_k)-I(z_l;v_k)\right)

Averaging by KK and normalizing by the entropy H(vk)H(v_k) provides a value between 0 and 1. MIG1\text{MIG}\rightarrow 1 implies good disentanglement.

MIG 1
Joint distribution between latent variables and ground truth factors. Image Credits to Ricky Chen’s talk at NIPS 2018
MIG 2
Mutual information between latent variables and ground truth factors. Image Credits to Ricky Chen’s talk at NIPS 2018

Results

TC results 1
Results from β\beta-TCVAE4.
TC results 2
Results from β\beta-TCVAE4.
TC results 3
Results from β\beta-TCVAE4.

Conclusion

There have been many efforts in different machine learning communities to produce interpretable artificial intelligence systems. Unsupervised learning is a particularly hard task to enforce the interpretability and independence between representations. However, through the exploration and attempt, we have gained more understanding towards its objective (ELBO) and optimization process, and we have many amazing results where the underlying factors are disentangled.


  1. Yoshua Bengio, Aaron Courville, and Pascal Vincent.Representation learning: A review and new perspectives. IEEE transactions on pattern analysis and machine intelligence, 35(8):1798–1828, 2013. ↩︎

  2. Brenden M Lake, Tomer D Ullman, Joshua B Tenenbaum, and Samuel J Gershman. Building machines that learn and think like people. Behavioral and brain sciences, 40, 2017. ↩︎

  3. Irina Higgins, Loic Matthey, Arka Pal, Christopher Burgess, XavierGlorot, Matthew Botvinick, Shakir Mohamed, and Alexander Lerchner. beta-vae: Learning basic visual concepts with a constrained variational framework. 2016. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  4. Ricky TQ Chen, Xuechen Li, Roger Grosse, and David Duvenaud. Isolating sources of disentanglement in variational autoencoders. arXiv preprint arXiv:1802.04942, 2018. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎

  5. Tejas D Kulkarni, Will Whitney, Pushmeet Kohli, and Joshua BTenenbaum. Deep convolutional inverse graphics network. arXiv preprint arXiv:1503.03167, 2015. ↩︎ ↩︎ ↩︎ ↩︎

  6. Xi Chen, Yan Duan, Rein Houthooft, John Schulman, Ilya Sutskever, and Pieter Abbeel. Infogan: Interpretable representation learning by information maximizing generative adversarial nets. In Proceedings of the 30th International Conference on Neural Information Processing Systems, pages 2180–2188, 2016. ↩︎

  7. Hyunjik Kim and Andriy Mnih. Disentangling by factorising. In International Conference on Machine Learning, pages 2649–2658. PMLR, 2018. ↩︎