diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 4ff55280e89..f75bfb77a7f 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -160,7 +160,7 @@ def meta_ps_roi_pool_backward( return grad.new_empty((batch_size, channels, height, width)) -@torch._custom_ops.impl_abstract("torchvision::nms") +@torch.library.register_fake("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 07a934bce5a..5ecf8812f1b 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -19,6 +19,7 @@ at::Tensor nms( } TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.set_python_module("torchvision._meta_registrations"); m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); }