Skip to content
This repository has been archived by the owner on Aug 23, 2023. It is now read-only.

Commit

Permalink
Fix backward propagation of Fsa.get_arc_post. (k2-fsa#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored May 10, 2022
1 parent 510cd72 commit ecfe7bd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
5 changes: 3 additions & 2 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,10 @@ def backward(
else _k2.backprop_get_arc_post_float)

incoming_arcs = fsas._get_incoming_arcs()
forward_scores_grad, backward_scores_grad = bprop_func(
fsas.arcs, incoming_arcs, arc_post_grad)

arc_scores_grad = arc_post_grad.detach().clone()
forward_scores_grad, backward_scores_grad = bprop_func(
fsas.arcs, incoming_arcs, arc_scores_grad)

return (
None, # fsas
Expand Down
20 changes: 20 additions & 0 deletions k2/python/tests/get_arc_post_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,26 @@ def test_simple_fsa_vec(self):
assert torch.allclose(fsa1.grad, scores1.grad, atol=1e-5)
assert torch.allclose(fsa2.grad, scores2.grad, atol=1e-5)

def test_simple_fsa_vec_2(self):
# test https://github.com/k2-fsa/k2/issues/969
s = '''
0 1 1 0.1
1 2 3 0.2
2 3 -1 0.3
3
'''
for device in self.devices:
for use_double_scores in [True, False]:
for log_semiring in [True, False]:
fsa = k2.Fsa.from_str(s).to(device).requires_grad_(True)
fsas = k2.Fsa.from_fsas([fsa])

arc_post = fsas.get_arc_post(
use_double_scores=use_double_scores,
log_semiring=log_semiring)
arc_post = arc_post.sum()
(-arc_post).backward()


if __name__ == '__main__':
unittest.main()

0 comments on commit ecfe7bd

Please sign in to comment.