Generative Models
Last updated
Last updated
The idea of generative models, is to be able to learn the probability distribution of the training set. This important idea could have the following use cases:
A super dataset augmenting system. (Able to create more data from the original data)
Reinforcement Learning systems where the Generator could be a simulator of the environment, simulating possible futures when planning a decision and reasoning.
If you are able to generate the data distribution, you probably captured the underlying causal factors. Now, in principle, you are in the best possible position to answer any question about that data.
Basically this done by having 2 neural networks playing against each other.
The discriminator Network is basically a Convolution Neural network classifier, that needs to classify the image as real or fake.
The generator receives a random vector at it's inputs and transforms them into images, and it's just a Deconvolutional Neural Network.
When training Generative Adversarial models we have 2 loss functions, one that encourages the generator to create better images, and one that encourages the discriminator to distinguish generated images from real images. During training we optimize both objectives hoping that after some time we reach a Nash equilibrium and both do their job really well.
On this example (MNIST) we're using 2 losses (Binary cross entropy) on the discriminator.
Training 2 networks at the same time is hard, you may probably have some difficulties training Gans. Here what you don't want is the discriminator overpower the generator or vice-versa. For example if the discriminator becomes more powerful than the generator by responds with absolute certainty, it leaves no gradient for the generator to descend.
Another problem called model collapse when generator discovers and exploits some weakness in the discriminator. You can recognize mode collapse in your GAN if it generates many very similar images regardless of variation in the generator input z.
Mode collapse can sometimes be corrected by "strengthening" the discriminator in some way—for instance, by adjusting its training rate or by re-configuring its layers. (Give more parameters to discriminator)
Use batchnorm on both networks.
Use Relu on generator except for last layer (tanh)
Use LeakyRelu on discriminator on all layers
No Fully connected layers to allow deeper architectures.
Replace pooling layers with strided convolutions on discriminator
Replace unpooling layers with fractional-strided convolutions on the generator.