Skip to content

Commit

Permalink
RoiAlignRotatedV2算子调用适配层 (#3177)
Browse files Browse the repository at this point in the history
  • Loading branch information
ason-rob authored Sep 26, 2024
1 parent bd0c65e commit 73fa88a
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 17 deletions.
36 changes: 19 additions & 17 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .riroi_align_rotated import RiRoIAlignRotated, riroi_align_rotated
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
from .roi_align_rotated_v2 import RoIAlignRotatedV2, roi_align_rotated_v2
from .roi_pool import RoIPool, roi_pool
from .roiaware_pool3d import RoIAwarePool3d
from .roipoint_pool3d import RoIPointPool3d
Expand Down Expand Up @@ -92,23 +93,24 @@
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'contour_expand', 'three_nn', 'three_interpolate',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d', 'SparseConv2d',
'SparseConv3d', 'SparseConvTranspose2d', 'SparseConvTranspose3d',
'SparseInverseConv2d', 'SparseInverseConv3d', 'SubMConv2d', 'SubMConv3d',
'SparseModule', 'SparseSequential', 'SparseMaxPool2d', 'SparseMaxPool3d',
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance',
'PrRoIPool', 'prroi_pool', 'bias_act', 'filtered_lrelu', 'conv2d',
'conv_transpose2d', 'filter2d', 'upsample2d', 'BezierAlign', 'bezier_align'
'RoIAlignRotatedV2', 'roi_align_rotated_v2', 'pixel_group',
'QueryAndGroup', 'GroupAll', 'grouping_operation', 'contour_expand',
'three_nn', 'three_interpolate', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'nms_quadri', 'furthest_point_sample_with_dist', 'PointsSampler',
'Correlation', 'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev',
'nms_bev', 'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization',
'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
'SparseConv2d', 'SparseConv3d', 'SparseConvTranspose2d',
'SparseConvTranspose3d', 'SparseInverseConv2d', 'SparseInverseConv3d',
'SubMConv2d', 'SubMConv3d', 'SparseModule', 'SparseSequential',
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d',
'chamfer_distance', 'PrRoIPool', 'prroi_pool', 'bias_act',
'filtered_lrelu', 'conv2d', 'conv_transpose2d', 'filter2d', 'upsample2d',
'BezierAlign', 'bezier_align'
]

if IS_MLU_AVAILABLE:
Expand Down
52 changes: 52 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void roi_align_rotated_v2_forward_npu(const Tensor input, Tensor rois_map,
Tensor output,
double spatial_scale,
int32_t sampling_ratio,
int32_t pooled_height,
int32_t pooled_width,
bool aligned,
bool clockwise) {
at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois = rois_map.permute({1, 0}).contiguous();
EXEC_NPU_CMD(aclnnRoiAlignRotatedV2, feature_map, rois, spatial_scale, sampling_ratio, pooled_height, pooled_width, aligned, clockwise, output);
}

void roi_align_rotated_v2_forward_impl(const Tensor input, Tensor rois,
Tensor output,
double spatial_scale,
int32_t sampling_ratio,
int32_t pooled_height,
int32_t pooled_width,
bool aligned,
bool clockwise);

REGISTER_NPU_IMPL(roi_align_rotated_v2_forward_impl, roi_align_rotated_v2_forward_npu);

void roi_align_rotated_v2_backward_npu(const Tensor input, Tensor rois,
Tensor grad_output, Tensor grad_input,
int32_t pooled_height,
int32_t pooled_width,
double spatial_scale,
int32_t sampling_ratio,
bool aligned,
bool clockwise) {
EXEC_NPU_CMD(aclnnRoiAlignRotatedGradV2, input, rois, grad_output,
pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise,
grad_input);
}

void roi_align_rotated_v2_backward_impl(const Tensor input, Tensor rois,
Tensor grad_output, Tensor grad_input,
int32_t pooled_height,
int32_t pooled_width,
double spatial_scale,
int32_t sampling_ratio,
bool aligned,
bool clockwise);

REGISTER_NPU_IMPL(roi_align_rotated_v2_backward_impl, roi_align_rotated_v2_backward_npu);
19 changes: 19 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned);

void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output,
double spatial_scale, int sampling_ratio,
int aligned_height, int aligned_width,
bool aligned, bool clockwise);

void roi_align_rotated_v2_backward(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input,
int pooled_height, int pooled_width, double spatial_scale,
int sampling_ratio, bool aligned, bool clockwise);

void roi_pool_forward(Tensor input, Tensor rois, Tensor output, Tensor argmax,
int pooled_height, int pooled_width, float spatial_scale);

Expand Down Expand Up @@ -792,6 +801,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"),
py::arg("sampling_ratio"), py::arg("aligned"), py::arg("clockwise"));
m.def("roi_align_rotated_v2_forward", &roi_align_rotated_v2_forward,
"roi_align_rotated_v2_forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("spatial_scale"), py::arg("sampling_ratio"),
py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("aligned"), py::arg("clockwise"));
m.def("roi_align_rotated_v2_backward", &roi_align_rotated_v2_backward,
"roi_align_rotated_v2_backward", py::arg("input"), py::arg("rois"),
py::arg("grad_output"), py::arg("grad_input"), py::arg("pooled_height"),
py::arg("pooled_width"), py::arg("spatial_scale"), py::arg("sampling_ratio"),
py::arg("aligned"), py::arg("clockwise"));
m.def("dynamic_point_to_voxel_forward", &dynamic_point_to_voxel_forward,
"dynamic_point_to_voxel_forward", py::arg("feats"), py::arg("coors"),
py::arg("reduce_type"));
Expand Down
37 changes: 37 additions & 0 deletions mmcv/ops/csrc/pytorch/roi_align_rotated_v2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

void roi_align_rotated_v2_forward_impl(Tensor input, Tensor rois, Tensor output,
double spatial_scale, int sampling_ratio,
int pooled_height, int pooled_width,
bool aligned, bool clockwise) {
DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_forward_impl, input, rois, output,
spatial_scale, sampling_ratio, pooled_height, pooled_width,
aligned, clockwise);
}


void roi_align_rotated_v2_forward(Tensor input, Tensor rois, Tensor output,
double spatial_scale, int sampling_ratio,
int pooled_height, int pooled_width,
bool aligned, bool clockwise) {
roi_align_rotated_v2_forward_impl(input, rois, output, spatial_scale, sampling_ratio,
pooled_height, pooled_width, aligned, clockwise);
}


void roi_align_rotated_v2_backward_impl(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input,
int pooled_height, int pooled_width, double spatial_scale,
int sampling_ratio, bool aligned, bool clockwise) {
DISPATCH_DEVICE_IMPL(roi_align_rotated_v2_backward_impl, input, rois, grad_output, grad_input,
pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise);
}


void roi_align_rotated_v2_backward(Tensor input, Tensor rois, Tensor grad_output, Tensor grad_input,
int pooled_height, int pooled_width, double spatial_scale,
int sampling_ratio, bool aligned, bool clockwise) {
roi_align_rotated_v2_backward_impl(input, rois, grad_output, grad_input,
pooled_height, pooled_width, spatial_scale, sampling_ratio, aligned, clockwise);
}
166 changes: 166 additions & 0 deletions mmcv/ops/roi_align_rotated_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any

import torch
import torch.nn as nn
from mmengine.utils import deprecated_api_warning
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['roi_align_rotated_v2_forward'])


class RoIAlignRotatedV2Function(Function):

@staticmethod
def symbolic(g, input, rois, spatial_scale, sampling_ratio, pooled_height,
pooled_width, aligned, clockwise):
return g.op(
'mmcv::MMCVRoIAlignRotatedV2',
input,
rois,
spatial_scale_f=spatial_scale,
sampling_ratio_i=sampling_ratio,
pooled_height=pooled_height,
pooled_width=pooled_width,
aligned_i=aligned,
clockwise_i=clockwise)

@staticmethod
def forward(ctx: Any,
input: torch.Tensor,
rois: torch.Tensor,
spatial_scale: float,
sampling_ratio: int,
pooled_height: int,
pooled_width: int,
aligned: bool = True,
clockwise: bool = False) -> torch.Tensor:
ctx.pooled_height = pooled_height
ctx.pooled_width = pooled_width
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.aligned = aligned
ctx.clockwise = clockwise
ctx.save_for_backward(input, rois)
ctx.feature_size = input.size()
batch_size, num_channels, data_height, data_width = input.size()
num_rois = rois.size(0)

output = input.new_zeros(num_rois, ctx.pooled_height, ctx.pooled_width,
num_channels)

ext_module.roi_align_rotated_v2_forward(
input,
rois,
output,
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
pooled_height=ctx.pooled_height,
pooled_width=ctx.pooled_width,
aligned=ctx.aligned,
clockwise=ctx.clockwise)
output = output.transpose(2, 3).transpose(1, 2).contiguous()
return output

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor):
input, rois = ctx.saved_tensors
rois_trans = torch.permute(rois, (1, 0)).contiguous()
grad_output_trans = torch.permute(grad_output,
(0, 2, 3, 1)).contiguous()
grad_input = input.new_zeros(
input.size(0), input.size(2), input.size(3), input.size(1))
ext_module.roi_align_rotated_v2_backward(
input, rois_trans, grad_output_trans, grad_input,
ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
ctx.sampling_ratio, ctx.aligned, ctx.clockwise)
grad_input = grad_input.permute(0, 3, 1, 2).contiguous()

return grad_input, None, None, None, None, None, None, None


roi_align_rotated_v2 = RoIAlignRotatedV2Function.apply


class RoIAlignRotatedV2(nn.Module):
"""RoI align pooling layer for rotated proposals.
It accepts a feature map of shape (N, C, H, W) and rois with shape
(n, 6) with each roi decoded as (batch_index, center_x, center_y,
w, h, angle). The angle is in radian.
Args:
output_size (tuple): h, w
spatial_scale (float): scale the input boxes by this number
sampling_ratio(int): number of inputs samples to take for each
output sample. 0 to take samples densely for current models.
aligned (bool): if False, use the legacy implementation in
MMDetection. If True, align the results more perfectly.
Default: True.
clockwise (bool): If True, the angle in each proposal follows a
clockwise fashion in image space, otherwise, the angle is
counterclockwise. Default: False.
Note:
The implementation of RoIAlign when aligned=True is modified from
https://github.com/facebookresearch/detectron2/
The meaning of aligned=True:
Given a continuous coordinate c, its two neighboring pixel
indices (in our pixel model) are computed by floor(c - 0.5) and
ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
indices [0] and [1] (which are sampled from the underlying signal
at continuous coordinates 0.5 and 1.5). But the original roi_align
(aligned=False) does not subtract the 0.5 when computing
neighboring pixel indices and therefore it uses pixels with a
slightly incorrect alignment (relative to our pixel model) when
performing bilinear interpolation.
With `aligned=True`,
we first appropriately scale the ROI and then shift it by -0.5
prior to calling roi_align. This produces the correct neighbors;
The difference does not make a difference to the model's
performance if ROIAlign is used together with conv layers.
"""

@deprecated_api_warning(
{
'out_size': 'output_size',
'sample_num': 'sampling_ratio'
},
cls_name='RoIAlignRotatedV2')
def __init__(self,
spatial_scale: float,
sampling_ratio: int,
pooled_height: int,
pooled_width: int,
aligned: bool = True,
clockwise: bool = False):
super().__init__()

self.pooled_height = int(pooled_height)
self.pooled_width = int(pooled_width)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
self.aligned = aligned
self.clockwise = clockwise

def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return RoIAlignRotatedV2Function.apply(input, rois, self.spatial_scale,
self.sampling_ratio,
self.pooled_height,
self.pooled_width, self.aligned,
self.clockwise)

def __repr__(self):
s = self.__class__.__name__
s += f'(pooled_height={self.pooled_height}, '
s += f'spatial_scale={self.spatial_scale}, '
s += f'sampling_ratio={self.sampling_ratio}, '
s += f'aligned={self.aligned}, '
s += f'clockwise={self.clockwise})'
return s

0 comments on commit 73fa88a

Please sign in to comment.