Skip to content

Commit

Permalink
Enable FA3 for the BW pass (fairinternal/xformers#1268)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@9731771
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 16, 2024
1 parent 839c4ec commit 3e62181
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions xformers/ops/fmha/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
)


def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
return False


def _dispatch_bw(
inp: Inputs, varlen_lse_packed: Optional[bool]
) -> Type[AttentionBwOpBase]:
Expand All @@ -151,6 +147,8 @@ def _dispatch_bw(
flash.BwOp,
cutlass.BwOp,
]
if _get_use_fa3():
priority_list_ops = [flash3.BwOp] + priority_list_ops
else:
priority_list_ops = [
ck.BwOp,
Expand Down Expand Up @@ -178,9 +176,6 @@ def _dispatch_bw(
priority_list_ops = [
op for op in priority_list_ops if op.VARLEN_LSE_PACKED == varlen_lse_packed
]
if torch.version.cuda and _is_cutlassB_faster_than_flash(inp):
priority_list_ops.remove(cutlass.BwOp)
priority_list_ops.insert(0, cutlass.BwOp)
return _run_priority_list(
"memory_efficient_attention_backward", priority_list_ops, inp
)

0 comments on commit 3e62181

Please sign in to comment.