GAN3 Wasserstein GANs
Xiaoyang Bai (cne2@illinois.edu)
Introduction
If we view GAN training as a distribution matching process, that is, to train a generated distribution to match the ground truth data distribution , then selecting an appropriate distance metric becomes the core problem.
The original design of GAN (henceforth called vanilla GAN)1 uses an approimation of the Jensen-Shannon divergence (JSD), while other alternatives include:
- The Total Variation (TV) distance
- The Kullback-Leibler divergence (KLD)
- The Earth-Mover’s distance (EMD), also called Wasserstein-1
In this blogpost, we will investigate those different distances and look into Wasserstein GAN (WGAN)2, which uses EMD to replace the vanilla discriminator criterion. After that, we will explore WGAN-GP3, an improved version of WGAN with larger mode capacity and more stable training dynamics.
We shall see that WGAN and WGAN-GP have several advantages over vanilla GAN - reduced mode collapse, less gradient vanishing and more interpretable discriminator losses.
What Is Earth Mover’s Distance?
Firstly we need to formalize each of the distance metrics:
TV distance is calculated as .
KLD is defined as , with a measure function.
JSD is defined as , where .
EMD is defined as , where is the set of all joint distributions with marginals and .
A simple way to understand EMD is to imagine two piles of earth in the shapes of and , and EMD is calculated as the minimal effort it takes to transform one pile to the other. In other words, we need to pick a matching between and out of the set of all matchings , and we want to minimize the effort to transform from one distribution to the other.
Why Earth Mover’s Distance?
To get a sense of the difference between those distances, here is a simple example.
Example. We want to learn a point on the real line. Then if we define to be the Dirac delta distribution centered at , we have the following deductions:
Note that all metrics except for EMD are discontinuous, and they do not converge to 0 as approaches . The following figure is a visualization of EMD (left) and JSD (right).
More formally, the WGAN paper proves that:
- With mild assumptions on , is continuous everywhere, and differentiable almost everywhere
- This is not true for all other metrics
The following figure shows that when and are far apart (or equivalently, when the discriminator is sufficiently well-trained such that it correctly classifies true and fake images), the discriminator outputs are concentrated around 0 and 1 for vanilla GAN, while they fall on a linear function for WGAN.
As a result, the critic’s gradients are almost zero for both true and fake inputs in vanilla GAN, thus hindering further training, while in WGAN such vanishing gradients are avoided.
What’s more, if we consider a sequence of distributions and rank different distance metrics by their convergence conditions, from strong to weak, we get:
- KLD
- TV and JSD
- EMD
For example, we have . This means that it’s easiest to train a model on EMD.
Moreover, the paper proves that , that is, even though EMD has the weakest convergence condition, it can still guarantee that our generated distribution converges to the ground truth .
Therefore it is desirable for us to use EMD as the distribution matching criterion.
Wasserstein GAN
Model Design
The original EMD formula is highly intractable. Fortunately, we can transform it to its Kantorovich-Rubinstein duality form:
.
Here is a function that maps a data sample to a scalar, and basically means that is restricted to be 1-Lipschitz.
This formula naturally blends into the objective function of GAN, and we can straightforwardly write out the objective of WGAN as:
.
The above expression means that we train the discriminator to approximate EMD as accurate as possible, and the generator to minimize the calculated distance.
However, we still need to ensure that is -Lipschitz with a finite . To this end, the WGAN paper proposes to use weight clipping. Such primitive method will be further improved in WGAN-GP.
Experimental Results
The paper trains WGAN and vanilla GAN on the LSUN-bedroom dataset with two different architectures: a smaller-scale MLP and a larger-scale DCGAN. The following figure shows that WGAN is able to converge with both architectures, and the discriminator criterion decreases smoothly as the image quality gets better.
On the contrary, vanilla GAN fails to converge with the MLP architecture, and the discriminator loss does not directly reflect sample quality.
We can conclude from those results that WGAN is more stable, requires less model capacity and provides us with a more interpretable disriminator loss.
WGAN-GP
Weight Clipping: What’s Wrong?
The biggest bottleneck of WGAN is that it uses weight clipping to enforce Lipschitz condition on the discriminator. To show that this design causes trouble, the authors of WGAN-GP make the following proposition:
Proposition. The optimal solution to the Kantorovich-Rubinstein duality form of EMD, denoted as , has gradient norm 1 almost everywhere under and . Furthermore, with any pairing with , it holds that , where .
Therefore, by inspecting whether the discriminator has unit gradient norm, we can know how well WGAN approximates the EMD.
Astonishingly, with different clipping thresholds, WGAN consistently fails to converges to a discriminator with unit gradient norm. Furthermore, the weights of discriminator concentrate around the clipping threshold, largely impairing the discriminator capacity.
Another observation is made by training WGAN on several toy datasets. It is discovered that WGAN discriminators end up learning very simple functions:
Gradient Penalty
A smarter way to encourage (instead of enforce) Lipschitz condition on is to regularize it to have unit gradient norm. The additional loss term, called gradient penalty (GP), has the following form:
, where . That is, is obtained from interpolating real and fake samples.
The GP term is then added to the overall loss function with a weight .
Experimental Results
The paper conducts a large-scale comparison between WGAN-GP and vanilla DCGAN by training them with 200 different architectures. Using Inception Score (IS) as the metric, the following table show that WGAN-GP is consistently better than DCGAN:
Also by evaluating IS, it is shown that WGAN-GP performs better than WGAN on CIFAR-10:
Finally, WGAN-GP beats other SOTA GAN-based models under unsupervised setting, and is comparable to the best model under supervised setting:
Below are some samples (trained on LSUN-bedroom) obtained from the WGAN-GP generator.