diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 7e0015329..4906e2842 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -340,7 +340,9 @@ def forward( values = xv.transpose(1, 2) # Flash attention. - attn_output = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, attention_mask) + attn_output = torch.nn.functional.scaled_dot_product_attention( + xq, keys, values, attention_mask + ) attn_output = attn_output.transpose(1, 2).reshape(bs, batch_seq_len, -1) # Project. diff --git a/sharktank/sharktank/models/llama/llama_ref.py b/sharktank/sharktank/models/llama/llama_ref.py index 1b8d5afee..52da4d5fd 100644 --- a/sharktank/sharktank/models/llama/llama_ref.py +++ b/sharktank/sharktank/models/llama/llama_ref.py @@ -209,7 +209,9 @@ def forward( values = values.transpose(1, 2) # Flash attention. - attn_output = torch.nn.functional.scaled_dot_product_attention(xq, keys, values, attention_mask) + attn_output = torch.nn.functional.scaled_dot_product_attention( + xq, keys, values, attention_mask + ) attn_output = attn_output.transpose(1, 2).reshape(bs, q_len, -1) # Project.