diff --git a/xformers/ops/fmha/_triton/splitk_kernels.py b/xformers/ops/fmha/_triton/splitk_kernels.py index cef4cb1740..6f27ebb4c0 100644 --- a/xformers/ops/fmha/_triton/splitk_kernels.py +++ b/xformers/ops/fmha/_triton/splitk_kernels.py @@ -341,7 +341,13 @@ def _fwd_kernel_splitK( # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 + # + # We declare log2e as a constant with a precisely-specified type to guarantee that + # triton will use the exact same value in all instances below, rather than sometimes + # using float32 and sometimes using float64. For more discussion see: + # https://github.com/triton-lang/triton/issues/5466 + log2e = tl.full((), 1.44269504, tl.float32) + qk_scale = sm_scale * log2e # load q: it will stay in SRAM throughout q: "VAR_ARGS_ARRAY" # noqa: F821 for i in range(len(acc)): # noqa: F821 @@ -468,7 +474,7 @@ def _fwd_kernel_splitK( additive_bias_block_ptr, boundary_check=(0, 1) if BOUNDS_CHECKS_N else (0,), ) - qk += loaded_bias.to(tl.float32) * 1.44269504 + qk += loaded_bias.to(tl.float32) * log2e additive_bias_block_ptr = tl.advance(additive_bias_block_ptr, (0, BLOCK_N)) # TODO: This is slow, and only needed at the last iteration. @@ -548,7 +554,7 @@ def _fwd_kernel_splitK( lse_dtype = LSE_splitk.dtype.element_ty tl.store( LSE_splitk_ptr, - (tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / 1.44269504, + (tl.math.log2(l_i.to(lse_dtype)) + m_i.to(lse_dtype)) / log2e, mask=mask, )