diff --git a/CHANGELOG.md b/CHANGELOG.md index d48ff31e5..f8b163f6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.0.28.post3] - TBD +### Fixed: +- Creating a `LowerTriangularMask` no longer creates a CUDA tensor ### Removed: - Following PyTorch, xFormers no longer builds binaries for conda. Pip is now the only recommended way to get xFormers diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index ae1ff62fd..1466bb240 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -1589,18 +1589,8 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias): _subtensor: torch.Tensor @staticmethod - def __new__(cls, *, _subtensor=None): - if _subtensor is None: - _subtensor = torch.empty((0,), device=_get_default_bias_device()) - tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, - [], - device=_subtensor.device, - dtype=_subtensor.dtype, - requires_grad=False, - ) - tensor._subtensor = _subtensor - return tensor + def __new__(cls, *, _subtensor=None, device=None, **kwargs): + raise NotImplementedError() def __init__(self, *args, **kwargs) -> None: super().__init__() @@ -1667,6 +1657,24 @@ class LowerTriangularMask(AttentionBiasSubTensor): HOLDS_DENSE_TENSOR = False + @staticmethod + def __new__(cls, *, _subtensor=None, device="cpu", **kwargs): + """ + Note: create on CPU by default to avoid initializing CUDA context + by mistake. + """ + if _subtensor is None: + _subtensor = torch.empty((0,), device=device) + tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + [], + device=_subtensor.device, + dtype=_subtensor.dtype, + requires_grad=False, + ) + tensor._subtensor = _subtensor + return tensor + def materialize( self, shape: Tuple[int, ...],