From f40e2ee10eeb8d86e91243f55e4b21c8a3db4652 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Jun 2024 04:41:48 -0700 Subject: [PATCH] Add k_activation.py --- mamba_ssm/__init__.py | 2 +- mamba_ssm/ops/triton/k_activation.py | 153 +++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 mamba_ssm/ops/triton/k_activation.py diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index 0204ce2f..7972f0b4 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.1" +__version__ = "2.0.2" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/ops/triton/k_activation.py b/mamba_ssm/ops/triton/k_activation.py new file mode 100644 index 00000000..9ded3b66 --- /dev/null +++ b/mamba_ssm/ops/triton/k_activation.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 32}), + triton.Config({'BLOCK_N': 64}), + triton.Config({'BLOCK_N': 128}), + triton.Config({'BLOCK_N': 256}), + triton.Config({'BLOCK_N': 512}), + triton.Config({'BLOCK_N': 1024}), + ], + key=['ncols'], +) +@triton.jit +def _swiglu_fwd_kernel( + X, + Y, + OUT, + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_out_row, + ncols, + BLOCK_N: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + start_col = tl.program_id(1) * BLOCK_N + X += row * stride_x_row + Y += row * stride_y_row + OUT += row * stride_out_row + cols = start_col + tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) + out = x * tl.sigmoid(x) * y + tl.store(OUT + cols, out, mask=cols < ncols) + + +def _swiglu_fwd(xy, out=None): + if xy.stride(-1) != 1: + xy = xy.contiguous() + batch_shape = xy.shape[:-1] + xy = xy.reshape(-1, xy.shape[-1]) + x, y = xy.chunk(2, dim=-1) + if out is None: + out = torch.empty_like(x) + else: + out = out.reshape(-1, out.shape[-1]) + assert out.shape == x.shape + assert out.stride(-1) == 1 + M, N = x.shape + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) + with torch.cuda.device(x.device.index): + _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N) + return out.reshape(*batch_shape, out.shape[-1]) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_N': 32}), + triton.Config({'BLOCK_N': 64}), + triton.Config({'BLOCK_N': 128}), + triton.Config({'BLOCK_N': 256}), + triton.Config({'BLOCK_N': 512}), + triton.Config({'BLOCK_N': 1024}), + ], + key=['ncols'], +) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None}) +@triton.jit +def _swiglu_bwd_kernel( + X, + Y, + DOUT, + OUT, + DX, + DY, + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dout_row, + stride_out_row, + stride_dx_row, + stride_dy_row, + ncols, + BLOCK_N: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + start_col = tl.program_id(1) * BLOCK_N + X += row * stride_x_row + Y += row * stride_y_row + DOUT += row * stride_dout_row + if RECOMPUTE_OUTPUT: + OUT += row * stride_out_row + DX += row * stride_dx_row + DY += row * stride_dy_row + cols = start_col + tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32) + y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32) + dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32) + x_sigmoid = tl.sigmoid(x) + dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout + dy = x * x_sigmoid * dout + tl.store(DX + cols, dx, mask=cols < ncols) + tl.store(DY + cols, dy, mask=cols < ncols) + if RECOMPUTE_OUTPUT: + out = x * x_sigmoid * y + tl.store(OUT + cols, out, mask=cols < ncols) + + +def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None): + if xy.stride(-1) != 1: + xy = xy.contiguous() + if dout.stride(-1) != 1: + dout = dout.contiguous() + batch_shape = xy.shape[:-1] + xy = xy.reshape(-1, xy.shape[-1]) + x, y = xy.chunk(2, dim=-1) + dout = dout.reshape(-1, dout.shape[-1]) + assert dout.shape == x.shape + if dxy is None: + dxy = torch.empty_like(xy) + else: + dxy = dxy.reshape(-1, dxy.shape[-1]) + assert dxy.shape == xy.shape + dx, dy = dxy.chunk(2, dim=-1) + assert dx.stride(-1) == 1 + assert dy.stride(-1) == 1 + if recompute_output: + if out is None: + out = torch.empty_like(x) + else: + out = out.reshape(-1, out.shape[-1]) + assert out.shape == x.shape + assert out.stride(-1) == 1 + M, N = x.shape + grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N'])) + with torch.cuda.device(x.device.index): + _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy, + x.stride(0), y.stride(0), dout.stride(0), + out.stride(0) if recompute_output else 0, + dx.stride(0), dy.stride(0), + N) + if not recompute_output: + return dxy.reshape(*batch_shape, dxy.shape[-1]) + else: + return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])