Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
momo609 committed Dec 2, 2023
1 parent b8ad2ca commit b5861b3
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions mmcv/ops/csrc/common/pytorch_npu_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ constexpr int kHashBufMaxSize = kHashBufSize + 1024;
extern thread_local char g_hashBuf[kHashBufSize];
extern thread_local int g_hashOffset;

#ifdef MMCV_WITH_XLA
#define DEVICE_TYPE at_npu::key::NativeDeviceType
#else
#define DEVICE_TYPE c10::DeviceType::PrivateUse1
#endif

#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \
_(at::ScalarType::Byte, ACL_UINT8) \
_(at::ScalarType::Char, ACL_INT8) \
Expand Down Expand Up @@ -201,7 +207,7 @@ inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) {
at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory();
int deviceIndex = 0;
return cpuPinMemTensor.to(
c10::Device(at_npu::key::NativeDeviceType, deviceIndex),
c10::Device(DEVICE_TYPE, deviceIndex),
cpuPinMemTensor.scalar_type(), true, true);
}

Expand Down Expand Up @@ -259,13 +265,13 @@ inline aclTensor *ConvertType(const at::Tensor &at_tensor) {
return aclCreateTensor(
aclInput.sizes().data(), aclInput.sizes().size(), acl_data_type,
aclInput.strides().data(), aclInput.storage_offset(), format,
storageDims.data(), storageDims.size(), aclInput.storage().data());
storageDims.data(), storageDims.size(), const_cast<void*>(aclInput.storage().data()));
}

auto acl_tensor = aclCreateTensor(
at_tensor.sizes().data(), at_tensor.sizes().size(), acl_data_type,
at_tensor.strides().data(), at_tensor.storage_offset(), format,
storageDims.data(), storageDims.size(), at_tensor.storage().data());
storageDims.data(), storageDims.size(), const_cast<void*>(at_tensor.storage().data()));
return acl_tensor;
}

Expand Down Expand Up @@ -550,7 +556,7 @@ typedef void (*ReleaseHugeMem)(void *, bool);
at::TensorOptions(torch_npu::utils::get_npu_device_type()); \
auto workspace_tensor = \
at::empty({workspace_size}, options.dtype(kByte)); \
workspace_addr = workspace_tensor.storage().data(); \
workspace_addr = const_cast<void*>(workspace_tensor.storage().data()); \
} \
auto acl_call = [converted_params, workspace_addr, workspace_size, \
acl_stream, executor]() -> int { \
Expand Down

0 comments on commit b5861b3

Please sign in to comment.