Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in half-precision scatter #269

Open
noahstier opened this issue Feb 13, 2022 · 4 comments
Open

Possible bug in half-precision scatter #269

noahstier opened this issue Feb 13, 2022 · 4 comments
Labels

Comments

@noahstier
Copy link

I am getting different results for manually computing the mean of a half-precision tensor using torch.mean vs. torch_scatter.scatter.

Is this expected behavior?

import torch
import torch_scatter

src_float = torch.randn(1_000_000).float()
src_half = src_float.half()
idx = torch.zeros(len(src_float), dtype=torch.long)

result_float = torch_scatter.scatter(src_float, idx, reduce='mean')
result_half = torch_scatter.scatter(src_half, idx, reduce='mean')

result_float_manual = torch.mean(src_float)
result_half_manual = torch.mean(src_half)

print(result_float)
print(result_float_manual)
print(result_half)
print(result_half_manual)

prints:

tensor([-0.0014])
tensor(-0.0014)
tensor([-0.7241], dtype=torch.float16)
tensor(-0., dtype=torch.float16)
@rusty1s
Copy link
Owner

rusty1s commented Feb 14, 2022

Are you sure result_half_manual provides the correct result? To me, this looks more like overfloating problem since 1_000_000 (the number of elements) is not representable by torch.half. For example,

src_half = src_float.half().abs()
result_half_manual = torch.mean(src_half)

produces tensor(nan, dtype=torch.float16). Everything looks good as well in case you reduce the number of elements to, e.g., 1_000.

@noahstier
Copy link
Author

Using 10_000 elements is a better example:

tensor([0.0024])
tensor(0.0024)
tensor([0.0119], dtype=torch.float16)
tensor(0.0024, dtype=torch.float16)

Here, result_half_manual is correct but result_half is not. Even in the case of 1_000_000 elements though, the fact that the results are different is what drew my eye, even if neither is correct in that case.

@rusty1s
Copy link
Owner

rusty1s commented Feb 14, 2022

Indeed. Seems to be a bug in torch.scatter_add, where

N = 10000
src = torch.ones(N, dtype=torch.half)
out = torch.zeros(1, dtype=torch.half)
index = torch.zeros(N, dtype=torch.long)
out.scatter_add_(0, index, src)

cannot produce an output larger than 2048.

@github-actions
Copy link

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

@github-actions github-actions bot added the stale label Aug 14, 2022
@rusty1s rusty1s added bug and removed stale labels Aug 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants