Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redefine transform_reduce's scratch & result mem #1354

Merged
merged 15 commits into from
Apr 12, 2024
66 changes: 31 additions & 35 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,25 +128,23 @@ struct __parallel_transform_reduce_small_submitter<_Tp, __work_group_size, __ite
_Commutative>{__reduce_op, __transform_op};
auto __reduce_pattern = unseq_backend::reduce_over_group<_ExecutionPolicy, _ReduceOp, _Tp>{__reduce_op};

__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp> __res_container(__exec, 1);
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, 0);

sycl::event __reduce_event = __exec.queue().submit([&, __n](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); // get an access to data under SYCL buffer
auto __res_acc = __res_container.__get_acc(__cgh);
auto __res_acc = __scratch_container.__get_result_acc(__cgh);
::std::size_t __local_mem_size = __reduce_pattern.local_mem_req(__work_group_size);
__dpl_sycl::__local_accessor<_Tp> __temp_local(sycl::range<1>(__local_mem_size), __cgh);
__cgh.parallel_for<_Name...>(
sycl::nd_range<1>(sycl::range<1>(__work_group_size), sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __res_ptr =
__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp>::__get_usm_host_or_buffer_accessor_ptr(
__res_acc);
auto __res_ptr = __res_acc.__get_pointer();
__work_group_reduce_kernel<_Tp>(__item_id, __n, __transform_pattern, __reduce_pattern, __init,
__temp_local, __res_ptr, __rngs...);
});
});

return __future(__reduce_event, __res_container);
return __future(__reduce_event, __scratch_container);
}
}; // struct __parallel_transform_reduce_small_submitter

Expand Down Expand Up @@ -181,11 +179,11 @@ struct __parallel_transform_reduce_device_kernel_submitter<_Tp, __work_group_siz
__internal::__optional_kernel_name<_KernelName...>>
{
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _TransformOp, typename _InitType,
typename... _Ranges>
typename _ExecutionPolicy2, typename... _Ranges>
auto
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Size __n,
_ReduceOp __reduce_op, _TransformOp __transform_op, _InitType __init, sycl::buffer<_Tp>& __temp,
_Ranges&&... __rngs) const
_ReduceOp __reduce_op, _TransformOp __transform_op, _InitType __init,
__result_and_scratch_storage<_ExecutionPolicy2, _Tp> __scratch_container, _Ranges&&... __rngs) const
{
auto __transform_pattern =
unseq_backend::transform_reduce<_ExecutionPolicy, __iters_per_work_item, _ReduceOp, _TransformOp, _Tp,
Expand All @@ -198,14 +196,15 @@ struct __parallel_transform_reduce_device_kernel_submitter<_Tp, __work_group_siz

return __exec.queue().submit([&, __n](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); // get an access to data under SYCL buffer
sycl::accessor __temp_acc{__temp, __cgh, sycl::write_only, __dpl_sycl::__no_init{}};
::std::size_t __local_mem_size = __reduce_pattern.local_mem_req(__work_group_size);
__dpl_sycl::__local_accessor<_Tp> __temp_local(sycl::range<1>(__local_mem_size), __cgh);
auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh);
__cgh.parallel_for<_KernelName...>(
sycl::nd_range<1>(sycl::range<1>(__n_groups * __work_group_size), sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __temp_ptr = __temp_acc.__get_pointer();
__device_reduce_kernel<_Tp>(__item_id, __n, __transform_pattern, __reduce_pattern, __temp_local,
__temp_acc, __rngs...);
__temp_ptr, __rngs...);
});
});
}
Expand All @@ -223,11 +222,12 @@ template <typename _Tp, ::std::uint16_t __work_group_size, ::std::uint8_t __iter
struct __parallel_transform_reduce_work_group_kernel_submitter<
_Tp, __work_group_size, __iters_per_work_item, _Commutative, __internal::__optional_kernel_name<_KernelName...>>
{
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _TransformOp, typename _InitType>
template <typename _ExecutionPolicy, typename _Size, typename _ReduceOp, typename _TransformOp, typename _InitType,
typename _ExecutionPolicy2>
auto
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, sycl::event& __reduce_event,
_Size __n, _ReduceOp __reduce_op, _TransformOp __transform_op, _InitType __init,
sycl::buffer<_Tp>& __temp) const
__result_and_scratch_storage<_ExecutionPolicy2, _Tp> __scratch_container) const
{
using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>;
auto __transform_pattern =
Expand All @@ -247,28 +247,25 @@ struct __parallel_transform_reduce_work_group_kernel_submitter<
}
}

__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp> __res_container(__exec, 1);

__reduce_event = __exec.queue().submit([&, __n](sycl::handler& __cgh) {
__cgh.depends_on(__reduce_event);

sycl::accessor __temp_acc{__temp, __cgh, sycl::read_only};
auto __res_acc = __res_container.__get_acc(__cgh);
auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh);
auto __res_acc = __scratch_container.__get_result_acc(__cgh);
::std::size_t __local_mem_size = __reduce_pattern.local_mem_req(__work_group_size2);
__dpl_sycl::__local_accessor<_Tp> __temp_local(sycl::range<1>(__local_mem_size), __cgh);

__cgh.parallel_for<_KernelName...>(
sycl::nd_range<1>(sycl::range<1>(__work_group_size2), sycl::range<1>(__work_group_size2)),
[=](sycl::nd_item<1> __item_id) {
auto __res_ptr =
__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp>::__get_usm_host_or_buffer_accessor_ptr(
__res_acc);
auto __temp_ptr = __temp_acc.__get_pointer();
auto __res_ptr = __res_acc.__get_pointer();
__work_group_reduce_kernel<_Tp>(__item_id, __n, __transform_pattern, __reduce_pattern, __init,
__temp_local, __res_ptr, __temp_acc);
__temp_local, __res_ptr, __temp_ptr);
});
});

return __future(__reduce_event, __res_container);
return __future(__reduce_event, __scratch_container);
}
}; // struct __parallel_transform_reduce_work_group_kernel_submitter

Expand All @@ -293,19 +290,19 @@ __parallel_transform_reduce_mid_impl(oneapi::dpl::__internal::__device_backend_t
// number of buffer elements processed within workgroup
constexpr _Size __size_per_work_group = __iters_per_work_item_device_kernel * __work_group_size;
const _Size __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_work_group);
sycl::buffer<_Tp> __temp{sycl::range<1>(__n_groups)};
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, __n_groups);

sycl::event __reduce_event =
__parallel_transform_reduce_device_kernel_submitter<_Tp, __work_group_size, __iters_per_work_item_device_kernel,
_Commutative, _ReduceDeviceKernel>()(
__backend_tag, __exec, __n, __reduce_op, __transform_op, __init, __temp,
__backend_tag, __exec, __n, __reduce_op, __transform_op, __init, __scratch_container,
::std::forward<_Ranges>(__rngs)...);

__n = __n_groups; // Number of preliminary results from the device kernel.
return __parallel_transform_reduce_work_group_kernel_submitter<
_Tp, __work_group_size, __iters_per_work_item_work_group_kernel, _Commutative, _ReduceWorkGroupKernel>()(
__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), __reduce_event, __n, __reduce_op, __transform_op,
__init, __temp);
__init, __scratch_container);
}

// General implementation using a tree reduction
Expand Down Expand Up @@ -343,8 +340,8 @@ struct __parallel_transform_reduce_impl
_Size __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_work_group);

// Create temporary global buffers to store temporary values
sycl::buffer<_Tp> __temp(sycl::range<1>(2 * __n_groups));
__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp> __res_container(__exec, 1);
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, 2 * __n_groups);

// __is_first == true. Reduce over each work_group
// __is_first == false. Reduce between work groups
bool __is_first = true;
Expand All @@ -360,11 +357,11 @@ struct __parallel_transform_reduce_impl
__reduce_event = __exec.queue().submit([&, __is_first, __offset_1, __offset_2, __n,
__n_groups](sycl::handler& __cgh) {
__cgh.depends_on(__reduce_event);
auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh);
auto __res_acc = __scratch_container.__get_result_acc(__cgh);

// get an access to data under SYCL buffer
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...);
sycl::accessor __temp_acc{__temp, __cgh, sycl::read_write};
auto __res_acc = __res_container.__get_acc(__cgh);
::std::size_t __local_mem_size = __reduce_pattern.local_mem_req(__work_group_size);
__dpl_sycl::__local_accessor<_Tp> __temp_local(sycl::range<1>(__local_mem_size), __cgh);
#if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT
Expand All @@ -377,9 +374,8 @@ struct __parallel_transform_reduce_impl
sycl::nd_range<1>(sycl::range<1>(__n_groups * __work_group_size),
sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __res_ptr =
__usm_host_or_buffer_storage<_ExecutionPolicy, _Tp>::__get_usm_host_or_buffer_accessor_ptr(
__res_acc);
auto __temp_ptr = __temp_acc.__get_pointer();
auto __res_ptr = __res_acc.__get_pointer();
auto __local_idx = __item_id.get_local_id(0);
auto __group_idx = __item_id.get_group(0);
// 1. Initialization (transform part). Fill local memory
Expand All @@ -397,7 +393,7 @@ struct __parallel_transform_reduce_impl
}
else
{
__transform_pattern2(__item_id, __n, __offset_2, __result, __temp_acc);
__transform_pattern2(__item_id, __n, __offset_2, __result, __temp_ptr);
__n_items = __transform_pattern2.output_size(__n, __work_group_size);
}
// 2. Reduce within work group using local memory
Expand All @@ -411,7 +407,7 @@ struct __parallel_transform_reduce_impl
__res_ptr[0] = __result.__v;
}

__temp_acc[__offset_1 + __group_idx] = __result.__v;
__temp_ptr[__offset_1 + __group_idx] = __result.__v;
}
__result.__v.~_Tp();
});
Expand All @@ -422,7 +418,7 @@ struct __parallel_transform_reduce_impl
__n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_work_group);
} while (__n > 1);

return __future(__reduce_event, __res_container);
return __future(__reduce_event, __scratch_container);
}
}; // struct __parallel_transform_reduce_impl

Expand Down
Loading
Loading