Skip to content

Commit

Permalink
Update README.md for float8
Browse files Browse the repository at this point in the history
  • Loading branch information
vkuzo authored Oct 16, 2024
1 parent 7a35695 commit 0cf1281
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,12 @@ m = nn.Sequential(
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the last module
if fqn == "1":
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
# convert specified `torch.nn.Linear` modules to `Float8Linear`, with compute
# and optionally distributed communications in float8
convert_to_float8_training(m)

# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down Expand Up @@ -94,7 +85,8 @@ config = Float8LinearConfig(
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying custom scaling behavior
convert_to_float8_training(m, config=config)

# enable torch.compile for competitive performance
# enable torch.compile to generate fused kernels for float8 scaling and casting,
# which improves performance
m = torch.compile(m)

# toy training loop
Expand Down

0 comments on commit 0cf1281

Please sign in to comment.