Skip to content

Commit

Permalink
[fbsync] Update usages of torch.library APIs (#8384)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D58283859

fbshipit-source-id: e882d7dbc22ec3e04edea50f95b6a30456f8fd2b

Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
3 people authored and facebook-github-bot committed Jun 7, 2024
1 parent c4d1728 commit 0b43746
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def meta_ps_roi_pool_backward(
return grad.new_empty((batch_size, channels, height, width))


@torch._custom_ops.impl_abstract("torchvision::nms")
@torch.library.register_fake("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/ops/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ at::Tensor nms(
}

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.set_python_module("torchvision._meta_registrations");
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
}
Expand Down

0 comments on commit 0b43746

Please sign in to comment.