From 8fa35b59a9a31f1ed3a2bf267869cefe0425acfd Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Thu, 26 Dec 2024 09:23:23 +0000 Subject: [PATCH] Creating a `LowerTriangularMask` no longer creates a CUDA tensor (fairinternal/xformers#1274) __original_commit__ = fairinternal/xformers@4a6a2a1c7a9f865eeaa6d053e2a1c9d05c29cad4 --- CHANGELOG.md | 2 ++ xformers/ops/fmha/attn_bias.py | 32 ++++++++++++++++++++------------ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d48ff31e5f..f8b163f6d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.0.28.post3] - TBD +### Fixed: +- Creating a `LowerTriangularMask` no longer creates a CUDA tensor ### Removed: - Following PyTorch, xFormers no longer builds binaries for conda. Pip is now the only recommended way to get xFormers diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index ae1ff62fd4..1466bb240f 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -1589,18 +1589,8 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias): _subtensor: torch.Tensor @staticmethod - def __new__(cls, *, _subtensor=None): - if _subtensor is None: - _subtensor = torch.empty((0,), device=_get_default_bias_device()) - tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - [], - device=_subtensor.device, - dtype=_subtensor.dtype, - requires_grad=False, - ) - tensor._subtensor = _subtensor - return tensor + def __new__(cls, *, _subtensor=None, device=None, **kwargs): + raise NotImplementedError() def __init__(self, *args, **kwargs) -> None: super().__init__() @@ -1667,6 +1657,24 @@ class LowerTriangularMask(AttentionBiasSubTensor): HOLDS_DENSE_TENSOR = False + @staticmethod + def __new__(cls, *, _subtensor=None, device="cpu", **kwargs): + """ + Note: create on CPU by default to avoid initializing CUDA context + by mistake. + """ + if _subtensor is None: + _subtensor = torch.empty((0,), device=device) + tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + [], + device=_subtensor.device, + dtype=_subtensor.dtype, + requires_grad=False, + ) + tensor._subtensor = _subtensor + return tensor + def materialize( self, shape: Tuple[int, ...],