-
Notifications
You must be signed in to change notification settings - Fork 634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect attention output with SparseCS mask #1124
Comments
After more troubleshooting, it seems that the conversion to a sparse matrix is incorrect. >>> mask = torch.rand((M, N)).cuda() < 0.01
>>> sparse = xf.SparseCS(mask, device=mask.device)
>>> (sparse.to_dense() != mask).nonzero()
tensor([[ 0, 1023, 418],
[ 0, 1023, 573],
[ 0, 1023, 583]], device='cuda:0') The issue comes from xformers/xformers/sparse/utils.py Lines 116 to 117 in 4a9dd7e
Modifying def monkey_round_nnz(mask, divisible_by=4):
nnz = torch.count_nonzero(mask)
cunz = torch.cumsum(~mask.flatten(), dim=0)
flip = cunz <= (-nnz) % divisible_by
return torch.logical_or(flip.reshape_as(mask), mask)
xformers.sparse.utils._round_nnz = _round_nnz However, xformers/xformers/sparse/csr_tensor.py Line 177 in 4a9dd7e
Taking the values of @classmethod
def _masked_matmul(cls, a, b, mask):
if not (type(a) is torch.Tensor and type(b) is torch.Tensor):
return NotImplemented
assert mask.shape[1] == a.shape[1]
assert mask.shape[2] == b.shape[2]
values = mask.__values
row_indices = mask.__row_indices
row_offsets = mask.__row_offsets
column_indices = mask.__column_indices
tansp_info = mask.__transp_info
out = _csr_ops._sddmm.apply(
a.contiguous(),
b.transpose(-2, -1).contiguous(),
row_indices,
row_offsets,
column_indices,
tansp_info,
)
out = torch.where(values, out, float("-inf"))
return cls._wrap(
mask.shape,
out,
row_indices,
row_offsets,
column_indices,
tansp_info,
) |
🐛 Bug
The output of
scaled_dot_product_attention
is wrong when the mask is aSparseCS
matrix. In particular the last element of the sequence is incorrect, while others are correct.To Reproduce
Expected behavior
The output of
torch.nn.functional.scaled_dot_product_attention
andxf.scaled_dot_product_attention
should be the same (up to some tolerance).Environment
The text was updated successfully, but these errors were encountered: