We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello guys
I'm getting the following error when trying to set
mixed_precision=True mp_dtype=torch.float16
it happens in the GNN module
664 # Average the updates for each junction (requires torch > 1.12) --> 665 update0 = update0.scatter_reduce_( 666 dim=2, 667 index=lines_junc_idx0[:, None].repeat(1, dim, 1), 668 src=lupdate0, 669 reduce="mean", 670 include_self=False, 671 ) RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype
then I changed the calling to:
desc0, desc1 = self.gnn(desc0.half(), desc1.half(), line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1) and it worked.
desc0, desc1 = self.gnn(desc0.half(), desc1.half(), line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1)
It seems like the output from the superpoint network comes in float32
However I think it's not the way to go... Have anyone encountered this problem before?
my torch version is '2.2.0+cu118'
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hello guys
I'm getting the following error when trying to set
it happens in the GNN module
then I changed the calling to:
desc0, desc1 = self.gnn(desc0.half(), desc1.half(), line_enc0, line_enc1, lines_junc_idx0, lines_junc_idx1)
and it worked.
It seems like the output from the superpoint network comes in float32
However I think it's not the way to go...
Have anyone encountered this problem before?
my torch version is
'2.2.0+cu118'
The text was updated successfully, but these errors were encountered: