This repository contains a JAX implementation of the methods described in the paper Gradients without Backpropagation
Sometimes, all we want is to get rid of backpropagation of errors and estimate unbiased gradient of loss function during single inference pass :)
The code demonstrates how to train a simple MLP on MNIST, using either forward gradients (described as --num_layers
parameters.
Note: It seems like this doesn't efficiently scale beyond 10 layers because variance of the gradient estimation depends on number of parameters of the network.
- JAX <3
- optax (for learning rate scheduling)
- wandb (optional, for logging)
To run the code and replicate MLP training with forward gradients on MNIST, simply execute the train.py
:
python train.py