Swarm training "framework" using Haiku + Jax + Ray.
Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)
Look in swarm_run.py
for an example of running a character transformer on enwik8.
- Forward passes
- Backward passes with activation reconstruction
- Run optimizer
- Logging
- Checkpointing
- Actually do pipelining
- fp16 with static loss scaling
- Integer quantization for activations and gradients between layers
- Get rid of pipeline stalls from running optimizer
- Data parallelism with multiple nodes per layer and gradient/weight aggregation
- Heterogeneous nodes with potentially multiple layers per node
- Handle unbalanced and unreliable nodes (layerdrop)
- Dynamic node addition
- 1T or bust?