This is the result of project "Reproduce Neural ODE and SDE" in HuggingFace Flax/JAX community week.
main.py
will execute training of ResNet or OdeNet for MNIST dataset.
For JAX installation, please follow here.
or simply, type
pip install jax jaxlib
For Flax installation,
pip install flax
Tensorflow-datasets will download MNIST dataset to environment.
For (small) ResNet training,
python main.py --model=resnet --lr=1e-4 --n_epoch=20 --batch_size=64
For Neural ODE training,
python main.py --model=odenet --lr=1e-4 --n_epoch=20 --batch_size=64
For Continuous Normalizing Flow,
python main.py --model=cnf --sample_dataset=circles
Sample datasets can be chosen as circles, moons, or scurve.