Skip to content

Commit

Permalink
Fix NaN issue in SetTransformerAggregation (#7902)
Browse files Browse the repository at this point in the history
Fixes #7899
  • Loading branch information
rusty1s authored Aug 18, 2023
1 parent c0c060c commit deff5a4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed an issue where `SetTransformerAggregation` produced `NaN` values for isolates nodes ([#7902](https://github.com/pyg-team/pytorch_geometric/pull/7902))
- Fixed `model_summary` on modules with uninitialized parameters ([#7884](https://github.com/pyg-team/pytorch_geometric/pull/7884))
- Updated `QM9` data pre-processing to include the SMILES string ([#7867](https://github.com/pyg-team/pytorch_geometric/pull/7867))
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
Expand Down
6 changes: 4 additions & 2 deletions test/nn/aggr/test_set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@

def test_set_transformer_aggregation():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = torch.tensor([0, 0, 1, 1, 1, 3])

aggr = SetTransformerAggregation(16, num_seed_points=2, heads=2)
aggr.reset_parameters()
assert str(aggr) == ('SetTransformerAggregation(16, num_seed_points=2, '
'heads=2, layer_norm=False, dropout=0.0)')

out = aggr(x, index)
assert out.size() == (3, 2 * 16)
assert out.size() == (4, 2 * 16)
assert out.isnan().sum() == 0
assert out[2].abs().sum() == 0

if is_full_test():
jit = torch.jit.script(aggr)
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/set_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def forward(
for decoder in self.decoders:
x = decoder(x)

x = x.nan_to_num()

return x.flatten(1, 2) if self.concat else x.mean(dim=1)

def __repr__(self) -> str:
Expand Down

0 comments on commit deff5a4

Please sign in to comment.