Implementation of a Wasserstein Generative Adversarial Network that learns the distribution of a Mixture of Gaussian. The WGAN loss is Lipschitz constrained, so to enforces the constraint I implemented and tested two possible methods: weight clipping and Spectral Normalization. Since the state-of-the-art GAN training is computationally expensive, this project will use a simple GAN with a linear generator and a dual variable.
For a more detailed explanation of the terms mentioned above, please read Exercise instructions.pdf, it contains also some theoretical questions answered in Answers.pdf (handwritten).
The project was part of an assignment for the EPFL course EE-556 Mathematics of data: from theory to computation. The backbone of the code structure to run the experiments was already given by the professor and his assistants, what I had to do was to implement the core theoretical concepts to actually make the experiments work. Hence, every code file is a combination of my personal code and the code that was given us by the professor.
The following GIFs shows the output of the code train.py. The first one (left) is obtained by using the weight clipping to enforce a Lipschitz constraint, while the second one (right) is the result of using Spectral Normalization.
From those results, we can see how with both techniques the noise (red points) tends to have the same distribution as real data (blue points) and they keep oscillating between a less precise solution and a more accurate one. Using spectral normalization makes the noise achieve a similar distribution to the real data one after 600 iterations, while weight clipping is faster and achieves it after about 300 iterations. Spectral normalization does also get worse after 800 iterations and then gets back to a good distribution after 1000, whereas weight clipping maintains a more constant result.
Download this repository as a zip file and extract it into a folder The easiest way to run the code is to install Anaconda 3 distribution (available for Windows, macOS and Linux). To do so, follow the guidelines from the official website (select python of version 3): https://www.anaconda.com/download/
Additional package required are:
- pytorch
- matplotlib
- tqdm
- imageio
To install them write the following command on Anaconda Prompt (anaconda3):
cd *THE_FOLDER_PATH_WHERE_YOU_DOWNLOADED_AND_EXTRACTED_THIS_REPOSITORY*
Then write for each of the mentioned packages:
conda install *PACKAGE_NAME*
Some packages might require more complex installation procedures (especially pytorch). If the above command doesn't work for a package, just google "How to install PACKAGE_NAME on YOUR_MACHINE'S_OS" and follow those guides.
Finally, run train.py. To use weight clipping instead of Spectral normalization, you have to change the line of code 110 in trainer.py from "f.enforce_lipschitz()" to "f.enforce_lipschitz(spectral_norm=False)".
python train.py
-
code/src/ : folder containing all the sub-components useful for the training
-
code/train.py: main code to run the training and create the result's GIF
-
code/movie.gif: ouput of the train.py function (using weight clipping)
-
Answers.pdf: pdf with the answers and plots to the assignment of the course
-
Exercise instructions.pdf: pdf with the questions of the assignment of the course
Python, Pytorch, Matplotlib. Machine learning, Wasserstein Generative Adversarial Network (WGAN) implementation, minimax problems, implementation of both weights clipping and Spectral Normalization for Lipschitz constrained minimization.