Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

knn/tnn npu adapter added #3124

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"

using namespace NPU_NAME_SPACE;
using namespace std;

void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(1, 2).contiguous();
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2);

REGISTER_NPU_IMPL(knn_forward_impl, knn_forward_npu);
30 changes: 30 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "pytorch_npu_helper.hpp"
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
#include "torch_npu/csrc/framework/utils/OpAdapter.h"

using namespace NPU_NAME_SPACE;
using namespace std;

void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// transpose known [B, N, 3] -> [B, 3, N]
at::Tensor source = known.transpose(1, 2).contiguous();
at::Tensor target = unknown.contiguous();
auto originDtype = source.scalar_type();
if (originDtype == at::kHalf) {
source = source.to(at::kFloat);
target = target.to(at::kFloat);
}

bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
}

void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx);

REGISTER_NPU_IMPL(three_nn_forward_impl, three_nn_forward_npu);
5 changes: 3 additions & 2 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def forward(ctx,
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
if xyz.device.type != 'npu':
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
Expand Down
Loading