diff --git a/source/adapters/hip/enqueue.cpp b/source/adapters/hip/enqueue.cpp index 68e3e665d2..95b3c2a58d 100644 --- a/source/adapters/hip/enqueue.cpp +++ b/source/adapters/hip/enqueue.cpp @@ -89,33 +89,46 @@ ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size, ur_usm_advice_flags_t URAdviceFlags, hipDevice_t Device) { // Handle unmapped memory advice flags + // FIXME: Temporary use UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY and + // UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY for controlling coarse-grain + // memory until we introduce a new flag more appropriately. Add them back to + // unsupported when that happens. if (URAdviceFlags & - (UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY | - UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY | - UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) { + (UR_USM_ADVICE_FLAG_BIAS_CACHED | UR_USM_ADVICE_FLAG_BIAS_UNCACHED)) { return UR_RESULT_ERROR_INVALID_ENUMERATION; } using ur_to_hip_advice_t = std::pair; - static constexpr std::array - URToHIPMemAdviseDeviceFlags{ - std::make_pair(UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, - hipMemAdviseSetReadMostly), - std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY, - hipMemAdviseUnsetReadMostly), - std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION, - hipMemAdviseSetPreferredLocation), - std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION, - hipMemAdviseUnsetPreferredLocation), - std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE, - hipMemAdviseSetAccessedBy), - std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE, - hipMemAdviseUnsetAccessedBy), - }; - for (auto &FlagPair : URToHIPMemAdviseDeviceFlags) { - if (URAdviceFlags & FlagPair.first) { - UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, FlagPair.second, Device)); +#if defined(__HIP_PLATFORM_AMD__) + constexpr size_t DeviceFlagCount = 8; +#else + constexpr size_t DeviceFlagCount = 6; +#endif + static constexpr std::array + URToHIPMemAdviseDeviceFlags { + std::make_pair(UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, + hipMemAdviseSetReadMostly), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY, + hipMemAdviseUnsetReadMostly), + std::make_pair(UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION, + hipMemAdviseSetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION, + hipMemAdviseUnsetPreferredLocation), + std::make_pair(UR_USM_ADVICE_FLAG_SET_ACCESSED_BY_DEVICE, + hipMemAdviseSetAccessedBy), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_ACCESSED_BY_DEVICE, + hipMemAdviseUnsetAccessedBy), +#if defined(__HIP_PLATFORM_AMD__) + std::make_pair(UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY, + hipMemAdviseSetCoarseGrain), + std::make_pair(UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY, + hipMemAdviseUnsetCoarseGrain), +#endif + }; + for (const auto &[URAdvice, HIPAdvice] : URToHIPMemAdviseDeviceFlags) { + if (URAdviceFlags & URAdvice) { + UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, HIPAdvice, Device)); } } @@ -130,10 +143,9 @@ ur_result_t setHipMemAdvise(const void *DevPtr, const size_t Size, hipMemAdviseUnsetAccessedBy), }; - for (auto &FlagPair : URToHIPMemAdviseHostFlags) { - if (URAdviceFlags & FlagPair.first) { - UR_CHECK_ERROR( - hipMemAdvise(DevPtr, Size, FlagPair.second, hipCpuDeviceId)); + for (const auto &[URAdvice, HIPAdvice] : URToHIPMemAdviseHostFlags) { + if (URAdviceFlags & URAdvice) { + UR_CHECK_ERROR(hipMemAdvise(DevPtr, Size, HIPAdvice, hipCpuDeviceId)); } } @@ -1615,6 +1627,10 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size, pMem, size, hipMemAdviseUnsetPreferredLocation, DeviceID)); UR_CHECK_ERROR( hipMemAdvise(pMem, size, hipMemAdviseUnsetAccessedBy, DeviceID)); +#if defined(__HIP_PLATFORM_AMD__) + UR_CHECK_ERROR( + hipMemAdvise(pMem, size, hipMemAdviseUnsetCoarseGrain, DeviceID)); +#endif } else { Result = setHipMemAdvise(HIPDevicePtr, size, advice, DeviceID); // UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid but