Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize merge algorithm for data sizes equal or greater then 4M items #1933

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 278 additions & 37 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,39 +45,143 @@ namespace __par_backend_hetero
// | ---->
// 3 | 0 0 0 0 0 |
template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare>
auto
__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1,
const _Index __n2, _Compare __comp)
std::pair<_Index, _Index>
__find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_to, const _Rng2& __rng2,
const _Index __rng2_from, _Index __rng2_to, const _Index __i_elem, _Compare __comp)
{
//searching for the first '1', a lower bound for a diagonal [0, 0,..., 0, 1, 1,.... 1, 1]
oneapi::dpl::counting_iterator<_Index> __diag_it(0);
assert(__rng1_from <= __rng1_to);
assert(__rng2_from <= __rng2_to);

assert(__rng1_to > 0 || __rng2_to > 0);

if constexpr (!std::is_pointer_v<_Rng1>)
assert(__rng1_to <= __rng1.size());
if constexpr (!std::is_pointer_v<_Rng2>)
assert(__rng2_to <= __rng2.size());

assert(__i_elem >= 0);

if (__i_elem < __n2) //a condition to specify upper or lower part of the merge matrix to be processed
// ----------------------- EXAMPLE ------------------------
// Let's consider the following input data:
// rng1.size() = 10
// rng2.size() = 6
// i_diag = 9
// Let's define the following ranges for processing:
// rng1: [3, ..., 9) -> __rng1_from = 3, __rng1_to = 9
// rng2: [1, ..., 4) -> __rng2_from = 1, __rng2_to = 4
//
// The goal: required to process only X' items of the merge matrix
// as intersection of rng1[3, ..., 9) and rng2[1, ..., 4)
//
// --------------------------------------------------------
//
// __diag_it_begin(rng1) __diag_it_end(rng1)
// (init state) (dest state) (init state, dest state)
// | | |
// V V V
// + + + + + +
// \ rng1 0 1 2 3 4 5 6 7 8 9
// rng2 +--------------------------------------+
// 0 | ^ ^ ^ X | <--- __diag_it_end(rng2) (init state)
// + 1 | <----------------- + + X'2 ^ | <--- __diag_it_end(rng2) (dest state)
// + 2 | <----------------- + X'1 | |
// + 3 | <----------------- X'0 | | <--- __diag_it_begin(rng2) (dest state)
// 4 | X ^ | |
// 5 | X | | | <--- __diag_it_begin(rng2) (init state)
// +-------AX-----------+-----------+-----+
// AX | |
// AX | |
// Run lower_bound:[from = 5, to = 8)
//
// AX - absent items in rng2
//
// We have three points on diagonal for call comparison:
// X'0 : call __comp(rng1[5], rng2[3]) // 5 + 3 == 9 - 1 == 8
// X'1 : call __comp(rng1[6], rng2[2]) // 6 + 2 == 9 - 1 == 8
// X'3 : call __comp(rng1[7], rng2[1]) // 7 + 1 == 9 - 1 == 8
// - where for every comparing pairs idx(rng1) + idx(rng2) == i_diag - 1

////////////////////////////////////////////////////////////////////////////////////
// Process the corner case: for the first diagonal with the index 0 split point
// is equal to (0, 0) regardless of the size and content of the data.
if (__i_elem > 0)
{
const _Index __q = __i_elem; //diagonal index
const _Index __n_diag = std::min<_Index>(__q, __n1); //diagonal size
auto __res =
std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/,
[&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const auto& __value) mutable {
const auto __zero_or_one = __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]);
return __zero_or_one < __value;
});
return std::make_pair(*__res, __q - *__res);
////////////////////////////////////////////////////////////////////////////////////
// Taking into account the specified constraints of the range of processed data
const auto __index_sum = __i_elem - 1;

using _IndexSigned = std::make_signed_t<_Index>;

_IndexSigned idx1_from = __rng1_from;
_IndexSigned idx1_to = __rng1_to;
assert(idx1_from <= idx1_to);

_IndexSigned idx2_from = __index_sum - (__rng1_to - 1);
_IndexSigned idx2_to = __index_sum - __rng1_from + 1;
assert(idx2_from <= idx2_to);

const _IndexSigned idx2_from_diff =
idx2_from < (_IndexSigned)__rng2_from ? (_IndexSigned)__rng2_from - idx2_from : 0;
const _IndexSigned idx2_to_diff = idx2_to > (_IndexSigned)__rng2_to ? idx2_to - (_IndexSigned)__rng2_to : 0;

idx1_to -= idx2_from_diff;
idx1_from += idx2_to_diff;

idx2_from = __index_sum - (idx1_to - 1);
idx2_to = __index_sum - idx1_from + 1;

assert(idx1_from <= idx1_to);
assert(__rng1_from <= idx1_from && idx1_to <= __rng1_to);

assert(idx2_from <= idx2_to);
assert(__rng2_from <= idx2_from && idx2_to <= __rng2_to);

////////////////////////////////////////////////////////////////////////////////////
// Run search of split point on diagonal

using __it_t = oneapi::dpl::counting_iterator<_Index>;

__it_t __diag_it_begin(idx1_from);
__it_t __diag_it_end(idx1_to);

constexpr int kValue = 1;
const __it_t __res =
std::lower_bound(__diag_it_begin, __diag_it_end, kValue, [&](_Index __idx, const auto& __value) {
const auto __rng1_idx = __idx;
const auto __rng2_idx = __index_sum - __idx;

assert(__rng1_from <= __rng1_idx && __rng1_idx < __rng1_to);
assert(__rng2_from <= __rng2_idx && __rng2_idx < __rng2_to);
assert(__rng1_idx + __rng2_idx == __index_sum);

const auto __zero_or_one = __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]);
return __zero_or_one < kValue;
});

const std::pair<_Index, _Index> __result = std::make_pair(*__res, __index_sum - *__res + 1);
assert(__result.first + __result.second == __i_elem);

assert(__rng1_from <= __result.first && __result.first <= __rng1_to);
assert(__rng2_from <= __result.second && __result.second <= __rng2_to);

return __result;
}
else
{
const _Index __q = __i_elem - __n2; //diagonal index
const _Index __n_diag = std::min<_Index>(__n1 - __q, __n2); //diagonal size
auto __res =
std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/,
[&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const auto& __value) mutable {
const auto __zero_or_one = __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]);
return __zero_or_one < __value;
});
return std::make_pair(__q + *__res, __n2 - *__res);
assert(__rng1_from == 0);
assert(__rng2_from == 0);
return std::make_pair(__rng1_from, __rng2_from);
}
}

template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare>
std::pair<_Index, _Index>
__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1,
const _Index __n2, _Compare __comp)
{
return __find_start_point_in(__rng1, (_Index)0, __n1, __rng2, (_Index)0, __n2, __i_elem, __comp);
}

// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2)
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare>
Expand Down Expand Up @@ -133,6 +237,9 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index _
template <typename _IdType, typename _Name>
struct __parallel_merge_submitter;

template <typename _IdType, typename _CustomName, typename _Name>
struct __parallel_merge_submitter_large;

template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_Name...>>
{
Expand Down Expand Up @@ -166,9 +273,119 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
}
};

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal_uint32_t;

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal_uint64_t;

template <typename _IdType, typename _CustomName, typename... _Name>
struct __parallel_merge_submitter_large<_IdType, _CustomName, __internal::__optional_kernel_name<_Name...>>
{
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
operator()(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const
{
const _IdType __n1 = __rng1.size();
const _IdType __n2 = __rng2.size();
const _IdType __n = __n1 + __n2;

assert(__n1 > 0 || __n2 > 0);

_PRINT_INFO_IN_DEBUG_MODE(__exec);

using _FindSplitPointsOnMidDiagonalKernel =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rarutyun I have fixed the error here. Is it correct way?
I am using __kernel_name_generator here because I should have two Kernel names: one passed as template parameter pack and the second name I should create inside.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't yet looked at this in detail, but can't we just pass the _IdType to __kernel_name_generator directly, and use a single _find_split_points_kernel_on_mid_diagonal type?

Copy link
Contributor Author

@SergeyKopienko SergeyKopienko Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will discuss with @rarutyun but as far as I understood it's not possible because we should have mandatory type pack as template param for the class like

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal;

And by this reason we can't specialize it with something else.

Copy link
Contributor

@danhoeflinger danhoeflinger Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, this is actually an interesting one because the type differs based on a runtime parameter rather than the types the user provides, and so we need both for a single public API call in the code. I believe I see why its necessary and we cant just add it to the pack now.

Maybe something like this to make it cleaner / more clear.

Suggested change
using _FindSplitPointsOnMidDiagonalKernel =
// Runtime decision to pick the _IdType, so we need both kernels for a single public call, and must embed that in the base name
using _KernelNameWithType = std::conditional_t<std::is_same_v<_IdType, std::uint32_t>, _find_split_points_kernel_on_mid_diagonal_uint32_t, _find_split_points_kernel_on_mid_diagonal_uint64_t>;
using _FindSplitPointsOnMidDiagonalKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<_KernelNameWithType,_CustomName, _ExecutionPolicy,
_Range1, _Range2, _Range3, _Compare>;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this approach has compile errors: discussed offline.

std::conditional_t<std::is_same_v<_IdType, std::uint32_t>,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal_uint32_t, _CustomName, _ExecutionPolicy,
_Range1, _Range2, _Range3, _Compare>,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal_uint64_t, _CustomName, _ExecutionPolicy,
_Range1, _Range2, _Range3, _Compare>>;

// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;

const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);
const _IdType __base_diag_count = 1'024 * 32;
const _IdType __base_diag_part = oneapi::dpl::__internal::__dpl_ceiling_div(__steps, __base_diag_count);

using _split_point_t = std::pair<_IdType, _IdType>;

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _split_point_t>;
__result_and_scratch_storage_t __result_and_scratch{__exec, 0, __base_diag_count + 1};

sycl::event __event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2);
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::write>(
__cgh, __dpl_sycl::__no_init{});

__cgh.parallel_for<_FindSplitPointsOnMidDiagonalKernel>(
sycl::range</*dim=*/1>(__base_diag_count + 1), [=](sycl::item</*dim=*/1> __item_id) {
auto __global_idx = __item_id.get_linear_id();
auto __scratch_ptr =
__result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc);

if (__global_idx == 0)
{
__scratch_ptr[0] = std::make_pair((_IdType)0, (_IdType)0);
}
else if (__global_idx == __base_diag_count)
{
__scratch_ptr[__base_diag_count] = std::make_pair(__n1, __n2);
}
else
{
const _IdType __i_elem = __global_idx * __base_diag_part * __chunk;
__scratch_ptr[__global_idx] = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp);
}
});
});

__event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3);
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::read>(__cgh);

__cgh.depends_on(__event);

__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
auto __global_idx = __item_id.get_linear_id();
const _IdType __i_elem = __global_idx * __chunk;

auto __scratch_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc);
auto __scratch_idx = __global_idx / __base_diag_part;

_split_point_t __start;
if (__global_idx % __base_diag_part != 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed offline about the approach to partition based on SLM size and then work within the partitioned blocks in the second kernel.

One advantage of this method (beyond working within SLM for all diagonals in this kernel) would be that there would be no work-item divergence with a branch and mod operation like this. The first partitioning kernel would be lightweight and basically only to establish bounds for the second kernel. Then the second kernel would work within SLM loaded data and search for all diagonals within that block then serial merge, and all work items could be the same (with possible exception for the zeroth work item).

{
// Check that we fit into size of scratch
assert(__scratch_idx + 1 < __base_diag_count + 1);

const _split_point_t __sp_left = __scratch_ptr[__scratch_idx];
const _split_point_t __sp_right = __scratch_ptr[__scratch_idx + 1];

__start = __find_start_point_in(__rng1, __sp_left.first, __sp_right.first, __rng2, __sp_left.second,
__sp_right.second, __i_elem, __comp);
}
else
{
__start = __scratch_ptr[__scratch_idx];
}

__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2,
__comp);
});
});
return __future(__event);
}
};

template <typename... _Name>
class __merge_kernel_name;

template <typename... _Name>
class __merge_kernel_name_large;

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1,
Expand All @@ -177,23 +394,47 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;

const auto __n = __rng1.size() + __rng2.size();
if (__n <= std::numeric_limits<std::uint32_t>::max())
if (__n < 4 * 1'048'576)
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
}

Expand Down
Loading