Skip to content

Commit

Permalink
[Pallas] Support FA sm_scale (#7035)
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan authored May 8, 2024
1 parent 825ba0d commit 1c31cde
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 14 deletions.
58 changes: 58 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,64 @@ def test_flash_attention_backward_segment_ids(self):
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_sm_scale(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")
sm_scale = 0.7
o = flash_attention(q, k, v, False, None, None, sm_scale)

expected_o = self._attention(q * sm_scale, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_sm_scale_backward(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
sm_scale = 0.7
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = flash_attention(q, k, v, False, None, None, sm_scale)
loss = o.sum()
loss.backward()
xm.mark_step()

q_grad = q.grad
k_grad = k.grad
v_grad = v.grad

torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()

o = self._attention(q * sm_scale, k, v)
loss = o.sum()
loss.backward()
xm.mark_step()

# Hmm, the gradients are the same even the autograd graph seems different.
for i in [(q, q_grad), (k, k_grad), (v, v_grad)]:
self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
24 changes: 10 additions & 14 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,22 +198,16 @@ def prepare_segment_ids(q_segment_ids, kv_segment_ids):
return segment_ids, q_segment_ids, kv_segment_ids

@staticmethod
def forward(ctx,
q,
k,
v,
causal=False,
q_segment_ids=None,
kv_segment_ids=None,
partition_spec=None,
mesh=None):
def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale,
partition_spec, mesh):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
jax_import_guard()
import jax
from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_impl

ctx.causal = causal
ctx.sm_scale = sm_scale
ctx.partition_spec = partition_spec
ctx.mesh = mesh
ctx.full_shape = None
Expand Down Expand Up @@ -258,7 +252,7 @@ def forward(ctx,
segment_ids,
save_residuals,
causal,
1.0,
sm_scale,
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]),
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]),
Expand Down Expand Up @@ -300,6 +294,7 @@ def backward(ctx, grad_output):

q, k, v, o, l, m, q_segment_ids, kv_segment_ids = ctx.saved_tensors
causal = ctx.causal
sm_scale = ctx.sm_scale
partition_spec = ctx.partition_spec
mesh = ctx.mesh
full_shape = ctx.full_shape
Expand Down Expand Up @@ -350,7 +345,7 @@ def backward(ctx, grad_output):
k.shape[2]),
block_k=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"],
k.shape[2]),
sm_scale=1.0,
sm_scale=sm_scale,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
Expand Down Expand Up @@ -388,7 +383,7 @@ def backward(ctx, grad_output):
k.shape[2]),
block_q=min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"],
q.shape[2]),
sm_scale=1.0,
sm_scale=sm_scale,
causal=causal,
mask_value=FlashAttention.DEFAULT_MASK_VALUE,
debug=False,
Expand Down Expand Up @@ -418,7 +413,7 @@ def backward(ctx, grad_output):
grad_v = xs.disable_manual_sharding(
grad_v, partition_spec, full_shape, mesh=mesh).global_tensor

return grad_q, grad_k, grad_v, None, None, None, None, None
return grad_q, grad_k, grad_v, None, None, None, None, None, None


def flash_attention(
Expand All @@ -428,12 +423,13 @@ def flash_attention(
causal=False,
q_segment_ids=None,
kv_segment_ids=None,
sm_scale=1.0,
*,
partition_spec=None,
mesh=None):
# TODO: support SPMD and Dynamo with segment_ids.
return FlashAttention.apply(q, k, v, causal, q_segment_ids, kv_segment_ids,
partition_spec, mesh)
sm_scale, partition_spec, mesh)


def paged_attention(q, k_pages, v_pages, lengths, page_indices,
Expand Down

0 comments on commit 1c31cde

Please sign in to comment.