Skip to content

Commit

Permalink
fix gamgm-backward testing error
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerwwww committed Nov 7, 2023
1 parent d3c2bc2 commit 6ddb482
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions tests/test_multi_graph_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,12 @@ def test_gamgm_backward():
# This function is differentiable by the black-box trick
W.requires_grad_(True) # tell PyTorch to track the gradients
X = pygm.gamgm(As, W)
matched = 0
loss = 0 # a random loss function
for i, j in itertools.product(range(graph_num), repeat=2):
matched += (X[i, j] * X_gt[i, j]).sum()
acc = matched / X_gt.sum()
loss += (X[i, j] * torch.rand_like(X[i, j])).sum()

# Backward pass via black-box trick
acc.backward()
loss.backward()
assert torch.sum(W.grad != 0) > 0

# Jittor
Expand Down Expand Up @@ -271,14 +270,13 @@ def execute(self, As):
W.start_grad()
model = Model(W)
X = model(As)
matched = 0
loss = 0
for i, j in itertools.product(range(graph_num), repeat=2):
matched += (X[i, j] * X_gt[i, j]).sum()
acc = matched / X_gt.sum()
loss += (X[i, j] * jt.rand_like(X[i, j])).sum()

# Backward pass via black-box trick
optim = jt.nn.SGD(model.parameters(), lr=0.1)
optim.step(acc)
optim.step(loss)
grad = W.opt_grad(optim)
print(jt.sum(grad != 0))
assert jt.sum(grad != 0) > 0
Expand Down

0 comments on commit 6ddb482

Please sign in to comment.