Skip to content

Commit

Permalink
Better Shape Function Registration
Browse files Browse the repository at this point in the history
Summary:
`torch.compile` requires that custom operators have shape functions registered, as they are needed for tracing.

In FBGEMM, we have inconsistently registered such shape functions. This diff attempts to clean up and register a good number of commonly used operators.

Notably, pytorch allows two methods of registering shape functions for custom ops. In CPP, you can use a Meta function, or in python  you can use `register_fake`. It turns out `register_fake` is the recommended and more powerful approach. For example, it is needed for ops that cross devices (such as the car ops) and is needed when exporting a traced graph.

This diff thus focuses on the register_fake method and converts a handful of Meta registrations to it. My hope is that this can provide an easily extensible way of registering shape functions for other kernel authors.

Differential Revision: D64147797
  • Loading branch information
jwfromm authored and facebook-github-bot committed Oct 10, 2024
1 parent 700a7a6 commit d0c968e
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 208 deletions.
4 changes: 4 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:comm_ops"
)
from fbgemm_gpu.experimental.gen_ai import comm_ops # noqa: F401

torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:gemm_ops"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:quantize_ops"
)
from fbgemm_gpu.experimental.gen_ai import quantize_ops # noqa: F401

torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai:kv_cache_ops"
)
79 changes: 79 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/comm_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Optional

import torch

"""
This file contains manual shape registrations for communication custom operators.
These are needed for custom operators to be compatible with torch.compile.
In some cases, fake tensor handling can be done by registering a meta implementation
directly in cpp. However, for more complicated functions such as those that involve
cross device synchronization, pytorch requires a full fake implementation be registered
in python.
"""


@torch.library.register_fake("fbgemm::nccl_allreduce")
def nccl_allreduce_abstract(
dst: torch.Tensor,
src: torch.Tensor,
bias: Optional[torch.Tensor] = None,
comm_idx: int = 0,
) -> None:
return None


@torch.library.register_fake("fbgemm::nccl_allgather")
def nccl_allgather_abstract(
dst: torch.Tensor,
src: torch.Tensor,
comm_idx: int = 0,
) -> None:
return None


@torch.library.register_fake("fbgemm::nccl_alltoall")
def nccl_alltoall(
dst: torch.Tensor,
src: torch.Tensor,
world_size: int,
comm_idx: int = 0,
) -> None:
return None


@torch.library.register_fake("fbgemm::nccl_reducescatter")
def nccl_reducescatter(
dst: torch.Tensor,
src: torch.Tensor,
comm_idx: int = 0,
) -> None:
return None


@torch.library.register_fake("fbgemm::one_shot_car_allreduce")
def one_shot_car_allreduce_abstract(
dst: torch.Tensor,
src: torch.Tensor,
bias: Optional[torch.Tensor] = None,
comm_idx: int = 0,
) -> None:
return None


@torch.library.register_fake("fbgemm::two_shot_car_allreduce")
def two_shot_car_allreduce_abstract(
dst: torch.Tensor,
src: torch.Tensor,
bias: Optional[torch.Tensor] = None,
comm_idx: int = 0,
) -> None:
return None
228 changes: 228 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/gen_ai/quantize_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Optional, Tuple

import torch

"""
This file contains manual shape registrations for quantize custom operators.
These are needed for custom operators to be compatible with torch.compile.
In some cases, fake tensor handling can be done by registering a meta implementation
directly in cpp. However, for more complicated functions such as those that involve
cross device synchronization, pytorch requires a full fake implementation be registered
in python.
"""


@torch.library.register_fake("fbgemm::f8f8bf16_blockwise")
def f8f8bf16_blockwise_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
block_m: int = 128,
block_n: int = 128,
block_k: int = 128,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)


@torch.library.register_fake("fbgemm::f8f8bf16_tensorwise")
def f8f8bf16_tensorwise_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
scale: float,
use_fast_accum: bool = True,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)


@torch.library.register_fake("fbgemm::f8f8bf16_rowwise")
def f8f8bf16_rowwise_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)


@torch.library.register_fake("fbgemm::quantize_fp8_per_tensor")
def quantize_fp8_per_tensor_abstract(
input: torch.Tensor,
bs: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
stochastic_rounding: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.version.hip:
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
return output, scale


@torch.library.register_fake("fbgemm::quantize_fp8_per_row")
def quantize_fp8_per_row_abstract(
input: torch.Tensor,
bs: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
output_dtype: Optional[torch.dtype] = None,
stochastic_rounding: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.version.hip:
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
return output, scale


@torch.library.register_fake("fbgemm::quantize_fp8_per_col")
def quantize_fp8_per_col_abstract(
input: torch.Tensor,
bs: Optional[torch.Tensor] = None,
scale_ub: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if torch.version.hip:
fp8_dtype = torch.float8_e4m3fnuz
else:
fp8_dtype = torch.float8_e4m3fn
output = torch.empty_like(input, dtype=fp8_dtype)
scale = torch.empty([], dtype=torch.bfloat16)
return output, scale


# The following operators are not supported on AMD.
if not torch.version.hip:

@torch.library.register_fake("fbgemm::i8i8bf16")
def i8i8bf16_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
scale: float,
split_k: int = 1,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::f8f8bf16")
def f8f8bf16_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
scale: torch.Tensor,
use_fast_accum: bool = True,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::f8f8bf16_cublas")
def f8f8bf16_cublas_abstract(
A: torch.Tensor,
B: torch.Tensor,
Ainvs: Optional[torch.Tensor] = None,
Binvs: Optional[torch.Tensor] = None,
use_fast_accum: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = A.shape[0]
N = B.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::f8f8bf16_rowwise_batched")
def f8f8bf16_rowwise_batched_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
use_fast_accum: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::f8i4bf16_rowwise")
def f8i4bf16_rowwise_abstract(
XQ: torch.Tensor,
WQ: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
w_zp: torch.Tensor,
) -> torch.Tensor:
M = XQ.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise")
def bf16i4bf16_rowwise_abstract(
X: torch.Tensor,
WQ: torch.Tensor,
w_scale: torch.Tensor,
w_zp: torch.Tensor,
) -> torch.Tensor:
M = X.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)

@torch.library.register_fake("fbgemm::bf16i4bf16_rowwise_batched")
def bf16i4bf16_rowwise_batched_abstract(
X: torch.Tensor,
WQ: torch.Tensor,
w_scale: torch.Tensor,
w_zp: torch.Tensor,
) -> torch.Tensor:
M = X.shape[0]
N = WQ.shape[0]
return torch.empty(
[M, N],
dtype=torch.bfloat16,
)
28 changes: 16 additions & 12 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ void two_shot_car_allreduce(
at::Tensor car_tensor();

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.set_python_module("fbgemm_gpu.experimental.gen_ai.comm_ops");

m.def(
"nccl_init(int rank, int world_size, str rendevouz, int comm_idx=0) -> ()");
m.impl("nccl_init", nccl_init);
Expand All @@ -240,20 +242,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"nccl_comm_init_rank(int world_size, int rank, Tensor id_, int comm_idx=0) -> ()");
m.impl("nccl_comm_init_rank", nccl_comm_init_rank);

m.def("nccl_allgather(Tensor dst, Tensor src, int comm_idx=0) -> ()");
m.impl("nccl_allgather", nccl_allgather);
m.def("nccl_allgather(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");

m.def(
"nccl_alltoall(Tensor dst, Tensor src, int world_size, int comm_idx=0) -> ()");
m.impl("nccl_alltoall", nccl_alltoall);
"nccl_alltoall(Tensor(a!) dst, Tensor src, int world_size, int comm_idx=0) -> ()");

m.def("nccl_reducescatter(Tensor dst, Tensor src, int comm_idx=0) -> ()");
m.impl("nccl_reducescatter", nccl_reducescatter);
m.def("nccl_reducescatter(Tensor(a!) dst, Tensor src, int comm_idx=0) -> ()");

m.def(
"nccl_allreduce(Tensor dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
m.impl("nccl_allreduce", nccl_allreduce);

"nccl_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
// car: customized all reduce
m.def("car_tensor() -> Tensor");
m.impl("car_tensor", car_tensor);
Expand All @@ -266,11 +263,18 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl("car_init", car_init);

m.def(
"one_shot_car_allreduce(Tensor dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
"one_shot_car_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");

m.def(
"two_shot_car_allreduce(Tensor dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
"two_shot_car_allreduce(Tensor(a!) dst, Tensor src, Tensor? bias=None, int comm_idx=0) -> ()");
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("nccl_allreduce", nccl_allreduce);
m.impl("nccl_allgather", nccl_allgather);
m.impl("nccl_alltoall", nccl_alltoall);
m.impl("nccl_reducescatter", nccl_reducescatter);
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
}

Expand Down
Loading

0 comments on commit d0c968e

Please sign in to comment.