In Part 1 on GANs, we started to build intuition regarding what GANs are, why we need them, and how the entire point behind training GANs is to create a generator model that knows how to convert a random noise vector into a (beautiful) almost real image. Since we have already discussed the pseudocode in great depth in Part 1, be sure to check that out as there will be a lot of references to it!
In case you would like to follow along, here is the Github Notebook containing the source code for training GANs using the PyTorch framework.
The whole idea behind training a GAN network is to obtain a Generator network (with most optimal model weights and layers, etc.) that is excellent at spewing out fakes that look like real!
Note: I would like to take a moment to truly appreciate Nathan Inkawhich for writing a superb article explaining the inner workings of DCGANs and the official Github repository for Pytorch that helped me with the code implementations, especially the network architectures for both Generator and Discriminator. Hopefully, the explanations I have presented in this article help you gain even further clarity (than already present in the aforementioned blogs) regarding GANs and implement them even better for your own use-case!
If this in-depth educational content is useful for you, you can subscribe to our AI research mailing list to be alerted when we release new material.
Preparing the image dataset
One of the main reasons I started writing this article was because I wanted to try coding GANs on a custom image dataset. Most tutorials I came across were using one of the popular datasets (such as MNIST, CIFAR-10, Celeb-A, etc) that come pre-installed into the framework and ready to be used out-of-the-box. Instead, we will be working with the Human Faces dataset available on Kaggle, containing approximately 7k images — with a wide variety of side/frontal poses, age groups, gender, etc — that were scraped from the web.
After unzipping and loading the image folder (called Humans
) into your current working directory, let’s start writing our code in the Jupyter Notebook:
# path to the image directory dir_data = "Humans" # setting image shape to 32x32 img_shape = (32, 32, 3) # listing out all file names nm_imgs = np.sort(os.listdir(dir_data))
We have downsized our high-def images to a smaller resolution i.e. 32×32 for ease of processing. Once you are through with this tutorial, you are free to re-run the code after changing the img_shape
parameter and a few other things (which we will discuss towards the end of the article).
Next, we will be converting all our images into NumPy arrays and store them collectively into X_train
. Also, it is always a good idea to explicitly convert images into RGB format (just in case some image looks grayscale but in reality, isn’t!).
X_train = [] for file in nm_imgs: try: img = Image.open(dir_data+'/'+file) img = img.convert('RGB') img = img.resize((32,32)) img = np.asarray(img)/255 X_train.append(img) except: print("something went wrong") X_train = np.array(X_train) X_train,shape
************* OUTPUT ***********
(7218, 32, 32, 3)
Beware: there are some filenames that do not contain any image (or must be corrupted) and that’s why I tend to use try-except blocks while coding.
Note: The process of converting images to their respective NumPy format might take a while. Hence, it is a good idea to store X_train
locally as .npy
file for future use. To do so:
from numpy import asarray from numpy import savez_compressed # save to npy file savez_compressed('kaggle_images_32x32.npz', X_train)
To re-load the file at a later time:
# load dict of arrays dict_data = np.load('kaggle_images_32x32.npz') # extract the first array data = dict_data['arr_0'] # print the array print(data)
Importing libraries
import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import torch.optim as optim from torch.nn import Module, Sequential, Conv2d, ConvTranspose2d, LeakyReLU, BatchNorm2d, ReLU, Tanh, Sigmoid, BCELoss %matplotlib inline
Setting up the GPU support
While the code we are working on today will run on both CPU and GPU, it is always advisable to check availability and use the latter, when possible.
# Always good to check if gpu support available or not dev = 'cuda:0' if torch.cuda.is_available() == True else 'cpu' device = torch.device(dev)
Using torch.cuda.is_available()
check if GPU is available and if so, set it as the device using torch.device
function.
Defining helper function
We mainly require one helper function plot_images()
that takes as input a NumPy array of images and displays images in a 5×5 grid.
# plot images in a nxn grid def plot_images(imgs, grid_size = 5): """ imgs: vector containing all the numpy images grid_size: 2x2 or 5x5 grid containing images """ fig = plt.figure(figsize = (8, 8)) columns = rows = grid_size plt.title("Training Images") for i in range(1, columns*rows +1): plt.axis("off") fig.add_subplot(rows, columns, i) plt.imshow(imgs[i]) plt.show()
In order to layout images in the form of a grid network, we add a subplot to our plotting area for each image we want to display. Using fig.add_subplot
, the subplot will take the ith position on a grid with r
rows and c
columns. Finally, the entire grid can be displayed using plt.show()
.
To see if our function works as intended, we can display a few images from our training set:
# load the numpy vector containing image representations imgs = np.load('kaggle_images_32x32.npz') # pls ignore the poor quality of the images as we are working with 32x32 sized images. plot_images(imgs['arr_0'], 3)
Note: Kindly ignore the distorted quality (as compared to the original data we saw on Kaggle) since we are working with 32×32 images instead of superior resolution!
Preparing custom dataset class
I know what you’re thinking — why do I need to create a special class for my dataset? What’s wrong with using my dataset as is?
Well, the simple answer is — that’s just how PyTorch likes it! For a detailed answer, you can read this article here which nicely explains how to use the torch.utils.data.Dataset
class in PyTorch to create a custom Dataset
object for any dataset.
At a very basic level, the Dataset
class you extend for your own dataset should have __init__
,__len__()
and __getitem__
methods.
In case you need further help creating the dataset class, do check out the PyTorch documentation here.
class HumanFacesDataset(Dataset): """Human Faces dataset.""" def __init__(self, npz_imgs) """ Args: npz_imgs (string): npz file with all the images (created in gan.ipynb) """ self.imgs = npz_imgs def __len__(self): return len(self.imgs) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.imgs[idx] return image
Creating Dataloader to load images
Again — why the heck I need Dataloader?
Of the many reasons listed on the documentation page such as customizing data loading order and automatic memory pinning — the Dataloader is essentially useful for creating batches (for both train and test set) to be sent as input to the model.
This is how a DataLoader is defined in PyTorch:
dataloader = DataLoader(dataset = dset, batch_size = batch_size, shuffle = shuffle)
dset
is basically an object of the HumanFacesDataset
class we created earlier. Let’s define it, along with our batch_size
and shuffle
variables.
# Preparing dataloader for training transpose_imgs = np.transpose( # imp step to convert image size from (7312, 32,32,3) to (7312, 3,32,32) np.float32(imgs['arr_0']), # imp step to convert double -> float (by default numpy input uses double as data type) (0, 3,1,2) # tuple to describe how to rearrange the dimensions ) dset = HumanFacesDataset(transpose_imgs) # passing the npz variable to the constructor class batch_size = 32 shuffle = True dataloader = DataLoader(dataset = dset, batch_size = batch_size, shuffle = shuffle)
An important thing to note is that while creating the constructor for the class, we do not simply pass the image array to it. Instead, we pass transpose_imgs
which has some computations performed on our original image array. Mainly, we are going to
- (a) explicitly set the image representations to float using
np.float32()
— this is because by default NumPy input uses double as data type (you can verify this for a single image usingimgs['arr_0'][0].dtype
and the output would befloat64
) and the model we will be creating will have weights asfloat32
; and - (b) we reorder the dimensions of each image from (32 x 32 x 3) to (3 x 32 x 32) using
np.tranpose()
because that’s how the layers in PyTorch models expect them to be.
Finally, we opt for a small batch_size
and set shuffle
as True
to rule out the possibility of any bias at the time of data collection.
Defining the Generator class
As discussed in Part 1, the Generator is a neural network that is trying to produce (hopefully) realistic-looking images. To do so, it takes as input a random noise vector z (say a vector of size 100; where the choice of 100 is arbitrary), passes it through several hidden layers in the network, and finally outputs an RGB image with the same size as the training images i.e. a tensor of shape (3, 32, 32).
Frankly speaking, it took me a while to grasp the precise permutation and combination of layers (and associated parameter values) that should go into my Generator class. To lay it down in simpler terms for you, we will be working with the following layers only:
- ConvTranspose2d is the MVP that is going to help you upsample your random noise to create an image i.e. going from 100x1x1 to 3x32x32.
From the documentation, one can see that it takes the form:
ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias)
- BatchNorm2d layer, as the name suggests, is used for applying batch normalization over the input. Going by the documentation, it takes as input the number of features or
num_features
which can be easily calculated based on the shape of the output from the preceding layer. Mainly, its value is C if the expected input is of size (N, C, H, W). For instance, if the shape of the output from the previous layer isbatch_size, 512, 4, 4)
thennum_features = 512
. - ReLU or the Rectified Linear Unit is the activation function used in the Generator network. In simpler terms, this layer will output the input directly if it is positive, otherwise, it will output zero.
- Tanh is another activation function that is applied at the very end of the Generator network to transform the input into the [-1, 1] range.
Finally, the Generator class with all the aforementioned layers would look something like this. In case you need a beginner’s guide on how to create networks in Pytorch, check out this article here.
At a very basic level, the Module
class you extend for your own network model should have __init__
, and forward
methods.
class Generator(Module): def __init__(self): # calling constructor of parent class super().__init__() self.gen = Sequential( ConvTranspose2d(in_channels = 100, out_channels = 512 , kernel_size = 4, stride = 1, padding = 0, bias = False), # the output from the above will be b_size ,512, 4,4 BatchNorm2d(num_features = 512), # From an input of size (b_size, C, H, W), pick num_features = C ReLU(inplace = True), ConvTranspose2d(in_channels = 512, out_channels = 256 , kernel_size = 4, stride = 2, padding = 1, bias = False), # the output from the above will be b_size ,256, 8,8 BatchNorm2d(num_features = 256), ReLU(inplace = True), ConvTranspose2d(in_channels = 256, out_channels = 128 , kernel_size = 4, stride = 2, padding = 1, bias = False), # the output from the above will be b_size ,128, 16,16 BatchNorm2d(num_features = 128), ReLU(inplace = True), ConvTranspose2d(in_channels = 128, out_channels = 3 , kernel_size = 4, stride = 2, padding = 1, bias = False), # the output from the above will be b_size ,3, 32,32 Tanh() ) def forward(self, input): return self.gen(input)
Note: It is very important to understand what is going on under-the-hood of this network because in case you want to work with images, say of size 64×64 or 128×128, this architecture (mainly the associated parameters) must be updated!
To begin with, we are creating a Sequentialmodelwhich is a linear stack of layers. That is, the output from each layer acts as the input for the next layer.
In the first convolution, we begin with a ConvTranspose2d
layer that takes 100 input_channels
. Why 100? You might ask. This is because the input to the Generator network is going to be something like batch_size, 100, 1, 1
, which according to PyTorch roughly translates to a 1×1 image with 100 channels. Consequently, these many channels will go into the ConvTranspose2d
layer and so, in_channels = 100
.
The logic behind setting out_channels
as 512 is completely arbitrary, something I picked up from the tutorials/blogs I mentioned in the beginning. The idea is to pick a large number for out_channels
in the beginning and subsequently, reduce it (by a factor of 2) for each ConvTranspose2d
layer, until you reach the very last layer where you can set out_channels = 3
, which is the precise number of channels we require to generate an RGB image of size 32×32.
Now, the output from this layer will have a spatial size of b_size, out_channels, height, width)
, where height and width can be calculated according to the formula given on the documentation page.
Plugging in the respective values in the formula above we get:
H_out = (1–1) * 1 – 2 * 0 + 1 * (4 –1) + 0 + 1 ; and
W_out = (1–1) * 1 – 2 * 0 +1 * (4 –1) + 0 + 1
Or,
H_out = 4
W_out = 4
And that is what you see as the spatial size (written as comments in the code above):
ConvTranspose2d(in_channels = 100, out_channels = 512, kernel_size = 4, stride = 1, padding = 0, bias = False), # the output from the above will be b_size ,512, 4,4
Now if you are not a mathematics wizard or feeling a little lazy to do the above calculations, you can even check the output from a layer by simply creating a dummy Generator network and passing it any random input.
For instance, we will be creating a dummy network with just the first convolutional layer:
class ExampleGenerator(Module): def __init__(self): # calling constructor of parent class super().__init__() self.gen = Sequential( ConvTranspose2d( in_channels = 100, out_channels = 512 , kernel_size = 4, stride = 1, padding = 0, bias = False) ) def forward(self, input): return self.gen(input) # defining class object ex_gen = ExampleGenerator() # defining random input for the model: b_size = 2 here t = torch.randn(2, 100, 1, 1) # checking the shape of the output from model ex_gen(t).shape
************* OUTPUT ***********
(2, 512 , 4 , 4)
The next layer we come across is the BatchNorm2d
. Now if you have been following the tutorial carefully, it should be amply clear why num_features
is set to 512. To recap, it’s because the output from the previous layer has a spatial size of (b_size, 512, 4 , 4)
.
Finally, we conclude the first (of the four) convolution with a ReLU activation.
The rest of the three convolutions follow the same pattern, more or less. I strongly encourage you to test out the calculations by hand to see how the spatial size of the input changes when it passes through a layer. This will help you set the values correctly for in_channels
, out_channels
, stride
, kernel
, etc in your Generator and Discriminator network when you are working with a differently-sized image dataset (i.e. something other than 32x32x3).
An important thing to note here is that BatchNorm2d
, ReLU
and Tanh
layers do not alter the spatial size of the input and that is why the second ConvTranspose2d
layer in the network begins with in_channels = 512
.
Defining the Discriminator class
As discussed in Part 1, Discriminator is essentially a binary classification network that takes as input an image and returns a scalar probability that the output is real (as opposed to fake).
The main layers involved in a Discriminator network are as follows:
- Conv2d : as opposed to
ConvTranspose2d
layer which helps in upscaling an image, aConv2d
layer helps in downscaling an image, i.e. reducing an image of size 32×32 to 16×16 to 8×8 and so on.. continuing all the way until we are left with 1×1, i.e. a scalar value. - LeakyReLU: A major advantage of using
LeakyReLU
overReLU
layer is that it solves the vanishing gradient problem. In simpler terms, when the input is negative, aReLU
layer will output a 0 whereasLeakyReLU
will output a non-zero value. Consequently,LeakyReLU
will contribute towards a small gradient update (instead of a zero gradient), when input is negative and so the model can continue to learn and be updated.
- Sigmoid: This is another activation layer through which we pass our inputs to transform our data in the [0,1] range.
Finally, this is how the discriminator class looks like:
# Defining the Discriminator class class Discriminator(Module): def __init__(self): super().__init__() self.dis = Sequential( # input is (3, 32, 32) Conv2d(in_channels = 3, out_channels = 32, kernel_size = 4, stride = 2, padding = 1, bias=False), # ouput from above layer is b_size, 32, 16, 16 LeakyReLU(0.2, inplace=True), Conv2d(in_channels = 32, out_channels = 32*2, kernel_size = 4, stride = 2, padding = 1, bias=False), # ouput from above layer is b_size, 32*2, 8, 8 BatchNorm2d(32 * 2), LeakyReLU(0.2, inplace=True), Conv2d(in_channels = 32*2, out_channels = 32*4, kernel_size = 4, stride = 2, padding = 1, bias=False), # ouput from above layer is b_size, 32*4, 4, 4 BatchNorm2d(32 * 4), LeakyReLU(0.2, inplace=True), Conv2d(in_channels = 32*4, out_channels = 32*8, kernel_size = 4, stride = 2, padding = 1, bias=False), # ouput from above layer is b_size, 256, 2, 2 # NOTE: spatial size of this layer is 2x2, hence in the final layer, the kernel size must be 2 instead (or smaller than) 4 BatchNorm2d(32 * 8), LeakyReLU(0.2, inplace=True), Conv2d(in_channels = 32*8, out_channels = 1, kernel_size = 2, stride = 2, padding = 0, bias=False), # ouput from above layer is b_size, 1, 1, 1 Sigmoid() ) def forward(self, input): return self.dis(input)
It might look very similar to the Generator network and in some sense it is. That is, we are again working with a Sequential network that contains four strided convolutions. However, we are now dealing with Conv2d
layers instead of ConvTranspose2d
layers. Within each of them, we set out_channels
to take up a small value initially and we gradually increase it by a factor of 2 until we reach our desired 1×1 image (i.e. H_out = W_out = 1), at which time we set out_channels = 1
.
An important thing to note here is that the shape of the input going into the last convolution layer is (b_size, 256, 2, 2)
. Because the kernel_size
must always be smaller than the spatial size of the input (in this case 2×2), we must set kernel_size = 2
for the last layer (as opposed to kernel_size = 4
in previous layers). Failure to do so will result in a Runtime error!
You might be wondering what happened to the scalar value that was promised as the output from the Discriminator whereas what we have here is a Tensor of shape b_size, 1, 1, 1
(i.e. output from the final layer). The good news is that we can easily convert this to a single vector containing b_size
values in it using view(-1)
. For instance, t.view(-1)
reshapes a 4-d tensor t
of shape (2,1,1,1)
to a 1-d tensor with 2 values only. We will be seeing its usage in action in the later sections!
Now that we have defined the classes for both the network, we can initialize an object for each of them.
# creating gen and disc netG = Generator().to(device) netD = Discriminator().to(device)
Initializing weights and biases
It is essential to initialize a neural network with random weights rather than letting them all be 0. That’s because all neurons with the same initial weight will learn the same features during training i.e. during subsequent iterations weights will be the same. In short, no improvements to the model whatsoever!
Based on the several blog posts I came across, the weights for the ConvTranspose2d
layer will be randomly initialized from a Normal distribution with mean=0, standard deviation=0.02. For the BatchNorm2d
layer, the mean and standard deviation of the distribution will be 1 and 0.02, respectively. This applies to both Generator and Discriminator networks.
To simultaneously initialize all the different layers in a network, we need to:
(a) define a function that takes as input a model
def init_weights(m): if type(m) == ConvTranspose2d: nn.init.normal_(m.weight, 0.0, 0.02) elif type(m) == BatchNorm2d: nn.init.normal_(m.weight, 1.0, 0.02) nn.init.constant_(m.bias, 0)
(b) then use .apply()
to recursively initialize all layers
# initializing the weights netD.apply(init_weights) netG.apply(init_weights)
Setting up the optimizers for Generator and Discriminator
Optimizers are useful for performing parameter updates in a network using the optimizer.step
method.
# Setting up otimizers for both Generator and Discriminator opt_D = optim.Adam(netD.parameters(), lr = 0.0002, betas= (0.5, 0.999)) opt_G = optim.Adam(netG.parameters(), lr = 0.0002, betas= (0.5, 0.999))
Setting up the Loss function
In order to check how far the predicted label for an image is from the real label, we will be using BCELoss
.
# Setting up the loss function - BCELoss (to check how far the predicted value is from real value) loss = BCELoss()
Training GANs
Pseudocode
In Part 1, we discussed the main steps involved in training a GAN. To refresh our memory, here is the pseudocode (generated using the open-source code made available by PyTorch):
for each epoch: for each batch b of input images: ###################################### ## Part 1: Update Discriminator - D ## ###################################### # loss on real images clear gradients of D pred_labels_real = pass b through D to compute outputs true_labels_real = [1,1,1....1] calculate loss(pred_labels_real, true_labels_real) calculate gradients using this loss # loss on fake images generate batch of size b of fake images (b_fake) using G pred_labels_fake = pass b_fake through D true_labels_fake = [0,0,....0] calculate loss(pred_labels_fake, true_labels_fake) calculate gradients using this loss update weights of D ###################################### #### Part 2: Update Generator - G #### ###################################### clear gradients of G pred_labels = pass b_fake through D true_labels = [1,1,....1] calculate loss(pred_labels, true_labels) calculate gradient using this loss update weights of G ################################################ ### Part 3: Plot a batch of Generator images ### ################################################
Now keeping this in mind, let’s start building our training function step-by-step. The coding will be divided into three parts — Part 1 dedicated to updating the discriminator, Part 2 for updating the generator, and (an optional) Part 3 for plotting a batch of generator images using the helper function we defined at the beginning of the article.
Part 1: Updating the discriminator
The process includes calculating loss on real and fake images.
Code for calculating loss on real images
# Loss on real images # clear the gradient opt_D.zero_grad() # set the gradients to 0 at start of each loop because gradients are accumulated on subsequent backward passes # compute the D model output yhat = netD(b.to(device)).view(-1) # view(-1) reshapes a 4-d tensor of shape (2,1,1,1) to 1-d tensor with 2 values only # specify target labels or true labels target = torch.ones(len(b), dtype=torch.float, device=device) # calculate loss loss_real = loss(yhat, target) # calculate gradients - or rather accumulation of gradients on loss tensor loss_real.backward()
We begin by clearing the gradients for the discriminator using zero_grad()
. It is necessary to set the gradients to 0 at the start of each loop because the gradients are accumulated on subsequent backward passes (i.e. when loss.backward()
is called). Next, we store the output from the discriminator model when it is fed a batch b
of real images (i.e. images from our training set). Remember, the shape of b
is (32, 3, 32, 32).
It is important to note that rather than simply passing the images to the network as netD(b)
, we use b.to(device)
first on the batch. This is because we must put the image tensor on the same device as the model. While it may not matter much if you are running your code on a CPU, not doing so may throw a runtime error on a GPU.
Finally, as previously stated, we use view(-1)
on the output from the model to reshape the 4-d tensor to a 1-d tensor that contains the likelihood of the image is real.
We define the true labels for the real images or targets
as a tensor of size b
containing all 1s. We explicitly set this to float32
so that it matches the type of images in the batch b
. Finally, we ensure the target labels are also on the same device as the model.
Next, BCELoss
is calculated with the predicted values and target labels and gradients are calculated and accumulated using backward()
.
Code for calculating loss on fake images
# Loss on fake images # generate batch of fake images using G # Step1: creating noise to be fed as input to G noise = torch.randn(len(b), 100, 1, 1, device = device) # Step 2: feed noise to G to create a fake img (this will be reused when updating G) fake_img = netG(noise) # compute D model output on fake images yhat = netD.cuda()(fake_img.detach()).view(-1) # .cuda() is essential because our input i.e. fake_img is on gpu but model isnt (runtimeError thrown); detach is imp: Basically, only track steps on your generator optimizer when training the generator, NOT the discriminator. # specify target labels target = torch.zeros(len(b), dtype=torch.float, device=device) # calculate loss loss_fake = loss(yhat, target) # calculate gradients loss_fake.backward() # total error on D loss_disc = loss_real + loss_fake # Update weights of D opt_D.step()
To generate a batch of fake images, we first need a batch of random noise vectors, noise
, which is fed to the Generator to create fake_img
. Next, we calculate the output from the Discriminator on these fake images and store it in yhat
. cuda()
is essential in case our input i.e. fake_img
is on GPU but the model is not, in which case a runtime error is thrown.
An important thing to note is that we used detach()
on the batch of fake images. The reason for doing so is that while we want to use the services of the Generator, but we do not want to update it just yet (we will do once we are done updating the Discriminator).
Why use detach()
? Basically, we must track steps on our generator optimiser only when training the generator, NOT the discriminator.
Based on the explanation in Part 1, the target labels in this case would be a tensor of length b
containing all zeros. The remaining steps remain the same as in the previous code snippet.
Part 2: Updating the Generator
The steps are roughly the same as in the case of Discriminator. The main difference is that now target labels are set to ones (instead of zeros), even though they are fake images. A detailed explanation of why we are doing so has been provided in Part 1. In short:
the Generator wants the Discriminator to think it is churning out real images, and so it uses the true labels as 1 during training.
########################## #### Update Generator #### ########################## # clear gradient opt_G.zero_grad() # pass fake image through D yhat = netD.cuda()(fake_img).view(-1) # specify target variables - remember G wants D *to think* these are real images so label is 1 target = torch.ones(len(b), dtype=torch.float, device=device) # calculate loss loss_gen = loss(yhat, target) # calculate gradients loss_gen.backward() # update weights on G opt_G.step()
Part 3: Plotting a batch of images from Generator
In order to see how well our Generator is doing with each passing epoch, we will be plotting a bunch of images every 10th iteration using the helper function plot_images()
.
#################################### #### Plot some Generator images #### #################################### # during every epoch, print images at every 10th iteration. if i% 10 == 0: # convert the fake images from (b_size, 3, 32, 32) to (b_size, 32, 32, 3) for plotting img_plot = np.transpose(fake_img.detach().cpu(), (0,2,3,1)) # .detach().cpu() is imp for copying fake_img tensor to host memory first plot_images(img_plot) print("********************") print(" Epoch %d and iteration %d " % (e, i))
Now you may notice that the dimensions in fake_img
are reordered using np.transpose()
before being passed to the plotting function plot_images()
. This is because, plt.imshow()
method (used in plot_images()
) requires the images passed to it to be in the form (height, width, channels)
However, the shape of the images outputted by the Generator takes the form (channels, height, width)
which is standard in PyTorch. To fix this, we must transpose the dimensions of the fake images such that we have images like b_size, 32, 32, 3)
.
Another thing to keep in mind is that calling .detach().cpu()
is important for copying fake_img
tensor to host memory first before we can begin to pass it to the plotting function.
This is how the final block of code for training a GAN — including Part 1, Part 2, and Part 3 — looks like:
# TRAINING GANS epochs = 1000 # going over the entire dataset 10 times for e in range(epochs): # pick each batch b of input images: shape of each batch is (32, 3, 32, 32) for i, b in enumerate(dataloader): ########################## ## Update Discriminator ## ########################## # Loss on real images # clear the gradient opt_D.zero_grad() # set the gradients to 0 at start of each loop because gradients are accumulated on subsequent backward passes # compute the D model output yhat = netD(b.to(device)).view(-1) # view(-1) reshapes a 4-d tensor of shape (2,1,1,1) to 1-d tensor with 2 values only # specify target labels or true labels target = torch.ones(len(b), dtype=torch.float, device=device) # calculate loss loss_real = loss(yhat, target) # calculate gradients - or rather accumulation of gradients on loss tensor loss_real.backward() # Loss on fake images # generate batch of fake images using G # Step1: creating noise to be fed as input to G noise = torch.randn(len(b), 100, 1, 1, device = device) # Step 2: feed noise to G to create a fake img (this will be reused when updating G) fake_img = netG(noise) # compute D model output on fake images yhat = netD.cuda()(fake_img.detach()).view(-1) # .cuda() is essential because our input i.e. fake_img is on gpu but model isnt (runtimeError thrown); detach is imp: Basically, only track steps on your generator optimizer when training the generator, NOT the discriminator. # specify target labels target = torch.zeros(len(b), dtype=torch.float, device=device) # calculate loss loss_fake = loss(yhat, target) # calculate gradients loss_fake.backward() # total error on D loss_disc = loss_real + loss_fake # Update weights of D opt_D.step() ########################## #### Update Generator #### ########################## # clear gradient opt_G.zero_grad() # pass fake image through D yhat = netD.cuda()(fake_img).view(-1) # specify target variables - remember G wants D *to think* these are real images so label is 1 target = torch.ones(len(b), dtype=torch.float, device=device) # calculate loss loss_gen = loss(yhat, target) # calculate gradients loss_gen.backward() # update weights on G opt_G.step() #################################### #### Plot some Generator images #### #################################### # during every epoch, print images at every 10th iteration. if i% 10 == 0: # convert the fake images from (b_size, 3, 32, 32) to (b_size, 32, 32, 3) for plotting img_plot = np.transpose(fake_img.detach().cpu(), (0,2,3,1)) # .detach().cpu() is imp for copying fake_img tensor to host memory first plot_images(img_plot) print("********************") print(" Epoch %d and iteration %d " % (e, i))
And there we have it, we have implemented a vanilla GAN from scratch using our custom image dataset! Wohoooo…
To give you a rough estimate of the quality of the images generated by our GAN:
Epoch 0 (iteration 160th): Nice to see that the generator is picking up the fact that faces exist in the center of the image.
Common runtime errors
While the code I have shared with you in the Github Notebook is error-free, I would like to take a moment and discuss a few runtime errors that I encountered while learning to train GANs from scratch.
- Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
Here, weight type usually refers to the weights in your model which we explicitly set to type float32
if you recall. The reason you might be seeing this error is that you are feeding something to your model that is probably float64
instead of float32
, i.e. a type mismatch problem.
In my case, I came across this error when I tried to pass a batch of images via dataloader to the Discriminator model without first explicitly converting them to float using np.float32
.
- Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
The error itself is self-explanatory if you carefully observe the input type (i.e. torch.FloatTensor) and weight type (i.e. torch.cuda.FloatTensor) — only one of them contains the word ‘cuda’. What it means is that your model is on the GPU whereas the input data is still on the CPU. To rectify this error, simply send your input data tensor on the GPU using .to(device)
.
I encountered this error during Part 1 of GAN training when I was computing the model outputs for a batch of input images using yhat = netD(b).view(-1)
(for calculating discriminator loss on real images). The fix is simple: yhat = netD(b.to(device)).view(-1)
.
Congrats on making it this far. Hopefully, this tutorial (along with Part 1) was a warm intro to a super useful yet super complex deep learning concept that GANs are known to be.
Until next time 🙂
Podurama is the best podcast player to stream more than a million shows and 30 million episodes. It provides the best recommendations based on your interests and listening history. Available for iOS Android MacOS Windows 10 and web. Early users get free lifetime sync between unlimited devices.
This article was originally published on Towards AI and re-published to TOPBOTS with permission from the author.
Enjoy this article? Sign up for more computer vision updates.
We’ll let you know when we release more technical education.
Kahraman Berk Kahraman says
Hi Varshita,
Unfortunately I could not update the data in 64×64 format. Can you help me?
Shooter says
Hi,
Great article! May I know how to save the output? Thanks in advance.
Eduard F says
To make the code work without a GPU you need to change cuda to cpu.
Otherwise a nice tutorial, thanks!
Dan says
Hi Varshita, a really nice tutorial, I want to generate 254×254, have you tried some steady architecture for G and D? Appreciate it if you can give some information.
Mel Dorn says
You got my attention. Learning Image Data and its role very interesting for me. As a photographer, I am constantly looking for information that can help grow my professionalism. I also couldn’t ignore this article offers valuable insights into perfecting nature photography, enhancing my appreciation for this art form.