Skip to content

Commit

Permalink
Update FusedSDPA changes from habana_main (#234)
Browse files Browse the repository at this point in the history
Update changes to enable FusedSDPA from habana_main
(#168)
  • Loading branch information
shepark authored Sep 3, 2024
1 parent 8a4ad89 commit 574a1b5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 24 deletions.
27 changes: 19 additions & 8 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -122,6 +123,12 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
if self.prefill_usefusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -174,14 +181,17 @@ def forward(

if attn_metadata.is_prompt:
# Prompt run.
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
if not self.prefill_usefusedsdpa:
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward!'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None and \
self.position_bias is not None:
attn_bias.add_(self.position_bias[:, :,
-attn_bias.size(2):,
-attn_bias.size(3):])
else:
attn_bias = None

query_shape = (batch_size, seq_len, self.num_heads,
self.head_size)
Expand All @@ -197,6 +207,7 @@ def forward(
qk_matmul_op=self.qk_matmul,
softmax_op=self.softmax,
av_matmul_op=self.av_matmul,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand Down
60 changes: 48 additions & 12 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@

import vllm.hpu.utils as hpu_utils

HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")

PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')


Expand Down Expand Up @@ -133,6 +141,21 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
return final_hidden_states.view(-1, D)


#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def prompt_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -143,23 +166,36 @@ def prompt_attention(
qk_matmul_op = torch.matmul,
softmax_op = torch.softmax,
av_matmul_op = torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
attn_bias = attn_bias.unsqueeze(2)
attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = av_matmul_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
if attn_bias is not None or HPUFusedSDPA is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = av_matmul_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')
attn_weights = attn_weights.transpose(1, 2)
return attn_weights
14 changes: 10 additions & 4 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,21 @@ class HpuModelAdapter():

def __init__(self, model, block_size, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']

self.block_size = block_size
if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model,
backend='hpu_backend',
dynamic=False)

def _set_attn_bias(self, metadata, batch_size, seq_len, device, dtype):
seq_lens_t = metadata.seq_lens_tensor
prefill_metadata = metadata
if prefill_metadata is None or self.prefill_use_fusedsdpa:
return metadata

seq_lens_t = prefill_metadata.seq_lens_tensor
len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32)
.view(1, seq_len)
.ge(seq_lens_t.unsqueeze(-1))
Expand All @@ -180,7 +187,8 @@ def _set_attn_bias(self, metadata, batch_size, seq_len, device, dtype):
mask = causal_mask.logical_or(len_mask)
attn_bias = (torch.zeros_like(mask, dtype=dtype)
.masked_fill_(mask, -math.inf))
return metadata._replace(attn_bias=attn_bias)
metadata = prefill_metadata._replace(attn_bias=attn_bias)
return metadata

def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0)
Expand Down Expand Up @@ -611,7 +619,6 @@ def _prepare_prompt(
# actual prompt lens
context_lens.append(context_len)
query_lens.append(seq_len - context_len)

input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
Expand Down Expand Up @@ -679,7 +686,6 @@ def _prepare_prompt(
max_prompt_len = max(
find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg),
self.block_size)

input_tokens = make_tensor_with_pad(input_tokens,
max_len=max_prompt_len,
pad=0,
Expand Down

0 comments on commit 574a1b5

Please sign in to comment.