Generative adversarial networks (GANs) are one of the recent machine learning algorithms that has brought about a revolutionary storm in the AI world. In this post we will discuss the adversarial learning and the components of a GAN model. Before delving into GANs, lets first understand what are generative models and how they are different from discriminative models? A generative model is a type of machine learning model that aims at learning the true data distribution of training data to generate new, similar data. But it is not always possible to learn the exact distribution of your data and therefore you try to model a distribution which is as similar as possible to the true data distribution.
In machine learning, most problems you come across are discriminative in nature. The distinction between generative and discriminative models is fundamental in machine learning. Let's consider an example to understand the difference. Suppose you have a dataset of images of airplanes and motor vehicles. With sufficient data, you could train a model (discriminative model) to predict if a given image was an airplane or not? During the model training process your model would learn the features specific to airplanes and for images with those features, the model would upweight its predictions accordingly. Also, the discriminative models require a labeled training data to train a model. Like in our example, all airplane images would be labeled as 1 and non-airplane images as 0. So, the model is trained to be able to discriminate between these two groups of images and outputs the probability that a new observation has label 1, i.e. an airplane. In other words, discriminative modeling aims to model the probability of a label y given some observations x. On the other hand, generative models don’t require a labeled dataset because the aim is to generate entirely new images, rather than predicting the label. Putting it simply, generative model aims to model the probability of observing an observation x. Sampling from this distribution allows us to generate new observations.
Generative adversarial networks, or GANs, are a deep-learning-based generative model that is used to generate new data. It involves two adversarial neural networks that compete with each other to generate new observations that are indistinguishable from real data. GAN models consist of several deep neural networks that process data. They work great for synthetic data generation, particularly for image data. You might be wondering, "How can networks be adversarial?" Well, the two networks, called generator and discriminator, compete with each other to win, hence the name adversarial. The generator learns to create fake samples of data, which can be image, audio, text, or simply numerical data. The generator tries to fool the discriminator by producing novel synthesized instances that ideally look like the real data. The discriminator evaluates the generated data and tries to discriminate whether the data are real or fake. The key to GANs lies in how we alternate the training of the two networks. As the generator becomes more adept at fooling the discriminator, the discriminator must adapt in order to maintain its ability to correctly identify which observations are fake. This drives the generator to find new ways to fool the discriminator, and so the cycle continues. Through multiple cycles of generation and discrimination, both networks train each other, while simultaneously trying to deceive each other.
Generator Network: The generator network tries to learn the data distribution by using random noise as inputs and producing instances that look like real data. The main goal of the generator network is to maximize the likelihood that the discriminator misclassifies its output as real.
Generator Training Process
Select any image to see a larger version.
Mobile users: To view the images, select the "Full" version at the bottom of the page.
In the generator network training process, the fake sample produced by the generator is trained on the discriminator. The discriminator network classifies the generated data as real or fake and produces generator loss. This generator loss penalizes the generator for failing to dupe the discriminator. Remember, the main goal of the generator is to fool the discriminator into classifying its output as real. The back propagation method is used to adjust each weight by calculating the weight's impact on the output.
Discriminator Network: The discriminator network tries to differentiate between the fake data produced by the generator network and real data. Thus, the discriminator network is simply a classifier that could use any network architecture appropriate to the type of data that it's classifying.
The training data feeding into the discriminator network comes from two sources: • the real data instances • the fake data instances, which were generated by the generator network The discriminator network generates predictions for how likely the instances are to be real or fake. So, the main goal of discriminator network is to accurately distinguish between real and fake data.
Discriminator Training Process
In the process of training the discriminator network, it classifies both the real data and the fake data from the generator. Notice that the discriminator connects to two loss functions. However, during discriminator training, the discriminator ignores the generator loss and simply uses the discriminator loss. The discriminator loss penalizes the discriminator for misclassifying a real data instance as fake or a fake data instance as real. .The discriminator updates its weights through back propagation from the discriminator loss through the discriminator network.
We have seen that the generator and discriminator have different training processes. So, you must be wondering, "How are GANs trained as a whole?" Well, the two networks are trained in alternating fashion. Note that the discriminator needs to train for a few epochs prior to starting the adversarial training, as the discriminator will need to be able to actually classify the data (images) as real or fake.
While alternating the training of these two networks, we also must make sure that we update the weights of only one network at a time. For example, during the generator training process, only the generator’s weights are updated. Similarly, we keep the generator's weights constant during the discriminator training phase and update the discriminator’s weights only during this phase. As discriminator training tries to figure out how to distinguish real data from fake data, it has to learn how to recognize the generator's flaws. This is a different problem for a thoroughly trained generator than it is for an untrained generator producing random output. This back-and-forth training enables GANs to tackle otherwise intractable generative problems.
Recall that the generator network generates data that are similar to the real data, and to measure the similarity, we use objective functions. Both networks have their own objective functions, which they try to optimize during training. We train D to maximize the probability of assigning the correct label to both training instances and samples from the generator. We simultaneously train G to minimize the discriminator's reward. In other words, D and G play the two-player minimax game with the final objective function:
Objective Function
In this equation, Ex represents the expected value over all real data instances. Ez represents the expected value over all generated fake instances G(z). Px represents the real data distribution. Pz represents the distribution of data generated by the generator. D(x) represents the discriminator's output (that is, the probability that x came from the training sample or from the real data), and G(z) represents the generator's output given noise z.
During training, the discriminator wants to maximize the objective function, whereas the generator wants to minimize it. In this way, the generator and the discriminator repeatedly learn to work together, and eventually the GAN can reach a Nash equilibrium when the following conditions are met:
GANs offer quite a few advantages-
There are also a number of challenges-
Read more in the post from Jason Colón , Generating Synthetic Data Using Generative Adversarial Networks
Find more articles from SAS Global Enablement and Learning here.
Join us for SAS Innovate 2025, our biggest and most exciting global event of the year, in Orlando, FL, from May 6-9. Sign up by March 14 for just $795.
Data Literacy is for all, even absolute beginners. Jump on board with this free e-learning and boost your career prospects.