From 728de8aa879d19603d4b2bfb3accc6de5d1bcce8 Mon Sep 17 00:00:00 2001 From: Castellana Date: Mon, 19 Aug 2024 18:33:26 +0200 Subject: [PATCH 01/17] Update negative sampling to work directly on GPU --- test/utils/test_negative_sampling.py | 101 ++++-- torch_geometric/utils/_negative_sampling.py | 358 ++++++++++++-------- 2 files changed, 290 insertions(+), 169 deletions(-) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index 709452fe60a5..6bd50302f38e 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -8,10 +8,13 @@ structured_negative_sampling, structured_negative_sampling_feasible, to_undirected, + erdos_renyi_graph, + stochastic_blockmodel_graph ) + from torch_geometric.utils._negative_sampling import ( - edge_index_to_vector, - vector_to_edge_index, + edge_index_to_vector_id, + vector_id_to_edge_index ) @@ -31,35 +34,19 @@ def is_negative(edge_index, neg_edge_index, size, bipartite): def test_edge_index_to_vector_and_vice_versa(): # Create a fully-connected graph: - N = 10 - row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1) - col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1) + N1, N2 = 13, 17 + row = torch.arange(N1).view(-1, 1).repeat(1, N2).view(-1) + col = torch.arange(N2).view(1, -1).repeat(N1, 1).view(-1) edge_index = torch.stack([row, col], dim=0) - idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True) - assert population == N * N - assert idx.tolist() == list(range(population)) - edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True) - assert is_undirected(edge_index2) + idx = edge_index_to_vector_id(edge_index, (N1, N2)) + assert idx.tolist() == list(range(N1*N2)) + edge_index2 = torch.stack(vector_id_to_edge_index(idx, (N1, N2)), dim=0) assert edge_index.tolist() == edge_index2.tolist() - idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False) - assert population == N * N - N - assert idx.tolist() == list(range(population)) - mask = edge_index[0] != edge_index[1] # Remove self-loops. - edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False) - assert is_undirected(edge_index2) - assert edge_index[:, mask].tolist() == edge_index2.tolist() - - idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False, - force_undirected=True) - assert population == (N * (N + 1)) / 2 - N - assert idx.tolist() == list(range(population)) - mask = edge_index[0] != edge_index[1] # Remove self-loops. - edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False, - force_undirected=True) - assert is_undirected(edge_index2) - assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist() + vector_id = torch.arange(N1*N2) + edge_index3 = torch.stack(vector_id_to_edge_index(vector_id, (N1, N2)), dim=0) + assert edge_index.tolist() == edge_index3.tolist() def test_negative_sampling(): @@ -69,10 +56,6 @@ def test_negative_sampling(): assert neg_edge_index.size(1) == edge_index.size(1) assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) - neg_edge_index = negative_sampling(edge_index, method='dense') - assert neg_edge_index.size(1) == edge_index.size(1) - assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) - neg_edge_index = negative_sampling(edge_index, num_neg_samples=2) assert neg_edge_index.size(1) == 2 assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) @@ -97,6 +80,26 @@ def test_bipartite_negative_sampling(): assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) +def test_negative_sampling_with_different_edge_density(): + for num_nodes in [10, 100, 1000]: + for p in [0.1, 0.3, 0.5, 0.8]: + for is_directed in [False, True]: + edge_index = erdos_renyi_graph(num_nodes, p, is_directed) + neg_edge_index = negative_sampling(edge_index, num_nodes, force_undirected=not is_directed) + assert is_negative(edge_index, neg_edge_index, (num_nodes, num_nodes), bipartite=False) + + +def test_bipartite_negative_sampling_with_different_edge_density(): + for num_nodes in [10, 100, 1000]: + for p in [0.1, 0.3, 0.5, 0.8]: + size = (num_nodes, int(num_nodes*1.2)) + n_edges = int(p * size[0] * size[1]) + row, col = torch.randint(size[0], (n_edges,)), torch.randint(size[1], (n_edges,)) + edge_index = torch.stack([row, col], dim=0) + neg_edge_index = negative_sampling(edge_index, size) + assert is_negative(edge_index, neg_edge_index, size, bipartite=True) + + def test_batched_negative_sampling(): edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) edge_index = torch.cat([edge_index, edge_index + 4], dim=1) @@ -153,13 +156,47 @@ def test_structured_negative_sampling(): assert (adj & neg_adj).sum() == 0 # Test with no self-loops: - edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) + #edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) i, j, k = structured_negative_sampling(edge_index, num_nodes=4, contains_neg_self_loops=False) neg_edge_index = torch.vstack([i, k]) assert not contains_self_loops(neg_edge_index) +def test_structured_negative_sampling_sparse(): + num_nodes = 1000 + edge_index = erdos_renyi_graph(num_nodes, 0.1) + + i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, contains_neg_self_loops=True) + assert i.size(0) == edge_index.size(1) + assert j.size(0) == edge_index.size(1) + assert k.size(0) == edge_index.size(1) + + assert torch.all(torch.ne(k, -1)) + adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + adj[i, j] = 1 + + neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + neg_adj[i, k] = 1 + assert (adj & neg_adj).sum() == 0 + + i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, contains_neg_self_loops=False) + assert i.size(0) == edge_index.size(1) + assert j.size(0) == edge_index.size(1) + assert k.size(0) == edge_index.size(1) + + assert torch.all(torch.ne(k, -1)) + adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + adj[i, j] = 1 + + neg_adj = torch.zeros(num_nodes, num_nodes, dtype=torch.bool) + neg_adj[i, k] = 1 + assert (adj & neg_adj).sum() == 0 + + neg_edge_index = torch.vstack([i, k]) + assert not contains_self_loops(neg_edge_index) + + def test_structured_negative_sampling_feasible(): edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 3deb6bcec3d0..30a81a40994a 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -1,19 +1,23 @@ import random +import warnings from typing import Optional, Tuple, Union import numpy as np import torch from torch import Tensor -from torch_geometric.utils import coalesce, cumsum, degree, remove_self_loops +from torch_geometric.utils import coalesce, cumsum, degree, remove_self_loops, to_undirected, index_sort from torch_geometric.utils.num_nodes import maybe_num_nodes +_MAX_NUM_EDGES = 10**4 + + def negative_sampling( edge_index: Tensor, num_nodes: Optional[Union[int, Tuple[int, int]]] = None, num_neg_samples: Optional[int] = None, - method: str = "sparse", + method: str = "auto", force_undirected: bool = False, ) -> Tensor: r"""Samples random negative edges of a graph given by :attr:`edge_index`. @@ -30,11 +34,12 @@ def negative_sampling( If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (str, optional): The method to use for negative sampling, - *i.e.* :obj:`"sparse"` or :obj:`"dense"`. + *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. This is a memory/runtime trade-off. - :obj:`"sparse"` will work on any graph of any size, while - :obj:`"dense"` can perform faster true-negative checks. - (default: :obj:`"sparse"`) + :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples + :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"auto"` will automatically choose the best method + (default: :obj:`"auto"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) @@ -53,7 +58,7 @@ def negative_sampling( tensor([[0, 2, 2, 1], [2, 2, 1, 3]]) """ - assert method in ['sparse', 'dense'] + assert method in ['sparse', 'dense', 'auto'] if num_nodes is None: num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -66,52 +71,56 @@ def negative_sampling( bipartite = True force_undirected = False - idx, population = edge_index_to_vector(edge_index, size, bipartite, - force_undirected) - - if idx.numel() >= population: - return edge_index.new_empty((2, 0)) + num_edges = edge_index.size(1) + num_tot_edges = (size[0] * size[1]) if num_neg_samples is None: - num_neg_samples = edge_index.size(1) + num_neg_samples = num_edges + if force_undirected: num_neg_samples = num_neg_samples // 2 - prob = 1. - idx.numel() / population # Probability to sample a negative. - sample_size = int(1.1 * num_neg_samples / prob) # (Over)-sample size. - - neg_idx: Optional[Tensor] = None - if method == 'dense': - # The dense version creates a mask of shape `population` to check for - # invalid samples. - mask = idx.new_ones(population, dtype=torch.bool) - mask[idx] = False - for _ in range(3): # Number of tries to sample negative indices. - rnd = sample(population, sample_size, idx.device) - rnd = rnd[mask[rnd]] # Filter true negatives. - neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) - if neg_idx.numel() >= num_neg_samples: - neg_idx = neg_idx[:num_neg_samples] - break - mask[neg_idx] = False - - else: # 'sparse' - # The sparse version checks for invalid samples via `np.isin`. - idx = idx.to('cpu') - for _ in range(3): # Number of tries to sample negative indices. - rnd = sample(population, sample_size, device='cpu') - mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore - if neg_idx is not None: - mask |= np.isin(rnd, neg_idx.to('cpu')) - mask = torch.from_numpy(mask).to(torch.bool) - rnd = rnd[~mask].to(edge_index.device) - neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd]) - if neg_idx.numel() >= num_neg_samples: - neg_idx = neg_idx[:num_neg_samples] - break - - assert neg_idx is not None - return vector_to_edge_index(neg_idx, size, bipartite, force_undirected) + # transform a pair (u,v) in an edge id + edge_id = edge_index_to_vector_id(edge_index, size) + edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) # TODO: is this O(1) if the input is already sorted? + + method = get_method(method, size) + + k = None + prob = 1 - (num_edges / num_tot_edges) + if method == 'sparse': + if prob >= 0.3: + # the probability of sampling non-existing edge is high, so the sparse method should be ok + k = int(num_neg_samples / (prob - 0.1)) + else: + # the probability is too low, but the graph is too big for the exact sampling. + # we perform the sparse sampling but we raise a warning + k = int(min(10 * num_neg_samples, max(num_neg_samples / (prob - 0.1), 10 ** -3))) + + warnings.warn('The probability of sampling a negative edge is too low! ' + 'It could be that the number of sampled edges is smaller than the numbers you required!') + + guess_edge_index, guess_edge_id = sample_almost_k_edges(size, k, + force_undirected=force_undirected, + remove_self_loops=not bipartite, + method=method, device=edge_index.device) + + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) + + # we fiter the guessed id to maintain only the negative ones + neg_edge_index = guess_edge_index[:, neg_edge_mask] + + if neg_edge_index.shape[-1] > num_neg_samples: + neg_edge_index = neg_edge_index[:, :num_neg_samples] + + assert neg_edge_index is not None + + #print(f'{prob} - {method} - {k} - {num_neg_samples} - {neg_edge_index.shape[-1]}') + + if force_undirected: + neg_edge_index = to_undirected(neg_edge_index) + + return neg_edge_index def batched_negative_sampling( @@ -132,14 +141,15 @@ def batched_negative_sampling( If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph connecting two different node types. num_neg_samples (int, optional): The number of negative samples to - return. If set to :obj:`None`, will try to return a negative edge + return for each graph in the batch. If set to :obj:`None`, will try to return a negative edge for every positive edge. (default: :obj:`None`) method (str, optional): The method to use for negative sampling, - *i.e.* :obj:`"sparse"` or :obj:`"dense"`. + *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. This is a memory/runtime trade-off. - :obj:`"sparse"` will work on any graph of any size, while - :obj:`"dense"` can perform faster true-negative checks. - (default: :obj:`"sparse"`) + :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples + :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"auto"` will automatically choose the best method + (default: :obj:`"auto"`) force_undirected (bool, optional): If set to :obj:`True`, sampled negative edges will be undirected. (default: :obj:`False`) @@ -210,6 +220,7 @@ def structured_negative_sampling( edge_index: Tensor, num_nodes: Optional[int] = None, contains_neg_self_loops: bool = True, + method: str = "auto" ) -> Tuple[Tensor, Tensor, Tensor]: r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a @@ -222,6 +233,13 @@ def structured_negative_sampling( contains_neg_self_loops (bool, optional): If set to :obj:`False`, sampled negative edges will not contain self loops. (default: :obj:`True`) + method (str, optional): The method to use for negative sampling, + *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. + This is a memory/runtime trade-off. + :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples + :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"auto"` will automatically choose the best method + (default: :obj:`"auto"`) :rtype: (LongTensor, LongTensor, LongTensor) @@ -233,27 +251,48 @@ def structured_negative_sampling( """ num_nodes = maybe_num_nodes(edge_index, num_nodes) + size = (num_nodes, num_nodes) + num_edges = edge_index.size(1) + num_tot_edges = size[0] * size[1] + device = edge_index.device + deg = degree(edge_index[0], num_nodes, dtype=torch.long) + + # transform a pair (u,v) in an edge id + edge_id = edge_index_to_vector_id(edge_index, size) + edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) # TODO: is this O(1) if the input is already sorted? + + # select the method + method = get_method(method, size) + + k = None + prob = torch.min(1 - deg/num_nodes) + if method == 'sparse': + k = 10 + # TODO: can we compute the prob to find the k? + + + guess_col, guess_edge_id = sample_k_structured_edges(edge_index, num_nodes, k, + not contains_neg_self_loops, method, device) + + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) + if not torch.all(torch.any(neg_edge_mask.view(-1,k), dim=1)): + warnings.warn('We were not able to sample all negative edges requested!') + + if method == 'sparse': + neg_col = -torch.ones_like(edge_index[0]) + neg_edge_mask = get_first_k_true_values_for_each_row(neg_edge_mask.view(num_edges, k), 1) + ok_edges = torch.any(neg_edge_mask, dim=1) # this is the mask of edges for which we have obtained a neg sample + col_to_save = guess_col.view(num_edges, k)[neg_edge_mask] # this the list of neg samples + neg_col[ok_edges] = col_to_save.view(-1) + else: + shape = (num_nodes, num_nodes if contains_neg_self_loops else num_nodes-1) + neg_edge_mask = get_first_k_true_values_for_each_row(neg_edge_mask.view(*shape), deg) + col_to_save = guess_col.view(*shape)[neg_edge_mask] # this the list of neg samples + neg_col = col_to_save.view(-1) - row, col = edge_index.cpu() - pos_idx = row * num_nodes + col - if not contains_neg_self_loops: - loop_idx = torch.arange(num_nodes) * (num_nodes + 1) - pos_idx = torch.cat([pos_idx, loop_idx], dim=0) - - rand = torch.randint(num_nodes, (row.size(0), ), dtype=torch.long) - neg_idx = row * num_nodes + rand - - mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) - rest = mask.nonzero(as_tuple=False).view(-1) - while rest.numel() > 0: # pragma: no cover - tmp = torch.randint(num_nodes, (rest.size(0), ), dtype=torch.long) - rand[rest] = tmp - neg_idx = row[rest] * num_nodes + tmp - - mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool) - rest = rest[mask] + assert neg_col.size(0) == edge_index.size(1) - return edge_index[0], edge_index[1], rand.to(edge_index.device) + return edge_index[0], edge_index[1], neg_col def structured_negative_sampling_feasible( @@ -296,92 +335,137 @@ def structured_negative_sampling_feasible( max_num_neighbors -= 1 # Reduce number of valid neighbors deg = degree(edge_index[0], num_nodes) - # True if there exists no node that is connected to all other nodes. - return bool(torch.all(deg < max_num_neighbors)) + # structured sample is feasible if, for each node, deg > max_neigh/2 + return bool(torch.all(2*deg <= max_num_neighbors)) ############################################################################### +def get_method(method, size): + # select the method + tot_num_edges = size[0] * size[1] + auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small + method = auto_method if method == 'auto' else method -def sample( - population: int, + if method == 'dense' and tot_num_edges >= _MAX_NUM_EDGES: + warnings.warn(f'You choose the dense method on a graph with {tot_num_edges} possible edges! ' + f'It could require a lot of memory!') + + return method + + +def sample_almost_k_edges( + size: Tuple[int, int], k: int, + force_undirected: bool, + remove_self_loops: bool, + method: str, device: Optional[Union[torch.device, str]] = None, -) -> Tensor: - if population <= k: - return torch.arange(population, device=device) - else: - return torch.tensor(random.sample(range(population), k), device=device) +) -> Tuple[Tensor, Tensor]: + assert method in ['sparse', 'dense'] + N1, N2 = size -def edge_index_to_vector( - edge_index: Tensor, - size: Tuple[int, int], - bipartite: bool, - force_undirected: bool = False, -) -> Tuple[Tensor, int]: + if method == 'sparse': + k = 2*k if force_undirected else k + p = torch.tensor([1.], device=device).expand(N1*N2) + if k > N1 * N2: + k = N1*N2 + new_edge_id = torch.multinomial(p, k, replacement=False) - row, col = edge_index + else: + new_edge_id = torch.randperm(N1*N2, device=device) - if bipartite: # No need to account for self-loops. - idx = (row * size[1]).add_(col) - population = size[0] * size[1] - return idx, population + new_edge_index = torch.stack(vector_id_to_edge_index(new_edge_id, size), dim=0) - elif force_undirected: - assert size[0] == size[1] - num_nodes = size[0] + if remove_self_loops: + not_in_diagonal = new_edge_index[0] != new_edge_index[1] + new_edge_index = new_edge_index[:, not_in_diagonal] + new_edge_id = new_edge_id[not_in_diagonal] - # We only operate on the upper triangular matrix: - mask = row < col - row, col = row[mask], col[mask] - offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row] - idx = row.mul_(num_nodes).add_(col).sub_(offset) - population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes - return idx, population + if force_undirected: + # we consider only the upper part, i.e. col_idx > row_idx + in_upper_part = new_edge_index[1] > new_edge_index[0] + new_edge_index = new_edge_index[:, in_upper_part] + new_edge_id = new_edge_id[in_upper_part] + + return new_edge_index, new_edge_id - else: - assert size[0] == size[1] - num_nodes = size[0] - # We remove self-loops as we do not want to take them into account - # when sampling negative values. - mask = row != col - row, col = row[mask], col[mask] - col[row < col] -= 1 - idx = row.mul_(num_nodes - 1).add_(col) - population = num_nodes * num_nodes - num_nodes - return idx, population +def sample_k_structured_edges( + edge_index: Tensor, + num_nodes: int, + k: int, + remove_self_loops: bool, + method: str, + device: Optional[Union[torch.device, str]] = None, +) -> Tuple[Tensor, Tensor]: + assert method in ['sparse', 'dense'] + row, col = edge_index + num_edges = edge_index.size(1) + size = (num_nodes, num_nodes) -def vector_to_edge_index( - idx: Tensor, + if method == 'sparse': + new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint(num_nodes, (num_edges, k)) + new_edge_id = edge_index_to_vector_id((new_row, new_col), size) + else: + new_edge_id = torch.randperm(num_nodes*num_nodes, device=device) + new_row, new_col = vector_id_to_edge_index(new_edge_id, size) + new_row, idx_to_sort = index_sort(new_row) + new_edge_id = new_edge_id[idx_to_sort] + new_col = new_col[idx_to_sort] + + if remove_self_loops: + # instead of filterin (which break the shapes), we just add +1 + in_diagonal = torch.eq(new_row, new_col) + if method == 'dense': + not_in_diagonal = torch.logical_not(in_diagonal) + new_col = new_col[not_in_diagonal] + new_edge_id = new_edge_id[not_in_diagonal] + else: + new_v = (new_col[in_diagonal] + 1) % num_nodes + diff_v = new_v - new_col[in_diagonal] + new_col[in_diagonal] = new_v + new_edge_id[in_diagonal] += diff_v + + return new_col.view(-1), new_edge_id.view(-1) + + +def get_first_k_true_values_for_each_row(input_mask, k): + if isinstance(k, Tensor): + k = k.unsqueeze(1) + cum_mask = torch.le(torch.cumsum(input_mask, dim=1), k) + return torch.logical_and(input_mask, cum_mask) + + +def get_neg_edge_mask(edge_id, guess_edge_id): + num_edges = edge_id.size(0) + pos = torch.searchsorted(edge_id, guess_edge_id) + # pos contains the position where to insert the guessed id to maintain the edge_id sort. + # 1) if pos == num_edges, it means that we should add the guessed if at the end of the vector -> the id is new! + # 2) if pos != num_edges but the id in position pos != from the guessed one -> the id is new! + neg_edge_mask = torch.eq(pos, num_edges) # negative edge from case 1) + not_neg_edge_mask = torch.logical_not(neg_edge_mask) + # negative edge from case 2) + neg_edge_mask[not_neg_edge_mask] = edge_id[pos[not_neg_edge_mask]] != guess_edge_id[not_neg_edge_mask] + return neg_edge_mask + + +def edge_index_to_vector_id( + edge_index: Tensor | Tuple[Tensor,Tensor], size: Tuple[int, int], - bipartite: bool, - force_undirected: bool = False, ) -> Tensor: - if bipartite: # No need to account for self-loops. - row = idx.div(size[1], rounding_mode='floor') - col = idx % size[1] - return torch.stack([row, col], dim=0) + row, col = edge_index + return (row * size[1]).add_(col) - elif force_undirected: - assert size[0] == size[1] - num_nodes = size[0] - offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0) - end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes, - device=idx.device) - row = torch.bucketize(idx, end.sub_(offset), right=True) - col = offset[row].add_(idx) % num_nodes - return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0) +def vector_id_to_edge_index( + vector_id: Tensor, + size: Tuple[int, int], +) -> Tuple[Tensor,Tensor]: - else: - assert size[0] == size[1] - num_nodes = size[0] + row, col = vector_id // size[1], vector_id % size[1] + return row, col - row = idx.div(num_nodes - 1, rounding_mode='floor') - col = idx % (num_nodes - 1) - col[row <= col] += 1 - return torch.stack([row, col], dim=0) From 7ba4732a29bdecac3a0b52a09546b5197df5a2ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:39:04 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/utils/test_negative_sampling.py | 33 +++--- torch_geometric/utils/_negative_sampling.py | 111 ++++++++++++-------- 2 files changed, 85 insertions(+), 59 deletions(-) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index 6bd50302f38e..1550def876f1 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -3,18 +3,16 @@ from torch_geometric.utils import ( batched_negative_sampling, contains_self_loops, + erdos_renyi_graph, is_undirected, negative_sampling, structured_negative_sampling, structured_negative_sampling_feasible, to_undirected, - erdos_renyi_graph, - stochastic_blockmodel_graph ) - from torch_geometric.utils._negative_sampling import ( edge_index_to_vector_id, - vector_id_to_edge_index + vector_id_to_edge_index, ) @@ -40,12 +38,13 @@ def test_edge_index_to_vector_and_vice_versa(): edge_index = torch.stack([row, col], dim=0) idx = edge_index_to_vector_id(edge_index, (N1, N2)) - assert idx.tolist() == list(range(N1*N2)) + assert idx.tolist() == list(range(N1 * N2)) edge_index2 = torch.stack(vector_id_to_edge_index(idx, (N1, N2)), dim=0) assert edge_index.tolist() == edge_index2.tolist() - vector_id = torch.arange(N1*N2) - edge_index3 = torch.stack(vector_id_to_edge_index(vector_id, (N1, N2)), dim=0) + vector_id = torch.arange(N1 * N2) + edge_index3 = torch.stack(vector_id_to_edge_index(vector_id, (N1, N2)), + dim=0) assert edge_index.tolist() == edge_index3.tolist() @@ -85,19 +84,23 @@ def test_negative_sampling_with_different_edge_density(): for p in [0.1, 0.3, 0.5, 0.8]: for is_directed in [False, True]: edge_index = erdos_renyi_graph(num_nodes, p, is_directed) - neg_edge_index = negative_sampling(edge_index, num_nodes, force_undirected=not is_directed) - assert is_negative(edge_index, neg_edge_index, (num_nodes, num_nodes), bipartite=False) + neg_edge_index = negative_sampling( + edge_index, num_nodes, force_undirected=not is_directed) + assert is_negative(edge_index, neg_edge_index, + (num_nodes, num_nodes), bipartite=False) def test_bipartite_negative_sampling_with_different_edge_density(): for num_nodes in [10, 100, 1000]: for p in [0.1, 0.3, 0.5, 0.8]: - size = (num_nodes, int(num_nodes*1.2)) + size = (num_nodes, int(num_nodes * 1.2)) n_edges = int(p * size[0] * size[1]) - row, col = torch.randint(size[0], (n_edges,)), torch.randint(size[1], (n_edges,)) + row, col = torch.randint(size[0], (n_edges, )), torch.randint( + size[1], (n_edges, )) edge_index = torch.stack([row, col], dim=0) neg_edge_index = negative_sampling(edge_index, size) - assert is_negative(edge_index, neg_edge_index, size, bipartite=True) + assert is_negative(edge_index, neg_edge_index, size, + bipartite=True) def test_batched_negative_sampling(): @@ -167,7 +170,8 @@ def test_structured_negative_sampling_sparse(): num_nodes = 1000 edge_index = erdos_renyi_graph(num_nodes, 0.1) - i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, contains_neg_self_loops=True) + i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, + contains_neg_self_loops=True) assert i.size(0) == edge_index.size(1) assert j.size(0) == edge_index.size(1) assert k.size(0) == edge_index.size(1) @@ -180,7 +184,8 @@ def test_structured_negative_sampling_sparse(): neg_adj[i, k] = 1 assert (adj & neg_adj).sum() == 0 - i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, contains_neg_self_loops=False) + i, j, k = structured_negative_sampling(edge_index, num_nodes=num_nodes, + contains_neg_self_loops=False) assert i.size(0) == edge_index.size(1) assert j.size(0) == edge_index.size(1) assert k.size(0) == edge_index.size(1) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 30a81a40994a..774e22f1d89b 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -1,15 +1,19 @@ -import random import warnings from typing import Optional, Tuple, Union -import numpy as np import torch from torch import Tensor -from torch_geometric.utils import coalesce, cumsum, degree, remove_self_loops, to_undirected, index_sort +from torch_geometric.utils import ( + coalesce, + cumsum, + degree, + index_sort, + remove_self_loops, + to_undirected, +) from torch_geometric.utils.num_nodes import maybe_num_nodes - _MAX_NUM_EDGES = 10**4 @@ -82,7 +86,9 @@ def negative_sampling( # transform a pair (u,v) in an edge id edge_id = edge_index_to_vector_id(edge_index, size) - edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) # TODO: is this O(1) if the input is already sorted? + edge_id, _ = index_sort( + edge_id, max_value=num_tot_edges + ) # TODO: is this O(1) if the input is already sorted? method = get_method(method, size) @@ -95,15 +101,19 @@ def negative_sampling( else: # the probability is too low, but the graph is too big for the exact sampling. # we perform the sparse sampling but we raise a warning - k = int(min(10 * num_neg_samples, max(num_neg_samples / (prob - 0.1), 10 ** -3))) + k = int( + min(10 * num_neg_samples, + max(num_neg_samples / (prob - 0.1), 10**-3))) - warnings.warn('The probability of sampling a negative edge is too low! ' - 'It could be that the number of sampled edges is smaller than the numbers you required!') + warnings.warn( + 'The probability of sampling a negative edge is too low! ' + 'It could be that the number of sampled edges is smaller than the numbers you required!' + ) - guess_edge_index, guess_edge_id = sample_almost_k_edges(size, k, - force_undirected=force_undirected, - remove_self_loops=not bipartite, - method=method, device=edge_index.device) + guess_edge_index, guess_edge_id = sample_almost_k_edges( + size, k, force_undirected=force_undirected, + remove_self_loops=not bipartite, method=method, + device=edge_index.device) neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) @@ -217,11 +227,9 @@ def batched_negative_sampling( def structured_negative_sampling( - edge_index: Tensor, - num_nodes: Optional[int] = None, - contains_neg_self_loops: bool = True, - method: str = "auto" -) -> Tuple[Tensor, Tensor, Tensor]: + edge_index: Tensor, num_nodes: Optional[int] = None, + contains_neg_self_loops: bool = True, + method: str = "auto") -> Tuple[Tensor, Tensor, Tensor]: r"""Samples a negative edge :obj:`(i,k)` for every positive edge :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a tuple of the form :obj:`(i,j,k)`. @@ -259,35 +267,44 @@ def structured_negative_sampling( # transform a pair (u,v) in an edge id edge_id = edge_index_to_vector_id(edge_index, size) - edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) # TODO: is this O(1) if the input is already sorted? + edge_id, _ = index_sort( + edge_id, max_value=num_tot_edges + ) # TODO: is this O(1) if the input is already sorted? # select the method method = get_method(method, size) k = None - prob = torch.min(1 - deg/num_nodes) + torch.min(1 - deg / num_nodes) if method == 'sparse': k = 10 # TODO: can we compute the prob to find the k? - - guess_col, guess_edge_id = sample_k_structured_edges(edge_index, num_nodes, k, - not contains_neg_self_loops, method, device) + guess_col, guess_edge_id = sample_k_structured_edges( + edge_index, num_nodes, k, not contains_neg_self_loops, method, device) neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - if not torch.all(torch.any(neg_edge_mask.view(-1,k), dim=1)): - warnings.warn('We were not able to sample all negative edges requested!') + if not torch.all(torch.any(neg_edge_mask.view(-1, k), dim=1)): + warnings.warn( + 'We were not able to sample all negative edges requested!') if method == 'sparse': neg_col = -torch.ones_like(edge_index[0]) - neg_edge_mask = get_first_k_true_values_for_each_row(neg_edge_mask.view(num_edges, k), 1) - ok_edges = torch.any(neg_edge_mask, dim=1) # this is the mask of edges for which we have obtained a neg sample - col_to_save = guess_col.view(num_edges, k)[neg_edge_mask] # this the list of neg samples + neg_edge_mask = get_first_k_true_values_for_each_row( + neg_edge_mask.view(num_edges, k), 1) + ok_edges = torch.any( + neg_edge_mask, dim=1 + ) # this is the mask of edges for which we have obtained a neg sample + col_to_save = guess_col.view( + num_edges, k)[neg_edge_mask] # this the list of neg samples neg_col[ok_edges] = col_to_save.view(-1) else: - shape = (num_nodes, num_nodes if contains_neg_self_loops else num_nodes-1) - neg_edge_mask = get_first_k_true_values_for_each_row(neg_edge_mask.view(*shape), deg) - col_to_save = guess_col.view(*shape)[neg_edge_mask] # this the list of neg samples + shape = (num_nodes, + num_nodes if contains_neg_self_loops else num_nodes - 1) + neg_edge_mask = get_first_k_true_values_for_each_row( + neg_edge_mask.view(*shape), deg) + col_to_save = guess_col.view( + *shape)[neg_edge_mask] # this the list of neg samples neg_col = col_to_save.view(-1) assert neg_col.size(0) == edge_index.size(1) @@ -336,20 +353,22 @@ def structured_negative_sampling_feasible( deg = degree(edge_index[0], num_nodes) # structured sample is feasible if, for each node, deg > max_neigh/2 - return bool(torch.all(2*deg <= max_num_neighbors)) + return bool(torch.all(2 * deg <= max_num_neighbors)) ############################################################################### + def get_method(method, size): # select the method tot_num_edges = size[0] * size[1] - auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small + auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small method = auto_method if method == 'auto' else method if method == 'dense' and tot_num_edges >= _MAX_NUM_EDGES: - warnings.warn(f'You choose the dense method on a graph with {tot_num_edges} possible edges! ' - f'It could require a lot of memory!') + warnings.warn( + f'You choose the dense method on a graph with {tot_num_edges} possible edges! ' + f'It could require a lot of memory!') return method @@ -367,16 +386,17 @@ def sample_almost_k_edges( N1, N2 = size if method == 'sparse': - k = 2*k if force_undirected else k - p = torch.tensor([1.], device=device).expand(N1*N2) + k = 2 * k if force_undirected else k + p = torch.tensor([1.], device=device).expand(N1 * N2) if k > N1 * N2: - k = N1*N2 + k = N1 * N2 new_edge_id = torch.multinomial(p, k, replacement=False) else: - new_edge_id = torch.randperm(N1*N2, device=device) + new_edge_id = torch.randperm(N1 * N2, device=device) - new_edge_index = torch.stack(vector_id_to_edge_index(new_edge_id, size), dim=0) + new_edge_index = torch.stack(vector_id_to_edge_index(new_edge_id, size), + dim=0) if remove_self_loops: not_in_diagonal = new_edge_index[0] != new_edge_index[1] @@ -407,10 +427,11 @@ def sample_k_structured_edges( size = (num_nodes, num_nodes) if method == 'sparse': - new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint(num_nodes, (num_edges, k)) + new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint( + num_nodes, (num_edges, k)) new_edge_id = edge_index_to_vector_id((new_row, new_col), size) else: - new_edge_id = torch.randperm(num_nodes*num_nodes, device=device) + new_edge_id = torch.randperm(num_nodes * num_nodes, device=device) new_row, new_col = vector_id_to_edge_index(new_edge_id, size) new_row, idx_to_sort = index_sort(new_row) new_edge_id = new_edge_id[idx_to_sort] @@ -448,12 +469,13 @@ def get_neg_edge_mask(edge_id, guess_edge_id): neg_edge_mask = torch.eq(pos, num_edges) # negative edge from case 1) not_neg_edge_mask = torch.logical_not(neg_edge_mask) # negative edge from case 2) - neg_edge_mask[not_neg_edge_mask] = edge_id[pos[not_neg_edge_mask]] != guess_edge_id[not_neg_edge_mask] + neg_edge_mask[not_neg_edge_mask] = edge_id[ + pos[not_neg_edge_mask]] != guess_edge_id[not_neg_edge_mask] return neg_edge_mask def edge_index_to_vector_id( - edge_index: Tensor | Tuple[Tensor,Tensor], + edge_index: Tensor | Tuple[Tensor, Tensor], size: Tuple[int, int], ) -> Tensor: @@ -464,8 +486,7 @@ def edge_index_to_vector_id( def vector_id_to_edge_index( vector_id: Tensor, size: Tuple[int, int], -) -> Tuple[Tensor,Tensor]: +) -> Tuple[Tensor, Tensor]: row, col = vector_id // size[1], vector_id % size[1] return row, col - From c5c388979ae7ce064798b9f01654f1ee37982e76 Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 22 Aug 2024 18:17:12 +0200 Subject: [PATCH 03/17] Add compatibility to Python 3.8 --- torch_geometric/utils/_negative_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 774e22f1d89b..d9ef93d53dc8 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -475,7 +475,7 @@ def get_neg_edge_mask(edge_id, guess_edge_id): def edge_index_to_vector_id( - edge_index: Tensor | Tuple[Tensor, Tensor], + edge_index: Union[Tensor, Tuple[Tensor, Tensor]], size: Tuple[int, int], ) -> Tensor: From c6db6d79bdf2acc7ab6e93fc8fa7d72f6a1795b3 Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 22 Aug 2024 18:59:37 +0200 Subject: [PATCH 04/17] Correct structured_edge_sampling feasibility test --- test/utils/test_negative_sampling.py | 27 +++++++++++++++++---- torch_geometric/utils/_negative_sampling.py | 12 +++++---- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index 1550def876f1..e8f646af7751 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -203,8 +203,25 @@ def test_structured_negative_sampling_sparse(): def test_structured_negative_sampling_feasible(): - edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2], - [1, 2, 0, 2, 0, 1, 1]]) - assert not structured_negative_sampling_feasible(edge_index, 3, False) - assert structured_negative_sampling_feasible(edge_index, 3, True) - assert structured_negative_sampling_feasible(edge_index, 4, False) + + def create_ring_graph(num_nodes): + forward_edges = torch.stack([torch.arange(0, num_nodes, dtype=torch.long), + (torch.arange(0, num_nodes, dtype=torch.long) + 1) % num_nodes], dim=0) + backward_edges = torch.stack([torch.arange(0, num_nodes, dtype=torch.long), + (torch.arange(0, num_nodes, dtype=torch.long) - 1) % num_nodes], dim=0) + return torch.concat([forward_edges, backward_edges], dim=1) + + # ring 3 is always unfeasible + ring_3 = create_ring_graph(3) + assert not structured_negative_sampling_feasible(ring_3, 3, False) + assert not structured_negative_sampling_feasible(ring_3, 3, True) + + # ring 4 is feasible only if we consider self loops + ring_4 = create_ring_graph(4) + assert not structured_negative_sampling_feasible(ring_4, 4, False) + assert structured_negative_sampling_feasible(ring_4, 4, True) + + # ring 5 is always feasible + ring_5 = create_ring_graph(5) + assert structured_negative_sampling_feasible(ring_5, 5, False) + assert structured_negative_sampling_feasible(ring_5, 5, True) \ No newline at end of file diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index d9ef93d53dc8..5a24890af1fc 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -88,7 +88,7 @@ def negative_sampling( edge_id = edge_index_to_vector_id(edge_index, size) edge_id, _ = index_sort( edge_id, max_value=num_tot_edges - ) # TODO: is this O(1) if the input is already sorted? + ) method = get_method(method, size) @@ -269,7 +269,7 @@ def structured_negative_sampling( edge_id = edge_index_to_vector_id(edge_index, size) edge_id, _ = index_sort( edge_id, max_value=num_tot_edges - ) # TODO: is this O(1) if the input is already sorted? + ) # select the method method = get_method(method, size) @@ -284,9 +284,6 @@ def structured_negative_sampling( edge_index, num_nodes, k, not contains_neg_self_loops, method, device) neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - if not torch.all(torch.any(neg_edge_mask.view(-1, k), dim=1)): - warnings.warn( - 'We were not able to sample all negative edges requested!') if method == 'sparse': neg_col = -torch.ones_like(edge_index[0]) @@ -298,6 +295,11 @@ def structured_negative_sampling( col_to_save = guess_col.view( num_edges, k)[neg_edge_mask] # this the list of neg samples neg_col[ok_edges] = col_to_save.view(-1) + + if not torch.all(ok_edges): + warnings.warn( + 'We were not able to sample all negative edges requested!') + else: shape = (num_nodes, num_nodes if contains_neg_self_loops else num_nodes - 1) From 4d764fadabdfba8d9c0ee22022eb59d0cc078d2c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:04:24 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/utils/test_negative_sampling.py | 15 +++++++++------ torch_geometric/utils/_negative_sampling.py | 8 ++------ 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index e8f646af7751..1dac1fb78565 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -203,12 +203,15 @@ def test_structured_negative_sampling_sparse(): def test_structured_negative_sampling_feasible(): - def create_ring_graph(num_nodes): - forward_edges = torch.stack([torch.arange(0, num_nodes, dtype=torch.long), - (torch.arange(0, num_nodes, dtype=torch.long) + 1) % num_nodes], dim=0) - backward_edges = torch.stack([torch.arange(0, num_nodes, dtype=torch.long), - (torch.arange(0, num_nodes, dtype=torch.long) - 1) % num_nodes], dim=0) + forward_edges = torch.stack([ + torch.arange(0, num_nodes, dtype=torch.long), + (torch.arange(0, num_nodes, dtype=torch.long) + 1) % num_nodes + ], dim=0) + backward_edges = torch.stack([ + torch.arange(0, num_nodes, dtype=torch.long), + (torch.arange(0, num_nodes, dtype=torch.long) - 1) % num_nodes + ], dim=0) return torch.concat([forward_edges, backward_edges], dim=1) # ring 3 is always unfeasible @@ -224,4 +227,4 @@ def create_ring_graph(num_nodes): # ring 5 is always feasible ring_5 = create_ring_graph(5) assert structured_negative_sampling_feasible(ring_5, 5, False) - assert structured_negative_sampling_feasible(ring_5, 5, True) \ No newline at end of file + assert structured_negative_sampling_feasible(ring_5, 5, True) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 5a24890af1fc..0b7be87c4647 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -86,9 +86,7 @@ def negative_sampling( # transform a pair (u,v) in an edge id edge_id = edge_index_to_vector_id(edge_index, size) - edge_id, _ = index_sort( - edge_id, max_value=num_tot_edges - ) + edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) method = get_method(method, size) @@ -267,9 +265,7 @@ def structured_negative_sampling( # transform a pair (u,v) in an edge id edge_id = edge_index_to_vector_id(edge_index, size) - edge_id, _ = index_sort( - edge_id, max_value=num_tot_edges - ) + edge_id, _ = index_sort(edge_id, max_value=num_tot_edges) # select the method method = get_method(method, size) From 10d440c555bd23c58a3dd41c894486f8aa88ddcc Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 09:42:53 +0200 Subject: [PATCH 06/17] Add all type annotations --- torch_geometric/utils/_negative_sampling.py | 29 +++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index a9a79e8bfcc6..2fc02c0cb9d7 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -270,18 +270,13 @@ def structured_negative_sampling( # select the method method = get_method(method, size) - k = None - torch.min(1 - deg / num_nodes) if method == 'sparse': - k = 10 # TODO: can we compute the prob to find the k? + k = 10 + guess_col, guess_edge_id = sample_k_structured_edges( + edge_index, num_nodes, k, not contains_neg_self_loops, method, device) + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - guess_col, guess_edge_id = sample_k_structured_edges( - edge_index, num_nodes, k, not contains_neg_self_loops, method, device) - - neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - - if method == 'sparse': neg_col = -torch.ones_like(edge_index[0]) neg_edge_mask = get_first_k_true_values_for_each_row( neg_edge_mask.view(num_edges, k), 1) @@ -297,6 +292,10 @@ def structured_negative_sampling( 'We were not able to sample all negative edges requested!') else: + guess_col, guess_edge_id = sample_k_structured_edges( + edge_index, num_nodes, None, not contains_neg_self_loops, method, device) + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) + shape = (num_nodes, num_nodes if contains_neg_self_loops else num_nodes - 1) neg_edge_mask = get_first_k_true_values_for_each_row( @@ -357,7 +356,7 @@ def structured_negative_sampling_feasible( ############################################################################### -def get_method(method, size): +def get_method(method: str, size: Tuple[int, int]) -> str: # select the method tot_num_edges = size[0] * size[1] auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small @@ -373,7 +372,7 @@ def get_method(method, size): def sample_almost_k_edges( size: Tuple[int, int], - k: int, + k: Optional[int], force_undirected: bool, remove_self_loops: bool, method: str, @@ -384,6 +383,7 @@ def sample_almost_k_edges( N1, N2 = size if method == 'sparse': + assert k is not None k = 2 * k if force_undirected else k p = torch.tensor([1.], device=device).expand(N1 * N2) if k > N1 * N2: @@ -413,7 +413,7 @@ def sample_almost_k_edges( def sample_k_structured_edges( edge_index: Tensor, num_nodes: int, - k: int, + k: Optional[int], remove_self_loops: bool, method: str, device: Optional[Union[torch.device, str]] = None, @@ -425,6 +425,7 @@ def sample_k_structured_edges( size = (num_nodes, num_nodes) if method == 'sparse': + assert k is not None new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint( num_nodes, (num_edges, k)) new_edge_id = edge_index_to_vector_id((new_row, new_col), size) @@ -451,14 +452,14 @@ def sample_k_structured_edges( return new_col.view(-1), new_edge_id.view(-1) -def get_first_k_true_values_for_each_row(input_mask, k): +def get_first_k_true_values_for_each_row(input_mask: Tensor, k: Union[Tensor, int]) -> Tensor: if isinstance(k, Tensor): k = k.unsqueeze(1) cum_mask = torch.le(torch.cumsum(input_mask, dim=1), k) return torch.logical_and(input_mask, cum_mask) -def get_neg_edge_mask(edge_id, guess_edge_id): +def get_neg_edge_mask(edge_id: Tensor, guess_edge_id: Tensor) -> Tensor: num_edges = edge_id.size(0) pos = torch.searchsorted(edge_id, guess_edge_id) # pos contains the position where to insert the guessed id to maintain the edge_id sort. From 0926a61de708abc6731bc89eef91af3d23bb48ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:10:09 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/utils/_negative_sampling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 2fc02c0cb9d7..1e33fee7c41c 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -274,7 +274,8 @@ def structured_negative_sampling( # TODO: can we compute the prob to find the k? k = 10 guess_col, guess_edge_id = sample_k_structured_edges( - edge_index, num_nodes, k, not contains_neg_self_loops, method, device) + edge_index, num_nodes, k, not contains_neg_self_loops, method, + device) neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) neg_col = -torch.ones_like(edge_index[0]) @@ -293,7 +294,8 @@ def structured_negative_sampling( else: guess_col, guess_edge_id = sample_k_structured_edges( - edge_index, num_nodes, None, not contains_neg_self_loops, method, device) + edge_index, num_nodes, None, not contains_neg_self_loops, method, + device) neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) shape = (num_nodes, @@ -452,7 +454,8 @@ def sample_k_structured_edges( return new_col.view(-1), new_edge_id.view(-1) -def get_first_k_true_values_for_each_row(input_mask: Tensor, k: Union[Tensor, int]) -> Tensor: +def get_first_k_true_values_for_each_row(input_mask: Tensor, + k: Union[Tensor, int]) -> Tensor: if isinstance(k, Tensor): k = k.unsqueeze(1) cum_mask = torch.le(torch.cumsum(input_mask, dim=1), k) From f919b2db30e445c79efd9fc2e1df1f7e5ba657e5 Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 12:27:22 +0200 Subject: [PATCH 08/17] Add check for structured sampling feasibility. If the method fails to retrieve the requested number of edges, it raises an exception. --- torch_geometric/utils/_negative_sampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 1e33fee7c41c..e08d6d6a88ff 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -256,6 +256,9 @@ def structured_negative_sampling( (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2])) """ + if not structured_negative_sampling_feasible(edge_index, num_nodes, contains_neg_self_loops): + raise ValueError('Structured sampling is not feasible!') + num_nodes = maybe_num_nodes(edge_index, num_nodes) size = (num_nodes, num_nodes) num_edges = edge_index.size(1) @@ -289,8 +292,7 @@ def structured_negative_sampling( neg_col[ok_edges] = col_to_save.view(-1) if not torch.all(ok_edges): - warnings.warn( - 'We were not able to sample all negative edges requested!') + raise ValueError('Sparse method was not able to sample all negative edges requested!') else: guess_col, guess_edge_id = sample_k_structured_edges( @@ -344,8 +346,6 @@ def structured_negative_sampling_feasible( num_nodes = maybe_num_nodes(edge_index, num_nodes) max_num_neighbors = num_nodes - edge_index = coalesce(edge_index, num_nodes=num_nodes) - if not contains_neg_self_loops: edge_index, _ = remove_self_loops(edge_index) max_num_neighbors -= 1 # Reduce number of valid neighbors From 5e99af8e2349a162d6951619af3ca68afb2ab7c8 Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 12:28:04 +0200 Subject: [PATCH 09/17] Change the deafult method of negative_sampling to "auto" in RandomLinkSplit transform. --- torch_geometric/transforms/random_link_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/random_link_split.py b/torch_geometric/transforms/random_link_split.py index 20c64fb6997f..44884f915f92 100644 --- a/torch_geometric/transforms/random_link_split.py +++ b/torch_geometric/transforms/random_link_split.py @@ -236,7 +236,7 @@ def forward( size = size[0] neg_edge_index = negative_sampling(edge_index, size, num_neg_samples=num_neg, - method='sparse') + method='auto') # Adjust ratio if not enough negative edges exist if neg_edge_index.size(1) < num_neg: From afe67e2f17630ec90c1e9345be21de824f426eab Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 12:29:24 +0200 Subject: [PATCH 10/17] test_add_random_edge was based on fixing the seed. Now it checks that added edges are actually new. --- test/utils/test_augmentation.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/test/utils/test_augmentation.py b/test/utils/test_augmentation.py index 68c4e343cdc6..687a560d0e8f 100644 --- a/test/utils/test_augmentation.py +++ b/test/utils/test_augmentation.py @@ -78,26 +78,20 @@ def test_add_random_edge(): assert out[0].tolist() == edge_index.tolist() assert out[1].tolist() == [[], []] - seed_everything(5) + def _edge_idx_to_set(e: torch.Tensor) -> set: + return set([tuple(v) for v in e.tolist()]) + out = add_random_edge(edge_index, p=0.5) - assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2], - [1, 0, 2, 1, 3, 2, 0, 3, 0]] - assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]] + assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1])) - seed_everything(6) out = add_random_edge(edge_index, p=0.5, force_undirected=True) - assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3], - [1, 0, 2, 1, 3, 2, 3, 1]] - assert out[1].tolist() == [[1, 3], [3, 1]] + assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1])) assert is_undirected(out[0]) assert is_undirected(out[1]) # Test for bipartite graph: - seed_everything(7) edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]]) with pytest.raises(RuntimeError, match="not supported for bipartite"): add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5)) out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5)) - assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2], - [2, 3, 1, 4, 2, 1, 0, 4, 2]] - assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]] + assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1])) From 2f3c2e7dbb597b0ca695ba2228ec7ea71f7c460c Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 12:31:26 +0200 Subject: [PATCH 11/17] The generation of the graph test_signed_gcn was based on randint. We have no guarantee that there were no repeated edges. Now it is based on randperm. Also, the number of nodes is doubled to ensure that structured sampling is (almost) always feasible. --- test/nn/models/test_signed_gcn.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/nn/models/test_signed_gcn.py b/test/nn/models/test_signed_gcn.py index 391ef558118f..e5d32ff07ff4 100644 --- a/test/nn/models/test_signed_gcn.py +++ b/test/nn/models/test_signed_gcn.py @@ -8,9 +8,11 @@ def test_signed_gcn(): model = SignedGCN(8, 16, num_layers=2, lamb=5) assert str(model) == 'SignedGCN(8, 16, num_layers=2)' - - pos_index = torch.randint(high=10, size=(2, 40), dtype=torch.long) - neg_index = torch.randint(high=10, size=(2, 40), dtype=torch.long) + N, E = 20, 40 + all_index = torch.randperm(N*N, dtype=torch.long) + all_index = torch.stack([all_index // N, all_index % N], dim=0) + pos_index = all_index[:, :E] + neg_index = all_index[:, E:2*E] train_pos_index, test_pos_index = model.split_edges(pos_index) train_neg_index, test_neg_index = model.split_edges(neg_index) @@ -24,14 +26,14 @@ def test_signed_gcn(): x = model.create_spectral_features( train_pos_index, train_neg_index, - num_nodes=10, + num_nodes=N, ) - assert x.size() == (10, 8) + assert x.size() == (N, 8) else: - x = torch.randn(10, 8) + x = torch.randn(N, 8) z = model(x, train_pos_index, train_neg_index) - assert z.size() == (10, 16) + assert z.size() == (N, 16) loss = model.loss(z, train_pos_index, train_neg_index) assert loss.item() >= 0 From d837eb8465d7032c0793c023f88c01c4bd42c9e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 10:33:52 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/nn/models/test_signed_gcn.py | 4 ++-- test/utils/test_augmentation.py | 3 +-- torch_geometric/utils/_negative_sampling.py | 8 +++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/nn/models/test_signed_gcn.py b/test/nn/models/test_signed_gcn.py index e5d32ff07ff4..4dfbe286b313 100644 --- a/test/nn/models/test_signed_gcn.py +++ b/test/nn/models/test_signed_gcn.py @@ -9,10 +9,10 @@ def test_signed_gcn(): model = SignedGCN(8, 16, num_layers=2, lamb=5) assert str(model) == 'SignedGCN(8, 16, num_layers=2)' N, E = 20, 40 - all_index = torch.randperm(N*N, dtype=torch.long) + all_index = torch.randperm(N * N, dtype=torch.long) all_index = torch.stack([all_index // N, all_index % N], dim=0) pos_index = all_index[:, :E] - neg_index = all_index[:, E:2*E] + neg_index = all_index[:, E:2 * E] train_pos_index, test_pos_index = model.split_edges(pos_index) train_neg_index, test_neg_index = model.split_edges(neg_index) diff --git a/test/utils/test_augmentation.py b/test/utils/test_augmentation.py index 687a560d0e8f..b6d38dae20aa 100644 --- a/test/utils/test_augmentation.py +++ b/test/utils/test_augmentation.py @@ -1,7 +1,6 @@ import pytest import torch -from torch_geometric import seed_everything from torch_geometric.utils import ( add_random_edge, is_undirected, @@ -79,7 +78,7 @@ def test_add_random_edge(): assert out[1].tolist() == [[], []] def _edge_idx_to_set(e: torch.Tensor) -> set: - return set([tuple(v) for v in e.tolist()]) + return {tuple(v) for v in e.tolist()} out = add_random_edge(edge_index, p=0.5) assert _edge_idx_to_set(out[0]).isdisjoint(_edge_idx_to_set(out[1])) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index e08d6d6a88ff..976ae355bba8 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -5,7 +5,6 @@ from torch import Tensor from torch_geometric.utils import ( - coalesce, cumsum, degree, index_sort, @@ -256,7 +255,8 @@ def structured_negative_sampling( (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2])) """ - if not structured_negative_sampling_feasible(edge_index, num_nodes, contains_neg_self_loops): + if not structured_negative_sampling_feasible(edge_index, num_nodes, + contains_neg_self_loops): raise ValueError('Structured sampling is not feasible!') num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -292,7 +292,9 @@ def structured_negative_sampling( neg_col[ok_edges] = col_to_save.view(-1) if not torch.all(ok_edges): - raise ValueError('Sparse method was not able to sample all negative edges requested!') + raise ValueError( + 'Sparse method was not able to sample all negative edges requested!' + ) else: guess_col, guess_edge_id = sample_k_structured_edges( From 6d466ccac66df48c037674253c1577e92cb5fecb Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 13:45:04 +0200 Subject: [PATCH 13/17] Adjust the number of trials for sparse negative sampling. --- torch_geometric/utils/_negative_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 976ae355bba8..1c83860c3e85 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -94,7 +94,7 @@ def negative_sampling( if method == 'sparse': if prob >= 0.3: # the probability of sampling non-existing edge is high, so the sparse method should be ok - k = int(num_neg_samples / (prob - 0.1)) + k = max(int(1.5*num_neg_samples), int(num_neg_samples / (prob - 0.1))) else: # the probability is too low, but the graph is too big for the exact sampling. # we perform the sparse sampling but we raise a warning From a2f5a719cd9ce1eb877bfb9f64ee7a8b16646d63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 11:46:47 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/utils/_negative_sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 1c83860c3e85..7925e544d1fc 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -94,7 +94,8 @@ def negative_sampling( if method == 'sparse': if prob >= 0.3: # the probability of sampling non-existing edge is high, so the sparse method should be ok - k = max(int(1.5*num_neg_samples), int(num_neg_samples / (prob - 0.1))) + k = max(int(1.5 * num_neg_samples), + int(num_neg_samples / (prob - 0.1))) else: # the probability is too low, but the graph is too big for the exact sampling. # we perform the sparse sampling but we raise a warning From 27d607968942d0e13e4de964578220ae57386490 Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 14:02:20 +0200 Subject: [PATCH 15/17] Adjust PEP8 --- torch_geometric/utils/_negative_sampling.py | 53 +++++++++++++-------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 7925e544d1fc..29564aeb1dda 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -39,8 +39,10 @@ def negative_sampling( method (str, optional): The method to use for negative sampling, *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. This is a memory/runtime trade-off. - :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples - :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"sparse"` will work on any graph of any size, but it could + retrieve a different number of negative samples + :obj:`"dense"` will work only on small graphs since it enumerates + all possible edges :obj:`"auto"` will automatically choose the best method (default: :obj:`"auto"`) force_undirected (bool, optional): If set to :obj:`True`, sampled @@ -93,11 +95,13 @@ def negative_sampling( prob = 1 - (num_edges / num_tot_edges) if method == 'sparse': if prob >= 0.3: - # the probability of sampling non-existing edge is high, so the sparse method should be ok + # the probability of sampling non-existing edge is high, + # so the sparse method should be ok k = max(int(1.5 * num_neg_samples), int(num_neg_samples / (prob - 0.1))) else: - # the probability is too low, but the graph is too big for the exact sampling. + # the probability is too low, but the graph is too big + # for the exact sampling. # we perform the sparse sampling but we raise a warning k = int( min(10 * num_neg_samples, @@ -105,7 +109,8 @@ def negative_sampling( warnings.warn( 'The probability of sampling a negative edge is too low! ' - 'It could be that the number of sampled edges is smaller than the numbers you required!' + 'It could be that the number of sampled edges is smaller ' + 'than the numbers you required!' ) guess_edge_index, guess_edge_id = sample_almost_k_edges( @@ -123,8 +128,6 @@ def negative_sampling( assert neg_edge_index is not None - #print(f'{prob} - {method} - {k} - {num_neg_samples} - {neg_edge_index.shape[-1]}') - if force_undirected: neg_edge_index = to_undirected(neg_edge_index) @@ -149,13 +152,16 @@ def batched_negative_sampling( If given as a tuple, then :obj:`edge_index` is interpreted as a bipartite graph connecting two different node types. num_neg_samples (int, optional): The number of negative samples to - return for each graph in the batch. If set to :obj:`None`, will try to return a negative edge + return for each graph in the batch. If set to :obj:`None`, + will try to return a negative edge for every positive edge. (default: :obj:`None`) method (str, optional): The method to use for negative sampling, *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. This is a memory/runtime trade-off. - :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples - :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"sparse"` will work on any graph of any size, but it could + retrieve a different number of negative samples + :obj:`"dense"` will work only on small graphs since it enumerates + all possible edges :obj:`"auto"` will automatically choose the best method (default: :obj:`"auto"`) force_undirected (bool, optional): If set to :obj:`True`, sampled @@ -242,8 +248,11 @@ def structured_negative_sampling( method (str, optional): The method to use for negative sampling, *i.e.* :obj:`"sparse"`, :obj:`"dense"`, or :obj:`"auto"`. This is a memory/runtime trade-off. - :obj:`"sparse"` will work on any graph of any size, but it could retrieve a different number of negative samples - :obj:`"dense"` will work only on small graphs since it enumerates all possible edges + :obj:`"sparse"` will work on any graph of any size, but it could + retrieve a different number of negative + samples + :obj:`"dense"` will work only on small graphs since it enumerates + all possible edges :obj:`"auto"` will automatically choose the best method (default: :obj:`"auto"`) @@ -294,7 +303,8 @@ def structured_negative_sampling( if not torch.all(ok_edges): raise ValueError( - 'Sparse method was not able to sample all negative edges requested!' + 'Sparse method was not able to sample ' + 'all negative edges requested!' ) else: @@ -364,13 +374,14 @@ def structured_negative_sampling_feasible( def get_method(method: str, size: Tuple[int, int]) -> str: # select the method tot_num_edges = size[0] * size[1] - auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small + # prefer dense method if the graph is small + auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' method = auto_method if method == 'auto' else method if method == 'dense' and tot_num_edges >= _MAX_NUM_EDGES: warnings.warn( - f'You choose the dense method on a graph with {tot_num_edges} possible edges! ' - f'It could require a lot of memory!') + f'You choose the dense method on a graph with {tot_num_edges} ' + f'possible edges! It could require a lot of memory!') return method @@ -468,10 +479,12 @@ def get_first_k_true_values_for_each_row(input_mask: Tensor, def get_neg_edge_mask(edge_id: Tensor, guess_edge_id: Tensor) -> Tensor: num_edges = edge_id.size(0) pos = torch.searchsorted(edge_id, guess_edge_id) - # pos contains the position where to insert the guessed id to maintain the edge_id sort. - # 1) if pos == num_edges, it means that we should add the guessed if at the end of the vector -> the id is new! - # 2) if pos != num_edges but the id in position pos != from the guessed one -> the id is new! - neg_edge_mask = torch.eq(pos, num_edges) # negative edge from case 1) + # pos contains the position where to insert the guessed id + # to maintain the edge_id sort. There are two cases for new_id: + # 1) if pos == num_edges (it means that we should add it at the end) + # 2) if pos != num_edges but the id in position pos != from the guessed one + # negative edge from case 1) + neg_edge_mask = torch.eq(pos, num_edges) not_neg_edge_mask = torch.logical_not(neg_edge_mask) # negative edge from case 2) neg_edge_mask[not_neg_edge_mask] = edge_id[ From 92ff5382757cf857f56a7b4ac208438c115e1a74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:15:53 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/utils/_negative_sampling.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index 29564aeb1dda..267b819830f6 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -110,8 +110,7 @@ def negative_sampling( warnings.warn( 'The probability of sampling a negative edge is too low! ' 'It could be that the number of sampled edges is smaller ' - 'than the numbers you required!' - ) + 'than the numbers you required!') guess_edge_index, guess_edge_id = sample_almost_k_edges( size, k, force_undirected=force_undirected, @@ -302,10 +301,8 @@ def structured_negative_sampling( neg_col[ok_edges] = col_to_save.view(-1) if not torch.all(ok_edges): - raise ValueError( - 'Sparse method was not able to sample ' - 'all negative edges requested!' - ) + raise ValueError('Sparse method was not able to sample ' + 'all negative edges requested!') else: guess_col, guess_edge_id = sample_k_structured_edges( From 038f81abbdebeedfab9a7ca6845fec9ed34d28b7 Mon Sep 17 00:00:00 2001 From: Jinu Sunil Date: Sat, 7 Sep 2024 12:57:10 -0700 Subject: [PATCH 17/17] lint issue fix. --- test/utils/test_negative_sampling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/utils/test_negative_sampling.py b/test/utils/test_negative_sampling.py index 1dac1fb78565..d2baf5aabb56 100644 --- a/test/utils/test_negative_sampling.py +++ b/test/utils/test_negative_sampling.py @@ -159,7 +159,6 @@ def test_structured_negative_sampling(): assert (adj & neg_adj).sum() == 0 # Test with no self-loops: - #edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) i, j, k = structured_negative_sampling(edge_index, num_nodes=4, contains_neg_self_loops=False) neg_edge_index = torch.vstack([i, k])