-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize merge
algorithm for data sizes equal or greater then 4M items
#1933
base: main
Are you sure you want to change the base?
Changes from all commits
a2e142d
b33656a
d4721ca
93fd2e8
82167d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,39 +45,143 @@ namespace __par_backend_hetero | |
// | ----> | ||
// 3 | 0 0 0 0 0 | | ||
template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare> | ||
auto | ||
__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1, | ||
const _Index __n2, _Compare __comp) | ||
std::pair<_Index, _Index> | ||
__find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rng1_to, const _Rng2& __rng2, | ||
const _Index __rng2_from, _Index __rng2_to, const _Index __i_elem, _Compare __comp) | ||
{ | ||
//searching for the first '1', a lower bound for a diagonal [0, 0,..., 0, 1, 1,.... 1, 1] | ||
oneapi::dpl::counting_iterator<_Index> __diag_it(0); | ||
assert(__rng1_from <= __rng1_to); | ||
assert(__rng2_from <= __rng2_to); | ||
|
||
assert(__rng1_to > 0 || __rng2_to > 0); | ||
|
||
if constexpr (!std::is_pointer_v<_Rng1>) | ||
assert(__rng1_to <= __rng1.size()); | ||
if constexpr (!std::is_pointer_v<_Rng2>) | ||
assert(__rng2_to <= __rng2.size()); | ||
|
||
assert(__i_elem >= 0); | ||
|
||
if (__i_elem < __n2) //a condition to specify upper or lower part of the merge matrix to be processed | ||
// ----------------------- EXAMPLE ------------------------ | ||
// Let's consider the following input data: | ||
// rng1.size() = 10 | ||
// rng2.size() = 6 | ||
// i_diag = 9 | ||
// Let's define the following ranges for processing: | ||
// rng1: [3, ..., 9) -> __rng1_from = 3, __rng1_to = 9 | ||
// rng2: [1, ..., 4) -> __rng2_from = 1, __rng2_to = 4 | ||
// | ||
// The goal: required to process only X' items of the merge matrix | ||
// as intersection of rng1[3, ..., 9) and rng2[1, ..., 4) | ||
// | ||
// -------------------------------------------------------- | ||
// | ||
// __diag_it_begin(rng1) __diag_it_end(rng1) | ||
// (init state) (dest state) (init state, dest state) | ||
// | | | | ||
// V V V | ||
// + + + + + + | ||
// \ rng1 0 1 2 3 4 5 6 7 8 9 | ||
// rng2 +--------------------------------------+ | ||
// 0 | ^ ^ ^ X | <--- __diag_it_end(rng2) (init state) | ||
// + 1 | <----------------- + + X'2 ^ | <--- __diag_it_end(rng2) (dest state) | ||
// + 2 | <----------------- + X'1 | | | ||
// + 3 | <----------------- X'0 | | <--- __diag_it_begin(rng2) (dest state) | ||
// 4 | X ^ | | | ||
// 5 | X | | | <--- __diag_it_begin(rng2) (init state) | ||
// +-------AX-----------+-----------+-----+ | ||
// AX | | | ||
// AX | | | ||
// Run lower_bound:[from = 5, to = 8) | ||
// | ||
// AX - absent items in rng2 | ||
// | ||
// We have three points on diagonal for call comparison: | ||
// X'0 : call __comp(rng1[5], rng2[3]) // 5 + 3 == 9 - 1 == 8 | ||
// X'1 : call __comp(rng1[6], rng2[2]) // 6 + 2 == 9 - 1 == 8 | ||
// X'3 : call __comp(rng1[7], rng2[1]) // 7 + 1 == 9 - 1 == 8 | ||
// - where for every comparing pairs idx(rng1) + idx(rng2) == i_diag - 1 | ||
|
||
//////////////////////////////////////////////////////////////////////////////////// | ||
// Process the corner case: for the first diagonal with the index 0 split point | ||
// is equal to (0, 0) regardless of the size and content of the data. | ||
if (__i_elem > 0) | ||
{ | ||
const _Index __q = __i_elem; //diagonal index | ||
const _Index __n_diag = std::min<_Index>(__q, __n1); //diagonal size | ||
auto __res = | ||
std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/, | ||
[&__rng2, &__rng1, __q, __comp](const auto& __i_diag, const auto& __value) mutable { | ||
const auto __zero_or_one = __comp(__rng2[__q - __i_diag - 1], __rng1[__i_diag]); | ||
return __zero_or_one < __value; | ||
}); | ||
return std::make_pair(*__res, __q - *__res); | ||
//////////////////////////////////////////////////////////////////////////////////// | ||
// Taking into account the specified constraints of the range of processed data | ||
const auto __index_sum = __i_elem - 1; | ||
|
||
using _IndexSigned = std::make_signed_t<_Index>; | ||
|
||
_IndexSigned idx1_from = __rng1_from; | ||
_IndexSigned idx1_to = __rng1_to; | ||
assert(idx1_from <= idx1_to); | ||
|
||
_IndexSigned idx2_from = __index_sum - (__rng1_to - 1); | ||
_IndexSigned idx2_to = __index_sum - __rng1_from + 1; | ||
assert(idx2_from <= idx2_to); | ||
|
||
const _IndexSigned idx2_from_diff = | ||
idx2_from < (_IndexSigned)__rng2_from ? (_IndexSigned)__rng2_from - idx2_from : 0; | ||
const _IndexSigned idx2_to_diff = idx2_to > (_IndexSigned)__rng2_to ? idx2_to - (_IndexSigned)__rng2_to : 0; | ||
|
||
idx1_to -= idx2_from_diff; | ||
idx1_from += idx2_to_diff; | ||
|
||
idx2_from = __index_sum - (idx1_to - 1); | ||
idx2_to = __index_sum - idx1_from + 1; | ||
|
||
assert(idx1_from <= idx1_to); | ||
assert(__rng1_from <= idx1_from && idx1_to <= __rng1_to); | ||
|
||
assert(idx2_from <= idx2_to); | ||
assert(__rng2_from <= idx2_from && idx2_to <= __rng2_to); | ||
|
||
//////////////////////////////////////////////////////////////////////////////////// | ||
// Run search of split point on diagonal | ||
|
||
using __it_t = oneapi::dpl::counting_iterator<_Index>; | ||
|
||
__it_t __diag_it_begin(idx1_from); | ||
__it_t __diag_it_end(idx1_to); | ||
|
||
constexpr int kValue = 1; | ||
const __it_t __res = | ||
std::lower_bound(__diag_it_begin, __diag_it_end, kValue, [&](_Index __idx, const auto& __value) { | ||
const auto __rng1_idx = __idx; | ||
const auto __rng2_idx = __index_sum - __idx; | ||
|
||
assert(__rng1_from <= __rng1_idx && __rng1_idx < __rng1_to); | ||
assert(__rng2_from <= __rng2_idx && __rng2_idx < __rng2_to); | ||
assert(__rng1_idx + __rng2_idx == __index_sum); | ||
|
||
const auto __zero_or_one = __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]); | ||
return __zero_or_one < kValue; | ||
}); | ||
|
||
const std::pair<_Index, _Index> __result = std::make_pair(*__res, __index_sum - *__res + 1); | ||
assert(__result.first + __result.second == __i_elem); | ||
|
||
assert(__rng1_from <= __result.first && __result.first <= __rng1_to); | ||
assert(__rng2_from <= __result.second && __result.second <= __rng2_to); | ||
|
||
return __result; | ||
} | ||
else | ||
{ | ||
const _Index __q = __i_elem - __n2; //diagonal index | ||
const _Index __n_diag = std::min<_Index>(__n1 - __q, __n2); //diagonal size | ||
auto __res = | ||
std::lower_bound(__diag_it, __diag_it + __n_diag, 1 /*value to find*/, | ||
[&__rng2, &__rng1, __n2, __q, __comp](const auto& __i_diag, const auto& __value) mutable { | ||
const auto __zero_or_one = __comp(__rng2[__n2 - __i_diag - 1], __rng1[__q + __i_diag]); | ||
return __zero_or_one < __value; | ||
}); | ||
return std::make_pair(__q + *__res, __n2 - *__res); | ||
assert(__rng1_from == 0); | ||
assert(__rng2_from == 0); | ||
return std::make_pair(__rng1_from, __rng2_from); | ||
} | ||
} | ||
|
||
template <typename _Rng1, typename _Rng2, typename _Index, typename _Compare> | ||
std::pair<_Index, _Index> | ||
__find_start_point(const _Rng1& __rng1, const _Rng2& __rng2, const _Index __i_elem, const _Index __n1, | ||
const _Index __n2, _Compare __comp) | ||
{ | ||
return __find_start_point_in(__rng1, (_Index)0, __n1, __rng2, (_Index)0, __n2, __i_elem, __comp); | ||
} | ||
|
||
// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing | ||
// to rng3 (starting from start3) in 'chunk' steps, but do not exceed the total size of the sequences (n1 and n2) | ||
template <typename _Rng1, typename _Rng2, typename _Rng3, typename _Index, typename _Compare> | ||
|
@@ -133,6 +237,9 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index _ | |
template <typename _IdType, typename _Name> | ||
struct __parallel_merge_submitter; | ||
|
||
template <typename _IdType, typename _CustomName, typename _Name> | ||
struct __parallel_merge_submitter_large; | ||
|
||
template <typename _IdType, typename... _Name> | ||
struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_Name...>> | ||
{ | ||
|
@@ -166,9 +273,119 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N | |
} | ||
}; | ||
|
||
template <typename... _Name> | ||
class _find_split_points_kernel_on_mid_diagonal_uint32_t; | ||
|
||
template <typename... _Name> | ||
class _find_split_points_kernel_on_mid_diagonal_uint64_t; | ||
|
||
template <typename _IdType, typename _CustomName, typename... _Name> | ||
struct __parallel_merge_submitter_large<_IdType, _CustomName, __internal::__optional_kernel_name<_Name...>> | ||
{ | ||
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare> | ||
auto | ||
operator()(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _Range3&& __rng3, _Compare __comp) const | ||
{ | ||
const _IdType __n1 = __rng1.size(); | ||
const _IdType __n2 = __rng2.size(); | ||
const _IdType __n = __n1 + __n2; | ||
|
||
assert(__n1 > 0 || __n2 > 0); | ||
|
||
_PRINT_INFO_IN_DEBUG_MODE(__exec); | ||
|
||
using _FindSplitPointsOnMidDiagonalKernel = | ||
std::conditional_t<std::is_same_v<_IdType, std::uint32_t>, | ||
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< | ||
_find_split_points_kernel_on_mid_diagonal_uint32_t, _CustomName, _ExecutionPolicy, | ||
_Range1, _Range2, _Range3, _Compare>, | ||
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator< | ||
_find_split_points_kernel_on_mid_diagonal_uint64_t, _CustomName, _ExecutionPolicy, | ||
_Range1, _Range2, _Range3, _Compare>>; | ||
|
||
// Empirical number of values to process per work-item | ||
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4; | ||
|
||
const _IdType __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk); | ||
const _IdType __base_diag_count = 1'024 * 32; | ||
const _IdType __base_diag_part = oneapi::dpl::__internal::__dpl_ceiling_div(__steps, __base_diag_count); | ||
|
||
using _split_point_t = std::pair<_IdType, _IdType>; | ||
|
||
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _split_point_t>; | ||
__result_and_scratch_storage_t __result_and_scratch{__exec, 0, __base_diag_count + 1}; | ||
|
||
sycl::event __event = __exec.queue().submit([&](sycl::handler& __cgh) { | ||
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); | ||
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::write>( | ||
__cgh, __dpl_sycl::__no_init{}); | ||
|
||
__cgh.parallel_for<_FindSplitPointsOnMidDiagonalKernel>( | ||
sycl::range</*dim=*/1>(__base_diag_count + 1), [=](sycl::item</*dim=*/1> __item_id) { | ||
auto __global_idx = __item_id.get_linear_id(); | ||
auto __scratch_ptr = | ||
__result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc); | ||
|
||
if (__global_idx == 0) | ||
{ | ||
__scratch_ptr[0] = std::make_pair((_IdType)0, (_IdType)0); | ||
} | ||
else if (__global_idx == __base_diag_count) | ||
{ | ||
__scratch_ptr[__base_diag_count] = std::make_pair(__n1, __n2); | ||
} | ||
else | ||
{ | ||
const _IdType __i_elem = __global_idx * __base_diag_part * __chunk; | ||
__scratch_ptr[__global_idx] = __find_start_point(__rng1, __rng2, __i_elem, __n1, __n2, __comp); | ||
} | ||
}); | ||
}); | ||
|
||
__event = __exec.queue().submit([&](sycl::handler& __cgh) { | ||
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2, __rng3); | ||
auto __scratch_acc = __result_and_scratch.template __get_scratch_acc<sycl::access_mode::read>(__cgh); | ||
|
||
__cgh.depends_on(__event); | ||
|
||
__cgh.parallel_for<_Name...>(sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) { | ||
auto __global_idx = __item_id.get_linear_id(); | ||
const _IdType __i_elem = __global_idx * __chunk; | ||
|
||
auto __scratch_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__scratch_acc); | ||
auto __scratch_idx = __global_idx / __base_diag_part; | ||
|
||
_split_point_t __start; | ||
if (__global_idx % __base_diag_part != 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed offline about the approach to partition based on SLM size and then work within the partitioned blocks in the second kernel. One advantage of this method (beyond working within SLM for all diagonals in this kernel) would be that there would be no work-item divergence with a branch and mod operation like this. The first partitioning kernel would be lightweight and basically only to establish bounds for the second kernel. Then the second kernel would work within SLM loaded data and search for all diagonals within that block then serial merge, and all work items could be the same (with possible exception for the zeroth work item). |
||
{ | ||
// Check that we fit into size of scratch | ||
assert(__scratch_idx + 1 < __base_diag_count + 1); | ||
|
||
const _split_point_t __sp_left = __scratch_ptr[__scratch_idx]; | ||
const _split_point_t __sp_right = __scratch_ptr[__scratch_idx + 1]; | ||
|
||
__start = __find_start_point_in(__rng1, __sp_left.first, __sp_right.first, __rng2, __sp_left.second, | ||
__sp_right.second, __i_elem, __comp); | ||
} | ||
else | ||
{ | ||
__start = __scratch_ptr[__scratch_idx]; | ||
} | ||
|
||
__serial_merge(__rng1, __rng2, __rng3, __start.first, __start.second, __i_elem, __chunk, __n1, __n2, | ||
__comp); | ||
}); | ||
}); | ||
return __future(__event); | ||
} | ||
}; | ||
|
||
template <typename... _Name> | ||
class __merge_kernel_name; | ||
|
||
template <typename... _Name> | ||
class __merge_kernel_name_large; | ||
|
||
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare> | ||
auto | ||
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1, | ||
|
@@ -177,23 +394,47 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy | |
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; | ||
|
||
const auto __n = __rng1.size() + __rng2.size(); | ||
if (__n <= std::numeric_limits<std::uint32_t>::max()) | ||
if (__n < 4 * 1'048'576) | ||
{ | ||
using _WiIndex = std::uint32_t; | ||
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
if (__n <= std::numeric_limits<std::uint32_t>::max()) | ||
{ | ||
using _WiIndex = std::uint32_t; | ||
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
} | ||
else | ||
{ | ||
using _WiIndex = std::uint64_t; | ||
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
} | ||
} | ||
else | ||
{ | ||
using _WiIndex = std::uint64_t; | ||
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
if (__n <= std::numeric_limits<std::uint32_t>::max()) | ||
{ | ||
using _WiIndex = std::uint32_t; | ||
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name_large<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
} | ||
else | ||
{ | ||
using _WiIndex = std::uint64_t; | ||
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< | ||
__merge_kernel_name_large<_CustomName, _WiIndex>>; | ||
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()( | ||
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2), | ||
std::forward<_Range3>(__rng3), __comp); | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rarutyun I have fixed the error here. Is it correct way?
I am using
__kernel_name_generator
here because I should have two Kernel names: one passed as template parameter pack and the second name I should create inside.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't yet looked at this in detail, but can't we just pass the
_IdType
to__kernel_name_generator
directly, and use a single_find_split_points_kernel_on_mid_diagonal
type?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will discuss with @rarutyun but as far as I understood it's not possible because we should have mandatory type pack as template param for the class like
And by this reason we can't specialize it with something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, this is actually an interesting one because the type differs based on a runtime parameter rather than the types the user provides, and so we need both for a single public API call in the code. I believe I see why its necessary and we cant just add it to the pack now.
Maybe something like this to make it cleaner / more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this approach has compile errors: discussed offline.