Skip to content

Commit

Permalink
Minor fix for typos and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Oct 18, 2024
1 parent 9f170b0 commit 1e2bd6f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 31 deletions.
17 changes: 1 addition & 16 deletions examples/distributed/wholegraph/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, pyg_data):
self._store = {
} # A dictionary of tuple to hold the feature embeddings

if dist_shmem.get_local_rank() == dist.get_rank():
if dist_shmem.get_local_size() == dist.get_world_size():
self.backend = 'vmm'
else:
self.backend = 'vmm' if nvlink_network() else 'nccl'
Expand All @@ -63,17 +63,6 @@ def __init__(self, pyg_data):
self.put_tensor(pyg_data[group_name][attr_name],
group_name=group_name,
attr_name=attr_name, index=None)
# This is a hack for MAG240M dataset, to add node features for 'institution' and 'author' nodes.
# This should not be presented in the upstream code.
elif attr_name == 'num_nodes':
feature_dim = 768
num_nodes = group[attr_name]
shape = [num_nodes, feature_dim]
self[group_name, 'x',
None] = DistEmbedding(shape=shape,
dtype=torch.float16,
device="cpu",
backend=self.backend)
else:
raise TypeError(
"Expected pyg_data to be of type torch_geometric.data.Data or torch_geometric.data.HeteroData."
Expand All @@ -82,10 +71,6 @@ def __init__(self, pyg_data):
def _put_tensor(self, tensor: torch.Tensor, attr):
"""Creates and stores features (either DistTensor or DistEmbedding) from the given tensor,
using a key derived from the group and attribute name.
Args:
tensor (torch.Tensor): The tensor to be passed to the feature store.
attr: PyG's TensorAttr to fully specify each feature store.
"""
key = (attr.group_name, attr.attr_name)
out = self._store.get(key)
Expand Down
25 changes: 10 additions & 15 deletions examples/distributed/wholegraph/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class WholeGraphEdgeAttr(EdgeAttr):
def __init__(
self,
edge_type: Optional[
EdgeType] = None, # use string to represent edge type for simplicity
EdgeType] = None,
is_sorted: bool = False,
size: Optional[Tuple[int, int]] = None,
):
Expand All @@ -41,7 +41,7 @@ def __init__(self, pyg_data, format='wholegraph'):

if format == 'wholegraph':
pinned_shared = False
if dist_shmem.get_local_rank() == dist.get_rank():
if dist_shmem.get_local_size() == dist.get_world_size():
backend = 'vmm'
else:
backend = 'vmm' if nvlink_network() else 'nccl'
Expand All @@ -57,12 +57,11 @@ def __init__(self, pyg_data, format='wholegraph'):
if 'adj_t' not in pyg_data:
row, col = None, None
if dist_shmem.get_local_rank() == 0:
row, col, _ = pyg_data.csc() # discard permutation for now
row, col, _ = pyg_data.csc()
row = dist_shmem.to_shmem(row)
col = dist_shmem.to_shmem(col)
size = pyg_data.size()
else:
# issue: it wont work if adj_t is a SparseTensor
col = pyg_data.adj_t.crow_indices()
row = pyg_data.adj_t.col_indices()
size = pyg_data.adj_t.size()[::-1]
Expand All @@ -87,12 +86,11 @@ def __init__(self, pyg_data, format='wholegraph'):
row, col = None, None
if dist_shmem.get_local_rank() == 0:
row, col, _ = edge_store.csc(
) # discard permutation for now
)
row = dist_shmem.to_shmem(row)
col = dist_shmem.to_shmem(col)
size = edge_store.size()
else:
# issue: this will also if adj_t is a SparseTensor
col = edge_store.adj_t.crow_indices()
row = edge_store.adj_t.col_indices()
size = edge_store.adj_t.size()[::-1]
Expand All @@ -106,21 +104,18 @@ def __init__(self, pyg_data, format='wholegraph'):
self.put_adj_t(graph, edge_type=edge_type, size=size)

def put_adj_t(self, adj_t: DistGraphCSC, *args, **kwargs) -> bool:
r"""Synchronously adds an :obj:`edge_index` tuple to the
:class:`GraphStore`.
"""Add an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC`
to :class:`WholeGraphGraphStore`.
Returns whether insertion was successful.
Args:
edge_index (Tuple[torch.Tensor, torch.Tensor]): The
:obj:`edge_index` tuple in a format specified in
:class:`EdgeAttr`.
*args: Arguments passed to :class:`EdgeAttr`.
**kwargs: Keyword arguments passed to :class:`EdgeAttr`.
"""
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
return self._put_adj_t(adj_t, edge_attr)

def get_adj_t(self, *args, **kwargs) -> DistGraphCSC:
"""Retrieves an adj_t (adj with transpose) matrix, :obj:`DistGraphCSC`
from :class:`WholeGraphGraphStore`.
Return: :obj:`DistGraphCSC`
"""
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
graph_adj_t = self._get_adj_t(edge_attr)
if graph_adj_t is None:
Expand Down

0 comments on commit 1e2bd6f

Please sign in to comment.