Skip to content

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes

Notifications You must be signed in to change notification settings

kingoflolz/swarm-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Pipelined Swarm Training

Swarm training "framework" using Haiku + Jax + Ray.

Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)

Look in swarm_run.py for an example of running a character transformer on enwik8.

TODOs

  • Forward passes
  • Backward passes with activation reconstruction
  • Run optimizer
  • Logging
  • Checkpointing
  • Actually do pipelining
  • fp16 with static loss scaling
  • Integer quantization for activations and gradients between layers
  • Get rid of pipeline stalls from running optimizer
  • Data parallelism with multiple nodes per layer and gradient/weight aggregation
  • Heterogeneous nodes with potentially multiple layers per node
  • Handle unbalanced and unreliable nodes (layerdrop)
  • Dynamic node addition
  • 1T or bust?

About

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published