Skip to content

Commit

Permalink
Creating a LowerTriangularMask no longer creates a CUDA tensor (fai…
Browse files Browse the repository at this point in the history
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 26, 2024
1 parent a2f37f8 commit 8fa35b5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.28.post3] - TBD
### Fixed:
- Creating a `LowerTriangularMask` no longer creates a CUDA tensor
### Removed:
- Following PyTorch, xFormers no longer builds binaries for conda. Pip is now the only recommended way to get xFormers

Expand Down
32 changes: 20 additions & 12 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1589,18 +1589,8 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias):
_subtensor: torch.Tensor

@staticmethod
def __new__(cls, *, _subtensor=None):
if _subtensor is None:
_subtensor = torch.empty((0,), device=_get_default_bias_device())
tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
[],
device=_subtensor.device,
dtype=_subtensor.dtype,
requires_grad=False,
)
tensor._subtensor = _subtensor
return tensor
def __new__(cls, *, _subtensor=None, device=None, **kwargs):
raise NotImplementedError()

def __init__(self, *args, **kwargs) -> None:
super().__init__()
Expand Down Expand Up @@ -1667,6 +1657,24 @@ class LowerTriangularMask(AttentionBiasSubTensor):

HOLDS_DENSE_TENSOR = False

@staticmethod
def __new__(cls, *, _subtensor=None, device="cpu", **kwargs):
"""
Note: create on CPU by default to avoid initializing CUDA context
by mistake.
"""
if _subtensor is None:
_subtensor = torch.empty((0,), device=device)
tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
[],
device=_subtensor.device,
dtype=_subtensor.dtype,
requires_grad=False,
)
tensor._subtensor = _subtensor
return tensor

def materialize(
self,
shape: Tuple[int, ...],
Expand Down

0 comments on commit 8fa35b5

Please sign in to comment.