How GANs work and how you can use them to synthesize data
If you’re working in deep learning, you’ve probably heard of GANs, or Generative Adversarial Networks (Goodfellow et al, 2014). In this post we will explain what GANs are, and discuss some use cases with real examples. I am adding to this post a link to my GAN playground, called MP-GAN (Multi Purpose GAN). I prepared this playground in github as a research framework, and you are welcome to use it to train and explore GANs for yourselves. In the appendices I present and discuss some of the experiments I did on GAN training, using this playground.
GANs are part of a family of generative deep learning architectures, whose goal is to generate synthetic data, instead of predicting features of existing data points, as is the case with classifiers and regressors (both belong to a family of models called discriminative models. Object detection neural networks discussed in some of my previous post, like the object detector YOLOv3 and CenterNet, are a combination of a classifier and regressor, and therefore are also discriminative models). Among other generative machine learning models that we will not discuss at this time, are variational autoencoders (VAEs), diffusion models, and restricted Boltzman machines (RBMs).
Why Generate Data?
- To improve the training of discriminative models – Some applications, e.g. autonomous driving, require extremely large mileage data. Furthermore —for safety, the models need to train extensively on marginal cases like accidents, near-accidents, and aberrant behavior of other vehicles, with not enough examples in the actual collected data. Other examples: Image-based fire detection system; Automatic flaw detection in IC production lines; Synthetic scenarios for fraud detection algorithms and for multi-sensor machine failure detection systems (tabular data synthesis).
- Commercial — Many appealing images are hard or impossible to create in reality, or expensive and time-consuming to paint by hand (even when using a dedicated software). In this case an artificially generated image can be a fair substitute. For example: synthetic bedroom images in linen commercials, or synthetic human face (as shown in Fig. 1) for a toothpaste commercial.
- Artistic — If you are experienced in the medium, then generative models are a tool — just like a brush. Some artists are experts in generating artificial images that are visually appealing.
GAN Structure and Flow
As their name suggests, GANs consist of two rival neural networks — One — the generator (or G), tries to generate synthetic examples of data, and the other — the discriminator (or D) tries to distinguish the synthetic samples from real samples. D is, in fact, a classification model. Let’s assume we want to automate our mailing system. We train a robot arm connected to a camera to read zip code digits off the envelope, but we fear we don’t have enough samples, and the robot will get confused by hard examples. Therefore, we want to generate many synthetic handwritten digit images to boost the training set. The basic training flow after initializing both D and G goes like this:
- Freeze G and train only D on a few real and a few synthetic images (generated by G).
- Freeze D and train only G with the loss corresponding to the ratio of samples that D correctly classified as ‘synthetic samples’.
- Evaluate the results, and repeat until a satisfactory performance is achieved (if the real to synthetic images ratio presented to D at stage 2 is 50/50, then the ideal result is that at stage 3, D misclassifies both the synthetic and the real examples 50% of the time).
The high level structure of a GAN is given in the two following illustrations in Fig. 2:
As can be expected, at first both G and D will suck at what they do: D has no idea in the first training steps, how a valid digit image should look like, and neither does G. But through the labels (‘real’, ‘synthetic’ supplied to the D training phase, D gains some knowledge of how real data samples should look like. After a few examples, D improves, marginally, at classifying samples to real and synthetic (remember — at this point the synthetic samples are terrible, so it’s not so hard to learn the difference). Then we freeze its parameters, and train G. The loss incurred when D catches the fraud, pushes G to generate samples that look like what D perceives as real, and so forth. Fig. 3 demonstrates the process of training a GAN to generate samples of the handwritten digit ‘8’.
Experimenting with GANs for high-resolution, color images such as human faces, is very compute-heavy, so for simplicity, let’s limit our discussion to MNIST data, e.g. — 28×28 pixel grayscale images of handwritten digits. (MNIST is a modification of NIST data, Cohen, G., Afshar, S., Tapson, J., & van Schaik, A. (2017). EMNIST: an extension of MNIST to handwritten letters. Retrieved from http://arxiv.org/abs/1702.05373. It was created by LeCun, Cortes and Burges. MNIST dataset is made available under the terms of the Creative Commons Attribition-Share Alik3 3.0 License, see license details here).
We will argue that each digit image is a 28×28 = 784-dimensional vector, with the value of each coordinate being equal to the grayscale level of the corresponding pixel. The volume of this image space is finite, but huge: it’s a hypercube with 256⁷⁸⁴ different coordinate combinations, or images.
Naturally, most of the volume in this space corresponds to completely meaningless images. A tiny portion of the space corresponds to meaningful images, and a smaller-still portion corresponds to digit images.
Let’s phrase our steps and objectives in the terms used by researchers:
- Generating a synthetic digit image is equivalent to conjuring up a vector in the 784-dimensional space, somewhere inside a densely populated blob of real digit images, where the probability distribution of digit images is high.
- While we can, relatively easily, classify a given vector as either a valid or non-valid digit image (by training a classifier), the reverse process, of conjuring up a vector inside a digit image blob, is hard. That is mainly because the valid digit images are not simply concentrated in one, or a few nice and spherical blobs. Instead, they are scattered in numerous filaments across this 784-dimensional space; as a quick demonstration, think of a set of valid digit images, and then shift all the images one pixel to the right — this would form another set of valid digit images, but very far from the first set in the 784-dimensional space.
- Therefore, we are modifying our task — to learn a transformation from a nice, cozy and known distribution (e.g. gaussian) in another (latent) space, to the filaments of valid digit images in the 784-dimensional image space.
- After we’ve done that, we can draw points from dense areas in the latent space, to generate more samples of realistic images. The transformations learned by G and D are illustrated in Fig. 4.
A few points about this modified task:
- Whenever this learned function transforms a latent point into a non-valid image (as judged by the discriminator), the incurred loss pushes it slightly toward one of the valid image blobs.
- By definition, the model will train on more examples from the dense part of the distribution in the latent space, than on points from the sparse part, and will, therefore, be more motivated to transform them into valid images (i.e. near centers of valid image blobs). Given that the learning process mentioned above is effective, this will eventually lead to points in the dense part of the latent space distribution transforming to images in dense areas of the valid image blobs. What this means in practice, is that if you’ve trained your GAN by inputting samples from a normal distribution noise, then you can expect samples from the region close to the origin, to produce valid data samples, and and noise samples that are far away from the origin, to produce less realistic data samples. Check out the experiments section in the appendix to see how the generated samples change as we traverse the latent space.
If we’re training the discriminator from scratch (as is usually the case with GANs), then the contours that the generator tries to learn are very inaccurate at first (because the discriminator doesn’t know any better) but they improve with each discriminator learning phase. However, if we somehow have a somewhat trained classifier, we can use it as our discriminator and have a head start for the training of the generator.
The details of GAN training vary between users. Some train D for a few steps, and then train G, and so on. Some switch between them with each step. In my MP-GAN framework each G-training step (batch) is preceded by a D-training step, split into two: in the first half the discriminator is shown real data, and in the second half it is shown synthetic data.
It is commonly regarded as good practice to use Adam optimizer, probably because the momentum helps to stabilize the very noisy training process. As I show in the experiments section, this practice seems to be empirically justified.
Note that training GANs is more tricky than training discriminative models. In particular, Goodfellow et al., the authors of the original paper, mention ‘the Helvetica scenario’, later dubbed ‘mode collapse’, in which G maps multiple points in the latent space to the same output, or to a narrow region in the output space. This can happen if the training process is imbalanced, and G trains too much compared to D. For example — if at one point in the training, D is, by chance, better at classifying real and synthetic images of the digit ‘0’ than the other digits, and it stops training, then G is encouraged to transform all the latent points to ‘0’, and by the time D starts to train again, G may already be happily stuck in a local minimum with no motivation to leave (unless it’s somehow penalized for not generating other digits).
Another issue that makes GAN training hard is the inherent instability of the process, since the training is trying to minimize two loss functions simultaneously, but it does so alternatingly, each time sampling one loss landscape and updating parameters, and then the other, but as parameters of one model change, this also affects the loss of the other model.
GANs, or Generative Adversarial Networks, are a deep learning mechanism that learns to generate new data samples via a training competition between two models — a generator and a discriminator.
Training GANs is more tricky than the training discriminative models due to the inherent instability of the problem and the risk of mode collapse.
Using a framework such as the one I propose in the appendices, it is possible to build and train various architectures of GANs, and research the dynamics of their training.
Check out my GAN experiments, a few lines down!
Visit my previous posts:
In this section I’ll present several GAN experiments I did, using a GAN playground (MP-GAN) I prepared in github.
You’re welcome to fork from my MP-GAN (Multi-Purpose GAN) repository and experiment with different GAN architectures, datasets and training scenarios. This project supports both image (currently single channel only) and tabular data.
Experiment 1 — Effect of optimizer on training convergence
I used my MP-GAN github framework to train two identical architectures, one using the SGD optimizer, and another using the Adam optimizer. Sampling the generator’s output reveals the difference right from the first epochs. My findings support the claim made by multiple sources — Adam optimizer is indeed better for this task. A sample is given in Fig. 5.
Experiment 2 — Compare convergence from scratch to using a pretrained discriminator
I hypothesized that, while the generator D is trained from scratch by definition, there shouldn’t be a mathematically fundamental reason why the discriminator needs to train from scratch. The common reason for that is practical — usually we simply don’t have a model trained to classify real and synthetic images. But if we did have a trained discriminator somehow (e.g. from a previous training) — using it shouldn’t hurt the convergence of the generator. In Fig. 6 I compare samples from a generator trained with a discriminator from scratch, vs. a generator training when the discriminator is taken from a previous training session. As can be seen, after 10 epochs the GAN using a pretrained discriminator produces more advanced and realistic samples than the one that trains from scratch. However, after 50 epochs the GAN trained from scratch seems to have caught up (Fig. 7). The scarcity of some digits in the synthetic samples in both runs may indicate a mode collapse, as explained above.
Following that line of thought — I added an option to freeze the discriminator parameters and save training time. I found that, as expected, the total training time shortens, but not significantly, since in this architecture the discriminator is much smaller than the generator (40k params vs. 1M params).
Another experiment with possibly interesting implications (I haven’t done it yet) is to take a regular classifier that was trained on the dataset (possibly — use the very same classifier you wish to improve by synthesizing more data for!) and use it as the pretrained discriminator and speed up generator training convergence, with or without freezing most of the discriminator parameters, to save time.
Experiment 3 — Explore trajectories in the latent space
A nice experiment is to see if neighboring points in the latent space transform to similar images. In Fig. 8, I created a 2D coordinate grid and embedded it in the latent, 100-dimensional space, where the other 98 coordinates were kept fixed at 0 (the left and right panes are from two different planes in the latent space). In the left we can see a smooth transition between a synthetic ‘9’, to ‘4’, to ‘1’. In the right the synthetic images transform smoothly from ‘3’ to ‘9’, ‘8’ and ‘7’. Interestingly, I found that in all the experiments, the region close to 0 did not transform to a meaningful digit. I believe that the reason for it is that the non-linearity in the generator (the Leaky ReLU activation) makes 0 a natural border between digit blobs in the latent space, which makes points close to 0 borderline points between multiple digit progenitor regions.
The GAN I created for experimenting with MNIST image dataset in the MP-GAN infrastructure has the following structure: The discriminator (Fig. 9) is a CNN (convolutional neural network) with two convolutions separated by dropout layers, ending with a Dense layer with a sigmoid activation. The generator (Fig. 10) starts with a dense layers followed by a reshape — that transforms the latent-dimensional input vector into an image shape. Then deconv layers increase the spatial dimensions until the desired shape is reached and then a final convolution uses the information in the feature dimension to generate the final 1-channel image. I experimented in replacing the dropouts with batchnorms in the discriminator — but that didn’t improve results.
Don’t be alarmed by the discriminator params appearing as non-trainable. This is due to the specific implementation of the training flow in this pipeline: In phase 2 — training the generator, we, in practice, train the entire GAN model, but with the discriminator params frozen.