From 1e2bd6ff0676f8dc7aca0ab8ff39f7ff400c4e2e Mon Sep 17 00:00:00 2001 From: chang-l Date: Thu, 17 Oct 2024 17:19:04 -0700 Subject: [PATCH] Minor fix for typos and comments --- .../distributed/wholegraph/feature_store.py | 17 +------------ .../distributed/wholegraph/graph_store.py | 25 ++++++++----------- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/examples/distributed/wholegraph/feature_store.py b/examples/distributed/wholegraph/feature_store.py index c2d7ee9efd44..895adfe48258 100644 --- a/examples/distributed/wholegraph/feature_store.py +++ b/examples/distributed/wholegraph/feature_store.py @@ -42,7 +42,7 @@ def __init__(self, pyg_data): self._store = { } # A dictionary of tuple to hold the feature embeddings - if dist_shmem.get_local_rank() == dist.get_rank(): + if dist_shmem.get_local_size() == dist.get_world_size(): self.backend = 'vmm' else: self.backend = 'vmm' if nvlink_network() else 'nccl' @@ -63,17 +63,6 @@ def __init__(self, pyg_data): self.put_tensor(pyg_data[group_name][attr_name], group_name=group_name, attr_name=attr_name, index=None) - # This is a hack for MAG240M dataset, to add node features for 'institution' and 'author' nodes. - # This should not be presented in the upstream code. - elif attr_name == 'num_nodes': - feature_dim = 768 - num_nodes = group[attr_name] - shape = [num_nodes, feature_dim] - self[group_name, 'x', - None] = DistEmbedding(shape=shape, - dtype=torch.float16, - device="cpu", - backend=self.backend) else: raise TypeError( "Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData." @@ -82,10 +71,6 @@ def __init__(self, pyg_data): def _put_tensor(self, tensor: torch.Tensor, attr): """Creates and stores features (either DistTensor or DistEmbedding) from the given tensor, using a key derived from the group and attribute name. - - Args: - tensor (torch.Tensor): The tensor to be passed to the feature store. - attr: PyG's TensorAttr to fully specify each feature store. """ key = (attr.group_name, attr.attr_name) out = self._store.get(key) diff --git a/examples/distributed/wholegraph/graph_store.py b/examples/distributed/wholegraph/graph_store.py index de2cbe11dba0..c1cdb74394b3 100644 --- a/examples/distributed/wholegraph/graph_store.py +++ b/examples/distributed/wholegraph/graph_store.py @@ -16,7 +16,7 @@ class WholeGraphEdgeAttr(EdgeAttr): def __init__( self, edge_type: Optional[ - EdgeType] = None, # use string to represent edge type for simplicity + EdgeType] = None, is_sorted: bool = False, size: Optional[Tuple[int, int]] = None, ): @@ -41,7 +41,7 @@ def __init__(self, pyg_data, format='wholegraph'): if format == 'wholegraph': pinned_shared = False - if dist_shmem.get_local_rank() == dist.get_rank(): + if dist_shmem.get_local_size() == dist.get_world_size(): backend = 'vmm' else: backend = 'vmm' if nvlink_network() else 'nccl' @@ -57,12 +57,11 @@ def __init__(self, pyg_data, format='wholegraph'): if 'adj_t' not in pyg_data: row, col = None, None if dist_shmem.get_local_rank() == 0: - row, col, _ = pyg_data.csc() # discard permutation for now + row, col, _ = pyg_data.csc() row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = pyg_data.size() else: - # issue: it wont work if adj_t is a SparseTensor col = pyg_data.adj_t.crow_indices() row = pyg_data.adj_t.col_indices() size = pyg_data.adj_t.size()[::-1] @@ -87,12 +86,11 @@ def __init__(self, pyg_data, format='wholegraph'): row, col = None, None if dist_shmem.get_local_rank() == 0: row, col, _ = edge_store.csc( - ) # discard permutation for now + ) row = dist_shmem.to_shmem(row) col = dist_shmem.to_shmem(col) size = edge_store.size() else: - # issue: this will also if adj_t is a SparseTensor col = edge_store.adj_t.crow_indices() row = edge_store.adj_t.col_indices() size = edge_store.adj_t.size()[::-1] @@ -106,21 +104,18 @@ def __init__(self, pyg_data, format='wholegraph'): self.put_adj_t(graph, edge_type=edge_type, size=size) def put_adj_t(self, adj_t: DistGraphCSC, *args, **kwargs) -> bool: - r"""Synchronously adds an :obj:`edge_index` tuple to the - :class:`GraphStore`. + """Add an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC` + to :class:`WholeGraphGraphStore`. Returns whether insertion was successful. - - Args: - edge_index (Tuple[torch.Tensor, torch.Tensor]): The - :obj:`edge_index` tuple in a format specified in - :class:`EdgeAttr`. - *args: Arguments passed to :class:`EdgeAttr`. - **kwargs: Keyword arguments passed to :class:`EdgeAttr`. """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) return self._put_adj_t(adj_t, edge_attr) def get_adj_t(self, *args, **kwargs) -> DistGraphCSC: + """Retrieves an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC` + from :class:`WholeGraphGraphStore`. + Return: :obj:`DistGraphCSC` + """ edge_attr = self._edge_attr_cls.cast(*args, **kwargs) graph_adj_t = self._get_adj_t(edge_attr) if graph_adj_t is None: