diff --git a/test/nn/test_sequential.py b/test/nn/test_sequential.py index 086e194c377c..43adf1b27a8a 100644 --- a/test/nn/test_sequential.py +++ b/test/nn/test_sequential.py @@ -9,8 +9,10 @@ GCNConv, JumpingKnowledge, MessagePassing, + SAGEConv, Sequential, global_mean_pool, + to_hetero, ) from torch_geometric.typing import SparseTensor @@ -142,3 +144,33 @@ def test_sequential_with_ordered_dict(): x = model(x, edge_index) assert x.size() == (4, 64) + + +def test_sequential_to_hetero(): + model = Sequential('x, edge_index', [ + (SAGEConv((-1, -1), 32), 'x, edge_index -> x1'), + ReLU(), + (SAGEConv((-1, -1), 64), 'x1, edge_index -> x2'), + ReLU(), + ]) + + x_dict = { + 'paper': torch.randn(100, 16), + 'author': torch.randn(100, 16), + } + edge_index_dict = { + ('paper', 'cites', 'paper'): + torch.randint(100, (2, 200), dtype=torch.long), + ('paper', 'written_by', 'author'): + torch.randint(100, (2, 200), dtype=torch.long), + ('author', 'writes', 'paper'): + torch.randint(100, (2, 200), dtype=torch.long), + } + metadata = list(x_dict.keys()), list(edge_index_dict.keys()) + + model = to_hetero(model, metadata, debug=False) + + out_dict = model(x_dict, edge_index_dict) + assert isinstance(out_dict, dict) and len(out_dict) == 2 + assert out_dict['paper'].size() == (100, 64) + assert out_dict['author'].size() == (100, 64)