Skip to content

Commit

Permalink
Add k_activation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jun 3, 2024
1 parent 33aeaa0 commit f40e2ee
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mamba_ssm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.1"
__version__ = "2.0.2"

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
from mamba_ssm.modules.mamba_simple import Mamba
Expand Down
153 changes: 153 additions & 0 deletions mamba_ssm/ops/triton/k_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.

import torch

import triton
import triton.language as tl


@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
key=['ncols'],
)
@triton.jit
def _swiglu_fwd_kernel(
X,
Y,
OUT,
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_out_row,
ncols,
BLOCK_N: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
start_col = tl.program_id(1) * BLOCK_N
X += row * stride_x_row
Y += row * stride_y_row
OUT += row * stride_out_row
cols = start_col + tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
out = x * tl.sigmoid(x) * y
tl.store(OUT + cols, out, mask=cols < ncols)


def _swiglu_fwd(xy, out=None):
if xy.stride(-1) != 1:
xy = xy.contiguous()
batch_shape = xy.shape[:-1]
xy = xy.reshape(-1, xy.shape[-1])
x, y = xy.chunk(2, dim=-1)
if out is None:
out = torch.empty_like(x)
else:
out = out.reshape(-1, out.shape[-1])
assert out.shape == x.shape
assert out.stride(-1) == 1
M, N = x.shape
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
with torch.cuda.device(x.device.index):
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
return out.reshape(*batch_shape, out.shape[-1])


@triton.autotune(
configs=[
triton.Config({'BLOCK_N': 32}),
triton.Config({'BLOCK_N': 64}),
triton.Config({'BLOCK_N': 128}),
triton.Config({'BLOCK_N': 256}),
triton.Config({'BLOCK_N': 512}),
triton.Config({'BLOCK_N': 1024}),
],
key=['ncols'],
)
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
@triton.jit
def _swiglu_bwd_kernel(
X,
Y,
DOUT,
OUT,
DX,
DY,
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_dout_row,
stride_out_row,
stride_dx_row,
stride_dy_row,
ncols,
BLOCK_N: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
start_col = tl.program_id(1) * BLOCK_N
X += row * stride_x_row
Y += row * stride_y_row
DOUT += row * stride_dout_row
if RECOMPUTE_OUTPUT:
OUT += row * stride_out_row
DX += row * stride_dx_row
DY += row * stride_dy_row
cols = start_col + tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
x_sigmoid = tl.sigmoid(x)
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
dy = x * x_sigmoid * dout
tl.store(DX + cols, dx, mask=cols < ncols)
tl.store(DY + cols, dy, mask=cols < ncols)
if RECOMPUTE_OUTPUT:
out = x * x_sigmoid * y
tl.store(OUT + cols, out, mask=cols < ncols)


def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
if xy.stride(-1) != 1:
xy = xy.contiguous()
if dout.stride(-1) != 1:
dout = dout.contiguous()
batch_shape = xy.shape[:-1]
xy = xy.reshape(-1, xy.shape[-1])
x, y = xy.chunk(2, dim=-1)
dout = dout.reshape(-1, dout.shape[-1])
assert dout.shape == x.shape
if dxy is None:
dxy = torch.empty_like(xy)
else:
dxy = dxy.reshape(-1, dxy.shape[-1])
assert dxy.shape == xy.shape
dx, dy = dxy.chunk(2, dim=-1)
assert dx.stride(-1) == 1
assert dy.stride(-1) == 1
if recompute_output:
if out is None:
out = torch.empty_like(x)
else:
out = out.reshape(-1, out.shape[-1])
assert out.shape == x.shape
assert out.stride(-1) == 1
M, N = x.shape
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
with torch.cuda.device(x.device.index):
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
x.stride(0), y.stride(0), dout.stride(0),
out.stride(0) if recompute_output else 0,
dx.stride(0), dy.stride(0),
N)
if not recompute_output:
return dxy.reshape(*batch_shape, dxy.shape[-1])
else:
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])

0 comments on commit f40e2ee

Please sign in to comment.