Skip to content

Commit

Permalink
[EdgeIndex] Clean-up cache design (#8513)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] authored Dec 3, 2023
1 parent 08c87e5 commit 390942f
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 242 deletions.
100 changes: 49 additions & 51 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_basic(dtype, device):
assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],')
assert adj.dtype == dtype
assert adj.device == device
assert adj.sparse_size == (3, 3)
assert adj.sparse_size() == (3, 3)

assert adj.sort_order is None
assert not adj.is_sorted
Expand Down Expand Up @@ -64,9 +64,9 @@ def test_undirected(dtype, device):
assert isinstance(adj, EdgeIndex)
assert adj.is_undirected

assert adj.sparse_size == (None, None)
assert adj.sparse_size() == (None, None)
adj.get_num_rows()
assert adj.sparse_size == (3, 3)
assert adj.sparse_size() == (3, 3)

adj.validate()

Expand All @@ -81,40 +81,40 @@ def test_fill_cache_(dtype, device, is_undirected):
kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)
adj.validate().fill_cache_()
assert adj.sparse_size == (3, 3)
assert adj._rowptr.dtype == dtype
assert adj._rowptr.equal(tensor([0, 1, 3, 4], device=device))
assert adj._csr_col is None
assert adj._csr2csc.dtype == torch.int64
assert (adj._csr2csc.equal(tensor([1, 0, 3, 2], device=device))
or adj._csr2csc.equal(tensor([1, 3, 0, 2], device=device)))
assert adj.sparse_size() == (3, 3)
assert adj._indptr.dtype == dtype
assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device))
assert adj._T_perm.dtype == torch.int64
assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device))
or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device)))
assert adj._T_index[0].dtype == dtype
assert (adj._T_index[0].equal(tensor([1, 0, 2, 1], device=device))
or adj._T_index[0].equal(tensor([1, 2, 0, 1], device=device)))
assert adj._T_index[1].dtype == dtype
assert adj._T_index[1].equal(tensor([0, 1, 1, 2], device=device))
if is_undirected:
assert adj._colptr is None
assert adj._T_indptr is None
else:
assert adj._colptr.dtype == dtype
assert adj._colptr.equal(tensor([0, 1, 3, 4], device=device))
assert adj._csc_row.dtype == dtype
assert (adj._csc_row.equal(tensor([1, 0, 2, 1], device=device))
or adj._csc_row.equal(tensor([1, 2, 0, 1], device=device)))
assert adj._csc2csr is None
assert adj._T_indptr.dtype == dtype
assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device))

adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs)
adj.validate().fill_cache_()
assert adj.sparse_size == (3, 3)
assert adj._colptr.dtype == dtype
assert adj._colptr.equal(tensor([0, 1, 3, 4], device=device))
assert adj._csc_row is None
assert (adj._csc2csr.equal(tensor([1, 0, 3, 2], device=device))
or adj._csc2csr.equal(tensor([1, 3, 0, 2], device=device)))
assert adj.sparse_size() == (3, 3)
assert adj._indptr.dtype == dtype
assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device))
assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device))
or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device)))
assert adj._T_index[0].dtype == dtype
assert adj._T_index[0].equal(tensor([0, 1, 1, 2], device=device))
assert adj._T_index[1].dtype == dtype
assert (adj._T_index[1].equal(tensor([1, 0, 2, 1], device=device))
or adj._T_index[1].equal(tensor([1, 2, 0, 1], device=device)))
if is_undirected:
assert adj._rowptr is None
assert adj._T_indptr is None
else:
assert adj._rowptr.dtype == dtype
assert adj._rowptr.equal(tensor([0, 1, 3, 4], device=device))
assert adj._csr_col.dtype == dtype
assert (adj._csr_col.equal(tensor([1, 0, 2, 1], device=device))
or adj._csr_col.equal(tensor([1, 2, 0, 1], device=device)))
assert adj._csr2csc is None
assert adj._T_indptr.dtype == dtype
assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device))


@withCUDA
Expand Down Expand Up @@ -150,15 +150,15 @@ def test_to(dtype, device, is_undirected):
adj = adj.to(device)
assert isinstance(adj, EdgeIndex)
assert adj.device == device
assert adj._rowptr.device == device
assert adj._csr2csc.device == device
assert adj._indptr.device == device
assert adj._T_perm.device == device

out = adj.to(torch.int)
assert out.dtype == torch.int
if torch_geometric.typing.WITH_PT20:
assert isinstance(out, EdgeIndex)
assert out._rowptr.dtype == torch.int
assert out._csr2csc.dtype == torch.int
assert out._indptr.dtype == torch.int
assert out._T_perm.dtype == torch.int
else:
assert not isinstance(out, EdgeIndex)

Expand Down Expand Up @@ -203,7 +203,7 @@ def test_share_memory(dtype, device):
adj = adj.share_memory_()
assert isinstance(adj, EdgeIndex)
assert adj.is_shared()
assert adj._rowptr.is_shared()
assert adj._indptr.is_shared()


@withCUDA
Expand Down Expand Up @@ -247,24 +247,24 @@ def test_sort_by(dtype, device, is_undirected):
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)

out, perm = adj.sort_by('col')
assert adj._csr2csc is not None # Check caches.
assert adj._csc_row is not None
assert adj._T_perm is not None # Check caches.
assert adj._T_index[0] is not None and adj._T_index[1] is not None
assert (out[0].equal(tensor([1, 0, 2, 1], device=device))
or out[0].equal(tensor([1, 2, 0, 1], device=device)))
assert out[1].equal(tensor([0, 1, 1, 2], device=device))
assert (perm.equal(tensor([1, 0, 3, 2], device=device))
or perm.equal(tensor([1, 3, 0, 2], device=device)))
assert out._csr2csc is None
assert out._csc2csr is None
assert out._T_perm is None
assert out._T_index[0] is None and out._T_index[1] is None

out, perm = out.sort_by('row')
assert out[0].equal(tensor([0, 1, 1, 2], device=device))
assert (out[1].equal(tensor([1, 0, 2, 1], device=device))
or out[1].equal(tensor([1, 2, 0, 1], device=device)))
assert (perm.equal(tensor([1, 0, 3, 2], device=device))
or perm.equal(tensor([2, 3, 0, 1], device=device)))
assert out._csr2csc is None
assert out._csc2csr is None
assert out._T_perm is None
assert out._T_index[0] is None and out._T_index[1] is None


@withCUDA
Expand All @@ -278,7 +278,7 @@ def test_cat(dtype, device, is_undirected):
out = torch.cat([adj1, adj2], dim=1)
assert out.size() == (2, 8)
assert isinstance(out, EdgeIndex)
assert out.sparse_size == (4, 4)
assert out.sparse_size() == (4, 4)
assert not out.is_sorted
assert out.is_undirected == is_undirected

Expand All @@ -298,18 +298,16 @@ def test_flip(dtype, device, is_undirected):
out = adj.flip(0)
assert isinstance(out, EdgeIndex)
assert out.equal(tensor([[1, 0, 2, 1], [0, 1, 1, 2]], device=device))
assert out.sparse_size == (3, 3)
assert out.is_sorted_by_col
assert out.is_undirected == is_undirected
assert out._colptr.equal(tensor([0, 1, 3, 4], device=device))
assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))

out = adj.flip([0, 1])
assert isinstance(out, EdgeIndex)
assert out.equal(tensor([[1, 2, 0, 1], [2, 1, 1, 0]], device=device))
assert out.sparse_size == (3, 3)
assert not out.is_sorted
assert out.is_undirected == is_undirected
assert out._colptr is None
assert out._T_indptr is None


@withCUDA
Expand Down Expand Up @@ -455,7 +453,7 @@ def test_to_sparse_csr(dtype, device):
assert out.device == device
assert out.layout == torch.sparse_csr
assert out.size() == (3, 3)
assert adj._rowptr.equal(out.crow_indices())
assert adj._indptr.equal(out.crow_indices())
assert adj[1].equal(out.col_indices())


Expand All @@ -476,7 +474,7 @@ def test_to_sparse_csc(dtype, device):
assert out.device == device
assert out.layout == torch.sparse_csc
assert out.size() == (3, 3)
assert adj._colptr.equal(out.ccol_indices())
assert adj._indptr.equal(out.ccol_indices())
assert adj[0].equal(out.row_indices())


Expand Down Expand Up @@ -573,7 +571,7 @@ def test_save_and_load(dtype, device, tmp_path):
adj.fill_cache_()

assert adj.sort_order == 'row'
assert adj._rowptr is not None
assert adj._indptr is not None

path = osp.join(tmp_path, 'edge_index.pt')
torch.save(adj, path)
Expand All @@ -582,7 +580,7 @@ def test_save_and_load(dtype, device, tmp_path):
assert isinstance(out, EdgeIndex)
assert out.equal(adj)
assert out.sort_order == 'row'
assert out._rowptr.equal(adj._rowptr)
assert out._indptr.equal(adj._indptr)


@pytest.mark.parametrize('dtype', DTYPES)
Expand All @@ -608,7 +606,7 @@ def test_data_loader(dtype, num_workers):
assert isinstance(adj, EdgeIndex)
assert adj.dtype == adj.dtype
assert adj.is_shared() == (num_workers > 0)
assert adj._rowptr.is_shared() == (num_workers > 0)
assert adj._indptr.is_shared() == (num_workers > 0)


def test_torch_script():
Expand Down
Loading

0 comments on commit 390942f

Please sign in to comment.