Skip to content

Commit

Permalink
Allow optional but untyped tensors in MessagePassing (#9494)
Browse files Browse the repository at this point in the history
Fixes #9492
  • Loading branch information
rusty1s authored Jul 8, 2024
1 parent c8cd4de commit fbafbc4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494))
- Added support for modifying `filename` of the stored partitioned file in `ClusterLoader` ([#9448](https://github.com/pyg-team/pytorch_geometric/pull/9448))
- Support other than two-dimensional inputs in `AttentionalAggregation` ([#9433](https://github.com/pyg-team/pytorch_geometric/pull/9433))
- Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386), [#9445](https://github.com/pyg-team/pytorch_geometric/pull/9445))
Expand Down
21 changes: 21 additions & 0 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,24 @@ def test_pickle(tmp_path):

model = torch.load(path)
torch.jit.script(model)


class MyOptionalEdgeAttrConv(MessagePassing):
def __init__(self):
super().__init__()

def forward(self, x, edge_index, edge_attr=None):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_j, edge_attr=None):
return x_j if edge_attr is None else x_j * edge_attr.view(-1, 1)


def test_my_optional_edge_attr_conv():
conv = MyOptionalEdgeAttrConv()

x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

out = conv(x, edge_index)
assert out.size() == (4, 8)
9 changes: 6 additions & 3 deletions torch_geometric/nn/conv/collect.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,16 @@ def {{collect_name}}(

{%- if 'edge_weight' in collect_param_dict and
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
assert edge_weight is not None
if torch.jit.is_scripting():
assert edge_weight is not None
{%- elif 'edge_attr' in collect_param_dict and
collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
assert edge_attr is not None
if torch.jit.is_scripting():
assert edge_attr is not None
{%- elif 'edge_type' in collect_param_dict and
collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
assert edge_type is not None
if torch.jit.is_scripting():
assert edge_type is not None
{%- endif %}

# Collect user-defined arguments:
Expand Down

0 comments on commit fbafbc4

Please sign in to comment.