Skip to content

Commit

Permalink
Fix swiglu with newer torch compile (fairinternal/xformers#1272)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@2e1a714
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 20, 2024
1 parent ca7bc31 commit 4b035ad
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 41 deletions.
5 changes: 1 addition & 4 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,11 @@ def test_gemm_fused_operand_sum_compile(dtype, device) -> None:
[shape[0], shape[2]], device=device, dtype=dtype, requires_grad=False
)
dy = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=False)
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)

GemmFusedSumOp = xformers.ops.common.get_xformers_operator("gemm_fused_operand_sum")

def fn(x):
GemmFusedSumOp(dy.transpose(-2, -1), x, dw, db)
return [dw, db]
return GemmFusedSumOp(dy.transpose(-2, -1), x)

# Eager
output = fn(x)
Expand Down
33 changes: 11 additions & 22 deletions xformers/csrc/swiglu/cuda/gemm_fused_operand_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ void gemm_fused_operand_sum_(
using ElementOutput = scalar_t;

using LayoutInputA = cutlass::layout::ColumnMajor;
TORCH_CHECK(a.stride(0) == 1);
using LayoutInputB = cutlass::layout::RowMajor;
TORCH_CHECK(b.stride(1) == 1);
using LayoutOutput = cutlass::layout::RowMajor;
TORCH_CHECK(a.stride(0) == 1);
TORCH_CHECK(b.stride(1) == 1);
TORCH_CHECK(out_mm.stride(1) == 1);

// Layout of the output vector
Expand Down Expand Up @@ -205,50 +205,39 @@ void gemm_fused_operand_sum_(
template <bool kIsMeta = false>
std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum(
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& out_mm,
at::Tensor& out_sum) {
const at::Tensor& b) {
// TODO: Check all params. This would take a lot of lines of code...
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(b.dim() == 2);
TORCH_CHECK(out_mm.dim() == 2);
TORCH_CHECK(out_mm.sym_size(0) == a.sym_size(0));
TORCH_CHECK(out_mm.sym_size(1) == b.sym_size(1));
TORCH_CHECK(out_sum.dim() == 1);
TORCH_CHECK(a.stride(0) == 1);
TORCH_CHECK(b.stride(1) == 1);

#define FWD_PARAMS a, b, out_mm, out_sum
auto out_sum = at::empty({a.size(0)}, a.options());
auto out_mm = at::empty({a.size(0), b.size(1)}, a.options());

if (!kIsMeta) {
if (a.scalar_type() == at::ScalarType::Half) {
TORCH_CHECK(b.scalar_type() == at::ScalarType::Half);
TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::Half);
TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::Half);
gemm_fused_operand_sum_<cutlass::half_t>(FWD_PARAMS);
gemm_fused_operand_sum_<cutlass::half_t>(a, b, out_mm, out_sum);
} else {
TORCH_CHECK(
a.scalar_type() == at::ScalarType::BFloat16,
"Only supports bf16/f16");
TORCH_CHECK(b.scalar_type() == at::ScalarType::BFloat16);
TORCH_CHECK(out_mm.scalar_type() == at::ScalarType::BFloat16);
TORCH_CHECK(out_sum.scalar_type() == at::ScalarType::BFloat16);
gemm_fused_operand_sum_<cutlass::bfloat16_t>(FWD_PARAMS);
gemm_fused_operand_sum_<cutlass::bfloat16_t>(a, b, out_mm, out_sum);
}
}
return std::make_tuple(out_mm, out_sum);
}

std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum_autocast(
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& out_mm,
at::Tensor& out_sum) {
const at::Tensor& b) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
auto exec_type = at::autocast::get_autocast_dtype(at::kCUDA);
return gemm_fused_operand_sum(
at::autocast::cached_cast(exec_type, a),
at::autocast::cached_cast(exec_type, b),
out_mm,
out_sum);
at::autocast::cached_cast(exec_type, b));
}
} // namespace

Expand Down
6 changes: 4 additions & 2 deletions xformers/csrc/swiglu/swiglu_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) {
"xformers::dual_gemm_silu_identity_mul(Tensor x, Tensor w1, Tensor? b1, Tensor w2, Tensor? b2) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::silu_bw_fused(Tensor x1, Tensor x2, Tensor dx4) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::gemm_fused_operand_sum(Tensor a, Tensor b, Tensor out_mm, Tensor out_sum) -> (Tensor, Tensor)"));
m.def(
TORCH_SELECTIVE_SCHEMA(
"xformers::gemm_fused_operand_sum(Tensor a, Tensor b) -> (Tensor, Tensor)"),
{at::Tag::needs_fixed_stride_order});
}

TORCH_LIBRARY_IMPL(xformers, Meta, m) {
Expand Down
15 changes: 5 additions & 10 deletions xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,12 @@ std::tuple<at::Tensor, at::Tensor> silu_bw_fused(
}
std::tuple<at::Tensor, at::Tensor> gemm_fused_operand_sum(
const at::Tensor& a,
const at::Tensor& b,
at::Tensor& out_mm,
at::Tensor& out_sum) {
const at::Tensor& b) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("xformers::gemm_fused_operand_sum", "")
.typed<decltype(gemm_fused_operand_sum)>();
return op.call(a, b, out_mm, out_sum);
return op.call(a, b);
}

bool shapesMatch(at::Tensor x, std::vector<int64_t> expectedShape) {
Expand Down Expand Up @@ -147,10 +145,8 @@ class SwiGLUPackedWeights
at::Tensor db3, dw3;

if (has_b3) {
db3 = torch::empty({O}, w3.options());
dw3 = torch::empty({O, H}, w3.options());
TORCH_INTERNAL_ASSERT(dx5.size(0) == x4.size(0));
gemm_fused_operand_sum(dx5.transpose(-2, -1), x4, dw3, db3);
std::tie(dw3, db3) = gemm_fused_operand_sum(dx5.transpose(-2, -1), x4);
} else {
dw3 = torch::mm(dx5.transpose(-2, -1), x4);
}
Expand All @@ -167,9 +163,8 @@ class SwiGLUPackedWeights
// backward of linear1 + linear2 - packed
at::Tensor dw1dw2, db1db2;
if (has_b1b2) {
dw1dw2 = torch::empty({2 * H, I}, w1w2.options());
db1db2 = torch::empty({2 * H}, w1w2.options());
gemm_fused_operand_sum(dx1dx2.transpose(-2, -1), x, dw1dw2, db1db2);
std::tie(dw1dw2, db1db2) =
gemm_fused_operand_sum(dx1dx2.transpose(-2, -1), x);
db1db2 = db1db2.view({2, H});
} else {
dw1dw2 = torch::mm(dx1dx2.transpose(-2, -1), x);
Expand Down
4 changes: 1 addition & 3 deletions xformers/ops/swiglu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ def _linear_bw(
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if not bias:
return (dy.transpose(-2, -1) @ x), None
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
dw, db = GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x)
return dw, db

@classmethod
Expand Down

0 comments on commit 4b035ad

Please sign in to comment.