Skip to content

Commit

Permalink
更新knn和three_nn的NPU适配代码 (#3194)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuan64 authored Nov 18, 2024
1 parent 71437a3 commit e1aab12
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,

bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
Expand Down

0 comments on commit e1aab12

Please sign in to comment.