-
Notifications
You must be signed in to change notification settings - Fork 489
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: `torch.compile` requires that custom operators have shape functions registered, as they are needed for tracing. In FBGEMM, we have inconsistently registered such shape functions. This diff attempts to clean up and register a good number of commonly used operators. Notably, pytorch allows two methods of registering shape functions for custom ops. In CPP, you can use a Meta function, or in python you can use `register_fake`. It turns out `register_fake` is the recommended and more powerful approach. For example, it is needed for ops that cross devices (such as the car ops) and is needed when exporting a traced graph. This diff thus focuses on the register_fake method and converts a handful of Meta registrations to it. My hope is that this can provide an easily extensible way of registering shape functions for other kernel authors. Differential Revision: D64147797
- Loading branch information
1 parent
700a7a6
commit d0c968e
Showing
7 changed files
with
417 additions
and
208 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
|
||
""" | ||
This file contains manual shape registrations for communication custom operators. | ||
These are needed for custom operators to be compatible with torch.compile. | ||
In some cases, fake tensor handling can be done by registering a meta implementation | ||
directly in cpp. However, for more complicated functions such as those that involve | ||
cross device synchronization, pytorch requires a full fake implementation be registered | ||
in python. | ||
""" | ||
|
||
|
||
@torch.library.register_fake("fbgemm::nccl_allreduce") | ||
def nccl_allreduce_abstract( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None | ||
|
||
|
||
@torch.library.register_fake("fbgemm::nccl_allgather") | ||
def nccl_allgather_abstract( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None | ||
|
||
|
||
@torch.library.register_fake("fbgemm::nccl_alltoall") | ||
def nccl_alltoall( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
world_size: int, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None | ||
|
||
|
||
@torch.library.register_fake("fbgemm::nccl_reducescatter") | ||
def nccl_reducescatter( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None | ||
|
||
|
||
@torch.library.register_fake("fbgemm::one_shot_car_allreduce") | ||
def one_shot_car_allreduce_abstract( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None | ||
|
||
|
||
@torch.library.register_fake("fbgemm::two_shot_car_allreduce") | ||
def two_shot_car_allreduce_abstract( | ||
dst: torch.Tensor, | ||
src: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
comm_idx: int = 0, | ||
) -> None: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
""" | ||
This file contains manual shape registrations for quantize custom operators. | ||
These are needed for custom operators to be compatible with torch.compile. | ||
In some cases, fake tensor handling can be done by registering a meta implementation | ||
directly in cpp. However, for more complicated functions such as those that involve | ||
cross device synchronization, pytorch requires a full fake implementation be registered | ||
in python. | ||
""" | ||
|
||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16_blockwise") | ||
def f8f8bf16_blockwise_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
x_scale: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
block_m: int = 128, | ||
block_n: int = 128, | ||
block_k: int = 128, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16_tensorwise") | ||
def f8f8bf16_tensorwise_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
scale: float, | ||
use_fast_accum: bool = True, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16_rowwise") | ||
def f8f8bf16_rowwise_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
x_scale: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
use_fast_accum: bool = True, | ||
output: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
|
||
@torch.library.register_fake("fbgemm::quantize_fp8_per_tensor") | ||
def quantize_fp8_per_tensor_abstract( | ||
input: torch.Tensor, | ||
bs: Optional[torch.Tensor] = None, | ||
scale_ub: Optional[torch.Tensor] = None, | ||
stochastic_rounding: bool = False, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if torch.version.hip: | ||
fp8_dtype = torch.float8_e4m3fnuz | ||
else: | ||
fp8_dtype = torch.float8_e4m3fn | ||
output = torch.empty_like(input, dtype=fp8_dtype) | ||
scale = torch.empty([], dtype=torch.bfloat16) | ||
return output, scale | ||
|
||
|
||
@torch.library.register_fake("fbgemm::quantize_fp8_per_row") | ||
def quantize_fp8_per_row_abstract( | ||
input: torch.Tensor, | ||
bs: Optional[torch.Tensor] = None, | ||
scale_ub: Optional[torch.Tensor] = None, | ||
output_dtype: Optional[torch.dtype] = None, | ||
stochastic_rounding: bool = False, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if torch.version.hip: | ||
fp8_dtype = torch.float8_e4m3fnuz | ||
else: | ||
fp8_dtype = torch.float8_e4m3fn | ||
output = torch.empty_like(input, dtype=fp8_dtype) | ||
scale = torch.empty([], dtype=torch.bfloat16) | ||
return output, scale | ||
|
||
|
||
@torch.library.register_fake("fbgemm::quantize_fp8_per_col") | ||
def quantize_fp8_per_col_abstract( | ||
input: torch.Tensor, | ||
bs: Optional[torch.Tensor] = None, | ||
scale_ub: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if torch.version.hip: | ||
fp8_dtype = torch.float8_e4m3fnuz | ||
else: | ||
fp8_dtype = torch.float8_e4m3fn | ||
output = torch.empty_like(input, dtype=fp8_dtype) | ||
scale = torch.empty([], dtype=torch.bfloat16) | ||
return output, scale | ||
|
||
|
||
# The following operators are not supported on AMD. | ||
if not torch.version.hip: | ||
|
||
@torch.library.register_fake("fbgemm::i8i8bf16") | ||
def i8i8bf16_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
scale: float, | ||
split_k: int = 1, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16") | ||
def f8f8bf16_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
scale: torch.Tensor, | ||
use_fast_accum: bool = True, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16_cublas") | ||
def f8f8bf16_cublas_abstract( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
Ainvs: Optional[torch.Tensor] = None, | ||
Binvs: Optional[torch.Tensor] = None, | ||
use_fast_accum: bool = True, | ||
output: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
M = A.shape[0] | ||
N = B.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::f8f8bf16_rowwise_batched") | ||
def f8f8bf16_rowwise_batched_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
x_scale: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
use_fast_accum: bool = True, | ||
output: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::f8i4bf16_rowwise") | ||
def f8i4bf16_rowwise_abstract( | ||
XQ: torch.Tensor, | ||
WQ: torch.Tensor, | ||
x_scale: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
w_zp: torch.Tensor, | ||
) -> torch.Tensor: | ||
M = XQ.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise") | ||
def bf16i4bf16_rowwise_abstract( | ||
X: torch.Tensor, | ||
WQ: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
w_zp: torch.Tensor, | ||
) -> torch.Tensor: | ||
M = X.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise_batched") | ||
def bf16i4bf16_rowwise_batched_abstract( | ||
X: torch.Tensor, | ||
WQ: torch.Tensor, | ||
w_scale: torch.Tensor, | ||
w_zp: torch.Tensor, | ||
) -> torch.Tensor: | ||
M = X.shape[0] | ||
N = WQ.shape[0] | ||
return torch.empty( | ||
[M, N], | ||
dtype=torch.bfloat16, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.