Skip to content

Commit

Permalink
EdgeIndex documentation (#8515)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rusty1s and pre-commit-ci[bot] authored Dec 3, 2023
1 parent 390942f commit bd4c99a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 16 deletions.
13 changes: 13 additions & 0 deletions docs/source/modules/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ Data Objects
{{ name }}
{% endfor %}

Tensor Objects
--------------

.. currentmodule:: torch_geometric.data

.. autosummary::
:nosignatures:
:toctree: ../generated

{% for name in torch_geometric.data.tensor_classes %}
{{ name }}
{% endfor %}

Remote Backend Interfaces
-------------------------

Expand Down
3 changes: 2 additions & 1 deletion test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch import Tensor, tensor

import torch_geometric
from torch_geometric.data.edge_index import SUPPORTED_DTYPES, EdgeIndex
from torch_geometric.data import EdgeIndex
from torch_geometric.data.edge_index import SUPPORTED_DTYPES
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
Expand Down
8 changes: 7 additions & 1 deletion torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .hetero_data import HeteroData
from .batch import Batch
from .temporal import TemporalData
from .edge_index import EdgeIndex
from .database import Database, SQLiteDatabase, RocksDatabase
from .dataset import Dataset
from .in_memory_dataset import InMemoryDataset
Expand All @@ -26,6 +27,10 @@
'OnDiskDataset',
]

tensor_classes = [
'EdgeIndex',
]

remote_backend_classes = [
'FeatureStore',
'GraphStore',
Expand All @@ -48,7 +53,8 @@
'extract_gz',
]

__all__ = data_classes + remote_backend_classes + helper_functions
__all__ = (data_classes + tensor_classes + remote_backend_classes +
helper_functions)

lightning = LazyLoader('lightning', globals(),
'torch_geometric.data.lightning')
Expand Down
121 changes: 107 additions & 14 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,15 @@ def wrapper(*args, **kwargs):


class EdgeIndex(Tensor):
r"""An advanced :obj:`edge_index` representation with additional (meta)data
attached.
r"""An COO :obj:`edge_index` tensor with additional (meta)data attached.
:class:`EdgeIndex` is a :pytorch:`PyTorch` tensor, that holds an
:class:`EdgeIndex` is a :pytorch:`null` class:`torch.Tensor`, that holds an
:obj:`edge_index` representation of shape :obj:`[2, num_edges]`.
Edges are given as pairwise source and destination node indices in sparse
COO format.
While :class:`EdgeIndex` sub-classes a general :pytorch:`PyTorch` tensor,
it can hold additional (meta)data, *i.e.*:
While :class:`EdgeIndex` sub-classes a general :pytorch:`null`
:class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:
* :obj:`sparse_size`: The underlying sparse matrix size
* :obj:`sort_order`: The sort order (if present), either by row or column.
Expand All @@ -126,6 +125,42 @@ class EdgeIndex(Tensor):
This representation ensures for optimal computation in GNN message passing
schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
workflows.
.. code-block:: python
from torch_geometric.data import EdgeIndex
edge_index = EdgeIndex(
[[0, 1, 1, 2],
[1, 0, 2, 1]]
sparse_size=(3, 3),
sort_order='row',
is_undirected=True,
device='cpu',
)
>>> EdgeIndex([[0, 1, 1, 2],
... [1, 0, 2, 1]])
assert edge_index.is_sorted_by_row
assert not edge_index.is_undirected
# Flipping order:
edge_index = edge_index.flip(0)
>>> EdgeIndex([[1, 0, 2, 1],
... [0, 1, 1, 2]])
assert edge_index.is_sorted_by_col
assert not edge_index.is_undirected
# Filtering:
mask = torch.tensor([True, True, True, False])
edge_index = edge_index[:, mask]
>>> EdgeIndex([[1, 0, 2],
... [0, 1, 1]])
assert edge_index.is_sorted_by_col
assert not edge_index.is_undirected
# Sparse-Dense Matrix Multiplication:
out = edge_index @ torch.randn(3, 16)
assert out.size() == (3, 16)
"""
# See "https://pytorch.org/docs/stable/notes/extending.html"
# for a basic tutorial on how to subclass `torch.Tensor`.
Expand Down Expand Up @@ -198,9 +233,13 @@ def __new__(
# Validation ##############################################################

def validate(self) -> 'EdgeIndex':
r"""Validates the :class:`EdgeIndex` representation, i.e., it ensures
* that :class:`EdgeIndex` only holds valid entries.
* that the sort order is correctly set.
r"""Validates the :class:`EdgeIndex` representation.
In particular, it ensures that
* it only holds valid indices.
* the sort order is correctly set.
* indices are bidirectional in case it is specified as undirected.
"""
assert_valid_dtype(self)
assert_two_dimensional(self)
Expand Down Expand Up @@ -334,6 +373,9 @@ def get_num_cols(self) -> int:

@assert_sorted
def get_indptr(self) -> Tensor:
r"""Returns the compressed index representation in case
:class:`EdgeIndex` is sorted.
"""
if self._indptr is not None:
return self._indptr

Expand Down Expand Up @@ -368,6 +410,9 @@ def _sort_by_transpose(self) -> Tuple[Tuple[Tensor, Tensor], Tensor]:

@assert_sorted
def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]:
r"""Returns the compressed CSR representation
:obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted.
"""
if self.is_sorted_by_row:
return (self.get_indptr(), self[1]), slice(None, None, None)

Expand All @@ -389,6 +434,9 @@ def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]:

@assert_sorted
def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Union[Tensor, slice]]:
r"""Returns the compressed CSC representation
:obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted.
"""
if self.is_sorted_by_col:
return (self.get_indptr(), self[0]), slice(None, None, None)

Expand Down Expand Up @@ -505,9 +553,20 @@ def sort_by(
def to_dense(
self,
value: Optional[Tensor] = None,
fill_value: float = 0.0,
dtype: Optional[torch.dtype] = None,
) -> Tensor:
r"""Converts :class:`EdgeIndex` into a dense :class:`torch.Tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
fill_value (float, optional): The fill value for remaining elements
in the dense matrix. (default: :obj:`0.0`)
dtype (torch.dtype, optional): The data type of the returned
tensor. (default: :obj:`None`)
"""
# TODO Respect duplicated edges.

dtype = value.dtype if value is not None else dtype
Expand All @@ -516,12 +575,20 @@ def to_dense(
if value is not None and value.dim() > 1:
size = size + value.shape[1:]

out = torch.zeros(size, dtype=dtype, device=self.device)
out = torch.full(size, fill_value, dtype=dtype, device=self.device)
out[self[0], self[1]] = value if value is not None else 1

return out

def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor:
r"""Converts :class:`EdgeIndex` into a :pytorch:`null`
:class:`torch.sparse_coo_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
"""
value = self._get_value() if value is None else value
out = torch.sparse_coo_tensor(
indices=self.as_tensor(),
Expand All @@ -537,6 +604,14 @@ def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor:
return out

def to_sparse_csr(self, value: Optional[Tensor] = None) -> Tensor:
r"""Converts :class:`EdgeIndex` into a :pytorch:`null`
:class:`torch.sparse_csr_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
"""
(rowptr, col), perm = self.get_csr()
value = self._get_value() if value is None else value[perm]

Expand All @@ -550,6 +625,14 @@ def to_sparse_csr(self, value: Optional[Tensor] = None) -> Tensor:
)

def to_sparse_csc(self, value: Optional[Tensor] = None) -> Tensor:
r"""Converts :class:`EdgeIndex` into a :pytorch:`null`
:class:`torch.sparse_csc_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
"""
if not torch_geometric.typing.WITH_PT112:
raise NotImplementedError(
"'to_sparse_csc' not supported for PyTorch < 1.12")
Expand All @@ -569,10 +652,20 @@ def to_sparse_csc(self, value: Optional[Tensor] = None) -> Tensor:
def to_sparse(
self,
*,
layout: Optional[torch.layout] = None,
layout: torch.layout = torch.sparse_coo,
value: Optional[Tensor] = None,
) -> Tensor:
r"""Converts :class:`EdgeIndex` into a
:pytorch:`null` :class:`torch.sparse` tensor.
Args:
layout (torch.layout, optional): The desired sparse layout. One of
:obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, or
:obj:`torch.sparse_csc`. (default: :obj:`torch.sparse_coo`)
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
"""
if layout is None or layout == torch.sparse_coo:
return self.to_sparse_coo(value)
if layout == torch.sparse_csr:
Expand All @@ -586,12 +679,12 @@ def to_sparse_tensor(
self,
value: Optional[Tensor] = None,
) -> SparseTensor:
r"""Converts the :class:`EdgeIndex` representation to a
:class:`torch_sparse.SparseTensor`. Requires that :obj:`torch-sparse`
is installed.
r"""Converts :class:`EdgeIndex` into a
:class:`torch_sparse.SparseTensor`.
Requires that :obj:`torch-sparse` is installed.
Args:
value (torch.Tensor, optional): The values of non-zero indices.
value (torch.Tensor, optional): The values for sparse indices.
(default: :obj:`None`)
"""
return SparseTensor(
Expand Down

0 comments on commit bd4c99a

Please sign in to comment.