diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index ee767cf3e..7acb0b03e 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -182,9 +182,13 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: # self.trace_tensor("attn_mask", attention_mask) attn_weights = attn_weights + attention_mask - attn_weights = ops.softmax(ops.to(attn_weights, dtype=torch.float32), dim=-1) + attn_weights = ops.softmax( + ops.to(attn_weights, dtype=torch.float32), dim=-1 + ) attn_weights = ops.to(attn_weights, dtype=xq.dtype) - attn_output = ops.matmul(attn_weights, values) # (bs, heads, slen, head_dim) + attn_output = ops.matmul( + attn_weights, values + ) # (bs, heads, slen, head_dim) else: is_causal = attention_mask is None and batch_seq_len == 1 attn_output = torch.nn.functional.scaled_dot_product_attention(