From 10d440c555bd23c58a3dd41c894486f8aa88ddcc Mon Sep 17 00:00:00 2001 From: Castellana Date: Thu, 5 Sep 2024 09:42:53 +0200 Subject: [PATCH] Add all type annotations --- torch_geometric/utils/_negative_sampling.py | 29 +++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/torch_geometric/utils/_negative_sampling.py b/torch_geometric/utils/_negative_sampling.py index a9a79e8bfcc6..2fc02c0cb9d7 100644 --- a/torch_geometric/utils/_negative_sampling.py +++ b/torch_geometric/utils/_negative_sampling.py @@ -270,18 +270,13 @@ def structured_negative_sampling( # select the method method = get_method(method, size) - k = None - torch.min(1 - deg / num_nodes) if method == 'sparse': - k = 10 # TODO: can we compute the prob to find the k? + k = 10 + guess_col, guess_edge_id = sample_k_structured_edges( + edge_index, num_nodes, k, not contains_neg_self_loops, method, device) + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - guess_col, guess_edge_id = sample_k_structured_edges( - edge_index, num_nodes, k, not contains_neg_self_loops, method, device) - - neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) - - if method == 'sparse': neg_col = -torch.ones_like(edge_index[0]) neg_edge_mask = get_first_k_true_values_for_each_row( neg_edge_mask.view(num_edges, k), 1) @@ -297,6 +292,10 @@ def structured_negative_sampling( 'We were not able to sample all negative edges requested!') else: + guess_col, guess_edge_id = sample_k_structured_edges( + edge_index, num_nodes, None, not contains_neg_self_loops, method, device) + neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id) + shape = (num_nodes, num_nodes if contains_neg_self_loops else num_nodes - 1) neg_edge_mask = get_first_k_true_values_for_each_row( @@ -357,7 +356,7 @@ def structured_negative_sampling_feasible( ############################################################################### -def get_method(method, size): +def get_method(method: str, size: Tuple[int, int]) -> str: # select the method tot_num_edges = size[0] * size[1] auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small @@ -373,7 +372,7 @@ def get_method(method, size): def sample_almost_k_edges( size: Tuple[int, int], - k: int, + k: Optional[int], force_undirected: bool, remove_self_loops: bool, method: str, @@ -384,6 +383,7 @@ def sample_almost_k_edges( N1, N2 = size if method == 'sparse': + assert k is not None k = 2 * k if force_undirected else k p = torch.tensor([1.], device=device).expand(N1 * N2) if k > N1 * N2: @@ -413,7 +413,7 @@ def sample_almost_k_edges( def sample_k_structured_edges( edge_index: Tensor, num_nodes: int, - k: int, + k: Optional[int], remove_self_loops: bool, method: str, device: Optional[Union[torch.device, str]] = None, @@ -425,6 +425,7 @@ def sample_k_structured_edges( size = (num_nodes, num_nodes) if method == 'sparse': + assert k is not None new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint( num_nodes, (num_edges, k)) new_edge_id = edge_index_to_vector_id((new_row, new_col), size) @@ -451,14 +452,14 @@ def sample_k_structured_edges( return new_col.view(-1), new_edge_id.view(-1) -def get_first_k_true_values_for_each_row(input_mask, k): +def get_first_k_true_values_for_each_row(input_mask: Tensor, k: Union[Tensor, int]) -> Tensor: if isinstance(k, Tensor): k = k.unsqueeze(1) cum_mask = torch.le(torch.cumsum(input_mask, dim=1), k) return torch.logical_and(input_mask, cum_mask) -def get_neg_edge_mask(edge_id, guess_edge_id): +def get_neg_edge_mask(edge_id: Tensor, guess_edge_id: Tensor) -> Tensor: num_edges = edge_id.size(0) pos = torch.searchsorted(edge_id, guess_edge_id) # pos contains the position where to insert the guessed id to maintain the edge_id sort.