Skip to content

Commit

Permalink
Added a test with to_hetero for Sequential models (#7927)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Aug 24, 2023
1 parent 1be2217 commit d6951e7
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/nn/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
GCNConv,
JumpingKnowledge,
MessagePassing,
SAGEConv,
Sequential,
global_mean_pool,
to_hetero,
)
from torch_geometric.typing import SparseTensor

Expand Down Expand Up @@ -142,3 +144,33 @@ def test_sequential_with_ordered_dict():

x = model(x, edge_index)
assert x.size() == (4, 64)


def test_sequential_to_hetero():
model = Sequential('x, edge_index', [
(SAGEConv((-1, -1), 32), 'x, edge_index -> x1'),
ReLU(),
(SAGEConv((-1, -1), 64), 'x1, edge_index -> x2'),
ReLU(),
])

x_dict = {
'paper': torch.randn(100, 16),
'author': torch.randn(100, 16),
}
edge_index_dict = {
('paper', 'cites', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
('paper', 'written_by', 'author'):
torch.randint(100, (2, 200), dtype=torch.long),
('author', 'writes', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
}
metadata = list(x_dict.keys()), list(edge_index_dict.keys())

model = to_hetero(model, metadata, debug=False)

out_dict = model(x_dict, edge_index_dict)
assert isinstance(out_dict, dict) and len(out_dict) == 2
assert out_dict['paper'].size() == (100, 64)
assert out_dict['author'].size() == (100, 64)

0 comments on commit d6951e7

Please sign in to comment.