Skip to content

Commit

Permalink
Add all type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
danielecastellana22 committed Sep 5, 2024
1 parent 7db30a6 commit 10d440c
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions torch_geometric/utils/_negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,18 +270,13 @@ def structured_negative_sampling(
# select the method
method = get_method(method, size)

k = None
torch.min(1 - deg / num_nodes)
if method == 'sparse':
k = 10
# TODO: can we compute the prob to find the k?
k = 10
guess_col, guess_edge_id = sample_k_structured_edges(
edge_index, num_nodes, k, not contains_neg_self_loops, method, device)
neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id)

guess_col, guess_edge_id = sample_k_structured_edges(
edge_index, num_nodes, k, not contains_neg_self_loops, method, device)

neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id)

if method == 'sparse':
neg_col = -torch.ones_like(edge_index[0])
neg_edge_mask = get_first_k_true_values_for_each_row(
neg_edge_mask.view(num_edges, k), 1)
Expand All @@ -297,6 +292,10 @@ def structured_negative_sampling(
'We were not able to sample all negative edges requested!')

else:
guess_col, guess_edge_id = sample_k_structured_edges(
edge_index, num_nodes, None, not contains_neg_self_loops, method, device)
neg_edge_mask = get_neg_edge_mask(edge_id, guess_edge_id)

shape = (num_nodes,
num_nodes if contains_neg_self_loops else num_nodes - 1)
neg_edge_mask = get_first_k_true_values_for_each_row(
Expand Down Expand Up @@ -357,7 +356,7 @@ def structured_negative_sampling_feasible(
###############################################################################


def get_method(method, size):
def get_method(method: str, size: Tuple[int, int]) -> str:
# select the method
tot_num_edges = size[0] * size[1]
auto_method = 'dense' if tot_num_edges < _MAX_NUM_EDGES else 'sparse' # prefer dense method if the graph is small
Expand All @@ -373,7 +372,7 @@ def get_method(method, size):

def sample_almost_k_edges(
size: Tuple[int, int],
k: int,
k: Optional[int],
force_undirected: bool,
remove_self_loops: bool,
method: str,
Expand All @@ -384,6 +383,7 @@ def sample_almost_k_edges(
N1, N2 = size

if method == 'sparse':
assert k is not None
k = 2 * k if force_undirected else k
p = torch.tensor([1.], device=device).expand(N1 * N2)
if k > N1 * N2:
Expand Down Expand Up @@ -413,7 +413,7 @@ def sample_almost_k_edges(
def sample_k_structured_edges(
edge_index: Tensor,
num_nodes: int,
k: int,
k: Optional[int],
remove_self_loops: bool,
method: str,
device: Optional[Union[torch.device, str]] = None,
Expand All @@ -425,6 +425,7 @@ def sample_k_structured_edges(
size = (num_nodes, num_nodes)

if method == 'sparse':
assert k is not None
new_row, new_col = row.view(-1, 1).expand(-1, k), torch.randint(
num_nodes, (num_edges, k))
new_edge_id = edge_index_to_vector_id((new_row, new_col), size)
Expand All @@ -451,14 +452,14 @@ def sample_k_structured_edges(
return new_col.view(-1), new_edge_id.view(-1)


def get_first_k_true_values_for_each_row(input_mask, k):
def get_first_k_true_values_for_each_row(input_mask: Tensor, k: Union[Tensor, int]) -> Tensor:
if isinstance(k, Tensor):
k = k.unsqueeze(1)
cum_mask = torch.le(torch.cumsum(input_mask, dim=1), k)
return torch.logical_and(input_mask, cum_mask)


def get_neg_edge_mask(edge_id, guess_edge_id):
def get_neg_edge_mask(edge_id: Tensor, guess_edge_id: Tensor) -> Tensor:
num_edges = edge_id.size(0)
pos = torch.searchsorted(edge_id, guess_edge_id)
# pos contains the position where to insert the guessed id to maintain the edge_id sort.
Expand Down

0 comments on commit 10d440c

Please sign in to comment.