Skip to content

Commit

Permalink
Move GpuDriver::GetPointerMemorySpace to the appropriate Executor cla…
Browse files Browse the repository at this point in the history
…sses.

PiperOrigin-RevId: 685292532
  • Loading branch information
klucke authored and Google-ML-Automation committed Oct 13, 2024
1 parent 98fdb43 commit ea3993e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 48 deletions.
16 changes: 0 additions & 16 deletions xla/stream_executor/cuda/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -944,22 +944,6 @@ int GpuDriver::GetDeviceCount() {
return device_count;
}

absl::StatusOr<MemoryType> GpuDriver::GetPointerMemorySpace(
CUdeviceptr pointer) {
unsigned int value;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute(
&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer)));
switch (value) {
case CU_MEMORYTYPE_DEVICE:
return MemoryType::kDevice;
case CU_MEMORYTYPE_HOST:
return MemoryType::kHost;
default:
return absl::InternalError(
absl::StrCat("unknown memory space provided by CUDA API: ", value));
}
}

absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr,
CUdeviceptr* base,
size_t* size) {
Expand Down
17 changes: 17 additions & 0 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1171,5 +1171,22 @@ CudaExecutor::CreateDeviceDescription(int device_ordinal) {
return std::make_unique<DeviceDescription>(std::move(desc));
}

absl::StatusOr<MemoryType> CudaExecutor::GetPointerMemorySpace(
const void* ptr) {
CUdeviceptr pointer = reinterpret_cast<CUdeviceptr>(const_cast<void*>(ptr));
unsigned int value;
TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute(
&value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer)));
switch (value) {
case CU_MEMORYTYPE_DEVICE:
return MemoryType::kDevice;
case CU_MEMORYTYPE_HOST:
return MemoryType::kHost;
default:
return absl::InternalError(
absl::StrCat("unknown memory space provided by CUDA API: ", value));
}
}

} // namespace gpu
} // namespace stream_executor
5 changes: 1 addition & 4 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,7 @@ class CudaExecutor : public GpuExecutor {
bool HostMemoryRegister(void* location, uint64_t size) override;
bool HostMemoryUnregister(void* location) override;

absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
return GpuDriver::GetPointerMemorySpace(
reinterpret_cast<GpuDevicePtr>(const_cast<void*>(ptr)));
}
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override;

Stream* FindAllocatedStream(void* gpu_stream) override {
absl::MutexLock lock(&alive_gpu_streams_mu_);
Expand Down
3 changes: 0 additions & 3 deletions xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,6 @@ class GpuDriver {

// -- Pointer-specific calls.

// Returns the memory space addressed by pointer.
static absl::StatusOr<MemoryType> GetPointerMemorySpace(GpuDevicePtr pointer);

// Returns the base address and size of the device pointer dptr.
static absl::Status GetPointerAddressRange(GpuDevicePtr dptr,
GpuDevicePtr* base, size_t* size);
Expand Down
21 changes: 0 additions & 21 deletions xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -766,27 +766,6 @@ absl::Status GpuDriver::GetPointerAddressRange(hipDeviceptr_t dptr,
reinterpret_cast<void*>(dptr), ToString(result).c_str()));
}

absl::StatusOr<MemoryType> GpuDriver::GetPointerMemorySpace(
hipDeviceptr_t pointer) {
unsigned int value;
hipError_t result = wrap::hipPointerGetAttribute(
&value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
if (result == hipSuccess) {
switch (value) {
case hipMemoryTypeDevice:
return MemoryType::kDevice;
case hipMemoryTypeHost:
return MemoryType::kHost;
default:
return absl::InternalError(
absl::StrCat("unknown memory space provided by ROCM API: ", value));
}
}

return absl::InternalError(absl::StrCat(
"failed to query device pointer for memory space: ", ToString(result)));
}

absl::StatusOr<int32_t> GpuDriver::GetDriverVersion() {
int32_t version;
TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version),
Expand Down
24 changes: 24 additions & 0 deletions xla/stream_executor/rocm/rocm_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "absl/types/span.h"
#include "rocm/include/hip/driver_types.h"
#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_version.h"
#include "rocm/rocm_config.h"
Expand Down Expand Up @@ -975,6 +976,29 @@ RocmExecutor::CreateDeviceDescription(int device_ordinal) {
return std::make_unique<DeviceDescription>(std::move(desc));
}

absl::StatusOr<MemoryType> RocmExecutor::GetPointerMemorySpace(
const void* ptr) {
hipDeviceptr_t pointer =
reinterpret_cast<hipDeviceptr_t>(const_cast<void*>(ptr));
unsigned int value;
hipError_t result = wrap::hipPointerGetAttribute(
&value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer);
if (result == hipSuccess) {
switch (value) {
case hipMemoryTypeDevice:
return MemoryType::kDevice;
case hipMemoryTypeHost:
return MemoryType::kHost;
default:
return absl::InternalError(
absl::StrCat("unknown memory space provided by ROCM API: ", value));
}
}

return absl::InternalError(absl::StrCat(
"failed to query device pointer for memory space: ", ToString(result)));
}

} // namespace gpu

} // namespace stream_executor
Expand Down
5 changes: 1 addition & 4 deletions xla/stream_executor/rocm/rocm_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,7 @@ class RocmExecutor : public GpuExecutor {
return GpuDriver::HostDeallocate(gpu_context(), location);
}

absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override {
return GpuDriver::GetPointerMemorySpace(
reinterpret_cast<GpuDevicePtr>(const_cast<void*>(ptr)));
}
absl::StatusOr<MemoryType> GetPointerMemorySpace(const void* ptr) override;

Stream* FindAllocatedStream(void* gpu_stream) override {
absl::MutexLock lock(&alive_gpu_streams_mu_);
Expand Down

0 comments on commit ea3993e

Please sign in to comment.