Skip to content

Commit

Permalink
Resolved alibi bias issue due to porting flat PA pr
Browse files Browse the repository at this point in the history
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 committed Nov 4, 2024
1 parent 6643aa6 commit 9604c73
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
13 changes: 12 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
self.max_seq_len = max_seq_len
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand Down Expand Up @@ -235,6 +236,14 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
self.position_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads,
attn_bias.dtype,
self.max_seq_len if self.max_seq_len is not None else attn_bias.shape[-1])

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -245,10 +254,12 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
alibi_slopes=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
22 changes: 17 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,10 @@ def _prepare_decode(
dtype=self.model_config.dtype,
device=self.device)

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
Expand All @@ -1193,7 +1197,7 @@ def _prepare_decode(
block_scales=block_scales,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
Expand Down Expand Up @@ -1402,10 +1406,18 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'context_lens_tensor',
'block_list', 'block_mapping', 'block_usage', 'slot_mapping',
'is_prompt', 'block_indices', 'block_offsets', 'block_scales',
'block_groups'
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_scales',
'block_groups',
])
return attention_metadata

Expand Down

0 comments on commit 9604c73

Please sign in to comment.