Rate this Page

Note

Go to the endto download the full example code.

DCGAN Tutorial#

Created On: Jul 31, 2018 | Last Updated: Jan 19, 2024 | Last Verified: Nov 05, 2024

Author:Nathan Inkawhich

Introduction#

This tutorial will give an introduction to DCGANs through an example. Wewill train a generative adversarial network (GAN) to generate newcelebrities after showing it pictures of many real celebrities. Most ofthe code here is from the DCGAN implementation inpytorch/examples, and thisdocument will give a thorough explanation of the implementation and shedlight on how and why this model works. But don’t worry, no priorknowledge of GANs is required, but it may require a first-timer to spendsome time reasoning about what is actually happening under the hood.Also, for the sake of time it will help to have a GPU, or two. Letsstart from the beginning.

Generative Adversarial Networks#

What is a GAN?#

GANs are a framework for teaching a deep learning model to capture the trainingdata distribution so we can generate new data from that samedistribution. GANs were invented by Ian Goodfellow in 2014 and firstdescribed in the paperGenerative AdversarialNets.They are made of two distinct models, agenerator and adiscriminator. The job of the generator is to spawn ‘fake’ images thatlook like the training images. The job of the discriminator is to lookat an image and output whether or not it is a real training image or afake image from the generator. During training, the generator isconstantly trying to outsmart the discriminator by generating better andbetter fakes, while the discriminator is working to become a betterdetective and correctly classify the real and fake images. Theequilibrium of this game is when the generator is generating perfectfakes that look as if they came directly from the training data, and thediscriminator is left to always guess at 50% confidence that thegenerator output is real or fake.

Now, lets define some notation to be used throughout tutorial startingwith the discriminator. Let\(x\) be data representing an image.\(D(x)\) is the discriminator network which outputs the (scalar)probability that\(x\) came from training data rather than thegenerator. Here, since we are dealing with images, the input to\(D(x)\) is an image of CHW size 3x64x64. Intuitively,\(D(x)\)should be HIGH when\(x\) comes from training data and LOW when\(x\) comes from the generator.\(D(x)\) can also be thought ofas a traditional binary classifier.

For the generator’s notation, let\(z\) be a latent space vectorsampled from a standard normal distribution.\(G(z)\) represents thegenerator function which maps the latent vector\(z\) to data-space.The goal of\(G\) is to estimate the distribution that the trainingdata comes from (\(p_{data}\)) so it can generate fake samples fromthat estimated distribution (\(p_g\)).

So,\(D(G(z))\) is the probability (scalar) that the output of thegenerator\(G\) is a real image. As described inGoodfellow’spaper,\(D\) and\(G\) play a minimax game in which\(D\) tries tomaximize the probability it correctly classifies reals and fakes(\(logD(x)\)), and\(G\) tries to minimize the probability that\(D\) will predict its outputs are fake (\(log(1-D(G(z)))\)).From the paper, the GAN loss function is

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big]\]

In theory, the solution to this minimax game is where\(p_g = p_{data}\), and the discriminator guesses randomly if theinputs are real or fake. However, the convergence theory of GANs isstill being actively researched and in reality models do not alwaystrain to this point.

What is a DCGAN?#

A DCGAN is a direct extension of the GAN described above, except that itexplicitly uses convolutional and convolutional-transpose layers in thediscriminator and generator, respectively. It was first described byRadford et. al. in the paperUnsupervised Representation Learning WithDeep Convolutional Generative AdversarialNetworks. The discriminatoris made up of stridedconvolutionlayers,batchnormlayers, andLeakyReLUactivations. The input is a 3x64x64 input image and the output is ascalar probability that the input is from the real data distribution.The generator is comprised ofconvolutional-transposelayers, batch norm layers, andReLU activations. Theinput is a latent vector,\(z\), that is drawn from a standardnormal distribution and the output is a 3x64x64 RGB image. The stridedconv-transpose layers allow the latent vector to be transformed into avolume with the same shape as an image. In the paper, the authors alsogive some tips about how to setup the optimizers, how to calculate theloss functions, and how to initialize the model weights, all of whichwill be explained in the coming sections.

#%matplotlib inlineimportargparseimportosimportrandomimporttorchimporttorch.nnasnnimporttorch.nn.parallelimporttorch.optimasoptimimporttorch.utils.dataimporttorchvision.datasetsasdsetimporttorchvision.transformsastransformsimporttorchvision.utilsasvutilsimportnumpyasnpimportmatplotlib.pyplotaspltimportmatplotlib.animationasanimationfromIPython.displayimportHTML# Set random seed for reproducibilitymanualSeed=999#manualSeed = random.randint(1, 10000) # use if you want new resultsprint("Random Seed: ",manualSeed)random.seed(manualSeed)torch.manual_seed(manualSeed)torch.use_deterministic_algorithms(True)# Needed for reproducible results
Random Seed:  999

Inputs#

Let’s define some inputs for the run:

  • dataroot - the path to the root of the dataset folder. We willtalk more about the dataset in the next section.

  • workers - the number of worker threads for loading the data withtheDataLoader.

  • batch_size - the batch size used in training. The DCGAN paperuses a batch size of 128.

  • image_size - the spatial size of the images used for training.This implementation defaults to 64x64. If another size is desired,the structures of D and G must be changed. Seehere for moredetails.

  • nc - number of color channels in the input images. For colorimages this is 3.

  • nz - length of latent vector.

  • ngf - relates to the depth of feature maps carried through thegenerator.

  • ndf - sets the depth of feature maps propagated through thediscriminator.

  • num_epochs - number of training epochs to run. Training forlonger will probably lead to better results but will also take muchlonger.

  • lr - learning rate for training. As described in the DCGAN paper,this number should be 0.0002.

  • beta1 - beta1 hyperparameter for Adam optimizers. As described inpaper, this number should be 0.5.

  • ngpu - number of GPUs available. If this is 0, code will run inCPU mode. If this number is greater than 0 it will run on that numberof GPUs.

# Root directory for datasetdataroot="data/celeba"# Number of workers for dataloaderworkers=2# Batch size during trainingbatch_size=128# Spatial size of training images. All images will be resized to this#   size using a transformer.image_size=64# Number of channels in the training images. For color images this is 3nc=3# Size of z latent vector (i.e. size of generator input)nz=100# Size of feature maps in generatorngf=64# Size of feature maps in discriminatorndf=64# Number of training epochsnum_epochs=5# Learning rate for optimizerslr=0.0002# Beta1 hyperparameter for Adam optimizersbeta1=0.5# Number of GPUs available. Use 0 for CPU mode.ngpu=1

Data#

In this tutorial we will use theCeleb-A Facesdataset which canbe downloaded at the linked site, or inGoogleDrive.The dataset will download as a file namedimg_align_celeba.zip. Oncedownloaded, create a directory namedceleba and extract the zip fileinto that directory. Then, set thedataroot input for this notebook totheceleba directory you just created. The resulting directorystructure should be:

/path/to/celeba->img_align_celeba->188242.jpg->173822.jpg->284702.jpg->537394.jpg...

This is an important step because we will be using theImageFolderdataset class, which requires there to be subdirectories in thedataset root folder. Now, we can create the dataset, create thedataloader, set the device to run on, and finally visualize some of thetraining data.

# We can use an image folder dataset the way we have it setup.# Create the datasetdataset=dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(image_size),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),]))# Create the dataloaderdataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=workers)# Decide which device we want to run ondevice=torch.device("cuda:0"if(torch.cuda.is_available()andngpu>0)else"cpu")# Plot some training imagesreal_batch=next(iter(dataloader))plt.figure(figsize=(8,8))plt.axis("off")plt.title("Training Images")plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64],padding=2,normalize=True).cpu(),(1,2,0)))plt.show()
Training Images

Implementation#

With our input parameters set and the dataset prepared, we can now getinto the implementation. We will start with the weight initializationstrategy, then talk about the generator, discriminator, loss functions,and training loop in detail.

Weight Initialization#

From the DCGAN paper, the authors specify that all model weights shallbe randomly initialized from a Normal distribution withmean=0,stdev=0.02. Theweights_init function takes an initialized model asinput and reinitializes all convolutional, convolutional-transpose, andbatch normalization layers to meet this criteria. This function isapplied to the models immediately after initialization.

# custom weights initialization called on ``netG`` and ``netD``defweights_init(m):classname=m.__class__.__name__ifclassname.find('Conv')!=-1:nn.init.normal_(m.weight.data,0.0,0.02)elifclassname.find('BatchNorm')!=-1:nn.init.normal_(m.weight.data,1.0,0.02)nn.init.constant_(m.bias.data,0)

Generator#

The generator,\(G\), is designed to map the latent space vector(\(z\)) to data-space. Since our data are images, converting\(z\) to data-space means ultimately creating a RGB image with thesame size as the training images (i.e. 3x64x64). In practice, this isaccomplished through a series of strided two dimensional convolutionaltranspose layers, each paired with a 2d batch norm layer and a reluactivation. The output of the generator is fed through a tanh functionto return it to the input data range of\([-1,1]\). It is worthnoting the existence of the batch norm functions after theconv-transpose layers, as this is a critical contribution of the DCGANpaper. These layers help with the flow of gradients during training. Animage of the generator from the DCGAN paper is shown below.

dcgan_generator

Notice, how the inputs we set in the input section (nz,ngf, andnc) influence the generator architecture in code.nz is the lengthof the z input vector,ngf relates to the size of the feature mapsthat are propagated through the generator, andnc is the number ofchannels in the output image (set to 3 for RGB images). Below is thecode for the generator.

# Generator CodeclassGenerator(nn.Module):def__init__(self,ngpu):super(Generator,self).__init__()self.ngpu=ngpuself.main=nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(nz,ngf*8,4,1,0,bias=False),nn.BatchNorm2d(ngf*8),nn.ReLU(True),# state size. ``(ngf*8) x 4 x 4``nn.ConvTranspose2d(ngf*8,ngf*4,4,2,1,bias=False),nn.BatchNorm2d(ngf*4),nn.ReLU(True),# state size. ``(ngf*4) x 8 x 8``nn.ConvTranspose2d(ngf*4,ngf*2,4,2,1,bias=False),nn.BatchNorm2d(ngf*2),nn.ReLU(True),# state size. ``(ngf*2) x 16 x 16``nn.ConvTranspose2d(ngf*2,ngf,4,2,1,bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# state size. ``(ngf) x 32 x 32``nn.ConvTranspose2d(ngf,nc,4,2,1,bias=False),nn.Tanh()# state size. ``(nc) x 64 x 64``)defforward(self,input):returnself.main(input)

Now, we can instantiate the generator and apply theweights_initfunction. Check out the printed model to see how the generator object isstructured.

# Create the generatornetG=Generator(ngpu).to(device)# Handle multi-GPU if desiredif(device.type=='cuda')and(ngpu>1):netG=nn.DataParallel(netG,list(range(ngpu)))# Apply the ``weights_init`` function to randomly initialize all weights#  to ``mean=0``, ``stdev=0.02``.netG.apply(weights_init)# Print the modelprint(netG)
Generator(  (main): Sequential(    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (2): ReLU(inplace=True)    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (5): ReLU(inplace=True)    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (8): ReLU(inplace=True)    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (11): ReLU(inplace=True)    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (13): Tanh()  ))

Discriminator#

As mentioned, the discriminator,\(D\), is a binary classificationnetwork that takes an image as input and outputs a scalar probabilitythat the input image is real (as opposed to fake). Here,\(D\) takesa 3x64x64 input image, processes it through a series of Conv2d,BatchNorm2d, and LeakyReLU layers, and outputs the final probabilitythrough a Sigmoid activation function. This architecture can be extendedwith more layers if necessary for the problem, but there is significanceto the use of the strided convolution, BatchNorm, and LeakyReLUs. TheDCGAN paper mentions it is a good practice to use strided convolutionrather than pooling to downsample because it lets the network learn itsown pooling function. Also batch norm and leaky relu functions promotehealthy gradient flow which is critical for the learning process of both\(G\) and\(D\).

Discriminator Code

classDiscriminator(nn.Module):def__init__(self,ngpu):super(Discriminator,self).__init__()self.ngpu=ngpuself.main=nn.Sequential(# input is ``(nc) x 64 x 64``nn.Conv2d(nc,ndf,4,2,1,bias=False),nn.LeakyReLU(0.2,inplace=True),# state size. ``(ndf) x 32 x 32``nn.Conv2d(ndf,ndf*2,4,2,1,bias=False),nn.BatchNorm2d(ndf*2),nn.LeakyReLU(0.2,inplace=True),# state size. ``(ndf*2) x 16 x 16``nn.Conv2d(ndf*2,ndf*4,4,2,1,bias=False),nn.BatchNorm2d(ndf*4),nn.LeakyReLU(0.2,inplace=True),# state size. ``(ndf*4) x 8 x 8``nn.Conv2d(ndf*4,ndf*8,4,2,1,bias=False),nn.BatchNorm2d(ndf*8),nn.LeakyReLU(0.2,inplace=True),# state size. ``(ndf*8) x 4 x 4``nn.Conv2d(ndf*8,1,4,1,0,bias=False),nn.Sigmoid())defforward(self,input):returnself.main(input)

Now, as with the generator, we can create the discriminator, apply theweights_init function, and print the model’s structure.

# Create the DiscriminatornetD=Discriminator(ngpu).to(device)# Handle multi-GPU if desiredif(device.type=='cuda')and(ngpu>1):netD=nn.DataParallel(netD,list(range(ngpu)))# Apply the ``weights_init`` function to randomly initialize all weights# like this: ``to mean=0, stdev=0.2``.netD.apply(weights_init)# Print the modelprint(netD)
Discriminator(  (main): Sequential(    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (1): LeakyReLU(negative_slope=0.2, inplace=True)    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (4): LeakyReLU(negative_slope=0.2, inplace=True)    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (7): LeakyReLU(negative_slope=0.2, inplace=True)    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)    (10): LeakyReLU(negative_slope=0.2, inplace=True)    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)    (12): Sigmoid()  ))

Loss Functions and Optimizers#

With\(D\) and\(G\) setup, we can specify how they learnthrough the loss functions and optimizers. We will use the Binary CrossEntropy loss(BCELoss)function which is defined in PyTorch as:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]\]

Notice how this function provides the calculation of both log componentsin the objective function (i.e.\(log(D(x))\) and\(log(1-D(G(z)))\)). We can specify what part of the BCE equation touse with the\(y\) input. This is accomplished in the training loopwhich is coming up soon, but it is important to understand how we canchoose which component we wish to calculate just by changing\(y\)(i.e. GT labels).

Next, we define our real label as 1 and the fake label as 0. Theselabels will be used when calculating the losses of\(D\) and\(G\), and this is also the convention used in the original GANpaper. Finally, we set up two separate optimizers, one for\(D\) andone for\(G\). As specified in the DCGAN paper, both are Adamoptimizers with learning rate 0.0002 and Beta1 = 0.5. For keeping trackof the generator’s learning progression, we will generate a fixed batchof latent vectors that are drawn from a Gaussian distribution(i.e. fixed_noise) . In the training loop, we will periodically inputthis fixed_noise into\(G\), and over the iterations we will seeimages form out of the noise.

# Initialize the ``BCELoss`` functioncriterion=nn.BCELoss()# Create batch of latent vectors that we will use to visualize#  the progression of the generatorfixed_noise=torch.randn(64,nz,1,1,device=device)# Establish convention for real and fake labels during trainingreal_label=1.fake_label=0.# Setup Adam optimizers for both G and DoptimizerD=optim.Adam(netD.parameters(),lr=lr,betas=(beta1,0.999))optimizerG=optim.Adam(netG.parameters(),lr=lr,betas=(beta1,0.999))

Training#

Finally, now that we have all of the parts of the GAN framework defined,we can train it. Be mindful that training GANs is somewhat of an artform, as incorrect hyperparameter settings lead to mode collapse withlittle explanation of what went wrong. Here, we will closely followAlgorithm 1 from theGoodfellow’s paper,while abiding by some of the bestpractices shown inganhacks.Namely, we will “construct different mini-batches for real and fake”images, and also adjust G’s objective function to maximize\(log(D(G(z)))\). Training is split up into two main parts. Part 1updates the Discriminator and Part 2 updates the Generator.

Part 1 - Train the Discriminator

Recall, the goal of training the discriminator is to maximize theprobability of correctly classifying a given input as real or fake. Interms of Goodfellow, we wish to “update the discriminator by ascendingits stochastic gradient”. Practically, we want to maximize\(log(D(x)) + log(1-D(G(z)))\). Due to the separate mini-batchsuggestion fromganhacks,we will calculate this in two steps. First, wewill construct a batch of real samples from the training set, forwardpass through\(D\), calculate the loss (\(log(D(x))\)), thencalculate the gradients in a backward pass. Secondly, we will constructa batch of fake samples with the current generator, forward pass thisbatch through\(D\), calculate the loss (\(log(1-D(G(z)))\)),andaccumulate the gradients with a backward pass. Now, with thegradients accumulated from both the all-real and all-fake batches, wecall a step of the Discriminator’s optimizer.

Part 2 - Train the Generator

As stated in the original paper, we want to train the Generator byminimizing\(log(1-D(G(z)))\) in an effort to generate better fakes.As mentioned, this was shown by Goodfellow to not provide sufficientgradients, especially early in the learning process. As a fix, weinstead wish to maximize\(log(D(G(z)))\). In the code we accomplishthis by: classifying the Generator output from Part 1 with theDiscriminator, computing G’s lossusing real labels as GT, computingG’s gradients in a backward pass, and finally updating G’s parameterswith an optimizer step. It may seem counter-intuitive to use the reallabels as GT labels for the loss function, but this allows us to use the\(log(x)\) part of theBCELoss (rather than the\(log(1-x)\)part) which is exactly what we want.

Finally, we will do some statistic reporting and at the end of eachepoch we will push our fixed_noise batch through the generator tovisually track the progress of G’s training. The training statisticsreported are:

  • Loss_D - discriminator loss calculated as the sum of losses forthe all real and all fake batches (\(log(D(x)) + log(1 - D(G(z)))\)).

  • Loss_G - generator loss calculated as\(log(D(G(z)))\)

  • D(x) - the average output (across the batch) of the discriminatorfor the all real batch. This should start close to 1 thentheoretically converge to 0.5 when G gets better. Think about whythis is.

  • D(G(z)) - average discriminator outputs for the all fake batch.The first number is before D is updated and the second number isafter D is updated. These numbers should start near 0 and converge to0.5 as G gets better. Think about why this is.

Note: This step might take a while, depending on how many epochs yourun and if you removed some data from the dataset.

# Training Loop# Lists to keep track of progressimg_list=[]G_losses=[]D_losses=[]iters=0print("Starting Training Loop...")# For each epochforepochinrange(num_epochs):# For each batch in the dataloaderfori,datainenumerate(dataloader,0):############################# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))############################# Train with all-real batchnetD.zero_grad()# Format batchreal_cpu=data[0].to(device)b_size=real_cpu.size(0)label=torch.full((b_size,),real_label,dtype=torch.float,device=device)# Forward pass real batch through Doutput=netD(real_cpu).view(-1)# Calculate loss on all-real batcherrD_real=criterion(output,label)# Calculate gradients for D in backward passerrD_real.backward()D_x=output.mean().item()## Train with all-fake batch# Generate batch of latent vectorsnoise=torch.randn(b_size,nz,1,1,device=device)# Generate fake image batch with Gfake=netG(noise)label.fill_(fake_label)# Classify all fake batch with Doutput=netD(fake.detach()).view(-1)# Calculate D's loss on the all-fake batcherrD_fake=criterion(output,label)# Calculate the gradients for this batch, accumulated (summed) with previous gradientserrD_fake.backward()D_G_z1=output.mean().item()# Compute error of D as sum over the fake and the real batcheserrD=errD_real+errD_fake# Update DoptimizerD.step()############################# (2) Update G network: maximize log(D(G(z)))###########################netG.zero_grad()label.fill_(real_label)# fake labels are real for generator cost# Since we just updated D, perform another forward pass of all-fake batch through Doutput=netD(fake).view(-1)# Calculate G's loss based on this outputerrG=criterion(output,label)# Calculate gradients for GerrG.backward()D_G_z2=output.mean().item()# Update GoptimizerG.step()# Output training statsifi%50==0:print('[%d/%d][%d/%d]\tLoss_D:%.4f\tLoss_G:%.4f\tD(x):%.4f\tD(G(z)):%.4f /%.4f'%(epoch,num_epochs,i,len(dataloader),errD.item(),errG.item(),D_x,D_G_z1,D_G_z2))# Save Losses for plotting laterG_losses.append(errG.item())D_losses.append(errD.item())# Check how the generator is doing by saving G's output on fixed_noiseif(iters%500==0)or((epoch==num_epochs-1)and(i==len(dataloader)-1)):withtorch.no_grad():fake=netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake,padding=2,normalize=True))iters+=1
Starting Training Loop...[0/5][0/1583]   Loss_D: 1.4639  Loss_G: 6.9354  D(x): 0.7143    D(G(z)): 0.5876 / 0.0017[0/5][50/1583]  Loss_D: 0.0851  Loss_G: 12.3804 D(x): 0.9438    D(G(z)): 0.0000 / 0.0000[0/5][100/1583] Loss_D: 0.6445  Loss_G: 5.7368  D(x): 0.8860    D(G(z)): 0.2482 / 0.0087[0/5][150/1583] Loss_D: 1.4464  Loss_G: 7.9471  D(x): 0.9681    D(G(z)): 0.6656 / 0.0009[0/5][200/1583] Loss_D: 1.3758  Loss_G: 6.7867  D(x): 0.9747    D(G(z)): 0.5740 / 0.0029[0/5][250/1583] Loss_D: 0.5414  Loss_G: 4.9902  D(x): 0.6769    D(G(z)): 0.0022 / 0.0214[0/5][300/1583] Loss_D: 0.7736  Loss_G: 4.8102  D(x): 0.8244    D(G(z)): 0.3702 / 0.0185[0/5][350/1583] Loss_D: 0.4239  Loss_G: 5.1986  D(x): 0.9090    D(G(z)): 0.2303 / 0.0110[0/5][400/1583] Loss_D: 0.3529  Loss_G: 4.0137  D(x): 0.9027    D(G(z)): 0.1944 / 0.0261[0/5][450/1583] Loss_D: 0.4440  Loss_G: 4.9537  D(x): 0.8638    D(G(z)): 0.2088 / 0.0140[0/5][500/1583] Loss_D: 0.8861  Loss_G: 2.7838  D(x): 0.5446    D(G(z)): 0.0555 / 0.1059[0/5][550/1583] Loss_D: 0.3276  Loss_G: 4.4326  D(x): 0.8875    D(G(z)): 0.1378 / 0.0208[0/5][600/1583] Loss_D: 0.9374  Loss_G: 4.1385  D(x): 0.5358    D(G(z)): 0.0073 / 0.0322[0/5][650/1583] Loss_D: 0.8355  Loss_G: 3.3358  D(x): 0.6741    D(G(z)): 0.2301 / 0.0586[0/5][700/1583] Loss_D: 0.2010  Loss_G: 5.5247  D(x): 0.8606    D(G(z)): 0.0229 / 0.0079[0/5][750/1583] Loss_D: 0.4984  Loss_G: 5.0414  D(x): 0.8394    D(G(z)): 0.1959 / 0.0136[0/5][800/1583] Loss_D: 0.2822  Loss_G: 5.8567  D(x): 0.8146    D(G(z)): 0.0228 / 0.0073[0/5][850/1583] Loss_D: 1.1002  Loss_G: 8.3852  D(x): 0.9550    D(G(z)): 0.5744 / 0.0006[0/5][900/1583] Loss_D: 0.6649  Loss_G: 7.2239  D(x): 0.9371    D(G(z)): 0.3742 / 0.0019[0/5][950/1583] Loss_D: 0.3856  Loss_G: 4.5917  D(x): 0.7909    D(G(z)): 0.0453 / 0.0231[0/5][1000/1583]        Loss_D: 0.4167  Loss_G: 3.1231  D(x): 0.7446    D(G(z)): 0.0287 / 0.0861[0/5][1050/1583]        Loss_D: 0.2127  Loss_G: 3.2111  D(x): 0.8649    D(G(z)): 0.0352 / 0.0851[0/5][1100/1583]        Loss_D: 0.6615  Loss_G: 2.8970  D(x): 0.6205    D(G(z)): 0.0310 / 0.0896[0/5][1150/1583]        Loss_D: 1.2416  Loss_G: 8.0865  D(x): 0.9842    D(G(z)): 0.6128 / 0.0013[0/5][1200/1583]        Loss_D: 0.2730  Loss_G: 4.4813  D(x): 0.8520    D(G(z)): 0.0765 / 0.0193[0/5][1250/1583]        Loss_D: 1.0704  Loss_G: 2.1137  D(x): 0.4911    D(G(z)): 0.0370 / 0.1959[0/5][1300/1583]        Loss_D: 0.6955  Loss_G: 2.7311  D(x): 0.7302    D(G(z)): 0.2182 / 0.1200[0/5][1350/1583]        Loss_D: 0.5011  Loss_G: 4.2360  D(x): 0.7878    D(G(z)): 0.1619 / 0.0287[0/5][1400/1583]        Loss_D: 0.6945  Loss_G: 3.7859  D(x): 0.6198    D(G(z)): 0.0291 / 0.0441[0/5][1450/1583]        Loss_D: 0.6806  Loss_G: 2.8349  D(x): 0.7556    D(G(z)): 0.2388 / 0.1018[0/5][1500/1583]        Loss_D: 0.5213  Loss_G: 3.1917  D(x): 0.7327    D(G(z)): 0.1126 / 0.0648[0/5][1550/1583]        Loss_D: 0.6107  Loss_G: 4.4771  D(x): 0.8862    D(G(z)): 0.3223 / 0.0217[1/5][0/1583]   Loss_D: 0.7692  Loss_G: 7.7729  D(x): 0.9520    D(G(z)): 0.4430 / 0.0012[1/5][50/1583]  Loss_D: 0.3219  Loss_G: 3.8916  D(x): 0.8465    D(G(z)): 0.1150 / 0.0368[1/5][100/1583] Loss_D: 0.5995  Loss_G: 2.9560  D(x): 0.8526    D(G(z)): 0.2871 / 0.0883[1/5][150/1583] Loss_D: 0.4610  Loss_G: 5.0678  D(x): 0.9230    D(G(z)): 0.2664 / 0.0129[1/5][200/1583] Loss_D: 0.5800  Loss_G: 4.3488  D(x): 0.8417    D(G(z)): 0.2842 / 0.0218[1/5][250/1583] Loss_D: 1.0830  Loss_G: 1.2592  D(x): 0.4373    D(G(z)): 0.0150 / 0.3702[1/5][300/1583] Loss_D: 0.6294  Loss_G: 5.6650  D(x): 0.9031    D(G(z)): 0.3524 / 0.0069[1/5][350/1583] Loss_D: 0.3595  Loss_G: 5.0247  D(x): 0.8892    D(G(z)): 0.1777 / 0.0119[1/5][400/1583] Loss_D: 1.3809  Loss_G: 2.1831  D(x): 0.3719    D(G(z)): 0.0092 / 0.1816[1/5][450/1583] Loss_D: 0.2346  Loss_G: 3.5210  D(x): 0.8501    D(G(z)): 0.0451 / 0.0447[1/5][500/1583] Loss_D: 0.9032  Loss_G: 3.0782  D(x): 0.7109    D(G(z)): 0.2925 / 0.0855[1/5][550/1583] Loss_D: 0.3654  Loss_G: 3.1007  D(x): 0.7805    D(G(z)): 0.0697 / 0.0768[1/5][600/1583] Loss_D: 0.3883  Loss_G: 3.5774  D(x): 0.8333    D(G(z)): 0.1572 / 0.0372[1/5][650/1583] Loss_D: 0.3189  Loss_G: 3.2618  D(x): 0.8408    D(G(z)): 0.1058 / 0.0522[1/5][700/1583] Loss_D: 0.3815  Loss_G: 3.1567  D(x): 0.8340    D(G(z)): 0.1566 / 0.0630[1/5][750/1583] Loss_D: 0.4669  Loss_G: 2.7565  D(x): 0.7493    D(G(z)): 0.1102 / 0.0943[1/5][800/1583] Loss_D: 0.3789  Loss_G: 3.4999  D(x): 0.7654    D(G(z)): 0.0563 / 0.0510[1/5][850/1583] Loss_D: 0.4219  Loss_G: 2.7859  D(x): 0.7536    D(G(z)): 0.0741 / 0.0895[1/5][900/1583] Loss_D: 0.4764  Loss_G: 3.0801  D(x): 0.8249    D(G(z)): 0.2039 / 0.0640[1/5][950/1583] Loss_D: 0.7384  Loss_G: 5.2064  D(x): 0.8971    D(G(z)): 0.3995 / 0.0126[1/5][1000/1583]        Loss_D: 0.3675  Loss_G: 3.4234  D(x): 0.7870    D(G(z)): 0.0787 / 0.0525[1/5][1050/1583]        Loss_D: 0.3102  Loss_G: 3.3972  D(x): 0.8545    D(G(z)): 0.1139 / 0.0528[1/5][1100/1583]        Loss_D: 0.8471  Loss_G: 6.6303  D(x): 0.9318    D(G(z)): 0.4631 / 0.0027[1/5][1150/1583]        Loss_D: 1.4004  Loss_G: 8.2164  D(x): 0.9515    D(G(z)): 0.6582 / 0.0013[1/5][1200/1583]        Loss_D: 0.3403  Loss_G: 4.2836  D(x): 0.8755    D(G(z)): 0.1545 / 0.0258[1/5][1250/1583]        Loss_D: 0.4129  Loss_G: 3.4526  D(x): 0.8444    D(G(z)): 0.1893 / 0.0461[1/5][1300/1583]        Loss_D: 0.3312  Loss_G: 3.2010  D(x): 0.8611    D(G(z)): 0.1390 / 0.0583[1/5][1350/1583]        Loss_D: 1.4039  Loss_G: 1.7094  D(x): 0.3350    D(G(z)): 0.0038 / 0.2643[1/5][1400/1583]        Loss_D: 1.4756  Loss_G: 0.9350  D(x): 0.3250    D(G(z)): 0.0201 / 0.4793[1/5][1450/1583]        Loss_D: 0.7429  Loss_G: 5.2820  D(x): 0.9280    D(G(z)): 0.4340 / 0.0092[1/5][1500/1583]        Loss_D: 0.7464  Loss_G: 3.5911  D(x): 0.8269    D(G(z)): 0.3531 / 0.0460[1/5][1550/1583]        Loss_D: 0.8920  Loss_G: 0.7305  D(x): 0.4923    D(G(z)): 0.0177 / 0.5478[2/5][0/1583]   Loss_D: 0.6057  Loss_G: 2.3254  D(x): 0.6442    D(G(z)): 0.0784 / 0.1462[2/5][50/1583]  Loss_D: 1.0692  Loss_G: 5.7108  D(x): 0.9680    D(G(z)): 0.5920 / 0.0054[2/5][100/1583] Loss_D: 0.8286  Loss_G: 5.2765  D(x): 0.9165    D(G(z)): 0.4783 / 0.0083[2/5][150/1583] Loss_D: 0.9738  Loss_G: 4.1549  D(x): 0.9244    D(G(z)): 0.5164 / 0.0311[2/5][200/1583] Loss_D: 0.5475  Loss_G: 2.4569  D(x): 0.7468    D(G(z)): 0.1752 / 0.1134[2/5][250/1583] Loss_D: 0.7450  Loss_G: 3.7943  D(x): 0.8290    D(G(z)): 0.3696 / 0.0350[2/5][300/1583] Loss_D: 0.5113  Loss_G: 2.3613  D(x): 0.7783    D(G(z)): 0.1910 / 0.1245[2/5][350/1583] Loss_D: 0.6191  Loss_G: 3.0775  D(x): 0.7884    D(G(z)): 0.2558 / 0.0623[2/5][400/1583] Loss_D: 0.8983  Loss_G: 4.4402  D(x): 0.9307    D(G(z)): 0.5052 / 0.0190[2/5][450/1583] Loss_D: 0.6420  Loss_G: 3.8722  D(x): 0.8518    D(G(z)): 0.3372 / 0.0303[2/5][500/1583] Loss_D: 0.4671  Loss_G: 2.7546  D(x): 0.7878    D(G(z)): 0.1769 / 0.0827[2/5][550/1583] Loss_D: 0.9036  Loss_G: 2.0881  D(x): 0.5542    D(G(z)): 0.1568 / 0.1930[2/5][600/1583] Loss_D: 0.6141  Loss_G: 4.0458  D(x): 0.9060    D(G(z)): 0.3669 / 0.0249[2/5][650/1583] Loss_D: 0.7181  Loss_G: 1.1370  D(x): 0.5812    D(G(z)): 0.0746 / 0.3790[2/5][700/1583] Loss_D: 0.5612  Loss_G: 2.0203  D(x): 0.7471    D(G(z)): 0.2000 / 0.1692[2/5][750/1583] Loss_D: 0.5850  Loss_G: 3.0121  D(x): 0.8624    D(G(z)): 0.3170 / 0.0688[2/5][800/1583] Loss_D: 0.6519  Loss_G: 2.9880  D(x): 0.7570    D(G(z)): 0.2693 / 0.0680[2/5][850/1583] Loss_D: 0.6205  Loss_G: 3.8158  D(x): 0.9208    D(G(z)): 0.3756 / 0.0330[2/5][900/1583] Loss_D: 0.6450  Loss_G: 2.6535  D(x): 0.8185    D(G(z)): 0.3174 / 0.0931[2/5][950/1583] Loss_D: 0.5963  Loss_G: 2.6959  D(x): 0.7777    D(G(z)): 0.2406 / 0.0902[2/5][1000/1583]        Loss_D: 0.5321  Loss_G: 1.8000  D(x): 0.6826    D(G(z)): 0.1002 / 0.2049[2/5][1050/1583]        Loss_D: 0.7709  Loss_G: 3.0746  D(x): 0.7600    D(G(z)): 0.3300 / 0.0736[2/5][1100/1583]        Loss_D: 0.5996  Loss_G: 3.7402  D(x): 0.9148    D(G(z)): 0.3680 / 0.0329[2/5][1150/1583]        Loss_D: 0.5305  Loss_G: 2.5609  D(x): 0.7469    D(G(z)): 0.1684 / 0.1047[2/5][1200/1583]        Loss_D: 0.4832  Loss_G: 3.0567  D(x): 0.9098    D(G(z)): 0.2880 / 0.0636[2/5][1250/1583]        Loss_D: 0.4494  Loss_G: 1.9655  D(x): 0.7376    D(G(z)): 0.1125 / 0.1680[2/5][1300/1583]        Loss_D: 0.7552  Loss_G: 3.7669  D(x): 0.8978    D(G(z)): 0.4342 / 0.0315[2/5][1350/1583]        Loss_D: 0.6206  Loss_G: 1.8970  D(x): 0.7077    D(G(z)): 0.2007 / 0.1885[2/5][1400/1583]        Loss_D: 0.5829  Loss_G: 4.0084  D(x): 0.9047    D(G(z)): 0.3427 / 0.0261[2/5][1450/1583]        Loss_D: 0.7387  Loss_G: 3.8965  D(x): 0.8958    D(G(z)): 0.4115 / 0.0297[2/5][1500/1583]        Loss_D: 0.6604  Loss_G: 2.0410  D(x): 0.6503    D(G(z)): 0.1431 / 0.1760[2/5][1550/1583]        Loss_D: 0.4903  Loss_G: 1.9953  D(x): 0.6867    D(G(z)): 0.0708 / 0.1641[3/5][0/1583]   Loss_D: 0.5715  Loss_G: 2.1221  D(x): 0.6555    D(G(z)): 0.0805 / 0.1564[3/5][50/1583]  Loss_D: 0.6685  Loss_G: 1.8053  D(x): 0.6666    D(G(z)): 0.1717 / 0.2029[3/5][100/1583] Loss_D: 0.8436  Loss_G: 4.7138  D(x): 0.9148    D(G(z)): 0.4925 / 0.0124[3/5][150/1583] Loss_D: 1.2393  Loss_G: 4.4667  D(x): 0.9434    D(G(z)): 0.6295 / 0.0192[3/5][200/1583] Loss_D: 0.6140  Loss_G: 1.4225  D(x): 0.6740    D(G(z)): 0.1500 / 0.2807[3/5][250/1583] Loss_D: 0.3398  Loss_G: 2.9837  D(x): 0.8644    D(G(z)): 0.1602 / 0.0686[3/5][300/1583] Loss_D: 0.6218  Loss_G: 1.2181  D(x): 0.6564    D(G(z)): 0.1390 / 0.3480[3/5][350/1583] Loss_D: 0.6379  Loss_G: 1.5250  D(x): 0.6778    D(G(z)): 0.1548 / 0.2575[3/5][400/1583] Loss_D: 1.4069  Loss_G: 0.3211  D(x): 0.3123    D(G(z)): 0.0113 / 0.7506[3/5][450/1583] Loss_D: 0.6659  Loss_G: 2.7180  D(x): 0.7346    D(G(z)): 0.2539 / 0.0868[3/5][500/1583] Loss_D: 0.7648  Loss_G: 2.5427  D(x): 0.7457    D(G(z)): 0.3126 / 0.1060[3/5][550/1583] Loss_D: 0.5457  Loss_G: 2.7289  D(x): 0.8738    D(G(z)): 0.3102 / 0.0809[3/5][600/1583] Loss_D: 0.6105  Loss_G: 2.6130  D(x): 0.7653    D(G(z)): 0.2476 / 0.0982[3/5][650/1583] Loss_D: 0.6688  Loss_G: 2.8342  D(x): 0.7574    D(G(z)): 0.2863 / 0.0733[3/5][700/1583] Loss_D: 0.4442  Loss_G: 2.5463  D(x): 0.8113    D(G(z)): 0.1751 / 0.1024[3/5][750/1583] Loss_D: 0.6264  Loss_G: 2.4645  D(x): 0.7505    D(G(z)): 0.2351 / 0.1085[3/5][800/1583] Loss_D: 0.7283  Loss_G: 3.4575  D(x): 0.9011    D(G(z)): 0.4229 / 0.0403[3/5][850/1583] Loss_D: 0.5881  Loss_G: 2.4032  D(x): 0.7846    D(G(z)): 0.2623 / 0.1114[3/5][900/1583] Loss_D: 0.5927  Loss_G: 3.2541  D(x): 0.9055    D(G(z)): 0.3523 / 0.0529[3/5][950/1583] Loss_D: 0.5698  Loss_G: 2.5111  D(x): 0.7926    D(G(z)): 0.2406 / 0.1075[3/5][1000/1583]        Loss_D: 0.8790  Loss_G: 1.0398  D(x): 0.4864    D(G(z)): 0.0393 / 0.4015[3/5][1050/1583]        Loss_D: 0.6386  Loss_G: 1.9673  D(x): 0.7570    D(G(z)): 0.2606 / 0.1727[3/5][1100/1583]        Loss_D: 0.7309  Loss_G: 3.2321  D(x): 0.9078    D(G(z)): 0.4079 / 0.0587[3/5][1150/1583]        Loss_D: 0.8423  Loss_G: 1.3456  D(x): 0.5246    D(G(z)): 0.0837 / 0.3295[3/5][1200/1583]        Loss_D: 0.9711  Loss_G: 1.6972  D(x): 0.6094    D(G(z)): 0.2685 / 0.2384[3/5][1250/1583]        Loss_D: 0.5156  Loss_G: 2.0840  D(x): 0.7593    D(G(z)): 0.1793 / 0.1491[3/5][1300/1583]        Loss_D: 0.6282  Loss_G: 1.5250  D(x): 0.6027    D(G(z)): 0.0596 / 0.2620[3/5][1350/1583]        Loss_D: 1.1552  Loss_G: 1.1687  D(x): 0.3871    D(G(z)): 0.0268 / 0.3793[3/5][1400/1583]        Loss_D: 0.6427  Loss_G: 3.5239  D(x): 0.9051    D(G(z)): 0.3770 / 0.0408[3/5][1450/1583]        Loss_D: 0.9780  Loss_G: 0.9027  D(x): 0.4657    D(G(z)): 0.0829 / 0.4742[3/5][1500/1583]        Loss_D: 0.8949  Loss_G: 3.4893  D(x): 0.8459    D(G(z)): 0.4608 / 0.0430[3/5][1550/1583]        Loss_D: 0.5196  Loss_G: 2.8824  D(x): 0.8521    D(G(z)): 0.2642 / 0.0713[4/5][0/1583]   Loss_D: 0.5248  Loss_G: 2.9609  D(x): 0.8718    D(G(z)): 0.2849 / 0.0672[4/5][50/1583]  Loss_D: 1.0148  Loss_G: 3.5185  D(x): 0.9018    D(G(z)): 0.5368 / 0.0424[4/5][100/1583] Loss_D: 0.6069  Loss_G: 1.8107  D(x): 0.7055    D(G(z)): 0.1806 / 0.2017[4/5][150/1583] Loss_D: 0.6207  Loss_G: 3.3905  D(x): 0.8344    D(G(z)): 0.3150 / 0.0442[4/5][200/1583] Loss_D: 1.3615  Loss_G: 0.8640  D(x): 0.3535    D(G(z)): 0.0964 / 0.4822[4/5][250/1583] Loss_D: 0.6368  Loss_G: 2.7537  D(x): 0.8272    D(G(z)): 0.3162 / 0.0861[4/5][300/1583] Loss_D: 0.6550  Loss_G: 4.3033  D(x): 0.9164    D(G(z)): 0.3974 / 0.0182[4/5][350/1583] Loss_D: 0.4642  Loss_G: 2.8064  D(x): 0.8521    D(G(z)): 0.2359 / 0.0763[4/5][400/1583] Loss_D: 0.4893  Loss_G: 4.1874  D(x): 0.9271    D(G(z)): 0.3123 / 0.0208[4/5][450/1583] Loss_D: 0.8345  Loss_G: 1.4003  D(x): 0.5434    D(G(z)): 0.1061 / 0.3093[4/5][500/1583] Loss_D: 0.4984  Loss_G: 2.2925  D(x): 0.7357    D(G(z)): 0.1320 / 0.1309[4/5][550/1583] Loss_D: 1.0029  Loss_G: 4.5756  D(x): 0.9061    D(G(z)): 0.5296 / 0.0196[4/5][600/1583] Loss_D: 1.0571  Loss_G: 4.6994  D(x): 0.8801    D(G(z)): 0.5433 / 0.0142[4/5][650/1583] Loss_D: 1.6732  Loss_G: 0.9446  D(x): 0.2653    D(G(z)): 0.0329 / 0.4495[4/5][700/1583] Loss_D: 0.4947  Loss_G: 3.0144  D(x): 0.8622    D(G(z)): 0.2688 / 0.0622[4/5][750/1583] Loss_D: 0.7926  Loss_G: 4.7576  D(x): 0.9426    D(G(z)): 0.4695 / 0.0123[4/5][800/1583] Loss_D: 1.0689  Loss_G: 1.1621  D(x): 0.4206    D(G(z)): 0.0472 / 0.3675[4/5][850/1583] Loss_D: 0.9304  Loss_G: 0.9133  D(x): 0.4842    D(G(z)): 0.0686 / 0.4482[4/5][900/1583] Loss_D: 0.7063  Loss_G: 3.5406  D(x): 0.8862    D(G(z)): 0.4038 / 0.0395[4/5][950/1583] Loss_D: 1.9066  Loss_G: 0.4034  D(x): 0.1997    D(G(z)): 0.0087 / 0.7049[4/5][1000/1583]        Loss_D: 0.5548  Loss_G: 2.8483  D(x): 0.8634    D(G(z)): 0.2934 / 0.0779[4/5][1050/1583]        Loss_D: 0.5532  Loss_G: 2.1916  D(x): 0.7711    D(G(z)): 0.2246 / 0.1421[4/5][1100/1583]        Loss_D: 0.4106  Loss_G: 2.0307  D(x): 0.8159    D(G(z)): 0.1599 / 0.1620[4/5][1150/1583]        Loss_D: 0.6264  Loss_G: 1.7012  D(x): 0.6370    D(G(z)): 0.1057 / 0.2245[4/5][1200/1583]        Loss_D: 0.3717  Loss_G: 2.0763  D(x): 0.8354    D(G(z)): 0.1549 / 0.1538[4/5][1250/1583]        Loss_D: 0.6172  Loss_G: 3.2376  D(x): 0.8498    D(G(z)): 0.3282 / 0.0548[4/5][1300/1583]        Loss_D: 1.6700  Loss_G: 5.7220  D(x): 0.9852    D(G(z)): 0.7467 / 0.0055[4/5][1350/1583]        Loss_D: 0.6699  Loss_G: 3.6520  D(x): 0.8532    D(G(z)): 0.3590 / 0.0376[4/5][1400/1583]        Loss_D: 0.9703  Loss_G: 3.4816  D(x): 0.9004    D(G(z)): 0.5150 / 0.0472[4/5][1450/1583]        Loss_D: 0.6659  Loss_G: 1.3697  D(x): 0.6842    D(G(z)): 0.1955 / 0.3047[4/5][1500/1583]        Loss_D: 0.5204  Loss_G: 2.5325  D(x): 0.7609    D(G(z)): 0.1713 / 0.1096[4/5][1550/1583]        Loss_D: 0.5696  Loss_G: 1.7800  D(x): 0.6674    D(G(z)): 0.0904 / 0.2047

Results#

Finally, lets check out how we did. Here, we will look at threedifferent results. First, we will see how D and G’s losses changedduring training. Second, we will visualize G’s output on the fixed_noisebatch for every epoch. And third, we will look at a batch of real datanext to a batch of fake data from G.

Loss versus training iteration

Below is a plot of D & G’s losses versus training iterations.

plt.figure(figsize=(10,5))plt.title("Generator and Discriminator Loss During Training")plt.plot(G_losses,label="G")plt.plot(D_losses,label="D")plt.xlabel("iterations")plt.ylabel("Loss")plt.legend()plt.show()
Generator and Discriminator Loss During Training

Visualization of G’s progression

Remember how we saved the generator’s output on the fixed_noise batchafter every epoch of training. Now, we can visualize the trainingprogression of G with an animation. Press the play button to start theanimation.

fig=plt.figure(figsize=(8,8))plt.axis("off")ims=[[plt.imshow(np.transpose(i,(1,2,0)),animated=True)]foriinimg_list]ani=animation.ArtistAnimation(fig,ims,interval=1000,repeat_delay=1000,blit=True)HTML(ani.to_jshtml())
dcgan faces tutorial


Real Images vs. Fake Images

Finally, lets take a look at some real images and fake images side byside.

# Grab a batch of real images from the dataloaderreal_batch=next(iter(dataloader))# Plot the real imagesplt.figure(figsize=(15,15))plt.subplot(1,2,1)plt.axis("off")plt.title("Real Images")plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64],padding=5,normalize=True).cpu(),(1,2,0)))# Plot the fake images from the last epochplt.subplot(1,2,2)plt.axis("off")plt.title("Fake Images")plt.imshow(np.transpose(img_list[-1],(1,2,0)))plt.show()
Real Images, Fake Images

Where to Go Next#

We have reached the end of our journey, but there are several places youcould go from here. You could:

  • Train for longer to see how good the results get

  • Modify this model to take a different dataset and possibly change thesize of the images and the model architecture

  • Check out some other cool GAN projectshere

  • Create GANs that generatemusic

Total running time of the script: (6 minutes 31.005 seconds)