Skip to content

Commit

Permalink
Use a constant with clearly-defined type for log2e in fwd_kernel_spli…
Browse files Browse the repository at this point in the history
…tK (#1181)

Summary:
Triton 3.2 made some changes to its interpretation of constants
(triton-lang/triton#4613) which makes Triton more
consistent with pytorch/numpy, but cause some surprising issues with this
kernel.  Specifically it seems like log2e is interpreted as float32 in one
instance and float64 in another, which leads to reduced prediction accuracy in
some cases.

To prevent this, let's make log2e a constant and define it as float32.
  • Loading branch information
bertmaher authored Dec 23, 2024
1 parent 9a59df2 commit a2f37f8
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions xformers/ops/fmha/_triton/splitk_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,13 @@ def _fwd_kernel_splitK(
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
#
# We declare log2e as a constant with a precisely-specified type to guarantee that
# triton will use the exact same value in all instances below, rather than sometimes
# using float32 and sometimes using float64. For more discussion see:
# https://github.com/triton-lang/triton/issues/5466
log2e = tl.full((), 1.44269504, tl.float32)
qk_scale = sm_scale * log2e
# load q: it will stay in SRAM throughout
q: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
Expand Down Expand Up @@ -468,7 +474,7 @@ def _fwd_kernel_splitK(
additive_bias_block_ptr,
boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,),
)
qk += loaded_bias.to(tl.float32) * 1.44269504
qk += loaded_bias.to(tl.float32) * log2e
additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N))

# TODO: This is slow, and only needed at the last iteration.
Expand Down Expand Up @@ -548,7 +554,7 @@ def _fwd_kernel_splitK(
lse_dtype = LSE_splitk.dtype.element_ty
tl.store(
LSE_splitk_ptr,
(tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / 1.44269504,
(tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / log2e,
mask=mask,
)

Expand Down

0 comments on commit a2f37f8

Please sign in to comment.