diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c3b896672fd..9e1576b5a303 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added implemenation of `Batch.{from_batch_list,from_batch_index,add_graph_attr,set_edge_attr,set_edges}` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) ### Changed @@ -64,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Convert `Batch.index_select` to a full slicing operation returning a new batch instead of a list of `Data` ([#8414](https://github.com/pyg-team/pytorch_geometric/pull/8414)) - Add args to Taobao multi-GPU example and move item-item compute to dataset ([#9550](https://github.com/pyg-team/pytorch_geometric/pull/9550)) - Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618)) - Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541)) diff --git a/test/data/test_batch.py b/test/data/test_batch.py index b893c1e4dd11..75c71ecaa327 100644 --- a/test/data/test_batch.py +++ b/test/data/test_batch.py @@ -63,24 +63,27 @@ def test_batch_basic(): assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] assert batch.ptr.tolist() == [0, 3, 5, 9] - assert str(batch[0]) == ("Data(x=[3], edge_index=[2, 4], y=[1], " - "string='1', array=[2], num_nodes=3)") - assert str(batch[1]) == ("Data(x=[2], edge_index=[2, 2], y=[1], " - "string='2', array=[3], num_nodes=2)") - assert str(batch[2]) == ("Data(x=[4], edge_index=[2, 6], y=[1], " - "string='3', array=[4], num_nodes=4)") - - assert len(batch.index_select([1, 0])) == 2 - assert len(batch.index_select(torch.tensor([1, 0]))) == 2 - assert len(batch.index_select(torch.tensor([True, False]))) == 1 - assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2 - assert len(batch.index_select(np.array([True, False]))) == 1 - assert len(batch[:2]) == 2 + graphs = batch[0], batch[1], batch[2] + assert all([isinstance(e, Data) for e in graphs]) + assert [e.x.shape[0] for e in graphs] == [3, 2, 4] + assert [e.y.shape[0] for e in graphs] == [1, 1, 1] + assert [list(e.edge_index.shape) for e in graphs] == [[2, 4], [2, 2], + [2, 6]] + assert [e.string for e in graphs] == ["1", "2", "3"] + assert [e.num_nodes for e in graphs] == [3, 2, 4] + + assert batch.index_select([1, 0]).num_graphs == 2 + assert batch.index_select(torch.tensor([1, 0])).num_graphs == 2 + assert batch.index_select(torch.tensor([True, False, + False])).num_graphs == 1 + assert batch.index_select(np.array([1, 0], dtype=np.int64)).num_graphs == 2 + assert batch.index_select(np.array([True, False, False])).num_graphs == 1 + assert batch[:2].num_graphs == 2 data_list = batch.to_data_list() assert len(data_list) == 3 - assert len(data_list[0]) == 6 + assert set(data_list[0].keys()) == set(data1.keys()) assert data_list[0].x.tolist() == [1, 2, 3] assert data_list[0].y.tolist() == [1] assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] @@ -111,21 +114,24 @@ def test_batch_basic(): def test_index(): index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True) + index3 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True) data1 = Data(index=index1, num_nodes=3) data2 = Data(index=index2, num_nodes=4) + data3 = Data(index=index1, num_nodes=3) - batch = Batch.from_data_list([data1, data2]) + batch = Batch.from_data_list([data1, data2, data3]) - assert len(batch) == 2 - assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) - assert batch.ptr.equal(torch.tensor([0, 3, 7])) + assert len(batch) == 3 + assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2])) + assert batch.ptr.equal(torch.tensor([0, 3, 7, 10])) assert isinstance(batch.index, Index) - assert batch.index.equal(torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6])) - assert batch.index.dim_size == 7 + assert batch.index.equal( + torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6, 7, 8, 8, 9])) + assert batch.index.dim_size == 10 assert batch.index.is_sorted - for i, index in enumerate([index1, index2]): + for i, index in enumerate([index1, index2, index3]): data = batch[i] assert isinstance(data.index, Index) assert data.index.equal(index) @@ -149,18 +155,16 @@ def test_edge_index(): data1 = Data(edge_index=edge_index1) data2 = Data(edge_index=edge_index2) - batch = Batch.from_data_list([data1, data2]) + batch = Batch.from_data_list([data1, data2, data1.clone()]) - assert len(batch) == 2 - assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1])) - assert batch.ptr.equal(torch.tensor([0, 3, 7])) + assert len(batch) == 3 + assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2])) + assert batch.ptr.equal(torch.tensor([0, 3, 7, 10])) assert isinstance(batch.edge_index, EdgeIndex) assert batch.edge_index.equal( - torch.tensor([ - [0, 1, 1, 2, 4, 3, 5, 4, 6, 5], - [1, 0, 2, 1, 3, 4, 4, 5, 5, 6], - ])) - assert batch.edge_index.sparse_size() == (7, 7) + torch.tensor([[0, 1, 1, 2, 4, 3, 5, 4, 6, 5, 7, 8, 8, 9], + [1, 0, 2, 1, 3, 4, 4, 5, 5, 6, 8, 7, 9, 8]])) + assert batch.edge_index.sparse_size() == (10, 10) assert batch.edge_index.sort_order is None assert not batch.edge_index.is_undirected diff --git a/test/data/test_batch_manipulation.py b/test/data/test_batch_manipulation.py new file mode 100644 index 000000000000..8809c02e30a5 --- /dev/null +++ b/test/data/test_batch_manipulation.py @@ -0,0 +1,234 @@ +import torch + +from torch_geometric.data import Batch, Data + +device = torch.device("cpu") + + +def test_batch_set_attr(): + batch_size = 3 + node_range = (1, 10) + nodes_v = torch.randint(*node_range, (batch_size, )) + x_list = [torch.rand(n, 3).to(device) for n in nodes_v] + + batch_list = [Data(x=x) for x in x_list] + batch_truth = Batch.from_data_list(batch_list) + + batch_truth.batch + + batch = Batch.from_empty([g.num_nodes for g in batch_list]) + + batch.set_attr("x", torch.vstack(x_list)) + + compare(batch_truth, batch) + + +def test_batch_set_edge_index(): + batch_size = 4 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + batch_mod = Batch.from_data_list([Data(x=x) for x in x_list]) + batch_mod2 = batch_mod.clone() + + for _ in range(3): + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]).to(device) + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) + for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + + batch_list = [ + Data(x=x, edge_index=edges) + for x, edges in zip(x_list, edges_list) + ] + batch_truth = Batch.from_data_list(batch_list) + + batchidx_per_edge = torch.cat([ + torch.ones(num_edges).long().to(device) * igraph + for igraph, num_edges in enumerate(edges_per_graph) + ]) + batch_mod.set_edge_index(torch.hstack(edges_list), batchidx_per_edge) + batch_mod2.set_edge_index(edges_list) + + compare(batch_truth, batch_mod) + for dt, dn in zip(batch_truth.to_data_list(), + batch_mod.to_data_list()): + assert torch.allclose(dt.x, dn.x) + assert torch.allclose(dt.edge_index, dn.edge_index) + + compare(batch_truth, batch_mod2) + for dt, dn in zip(batch_truth.to_data_list(), + batch_mod2.to_data_list()): + assert torch.allclose(dt.x, dn.x) + assert torch.allclose(dt.edge_index, dn.edge_index) + + +def test_batch_set_edge_attr(): + batch_size = 4 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + + batch_list = [ + Data(x=x, edge_index=edges, edge_attr=ea) + for x, edges, ea in zip(x_list, edges_list, edge_attr_list) + ] + batch_truth = Batch.from_data_list(batch_list) + + batch_truth.batch + batch = Batch.from_empty([g.num_nodes for g in batch_list]) + batch.set_attr("x", torch.vstack(x_list)) + + batchidx_per_edge = torch.cat([ + torch.ones(num_edges).to(device).long() * igraph + for igraph, num_edges in enumerate(edges_per_graph) + ]) + batch.set_edge_index(torch.hstack(edges_list), batchidx_per_edge) + batch.set_attr("edge_attr", torch.hstack(edge_attr_list), "edge") + compare(batch_truth, batch) + + +def test_batch_add_graph_attr(): + batch_size = 3 + node_range = (1, 10) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + x_list = [torch.rand(n, 3).to(device) for n in nodes_v] + graph_attr_list = torch.rand(batch_size).to(device) + + batch_list = [Data(x=x, ga=ga) for x, ga in zip(x_list, graph_attr_list)] + batch_truth = Batch.from_data_list(batch_list) + + batch_truth.batch + + batch = Batch.from_empty([g.num_nodes for g in batch_list]) + + batch.set_attr("x", torch.vstack(x_list)) + batch.set_attr("ga", graph_attr_list, "graph") + compare(batch_truth, batch) + + +def test_from_batch_list(): + batch_size = 12 + node_range = (2, 5) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + graph_attr_list = torch.rand(batch_size).to(device) + + batch_list = [ + Data(x=x, edge_index=edges, + edge_attr=ea, ga=ga) for x, edges, ea, ga in zip( + x_list, edges_list, edge_attr_list, graph_attr_list) + ] + batch_truth = Batch.from_data_list(batch_list) + batch = Batch.from_batch_list([ + Batch.from_data_list(batch_list[:3]), + Batch.from_data_list(batch_list[3:5]), + Batch.from_data_list(batch_list[5:7]), + Batch.from_data_list(batch_list[7:]), + ]) + + compare(batch_truth, batch) + + +def test_batch_slice(): + batch_size = 9 + node_range = (3, 8) + nodes_v = torch.randint(*node_range, (batch_size, )).to(device) + edges_per_graph = torch.cat([ + torch.randint(1, num_nodes, size=(1, )).to(device) + for num_nodes in nodes_v + ]) + x_list = [torch.rand(num_nodes, 3).to(device) for num_nodes in nodes_v] + edges_list = [ + torch.vstack([ + torch.randint(0, num_nodes, size=(num_edges, )), + torch.randint(0, num_nodes, size=(num_edges, )), + ]).to(device) for num_nodes, num_edges in zip(nodes_v, edges_per_graph) + ] + edge_attr_list = [ + torch.rand(num_edges).to(device) for num_edges in edges_per_graph + ] + graph_attr_list = torch.rand(batch_size, 1, 5).to(device) + + batch_list = [ + Data(x=x, edge_index=edges, + edge_attr=ea, ga=ga) for x, edges, ea, ga in zip( + x_list, edges_list, edge_attr_list, graph_attr_list) + ] + bslice = torch.FloatTensor(batch_size).uniform_() > 0.4 + batch_full = Batch.from_data_list(batch_list) + batch_truth = Batch.from_data_list( + [batch_list[e] for e in bslice.nonzero().squeeze()]) + batch_new = batch_full[bslice] + compare(batch_truth, batch_new) + + +def compare(ba: Batch, bb: Batch): + if set(ba.keys()) != set(bb.keys()): + raise Exception() + assert (ba.batch == bb.batch).all() + assert (ba.ptr == bb.ptr).all() + for k in ba.keys(): + try: + rec_comp(ba[k], bb[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed tensor for key {k}.") from e + if k in ba._slice_dict or k in bb._slice_dict: + try: + rec_comp(ba._slice_dict[k], bb._slice_dict[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed _slice_dict for key {k}.") from e + if k in ba._inc_dict or k in bb._inc_dict: + try: + rec_comp(ba._inc_dict[k], bb._inc_dict[k]) + except Exception as e: + raise Exception( + f"Batch comparison failed _inc_dict for key {k}.") from e + + +def rec_comp(a, b): + if not type(a) is type(b): + raise Exception() + if isinstance(a, dict): + if not set(a.keys()) == set(b.keys()): + raise Exception() + for k in a: + rec_comp(a[k], b[k]) + if isinstance(a, torch.Tensor): + if not (a == b).all(): + raise Exception() diff --git a/torch_geometric/data/batch.py b/torch_geometric/data/batch.py index 411639a228d2..485673aee199 100644 --- a/torch_geometric/data/batch.py +++ b/torch_geometric/data/batch.py @@ -1,10 +1,12 @@ import inspect +from collections import defaultdict from collections.abc import Sequence -from typing import Any, List, Optional, Type, Union +from typing import Any, List, Literal, Optional, Tuple, Type, Union import numpy as np import torch -from torch import Tensor +from torch import LongTensor, Tensor +from torch.nn.functional import pad from typing_extensions import Self from torch_geometric.data.collate import collate @@ -79,6 +81,24 @@ class Batch(metaclass=DynamicInheritance): Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together. """ + @classmethod + def from_empty(cls, num_nodes: Union[Tensor, Sequence]) -> Self: + if isinstance(num_nodes, Sequence): + num_nodes = torch.tensor(num_nodes, dtype=torch.long) + if not num_nodes.dtype == torch.long: + raise Exception('`num_nodes` dtype must be torch.long') + if not num_nodes.dim() == 1: + raise Exception('`num_nodes` must have one dimension') + if (num_nodes < 0).any(): + raise Exception('`num_nodes` must be positive') + + batch = Batch() + batch.batch, batch.ptr = cls._batch_ptr_from_num_nodes(cls, num_nodes) + batch._num_graphs = int(batch.batch.max() + 1) + batch._slice_dict = defaultdict(dict) + batch._inc_dict = defaultdict(dict) + return batch + @classmethod def from_data_list( cls, @@ -89,7 +109,7 @@ def from_data_list( r"""Constructs a :class:`~torch_geometric.data.Batch` object from a list of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects. - The assignment vector :obj:`batch` is created on the fly. + The assignment vector :obj:`batch` is adjusted on the fly. In addition, creates assignment vectors for each key in :obj:`follow_batch`. Will exclude any keys given in :obj:`exclude_keys`. @@ -109,6 +129,42 @@ def from_data_list( return batch + @classmethod + def from_batch_list(cls, batches: List[Self]) -> Self: + r"""Same as :meth:`~Batch.from_data_list```, + but for concatenating existing batches. + Constructs a :class:`~torch_geometric.data.Batch` object from a + list of :class:`~torch_geometric.data.Batch` objects. + The assignment vector :obj:`batch` is created on the fly. + In addition, creates assignment vectors for each key in + :obj:`follow_batch`. + Will exclude any keys given in :obj:`exclude_keys`. + """ + batch = cls.from_data_list(batches) + + del batch._slice_dict['batch'], batch._inc_dict['batch'] + + batch.batch, batch.ptr = cls._batch_ptr_from_num_nodes( + cls, torch.concat([g.ptr.diff() for g in batches])) + + for k in set(batch.keys()) - {'batch', 'ptr'}: + batch._slice_dict[k] = batch._pad_zero( + torch.concat([be._slice_dict[k].diff() + for be in batches]).cumsum(0)) + if k != 'edge_index': + inc_shift = batch._pad_zero( + torch.tensor([sum(be._inc_dict[k]) + for be in batches])).cumsum(0) + else: + inc_shift = batch._pad_zero( + torch.tensor([be.num_nodes for be in batches])).cumsum(0) + + batch._inc_dict[k] = torch.cat([ + be._inc_dict[k] + inc_shift[ibatch] + for ibatch, be in enumerate(batches) + ]) + return batch + def get_example(self, idx: int) -> BaseData: r"""Gets the :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`. @@ -124,15 +180,16 @@ def get_example(self, idx: int) -> BaseData: data = separate( cls=self.__class__.__bases__[-1], batch=self, - idx=idx, + idx=torch.tensor([idx]).long(), slice_dict=getattr(self, '_slice_dict'), inc_dict=getattr(self, '_inc_dict'), decrement=True, + return_batch=False, ) return data - def index_select(self, idx: IndexType) -> List[BaseData]: + def index_select(self, idx: IndexType) -> Self: r"""Creates a subset of :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects from specified indices :obj:`idx`. @@ -143,42 +200,87 @@ def index_select(self, idx: IndexType) -> List[BaseData]: via :meth:`from_data_list` in order to be able to reconstruct the initial objects. """ - index: Sequence[int] - if isinstance(idx, slice): - index = list(range(self.num_graphs)[idx]) + if not isinstance(idx, (slice, Sequence, torch.Tensor, np.ndarray)): + raise IndexError( + f"Only slices (':'), list, tuples, torch.tensor and " + f'np.ndarray of dtype long or bool are valid indices (got ' + f"'{type(idx).__name__}')") - elif isinstance(idx, Tensor) and idx.dtype == torch.long: - index = idx.flatten().tolist() + index: torch.Tensor + def _to_tidx(o): + return torch.tensor(o, device=self.ptr.device) + + # convert numpt to torch tensors + if isinstance(idx, np.ndarray): + idx = _to_tidx(idx) + if isinstance(idx, slice): + index = _to_tidx(range(self.num_graphs)[idx]).long() + elif isinstance(idx, Sequence): + index = _to_tidx(idx).long() + elif isinstance(idx, Tensor) and idx.dtype == torch.long: + index = idx elif isinstance(idx, Tensor) and idx.dtype == torch.bool: - index = idx.flatten().nonzero(as_tuple=False).flatten().tolist() + if not len(idx) == self.num_graphs: + IndexError( + f'Boolen vector length does not match number of graphs' + f' (got {len(idx)} vector size ' + f'vs. {self.num_graphs} graphs).') + index = idx.nonzero().flatten() + else: + raise IndexError( + f"Could not convert index (got '{type(idx).__name__}')") - elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: - index = idx.flatten().tolist() + if index.dim() != 1: + raise IndexError( + f'Index must have a single dimension (got {index.dim()})') - elif isinstance(idx, np.ndarray) and idx.dtype == bool: - index = idx.flatten().nonzero()[0].flatten().tolist() + self.ptr.device - elif isinstance(idx, Sequence) and not isinstance(idx, str): - index = idx + subbatch = separate( + cls=self.__class__, + batch=self, + idx=index, + slice_dict=self._slice_dict, + inc_dict=self._inc_dict, + decrement=True, + ) - else: - raise IndexError( - f"Only slices (':'), list, tuples, torch.tensor and " - f"np.ndarray of dtype long or bool are valid indices (got " - f"'{type(idx).__name__}')") + nodes_per_graph = self.ptr.diff() + num_nodes = nodes_per_graph[index] - return [self.get_example(i) for i in index] + subbatch.batch, subbatch.ptr = self._batch_ptr_from_num_nodes( + num_nodes) + + # fix the _slice_dict and _inc_dict + subbatch._slice_dict = defaultdict(dict) + subbatch._inc_dict = defaultdict(dict) + for k in set(self.keys()) - {'ptr', 'batch'}: + if k not in self._slice_dict: + continue + subbatch._slice_dict[k] = pad(self._slice_dict[k].diff()[index], + (1, 0)).cumsum(0) + if k not in self._inc_dict: + continue + if self._inc_dict[k] is None: + subbatch._inc_dict[k] = None + continue + subbatch._inc_dict[k] = pad(self._inc_dict[k].diff()[index[:-1]], + (1, 0)).cumsum(0) + return subbatch def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: + # Return single Graph if (isinstance(idx, (int, np.integer)) or (isinstance(idx, Tensor) and idx.dim() == 0) or (isinstance(idx, np.ndarray) and np.isscalar(idx))): return self.get_example(idx) # type: ignore + # Return stored objects elif isinstance(idx, str) or (isinstance(idx, tuple) and isinstance(idx[0], str)): # Accessing attributes or node/edge types: return super().__getitem__(idx) # type: ignore + # Return subset of the batch else: return self.index_select(idx) @@ -192,6 +294,115 @@ def to_data_list(self) -> List[BaseData]: """ return [self.get_example(i) for i in range(self.num_graphs)] + def set_attr(self, attrname: str, attr: Tensor, + attrtype: Literal["node", "graph", "edge"] = "node") -> None: + r"""Set an attribute for the nodes, graphs, or edges. + + Args: + attrname (str): Name of the attribute. + attr (torch.Tensor): The attribute tensor. + attrtype (str): Indicates if the attribution belongs to the nodes, + graphs or edges (`node`, `graph`, `edge`), default: `node`. + + """ + if attrname == "edge_attr" and attrtype != "edge": + raise Exception( + "For the attribute `edge_attr`, the `attrtype` must be `edge`." + ) + + if attrname == "edge_index": + raise Exception( + """To overwrite the edges, use `set_edge_index`.""") + + assert attrtype in ["node", "graph", "edge"] + + assert attr.device == self.batch.device + batch_idxs = self.batch + + self[attrname] = attr + + if attrtype == "node": + assert attr.shape[0] == self.num_nodes + out = batch_idxs.unique(return_counts=True)[1] + out = out.cumsum(dim=0) + self._slice_dict[attrname] = self._pad_zero(out).cpu() + + self._inc_dict[attrname] = torch.zeros(self._num_graphs, + dtype=torch.long) + elif attrtype == "graph": + assert attr.shape[0] == self.num_graphs + self._slice_dict[attrname] = torch.arange(self.num_graphs + 1, + dtype=torch.long) + self._inc_dict[attrname] = torch.zeros(self.num_graphs, + dtype=torch.long) + elif attrtype == "edge": + assert attr.shape[0] == self.num_edges + assert (hasattr(self, 'edge_index') + and self['edge_index'].dtype == torch.long) + self._slice_dict['edge_attr'] = self._slice_dict['edge_index'] + self._inc_dict['edge_attr'] = torch.zeros(self.num_graphs) + else: + raise NotImplementedError() + + def set_edge_index(self, edge_index: List[LongTensor] | LongTensor, + batchidx_per_edge: None | LongTensor = None) -> None: + r"""Overwrites the :obj:`edge_index`. + :obj:`~Batch.ptr` will be used to assign the elements to + the correct graph. + + Args: + edge_index (List[LongTensor] | LongTensor): Either a list of the + new edges for each graph, or a tensor containing + the new edges. In this case the assignment to the graphs + must be give by `batchidx_per_edge`. + batchidx_per_edge (LongTensor): The index tensor that maps + each of the edges to a graph. + """ + if isinstance(edge_index, list): + device = edge_index[0].device + edges_per_graph = torch.tensor([e.shape[1] for e in edge_index]) + batchidx_per_edge = torch.arange( + self.num_graphs).repeat_interleave(edges_per_graph).to(device) + edge_index = torch.hstack(edge_index) + else: + edges_per_graph = batchidx_per_edge.unique(return_counts=True)[1] + + assert (batchidx_per_edge.diff() + >= 0).all(), 'Edges must be ordered by batch' + assert batchidx_per_edge.shape == torch.Size( + (edge_index.shape[1], )) + + assert edge_index.dim() == 2 + assert edge_index.shape[0] == 2 + + assert edge_index.dtype == batchidx_per_edge.dtype == torch.long + assert (edge_index.device == batchidx_per_edge.device == + self.batch.device) + + # Edges must be shifted by the number sum of the nodes + # in the previous graphs + self.edge_index = edge_index + self.ptr[batchidx_per_edge] + # Fix _slice_dict + self._slice_dict['edge_index'] = self._pad_zero( + edges_per_graph.cumsum(0)).cpu() + self._inc_dict['edge_index'] = self.ptr[:-1].cpu() + + def _pad_zero(self, arr: torch.Tensor) -> torch.Tensor: + return torch.cat([ + torch.tensor(0, dtype=arr.dtype, device=arr.device).unsqueeze(0), + arr, + ]) + + def _batch_ptr_from_num_nodes( + self, num_nodes: Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if not num_nodes.dtype == torch.long: + raise TypeError("Argument must have type `torch.long`") + batch = torch.arange( + len(num_nodes), + device=num_nodes.device).repeat_interleave(num_nodes) + ptr = pad(num_nodes, (1, 0), "constant", 0).cumsum(0) + return batch, ptr + @property def num_graphs(self) -> int: """Returns the number of graphs in the batch.""" @@ -202,7 +413,7 @@ def num_graphs(self) -> int: elif hasattr(self, 'batch'): return int(self.batch.max()) + 1 else: - raise ValueError("Can not infer the number of graphs") + raise ValueError('Can not infer the number of graphs') @property def batch_size(self) -> int: diff --git a/torch_geometric/data/in_memory_dataset.py b/torch_geometric/data/in_memory_dataset.py index 9f307bf14d5d..3e54d71af89b 100644 --- a/torch_geometric/data/in_memory_dataset.py +++ b/torch_geometric/data/in_memory_dataset.py @@ -107,14 +107,11 @@ def get(self, idx: int) -> BaseData: self._data_list = self.len() * [None] elif self._data_list[idx] is not None: return copy.copy(self._data_list[idx]) + self._data._num_graphs = self.len() - data = separate( - cls=self._data.__class__, - batch=self._data, - idx=idx, - slice_dict=self.slices, - decrement=False, - ) + data = separate(cls=self._data.__class__, batch=self._data, + idx=torch.tensor([idx]).long(), slice_dict=self.slices, + decrement=False, return_batch=False) self._data_list[idx] = copy.copy(data) @@ -158,6 +155,7 @@ def collate( increment=False, add_batch=False, ) + data._num_graphs = len(data_list) return data, slices diff --git a/torch_geometric/data/separate.py b/torch_geometric/data/separate.py index 2910b6679f60..d3522941605e 100644 --- a/torch_geometric/data/separate.py +++ b/torch_geometric/data/separate.py @@ -1,24 +1,26 @@ from collections.abc import Mapping, Sequence from typing import Any, Type, TypeVar +import torch from torch import Tensor +from torch.nn.functional import pad from torch_geometric import EdgeIndex, Index from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from torch_geometric.typing import SparseTensor, TensorFrame -from torch_geometric.utils import narrow T = TypeVar('T') def separate( cls: Type[T], - batch: Any, - idx: int, + batch: BaseData, + idx: torch.Tensor, slice_dict: Any, inc_dict: Any = None, decrement: bool = True, + return_batch: bool = True, ) -> T: # Separates the individual element from a `batch` at index `idx`. # `separate` can handle both homogeneous and heterogeneous data objects by @@ -50,8 +52,16 @@ def separate( # The `num_nodes` attribute needs special treatment, as we cannot infer # the real number of nodes from the total number of nodes alone: + # TODO make less ugly if hasattr(batch_store, '_num_nodes'): - data_store.num_nodes = batch_store._num_nodes[idx] + if return_batch: + data_store._num_nodes = [ + batch_store._num_nodes[i] for i in idx + ] + else: + data_store._num_nodes = batch_store._num_nodes[idx[0]] + if hasattr(batch_store, 'num_nodes'): + data_store.num_nodes = data_store._num_nodes return data @@ -59,43 +69,82 @@ def separate( def _separate( key: str, values: Any, - idx: int, + idx: torch.Tensor, slices: Any, incs: Any, batch: BaseData, store: BaseStorage, decrement: bool, ) -> Any: - if isinstance(values, Tensor): + idx = idx.to(values.device) + graph_slice = torch.concat([ + torch.arange(int(slices[i]), int(slices[i + 1])) for i in idx + ]).to(values.device) + valid_inc = incs is not None and (incs.dim() > 1 + or any(incs[idx] != 0)) + # Narrow a `torch.Tensor` based on `slices`. # NOTE: We need to take care of decrementing elements appropriately. key = str(key) cat_dim = batch.__cat_dim__(key, values, store) - start, end = int(slices[idx]), int(slices[idx + 1]) - value = narrow(values, cat_dim or 0, start, end - start) + value = torch.index_select(values, cat_dim or 0, graph_slice) value = value.squeeze(0) if cat_dim is None else value - if isinstance(values, Index) and values._cat_metadata is not None: - # Reconstruct original `Index` metadata: - value._dim_size = values._cat_metadata.dim_size[idx] - value._is_sorted = values._cat_metadata.is_sorted[idx] - - if isinstance(values, EdgeIndex) and values._cat_metadata is not None: - # Reconstruct original `EdgeIndex` metadata: - value._sparse_size = values._cat_metadata.sparse_size[idx] - value._sort_order = values._cat_metadata.sort_order[idx] - value._is_undirected = values._cat_metadata.is_undirected[idx] - - if (decrement and incs is not None - and (incs.dim() > 1 or int(incs[idx]) != 0)): - value = value - incs[idx].to(value.device) + if (decrement and incs is not None and valid_inc): + # remove the old offset + nelem_new = slices.diff()[idx] + if len(idx) == 1: + old_offset = incs[idx[0]] + new_offset = torch.zeros_like(old_offset) + shift = torch.ones_like(value) * (-old_offset + new_offset) + else: + old_offset = incs[idx] + # add the new offset + # for this we compute the number of nodes in the batch before + new_offset = pad(incs.diff()[idx[:-1]], (1, 0)).cumsum(0) + shift = (-old_offset + new_offset).repeat_interleave( + nelem_new, dim=cat_dim or 0) + value = value + shift + + if hasattr(values, + "_cat_metadata") and values._cat_metadata is not None: + + def _pad_diff(a): + a = torch.tensor(a) + return torch.concat([a[:1], a.diff(dim=0)]) + + if isinstance(values, Index): + # Reconstruct original `Index` metadata: + if decrement: + value._dim_size = _pad_diff( + values._cat_metadata.dim_size)[idx].squeeze() + else: + value._dim_size = values._cat_metadata.dim_size[idx] + value._is_sorted = values._cat_metadata.is_sorted[idx] + + if isinstance(values, EdgeIndex): + # Reconstruct original `EdgeIndex` metadata: + def _to_tup(a): + return tuple(a.flatten().tolist()) + + if decrement: + value._sparse_size = _to_tup( + _pad_diff(values._cat_metadata.sparse_size)[idx]) + else: + value._sparse_size = values._cat_metadata.sparse_size[idx] + value._sort_order = values._cat_metadata.sort_order[idx] + value._is_undirected = values._cat_metadata.is_undirected[idx] return value elif isinstance(values, SparseTensor) and decrement: # Narrow a `SparseTensor` based on `slices`. # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking. + if len(idx) > 1: + raise NotImplementedError + idx = idx[0] + key = str(key) cat_dim = batch.__cat_dim__(key, values, store) cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim @@ -107,8 +156,8 @@ def _separate( elif isinstance(values, TensorFrame): key = str(key) start, end = int(slices[idx]), int(slices[idx + 1]) - value = values[start:end] - return value + values = values[start:end] + return values elif isinstance(values, Mapping): # Recursively separate elements of dictionaries. @@ -131,6 +180,9 @@ def _separate( and not isinstance(values[0], str) and len(values[0]) > 0 and isinstance(values[0][0], (Tensor, SparseTensor)) and isinstance(slices, Sequence)): + if len(idx) > 1: + raise NotImplementedError + idx = idx[0] # Recursively separate elements of lists of lists. return [value[idx] for value in values] @@ -150,6 +202,10 @@ def _separate( decrement=decrement, ) for i, value in enumerate(values) ] - + elif isinstance(values, list) and batch._num_graphs == len(values): + if len(idx) == 1: + return values[idx[0]] + else: + return [values[i] for i in idx] else: return values[idx] diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index 07a52fdfc21a..c71ba627ffdb 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -420,7 +420,7 @@ def can_infer_num_nodes(self) -> bool: @property def num_nodes(self) -> Optional[int]: # We sequentially access attributes that reveal the number of nodes. - if 'num_nodes' in self: + if 'num_nodes' in self._mapping: return self['num_nodes'] for key, value in self.items(): if isinstance(value, Tensor) and key in N_KEYS: