From 0cf1281a38bdd8c84b11549ffc49a63eb3a92d40 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Wed, 16 Oct 2024 08:30:17 -0700 Subject: [PATCH] Update README.md for float8 --- torchao/float8/README.md | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index b9b40d7e4..34dee659f 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -34,21 +34,12 @@ m = nn.Sequential( x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16) optimizer = torch.optim.SGD(m.parameters(), lr=0.1) -# optional: filter modules from being eligible for float8 conversion -def module_filter_fn(mod: torch.nn.Module, fqn: str): - # don't convert the last module - if fqn == "1": - return False - # don't convert linear modules with weight dimensions not divisible by 16 - if isinstance(mod, torch.nn.Linear): - if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: - return False - return True - -# convert specified `torch.nn.Linear` modules to `Float8Linear` -convert_to_float8_training(m, module_filter_fn=module_filter_fn) - -# enable torch.compile for competitive performance +# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute +# and optionally distributed communications in float8 +convert_to_float8_training(m) + +# enable torch.compile to generate fused kernels for float8 scaling and casting, +# which improves performance m = torch.compile(m) # toy training loop @@ -94,7 +85,8 @@ config = Float8LinearConfig( # convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior convert_to_float8_training(m, config=config) -# enable torch.compile for competitive performance +# enable torch.compile to generate fused kernels for float8 scaling and casting, +# which improves performance m = torch.compile(m) # toy training loop