Click here to Skip to main content
14,972,394 members
Articles / Artificial Intelligence / Keras
Article
Posted 22 Mar 2021

Stats

2.8K views
5 bookmarked

Training and Running a GAN for Fashion Design Generation

Rate me:
Please Sign up or sign in to vote.
5.00/5 (3 votes)
22 Mar 2021CPOL3 min read
In this article we show you how to train the GAN for fashion design generation.
Here we train our GAN to generate realistic-looking clothing images, similar to the ones found in the DeepFashion dataset.

Introduction

The availability of datasets like DeepFashion open up new possibilities for the fashion industry. In this series of articles, we’ll showcase an AI-powered deep learning system that can revolutionize the fashion design industry by helping us better understand customers’ needs.

In this project, we’ll use:

We are assuming that you are familiar with the concepts of deep learning, as well as with Jupyter Notebooks and TensorFlow. If you’re new to Jupyter Notebooks, start with this tutorial. You are welcome to download the project code.

In the previous article, we designed and built a Generative Adversarial Network (GAN). In this article, we’ll train our GAN to generate realistic-looking clothing images, similar to the ones found in the DeepFashion dataset.

Training a GAN

The generator training is done by reducing the loss and error between the fake and real image ((log(D(x))+log(D(G(z)). We’ll select a large number of epochs because this kind of network needs many iterations to reduce the error between the real and fake images. We’ll start with 40 epochs for training and see what results this brings. We’ll train the network on our customized dataset. Parameter and variable definitions are as follows:

  • G_losses: the generator loss, calculated by summing all the losses of the generated images during the training of the generator
  • D_losses: the discriminator loss, calculated by summing all the losses of real and fake batches
  • D(G(z): the average discriminator outputs for all fake batches
  • D(x): the average output (across the batch) of the discriminator for all real batches
Python
# Lists to keep track of progress
img_list = []  
G_losses = []  
D_losses = [] 
iters = 0
####################################################################
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Image 1

As you can see, after epoch 40, the average discriminator outputs for all fake batches

D(G(Z)) is reduced to a very attractive value. With this, the GAN is skilled enough to generate images similar to those in the dataset. If you want even better images, you need to increase the number of epochs and train again.

We can also plot a graph for the generator and discriminator loss during training.

Python
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

Image 2

Visualizing Generated Images During Training

Pytorch offers a function for visualizing images generated during training as an animated video.

Python
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

Image 3

Generating Fashion Images from a Trained GAN

After our GAN has been trained, we can grab a batch of the fashion images it has generated by using this code.

Python
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

Image 4

Looks like our GAN was able to generate some fashion images that were similar to those found in the training dataset.

Some fast and easy ways to further improve the GAN performance are:

  • Build a deeper generator using transposed convolutions or upsampling layers
  • Change the type of the generator input noise to Gaussian
  • Build a deeper discriminator to improve its prediction performance
  • Train longer using larger number of epochs and more images

Next Steps

We’ve reached the end of our series! We achieved our goals: to create and train a deep network for fashion design category classification, and to develop a new fashion design generation using GAN.

Still, the results we achieved can be improved upon. For example, you can train your deep network on more images that contain the various clothing categories. You can also expand your project with a deep network that can detect different types of clothes in the same image using a Regional Proposed Network (RPN). Such a network would classify clothing items using a pre-trained model like the one we created in this series.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

Abdulkader Helwan
Engineer
Lebanon Lebanon
Dr. Helwan is a machine learning and medical image analysis enthusiast.

His research interests include but not limited to Machine and deep learning in medicine, Medical computational intelligence, Biomedical image processing, and Biomedical engineering and systems.

Comments and Discussions

 
QuestionHtt Pin
Shanky Kumar24-Mar-21 22:52
MemberShanky Kumar24-Mar-21 22:52 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Praise Praise    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.