You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
more specifically, interested in combining BlockDiagonalMask, with a tensor bias.
I hacked something together by creating an BlockDiagonalMaskWithTensorBias, but got RuntimeError: CUDA error: an illegal memory access was encountered with large per block kv len.
what interesting is, if I use small per block kv len (768, 512 or 8). the kernel run success, without cuda error. but when use 1024, or a very large value e.g. 13824, I will get cuda illegal memory access error.
========= Invalid __global__ read of size 16 bytes
========= at 0x2405c0 in /proc/self/cwd/external/cutlass/include/cutlass/arch/memory_sm80.h:369:cp_async_zfill
========= by thread (18,0,0) in block (0,7,0)
========= Address 0x7fe58f446de0 is out of bounds
========= and is 32 bytes before the nearest allocation at 0x7fe58f446e00 of size 37888 bytes
========= Device Frame:/proc/self/cwd/xxx/xformers/csrc/attention/cuda/fmha/gemm/mma_from_smem.h:1000:_prologue [0x2405d0]
========= Device Frame:/proc/self/cwd/xxx/xformers/csrc/attention/cuda/fmha/iterators/predicated_tile_access_iterator_residual_last.h:1040:operator() [0x24d0e0]
========= Device Frame:/proc/self/cwd/xxx/xformers/csrc/attention/cuda/fmha/kernel_forward.h:1046:attention_kernel [0x2b9fb0]
not familiar with cutlass, and the cpp template also somehow mess up cuda-gdb, cann't put breakpoint inside the kernel.
I wonder, if someone familiar with the kernel could help me to understand, what's the root cause?
is it fundamentally impossible to combine various sequence length inputs (q/k/v) with tensor bias? or, there is some limitation I need to pay attention to?
here is the test I used to play around with the customized bias:
@cuda_only
def test_attn_bias_blockdiag_with_tensor_bias() -> None:
"""IMPORTANT:
This is the example in the doc for `BlockDiagonalMask`.
If this example needs to be updated, please also update the doc
"""
import torch
from xformers.ops import fmha
K = 16
dtype = torch.float16
device = "cuda"
q_seqlen = [3, 6, 2]
kv_seqlen = [8, 8, 8]
B, H, M, K = 1, 8, 11, 16
# per_block_kv_len = 13824 # faialed: RuntimeError: CUDA error: an illegal memory access was encountered
per_block_kv_len = 1024 # failed: RuntimeError: CUDA error: an illegal memory access was encountered
# per_block_kv_len = 768 # passed
# per_block_kv_len = 512 # passed
# per_block_kv_len = 8 # passed
M_kv, K_kv = per_block_kv_len*3, 16
q = torch.randn([B, M, H, K], dtype=dtype, device=device)
k = v = torch.randn([B, M_kv, H, K_kv], dtype=dtype, device=device)
attn_mask = torch.rand([B, H, M, M_kv], dtype=dtype, device=device)
attn_bias = fmha.BlockDiagonalMaskWithTensorBias.from_seqlens(tensor_bias=attn_mask, q_seqlen=q_seqlen, kv_seqlen=kv_seqlen)
print(q.shape, k.shape, v.shape)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
print(out.shape, out)
# list_out = attn_bias.split(out)
# assert tuple(list_out[0].shape) == (1, 3, 1, K)
❓ Questions and Help
more specifically, interested in combining BlockDiagonalMask, with a tensor bias.
I hacked something together by creating an
BlockDiagonalMaskWithTensorBias
, but gotRuntimeError: CUDA error: an illegal memory access was encountered
with large per block kv len.what interesting is, if I use small per block kv len (768, 512 or 8). the kernel run success, without cuda error. but when use 1024, or a very large value e.g. 13824, I will get cuda illegal memory access error.
with cuda-gdb, I got this line:
xformers/xformers/csrc/attention/cuda/fmha/kernel_forward.h
Line 1046 in f7e46d5
not familiar with cutlass, and the cpp template also somehow mess up cuda-gdb, cann't put breakpoint inside the kernel.
I wonder, if someone familiar with the kernel could help me to understand, what's the root cause?
is it fundamentally impossible to combine various sequence length inputs (q/k/v) with tensor bias? or, there is some limitation I need to pay attention to?
here is the test I used to play around with the customized bias:
the bias:
cutlass.py is updated to support this new bias:
The text was updated successfully, but these errors were encountered: