Skip to content

Commit

Permalink
flash-attention 2 fwd pass: only scale output at end of loop (faceboo…
Browse files Browse the repository at this point in the history
…kresearch#1142)

* flash-attention 2 fwd pass: only scale output at end of loop

* lint

* restore the third_party change (i think)
  • Loading branch information
russellhowes authored Jun 26, 2024
1 parent 165642c commit 313a944
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,18 @@ def benchmark_backward(
del attn_bias, out


# Similar to CASES, but no causal parameter
CSR_CASES = list(
product_dict(
shape=SHAPES,
num_threads=NUM_THREADS,
sparsity=SPARSITIES,
expanded_indices=[False],
block_size=BLOCK_SIZE,
)
)


def benchmark_csr_transpose(
shape, num_threads: int, sparsity: float, expanded_indices: bool, block_size: int
):
Expand Down Expand Up @@ -272,4 +284,4 @@ def benchmark_csr_transpose(

benchmark_main_helper(benchmark_forward, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_backward, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_csr_transpose, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_csr_transpose, CSR_CASES, min_run_time=min_run_time)
45 changes: 18 additions & 27 deletions xformers/ops/triton_fairinternal/block_sparse_mem_eff_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _fwd_kernel(
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
if EVEN_M:
Expand Down Expand Up @@ -114,56 +114,44 @@ def _fwd_kernel(
k = tl.load(
K + off_k, mask=offs_n[None, :] < kv_len - column_index, other=0.0
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.dot(q, k) * sm_scale
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk = tl.where((column_index + offs_n)[None, :] < kv_len, qk, float("-inf"))
if IS_CAUSAL:
qk = tl.where(
offs_m[:, None] >= (column_index + offs_n[None, :]), qk, float("-inf")
)
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
if IS_CAUSAL:
tt = tl.exp(m_prev - m_curr)
l_prev *= tl.where(finitef(tt), tt, 0)
else:
l_prev *= tl.exp(m_prev - m_curr)
m_ij = tl.maximum(tl.max(qk, 1), m_i)
# attention weights
p = tl.exp(qk - m_curr[:, None])
p = tl.exp(qk - m_ij[:, None])
if IS_CAUSAL:
p = tl.where(finitef(p), p, 0)
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_ij = tl.sum(p, 1)
alpha = tl.exp(m_i - m_ij)
if IS_CAUSAL:
l_rcp = tl.where(l_curr > 0, 1.0 / l_curr, l_curr)
else:
l_rcp = 1.0 / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
alpha = tl.where(finitef(alpha), alpha, 0)
l_i = l_i * alpha + l_ij
# update p, acc, m_i
p = p.to(q.dtype)
acc *= alpha[:, None]
if EVEN_N:
v = tl.load(V + off_v)
else:
v = tl.load(
V + off_v, mask=offs_n[:, None] < kv_len - column_index, other=0.0
)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
acc = tl.dot(p, v, acc)
m_i = m_ij
# update pointers
CI += 1
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
# write back lse
q_len_rounded = ((q_len + BLOCK_M - 1) // BLOCK_M) * BLOCK_M
lse_ptrs = LSE + off_hz * q_len_rounded + offs_m
tl.store(lse_ptrs, m_prev + tl.log(l_prev))
tl.store(lse_ptrs, m_i + tl.log(l_i))
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = (
Expand All @@ -173,6 +161,9 @@ def _fwd_kernel(
+ offs_n[None, :] * stride_on
)
out_ptrs = Out + off_o
# scale acc before outputting
l_rcp = tl.where(l_i > 0, 1.0 / l_i, l_i)
acc = acc * l_rcp[:, None]
if EVEN_M:
tl.store(out_ptrs, acc)
else:
Expand Down

0 comments on commit 313a944

Please sign in to comment.