-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Features] Add NMS Kernel support with Triton Implementation #8746
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8746
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchvision/ops/xpu/nms.py
Outdated
picked.append(order[i]) | ||
remove_box[i:] |= iou_keep_out_mask[i][i:] | ||
|
||
return torch.as_tensor(picked) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this also respect the device of the boxes
? (remove_boxes
is allocated on boxes.device
, while the return value - always on CPU)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder~! Yes this should be on boxes.device. I will update it.
Motivation
This PR follows RFC #8679 which proposes to add torchvision custom op support with Triton kernels.
Implementing Method
The Triton kernel mapping basically follows the CUDA kernels. As is shown below, the native CUDA kernel will be mapped into the Triton kernel. Some logic could not be run in parallel, thus they will be implemented with Python as well as C++ Ops.
This PR contains the following parts:
torchvision/ops/triton/
. This contains the common logic that could be implemented in Triton.torchvision/ops/xpu
. This will do op registration and combine non-Triton ops with Triton kernels into one big op.Kernel Implementing Structure
The NMS kernel contains three parts, please see
torchvision/ops/xpu/nms.py
for details. It wraps the three parts:argsort
which are called using PyTorch ATen ops.torchvision/ops/triton/nms.py
. It is a device-agnostic part, which could be shared across devices.Kernel Implementing Detail
box j
if we have already chosenbox i
. A naive implementation will have a matrix with[N, N]
. However, as the performance consideration, it will try to combine the "bit mask" into 32-bits ints. Thus, the output will be[N, N//32]
.row i
, this means we choose thebox i
. As a result, some boxesj
will be excluded. That's what the post-process function does. To make it more device-agnostic, we choose to do this serialized process on the CPU.cc: @EikanWang