- Include more architectures for the critic network
- More comments in functions
- Testing code when X and Z are images
- Include units test
Throughout this repo, I offer a ready-to-use implementation of state-of-the-art variational methods for mutual information estimation in Tensorflow. This includes:
- MINE estimator [1], based on the Donsker-Varadhan representation of the KL divergence
- NWJ estimator [2], based on the f-divergence variational representation of the KL divergence
- NCE estimator [3], based on noise contrastive estimation principle
To define an estimator, one needs to provide a few arguments.
from estimator import Mi_estimator
my_mi_estimator = Mi_estimator(regu_name = ...,
dim_x = ...,
dim_z = ...,
batch_size= ...,
critic_layers=[256, 256, 256],
critic_lr=1e-4,
critic_activation='relu',
critic_type='joint',
ema_decay=0.99,
negative_samples=1)
import numpy as np
dim_x = 10
dim_z = 15
x_data = np.random.rand(data_size, dim_x)
z_data = np.random.rand(data_size, dim_z)
my_mi_estimator = Mi_estimator(regu_name = 'mine',
dim_x = [dim_x],
dim_z = [dim_z],
batch_size= 128,
critic_layers=[256, 256, 256],
critic_lr=1e-4,
critic_activation='relu',
critic_type='joint',
ema_decay=0.99,
negative_samples=1)
mi_estimate = my_mi_estimator.fit(x_data, z_data)
import tensorflow as tf
dim = 10
x = tf.placeholder(shape=[batch_size, dim], tf.float32)
z = tf.placeholder(shape=[batch_size, dim], tf.float32)
my_mi_estimator = Mi_estimator(regu_name = 'mine',
dim_x = [dim],
dim_z = [dim],
batch_size= 128,
critic_layers=[256, 256, 256],
critic_lr=1e-4,
critic_activation='relu',
critic_type='joint',
ema_decay=0.99,
negative_samples=1)
estimator_train_op, estimator_quantities = my_mi_estimator(x, z)
# Define regularized loss
loss = loss_unregularized + beta * estimator_quantities['mi_for_grads']
# Define main training op
main_train_op = optimizer.minimize(loss, var_list=...)
...
# At every training iteration:
# Perfom main training operation to optimize loss
sess.run([main_train_op], feed_dict={x: ..., z= ..., ...})
# Then update estimator
sess.run(estimator_train_op, feed_dict={x: ..., z= ...})
We provide code to showcase the two functionalities we just talked about
In the case of correlated Gaussian random variables:
We have access to the mutual information:
First go the example directory
cd examples/correlated_gaussians/
Then, run some tests to make sure the code doesn't yield any bug:
python3 run_test.py
Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py When that is done, simply run:
python3 run_exp.py
After training, plots are avaible in /plots/. You should obtain something like:
The plots above present the estimated MI after 10 epochs of training on a 100k samples dataset with batch size 128, for the three estimation methods.
The most interesting use case of these bounds is in the context of mutual information maximization. A typical example is reduction of mode collapse in GANs. In the context of GANs, the mutual information I(Z;X) is used as a proxy for the entropy of the generator H(X), where X represents the output of the generator, and Z the noise vector. The maximization of I(X;Z) results in the maximization of H(X).
loss_regularized = loss_gan - beta * I(X;Z)
A naive differentiation of the previously defined loss_regularized term will probably fail to yield the expected result. A very probable reason for that is that I(X;Z) will most likely have gradients whose scale are much larger than those from the loss_gan. In order to maintain a certain balance in gradients' scale, on must use adaptive gradient clipping proposed in [1]. This simple trick consists in scaling the MI term such that its gradient norm doesn't exceed the one from the original loss. Concretely, one must instead minimize:
scale = min( ||G(loss)||, ||G(I(X;Z))|| ) / ||G(I(X;Z))||
loss_regularized = loss_gan - beta * scale * I(X;Z)
where G(.) defines the gradient operator with respect to specified parameters.
We provide a simple example in 2D referred to as "25 gaussians experiments" where the target distribution is:
The simple GAN will produce, with the provided generator and discriminator architecture distributions like:
While the above plot clearly exposes a mode collapse, one can easily reduce this mode collapse by adding a MI regularization:
To see the code for this example, first go the example directory
cd examples/gan/
Then, run some tests to make sure the code doesn't yield any bug:
python3 run_test.py
Finally, to run the experiments, you can check all the available options in "demo_gaussian.py", and loop over any parameters by modifying the header of run_exp.py. Then simply run:
python3 run_exp.py
- Malik Boudiaf
[1] Ishmael Belghazi, Sai Rajeswar, Aristide Baratin, R. Devon Hjelm, and Aaron C.Courville. MINE: mutual information neural estimation.CoRR, https://arxiv.org/abs/1801.04062
[2] Sebastian Nowozin, Botond Cseke, Ryota Tomioka, f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization, https://arxiv.org/abs/1606.00709
[3] Aaron van den Oord, Yazhe Li, Oriol Vinyals, Representation Learning with Contrastive Predictive Coding, https://arxiv.org/abs/1807.03748