diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b3b005160da7e..fcadd8e92629c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -23,7 +23,8 @@ _PARTITION_SIZE_ROCM = 512 _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in _GPU_ARCH -_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) +_ON_MI250_MI300 = any(arch in _GPU_ARCH + for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) class ROCmFlashAttentionBackend(AttentionBackend): @@ -663,7 +664,8 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int) -> bool: # rocm custom page attention not support on navi (gfx1*) - return (_ON_MI250_MI300 and not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + return (_ON_MI250_MI300 and not _ON_NAVI + and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)