BookmarkSubscribeRSS Feed

Generative Adversarial Networks (GANs): A Brief Introduction

Started ‎11-28-2023 by
Modified ‎11-28-2023 by
Views 1,104

Generative adversarial networks (GANs) are one of the recent machine learning algorithms that has brought about a revolutionary storm in the AI world. In this post we will discuss the adversarial learning and the components of a GAN model. Before delving into GANs, lets first understand what are generative models and how they are different from discriminative models? A generative model is a type of machine learning model that aims at learning the true data distribution of training data to generate new, similar data. But it is not always possible to learn the exact distribution of your data and therefore you try to model a distribution which is as similar as possible to the true data distribution.

 

Generative versus Discriminative Models

 

In machine learning, most problems you come across are discriminative in nature. The distinction between generative and discriminative models is fundamental in machine learning. Let's consider an example to understand the difference. Suppose you have a dataset of images of airplanes and motor vehicles. With sufficient data, you could train a model (discriminative model) to predict if a given image was an airplane or not? During the model training process your model would learn the features specific to airplanes and for images with those features, the model would upweight its predictions accordingly. Also, the discriminative models require a labeled training data to train a model. Like in our example, all airplane images would be labeled as 1 and non-airplane images as 0. So, the model is trained to be able to discriminate between these two groups of images and outputs the probability that a new observation has label 1, i.e. an airplane. In other words, discriminative modeling aims to model the probability of a label y given some observations x. On the other hand, generative models don’t require a labeled dataset because the aim is to generate entirely new images, rather than predicting the label. Putting it simply, generative model aims to model the probability of observing an observation x. Sampling from this distribution allows us to generate new observations.

 

Generative Adversarial Networks

 

Generative adversarial networks, or GANs, are a deep-learning-based generative model that is used to generate new data. It involves two adversarial neural networks that compete with each other to generate new observations that are indistinguishable from real data. GAN models consist of several deep neural networks that process data. They work great for synthetic data generation, particularly for image data. You might be wondering, "How can networks be adversarial?" Well, the two networks, called generator and discriminator, compete with each other to win, hence the name adversarial. The generator learns to create fake samples of data, which can be image, audio, text, or simply numerical data. The generator tries to fool the discriminator by producing novel synthesized instances that ideally look like the real data. The discriminator evaluates the generated data and tries to discriminate whether the data are real or fake.  The key to GANs lies in how we alternate the training of the two networks. As the generator becomes more adept at fooling the discriminator, the discriminator must adapt in order to maintain its ability to correctly identify which observations are fake. This drives the generator to find new ways to fool the discriminator, and so the cycle continues. Through multiple cycles of generation and discrimination, both networks train each other, while simultaneously trying to deceive each other.

 

Components of GANs

 

Generator Network: The generator network tries to learn the data distribution by using random noise as inputs and producing instances that look like real data. The main goal of the generator network is to maximize the likelihood that the discriminator misclassifies its output as real.

 

01_MS_Generator-4-300x135.png

 

Generator Training Process

 

Select any image to see a larger version.
Mobile users: To view the images, select the "Full" version at the bottom of the page.

 

In the generator network training process, the fake sample produced by the generator is trained on the discriminator. The discriminator network classifies the generated data as real or fake and produces generator loss. This generator loss penalizes the generator for failing to dupe the discriminator. Remember, the main goal of the generator is to fool the discriminator into classifying its output as real. The back propagation method is used to adjust each weight by calculating the weight's impact on the output.

 

Discriminator Network: The discriminator network tries to differentiate between the fake data produced by the generator network and real data. Thus, the discriminator network is simply a classifier that could use any network architecture appropriate to the type of data that it's classifying.

 

The training data feeding into the discriminator network comes from two sources: • the real data instances • the fake data instances, which were generated by the generator network The discriminator network generates predictions for how likely the instances are to be real or fake. So, the main goal of discriminator network is to accurately distinguish between real and fake data. 

 

 

02_MS_Discriminator-300x160.png

 

Discriminator Training Process

 

In the process of training the discriminator network, it classifies both the real data and the fake data from the generator. Notice that the discriminator connects to two loss functions. However, during discriminator training, the discriminator ignores the generator loss and simply uses the discriminator loss. The discriminator loss penalizes the discriminator for misclassifying a real data instance as fake or a fake data instance as real. .The discriminator updates its weights through back propagation from the discriminator loss through the discriminator network.

 

GAN Training

 

We have seen that the generator and discriminator have different training processes. So, you must be wondering, "How are GANs trained as a whole?" Well, the two networks are trained in alternating fashion. Note that the discriminator needs to train for a few epochs prior to starting the adversarial training, as the discriminator will need to be able to actually classify the data (images) as real or fake.

 

While alternating the training of these two networks, we also must make sure that we update the weights of only one network at a time. For example, during the generator training process, only the generator’s weights are updated. Similarly, we keep the generator's weights constant during the discriminator training phase and update the discriminator’s weights only during this phase. As discriminator training tries to figure out how to distinguish real data from fake data, it has to learn how to recognize the generator's flaws. This is a different problem for a thoroughly trained generator than it is for an untrained generator producing random output. This back-and-forth training enables GANs to tackle otherwise intractable generative problems.

 

Objective Function

 

Recall that the generator network generates data that are similar to the real data, and to measure the similarity, we use objective functions. Both networks have their own objective functions, which they try to optimize during training. We train D to maximize the probability of assigning the correct label to both training instances and samples from the generator. We simultaneously train G to minimize the discriminator's reward. In other words, D and G play the two-player minimax game with the final objective function:

 

03_MS_OF.png

Objective Function

 

In this equation, Ex represents the expected value over all real data instances. Ez represents the expected value over all generated fake instances G(z). Px represents the real data distribution. Pz represents the distribution of data generated by the generator. D(x) represents the discriminator's output (that is, the probability that x came from the training sample or from the real data), and G(z) represents the generator's output given noise z.

 

During training, the discriminator wants to maximize the objective function, whereas the generator wants to minimize it. In this way, the generator and the discriminator repeatedly learn to work together, and eventually the GAN can reach a Nash equilibrium when the following conditions are met:

 

  • The discriminator is unable to distinguish between real and fake instances.
  • The generator can produce fake samples that are indistinguishable from the real data in the training data set.

 

Advantages of GANs

 

GANs offer quite a few advantages-

 

  • GANs can be used to generate fake data with or without labels. If you have labeled data, you can use GANs to generate more synthetic labeled data, but if you have unlabeled data, you can use GANs to generate more synthetic unlabeled data.
  • GANs generate data that are similar to real data. They can generate images, text, audio, video, and tabular data.
  • GANs learn the internal representations of data. They can learn messy and complicated distributions of data.
  • In adversarial networks, the generator network is not updated directly with data examples, but only with gradients flowing through the discriminator. This means that components of the input are not copied directly into the generator’s parameters.

 

Common Challenges

 

There are also a number of challenges-

 

  • Mode collapse- Usually, you expect your GAN to produce a wide variety of outputs given the input. However, if the discriminator is not powerful enough, the generator will find ways to easily trick the discriminator with a small sample of nearly identical images. This form of GAN failure is known as mode collapse.
  • Vanishing gradients- During back propagation, gradient flows backward, from the final layer to the first layer. As it flows backward, it gets increasingly smaller. Sometimes, the gradient is so small that the initial layers learn very slowly or stop learning completely. In this case, the gradient doesn't change the weight values of the initial layers at all, so training the initial layers in the network is effectively stopped. This is known as the vanishing gradients problem.
  • There is no explicit representation of the generator's distribution over data. Therefore, random noise is sampled.
  • GANs can require a lot of computational resources and can be slow to train, especially for high-resolution images or large data sets.
  • GANs can reflect the biases and unfairness present in the training data, leading to discriminatory or biased synthetic data.
  • GANs can be difficult to interpret or explain, making it challenging to ensure transparency, accountability, and fairness in their applications.

Read more in the post from Jason Colón , Generating Synthetic Data Using Generative Adversarial Networks

 

 

Find more articles from SAS Global Enablement and Learning here.

Version history
Last update:
‎11-28-2023 10:21 AM
Updated by:

sas-innovate-2024.png

Available on demand!

Missed SAS Innovate Las Vegas? Watch all the action for free! View the keynotes, general sessions and 22 breakouts on demand.

 

Register now!

Free course: Data Literacy Essentials

Data Literacy is for all, even absolute beginners. Jump on board with this free e-learning  and boost your career prospects.

Get Started

Article Tags