diff --git a/CHANGELOG.md b/CHANGELOG.md index 6255fe97e8..97d89fd899 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of copying operation to C-/F-contig array, with optimization for batch of square matrices [gh-1850](https://github.com/IntelPython/dpctl/pull/1850) * Improved performance of `tensor.argsort` function for all types [gh-1859](https://github.com/IntelPython/dpctl/pull/1859) * Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866) +* Implement radix sort algorithm to be used in `dpt.sort` and `dpt.argsort` [gh-1867](https://github.com/IntelPython/dpctl/pull/1867) ### Fixed * Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877) diff --git a/dpctl/tensor/CMakeLists.txt b/dpctl/tensor/CMakeLists.txt index 31d4eba03d..2bc811a1c9 100644 --- a/dpctl/tensor/CMakeLists.txt +++ b/dpctl/tensor/CMakeLists.txt @@ -112,10 +112,14 @@ set(_reduction_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp ) set(_sorting_sources - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp ) +set(_sorting_radix_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp +) set(_static_lib_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp ) @@ -151,6 +155,10 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_tensor_sorting_radix_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp + ${_sorting_radix_sources} +) set(_linalg_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp @@ -160,10 +168,10 @@ set(_tensor_linalg_impl_sources ${_linalg_sources} ) set(_accumulator_sources -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp -${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp ) set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp @@ -205,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_sorting_radix_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(python_module_name _tensor_linalg_impl) pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) diff --git a/dpctl/tensor/_sorting.py b/dpctl/tensor/_sorting.py index d9b22cff3e..d5026a6ee8 100644 --- a/dpctl/tensor/_sorting.py +++ b/dpctl/tensor/_sorting.py @@ -25,11 +25,26 @@ _sort_ascending, _sort_descending, ) +from ._tensor_sorting_radix_impl import ( + _radix_argsort_ascending, + _radix_argsort_descending, + _radix_sort_ascending, + _radix_sort_descending, + _radix_sort_dtype_supported, +) __all__ = ["sort", "argsort"] -def sort(x, /, *, axis=-1, descending=False, stable=True): +def _get_mergesort_impl_fn(descending): + return _sort_descending if descending else _sort_ascending + + +def _get_radixsort_impl_fn(descending): + return _radix_sort_descending if descending else _radix_sort_ascending + + +def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None): """sort(x, axis=-1, descending=False, stable=True) Returns a sorted copy of an input array `x`. @@ -49,7 +64,10 @@ def sort(x, /, *, axis=-1, descending=False, stable=True): relative order of `x` values which compare as equal. If `False`, the returned array may or may not maintain the relative order of `x` values which compare as equal. Default: `True`. - + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. Returns: usm_ndarray: a sorted array. The returned array has the same data type and @@ -74,10 +92,33 @@ def sort(x, /, *, axis=-1, descending=False, stable=True): axis, ] arr = dpt.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergesort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixsort_impl_fn(descending) + else: + impl_fn = _get_mergesort_impl_fn(descending) exec_q = x.sycl_queue _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - impl_fn = _sort_descending if descending else _sort_ascending if arr.flags.c_contiguous: res = dpt.empty_like(arr, order="C") ht_ev, impl_ev = impl_fn( @@ -109,7 +150,15 @@ def sort(x, /, *, axis=-1, descending=False, stable=True): return res -def argsort(x, axis=-1, descending=False, stable=True): +def _get_mergeargsort_impl_fn(descending): + return _argsort_descending if descending else _argsort_ascending + + +def _get_radixargsort_impl_fn(descending): + return _radix_argsort_descending if descending else _radix_argsort_ascending + + +def argsort(x, axis=-1, descending=False, stable=True, kind=None): """argsort(x, axis=-1, descending=False, stable=True) Returns the indices that sort an array `x` along a specified axis. @@ -129,6 +178,10 @@ def argsort(x, axis=-1, descending=False, stable=True): relative order of `x` values which compare as equal. If `False`, the returned array may or may not maintain the relative order of `x` values which compare as equal. Default: `True`. + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. Returns: usm_ndarray: @@ -157,10 +210,33 @@ def argsort(x, axis=-1, descending=False, stable=True): axis, ] arr = dpt.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergeargsort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixargsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixargsort_impl_fn(descending) + else: + impl_fn = _get_mergeargsort_impl_fn(descending) exec_q = x.sycl_queue _manager = du.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - impl_fn = _argsort_descending if descending else _argsort_ascending index_dt = ti.default_device_index_type(exec_q) if arr.flags.c_contiguous: res = dpt.empty_like(arr, dtype=index_dt, order="C") diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp similarity index 95% rename from dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp rename to dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp index 1432139bf3..f3b5030c48 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -32,7 +32,7 @@ #include #include "kernels/dpctl_tensor_types.hpp" -#include "kernels/sorting/sort_detail.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" namespace dpctl { @@ -41,9 +41,11 @@ namespace tensor namespace kernels { -namespace sort_detail +namespace merge_sort_detail { +using namespace dpctl::tensor::kernels::search_sorted_detail; + /*! @brief Merge two contiguous sorted segments */ template void merge_impl(const std::size_t offset, @@ -699,18 +701,7 @@ merge_sorted_block_contig_impl(sycl::queue &q, return dep_ev; } -} // end of namespace sort_detail - -typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &, - size_t, - size_t, - const char *, - char *, - ssize_t, - ssize_t, - ssize_t, - ssize_t, - const std::vector &); +} // end of namespace merge_sort_detail template > sycl::event stable_sort_axis1_contig_impl( @@ -741,8 +732,8 @@ sycl::event stable_sort_axis1_contig_impl( if (sort_nelems < sequential_sorting_threshold) { // equal work-item sorts entire row sycl::event sequential_sorting_ev = - sort_detail::sort_base_step_contig_impl( + merge_sort_detail::sort_base_step_contig_impl( exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, sort_nelems, depends); @@ -753,8 +744,8 @@ sycl::event stable_sort_axis1_contig_impl( // Sort segments of the array sycl::event base_sort_ev = - sort_detail::sort_over_work_group_contig_impl( + merge_sort_detail::sort_over_work_group_contig_impl( exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, sorted_block_size, // modified in place with size of sorted // block size @@ -762,7 +753,7 @@ sycl::event stable_sort_axis1_contig_impl( // Merge segments in parallel until all elements are sorted sycl::event merges_ev = - sort_detail::merge_sorted_block_contig_impl( + merge_sort_detail::merge_sorted_block_contig_impl( exec_q, iter_nelems, sort_nelems, res_tp, comp, sorted_block_size, {base_sort_ev}); @@ -816,8 +807,7 @@ sycl::event stable_argsort_axis1_contig_impl( const IndexComp index_comp{arg_tp, ValueComp{}}; static constexpr size_t determine_automatically = 0; - size_t sorted_block_size = - (sort_nelems >= 512) ? 512 : determine_automatically; + size_t sorted_block_size = determine_automatically; const size_t total_nelems = iter_nelems * sort_nelems; @@ -837,13 +827,15 @@ sycl::event stable_argsort_axis1_contig_impl( }); // Sort segments of the array - sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl( - exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, - sorted_block_size, // modified in place with size of sorted block size - {populate_indexed_data_ev}); + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); // Merge segments in parallel until all elements are sorted - sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl( + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, {base_sort_ev}); @@ -851,7 +843,8 @@ sycl::event stable_argsort_axis1_contig_impl( cgh.depends_on(merges_ev); auto temp_acc = - sort_detail::GetReadOnlyAccess{}(res_tp, cgh); + merge_sort_detail::GetReadOnlyAccess{}(res_tp, + cgh); using KernelName = index_map_to_rows_krn; diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp new file mode 100644 index 0000000000..b578de7e2b --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -0,0 +1,1921 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +// Implementation in this file were adapted from oneDPL's radix sort +// implementation, license Apache-2.0 WITH LLVM-exception + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +namespace radix_sort_details +{ + +template +class radix_sort_count_kernel; + +template +class radix_sort_scan_kernel; + +template +class radix_sort_reorder_peer_kernel; + +template +class radix_sort_reorder_kernel; + +//---------------------------------------------------------- +// bitwise order-preserving conversions to unsigned integers +//---------------------------------------------------------- + +template bool order_preserving_cast(bool val) +{ + if constexpr (is_ascending) + return val; + else + return !val; +} + +template , int> = 0> +UIntT order_preserving_cast(UIntT val) +{ + if constexpr (is_ascending) { + return val; + } + else { + // bitwise invert + return (~val); + } +} + +template && std::is_signed_v, + int> = 0> +std::make_unsigned_t order_preserving_cast(IntT val) +{ + using UIntT = std::make_unsigned_t; + // ascending_mask: 100..0 + constexpr UIntT ascending_mask = + (UIntT(1) << std::numeric_limits::digits); + // descending_mask: 011..1 + constexpr UIntT descending_mask = (std::numeric_limits::max() >> 1); + + constexpr UIntT mask = (is_ascending) ? ascending_mask : descending_mask; + const UIntT uint_val = sycl::bit_cast(val); + + return (uint_val ^ mask); +} + +template std::uint16_t order_preserving_cast(sycl::half val) +{ + using UIntT = std::uint16_t; + + const UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() + : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 15)); + + constexpr UIntT zero_mask = UIntT(0x8000u); + constexpr UIntT nonzero_mask = UIntT(0xFFFFu); + + constexpr UIntT inv_zero_mask = static_cast(~zero_mask); + constexpr UIntT inv_nonzero_mask = static_cast(~nonzero_mask); + + if constexpr (is_ascending) { + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + } + else { + mask = (zero_fp_sign_bit) ? (inv_zero_mask) : (inv_nonzero_mask); + } + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint32_t), + int> = 0> +std::uint32_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint32_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 31)); + + constexpr UIntT zero_mask = UIntT(0x80000000u); + constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint64_t), + int> = 0> +std::uint64_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint64_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 63)); + + constexpr UIntT zero_mask = UIntT(0x8000000000000000u); + constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +//----------------- +// bucket functions +//----------------- + +template constexpr std::size_t number_of_bits_in_type() +{ + constexpr std::size_t type_bits = + (sizeof(T) * std::numeric_limits::digits); + return type_bits; +} + +// the number of buckets (size of radix bits) in T +template +constexpr std::uint32_t number_of_buckets_in_type(std::uint32_t radix_bits) +{ + constexpr std::size_t type_bits = number_of_bits_in_type(); + return (type_bits + radix_bits - 1) / radix_bits; +} + +// get bits value (bucket) in a certain radix position +template +std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset) +{ + static_assert(std::is_unsigned_v); + + return (val >> radix_offset) & T(radix_mask); +} + +//-------------------------------- +// count kernel (single iteration) +//-------------------------------- + +template +sycl::event +radix_sort_count_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::uint32_t radix_offset, + std::size_t n_values, + ValueT *vals_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const Proj &proj_op, + const bool is_ascending, + const std::vector &dependency_events) +{ + // bin_count = radix_states used for an array storing bucket state counters + constexpr std::uint32_t radix_states = (std::uint32_t(1) << radix_bits); + constexpr std::uint32_t radix_mask = radix_states - 1; + + // iteration space info + const std::size_t n = n_values; + // each segment is processed by a work-group + const std::size_t elems_per_segment = (n + n_segments - 1) / n_segments; + const std::size_t no_op_flag_id = n_counts - 1; + + assert(n_counts == (n_segments + 1) * radix_states + 1); + + sycl::event local_count_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + + sycl::local_accessor counts_lacc(wg_size * radix_states, + cgh); + + sycl::nd_range<1> ndRange(n_iters * n_segments * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + // 0 <= lid < wg_size + const std::size_t lid = ndit.get_local_id(0); + // 0 <= group_id < n_segments * n_iters + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t val_iter_offset = iter_id * n; + // 0 <= wgr_id < n_segments + const std::size_t wgr_id = group_id - iter_id * n_segments; + + const std::size_t seg_start = elems_per_segment * wgr_id; + + // count per work-item: create a private array for storing count + // values here bin_count = radix_states + std::array counts_arr = {CountT{0}}; + + // count per work-item: count values and write result to private + // count array + const std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n); + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) + { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + + // count per work-item: write private count array to local count + // array counts_lacc is concatenation of private count arrays from + // each work-item in the order of their local ids + const std::uint32_t count_start_id = radix_states * lid; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + counts_lacc[count_start_id + radix_state_id] = + counts_arr[radix_state_id]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce till count_lacc[] size > wg_size + // all work-items in the work-group do the work. + for (std::uint32_t i = 1; i < radix_states; ++i) { + // Since we interested in computing total count over work-group + // for each radix state, the correct result is only assured if + // wg_size >= radix_states + counts_lacc[lid] += counts_lacc[wg_size * i + lid]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce until count_lacc[] size > + // radix_states (n_witems /= 2 per iteration) + for (std::uint32_t n_witems = (wg_size >> 1); + n_witems >= radix_states; n_witems >>= 1) + { + if (lid < n_witems) + counts_lacc[lid] += counts_lacc[n_witems + lid]; + + sycl::group_barrier(ndit.get_group()); + } + + const std::size_t iter_counter_offset = iter_id * n_counts; + + // count per work-group: write local count array to global count + // array + if (lid < radix_states) { + // move buckets with the same id to adjacent positions, + // thus splitting count array into radix_states regions + counts_ptr[iter_counter_offset + (n_segments + 1) * lid + + wgr_id] = counts_lacc[lid]; + } + + // side work: reset 'no-operation-flag', signaling to skip re-order + // phase + if (wgr_id == 0 && lid == 0) { + CountT &no_op_flag = + counts_ptr[iter_counter_offset + no_op_flag_id]; + no_op_flag = 0; + } + }); + }); + + return local_count_ev; +} + +//----------------------------------------------------------------------- +// radix sort: scan kernel (single iteration) +//----------------------------------------------------------------------- + +template +sycl::event radix_sort_scan_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::size_t n_values, + std::size_t n_counts, + CountT *counts_ptr, + const std::vector depends) +{ + const std::size_t no_op_flag_id = n_counts - 1; + + // Scan produces local offsets using count values. + // There are no local offsets for the first segment, but the rest segments + // should be scanned with respect to the count value in the first segment + // what requires n + 1 positions + const std::size_t scan_size = n_segments + 1; + wg_size = std::min(scan_size, wg_size); + + constexpr std::uint32_t radix_states = std::uint32_t(1) << radix_bits; + + // compilation of the kernel prevents out of resources issue, which may + // occur due to usage of collective algorithms such as joint_exclusive_scan + // even if local memory is not explicitly requested + sycl::event scan_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::nd_range<1> ndRange(n_iters * radix_states * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / radix_states; + const std::size_t wgr_id = group_id - iter_id * radix_states; + // find borders of a region with a specific bucket id + auto begin_ptr = + counts_ptr + scan_size * wgr_id + iter_id * n_counts; + + sycl::joint_exclusive_scan(ndit.get_group(), begin_ptr, + begin_ptr + scan_size, begin_ptr, + CountT(0), sycl::plus{}); + + const auto lid = ndit.get_local_linear_id(); + + // NB: No race condition here, because the condition may ever be + // true for only on one WG, one WI. + if ((lid == wg_size - 1) && (begin_ptr[scan_size - 1] == n_values)) + { + // set flag, since all the values got into one + // this is optimization, may happy often for + // higher radix offsets (all zeros) + auto &no_op_flag = + counts_ptr[iter_id * n_counts + no_op_flag_id]; + no_op_flag = 1; + } + }); + }); + + return scan_ev; +} + +//----------------------------------------------------------------------- +// radix sort: group level reorder algorithms +//----------------------------------------------------------------------- + +struct empty_storage +{ + template empty_storage(T &&...) {} +}; + +// Number with `n` least significant bits of uint32_t +inline std::uint32_t n_ls_bits_set(std::uint32_t n) noexcept +{ + constexpr std::uint32_t zero{}; + constexpr std::uint32_t all_bits_set = ~zero; + + return ~(all_bits_set << n); +} + +enum class peer_prefix_algo +{ + subgroup_ballot, + atomic_fetch_or, + scan_then_broadcast +}; + +template struct peer_prefix_helper; + +template auto get_accessor_pointer(const AccT &acc) +{ + return acc.template get_multi_ptr().get(); +} + +template +struct peer_prefix_helper +{ + using AtomicT = sycl::atomic_ref; + using TempStorageT = sycl::local_accessor; + + sycl::sub_group sgroup; + std::uint32_t lid; + std::uint32_t item_mask; + AtomicT atomic_peer_mask; + + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT lacc) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_mask(n_ls_bits_set(lid)), atomic_peer_mask(lacc[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) + { + // reset mask for each radix state + if (lid == 0) + atomic_peer_mask.store(std::uint32_t{0}); + sycl::group_barrier(sgroup); + + const std::uint32_t uint_contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + // set local id's bit to 1 if the bucket value matches the radix state + atomic_peer_mask.fetch_or(uint_contrib << lid); + sycl::group_barrier(sgroup); + std::uint32_t peer_mask_bits = atomic_peer_mask.load(); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask_bits &= item_mask; + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT{0}; + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ + using TempStorageT = empty_storage; + using ItemType = sycl::nd_item<1>; + using SubGroupType = sycl::sub_group; + + SubGroupType sgroup; + std::uint32_t sg_size; + + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) + { + const std::uint32_t contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + std::uint32_t sg_item_offset = sycl::exclusive_scan_over_group( + sgroup, contrib, sycl::plus{}); + + new_offset_id |= + (wi_bit_set ? (offset_prefix + sg_item_offset) : OffsetT(0)); + + // the last scanned value does not contain number of all copies, thus + // adding contribution + std::uint32_t sg_total_offset = sycl::group_broadcast( + sgroup, sg_item_offset + contrib, sg_size - 1); + + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ +private: + sycl::ext::oneapi::sub_group_mask mask_builder(std::uint32_t mask, + std::uint32_t sg_size) + { + return sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(mask, sg_size); + } + +public: + using TempStorageT = empty_storage; + + sycl::sub_group sgroup; + std::uint32_t lid; + sycl::ext::oneapi::sub_group_mask item_sg_mask; + + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_sg_mask( + mask_builder(n_ls_bits_set(lid), sgroup.get_local_linear_range())) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) + { + // set local id's bit to 1 if the bucket value matches the radix state + auto peer_mask = sycl::ext::oneapi::group_ballot(sgroup, wi_bit_set); + std::uint32_t peer_mask_bits{}; + + peer_mask.extract_bits(peer_mask_bits); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask &= item_sg_mask; + peer_mask.extract_bits(peer_mask_bits); + + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT(0); + + return sg_total_offset; + } +}; + +template +void copy_func_for_radix_sort(const std::size_t n_segments, + const std::size_t elems_per_segment, + const std::size_t sg_size, + const std::uint32_t lid, + const std::size_t wgr_id, + const InputT *input_ptr, + const std::size_t n_values, + OutputT *output_ptr) +{ + // item info + const std::size_t seg_start = elems_per_segment * wgr_id; + + std::size_t seg_end = sycl::min(seg_start + elems_per_segment, n_values); + + // ensure that each work item in a subgroup does the same number of loop + // iterations + const std::uint16_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + // find offsets for the same values within a segment and fill the resulting + // buffer + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) + { + output_ptr[val_id] = std::move(input_ptr[val_id]); + } + + if (tail_size > 0 && lid < tail_size) { + const std::size_t val_id = seg_end + lid; + output_ptr[val_id] = std::move(input_ptr[val_id]); + } +} + +//----------------------------------------------------------------------- +// radix sort: reorder kernel (per iteration) +//----------------------------------------------------------------------- +template +sycl::event +radix_sort_reorder_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_offset, + std::size_t n_values, + const InputT *input_ptr, + OutputT *output_ptr, + std::size_t n_offsets, + OffsetT *offset_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector dependency_events) +{ + using ValueT = InputT; + using PeerHelper = peer_prefix_helper; + + constexpr std::uint32_t radix_states = std::uint32_t{1} << radix_bits; + constexpr std::uint32_t radix_mask = radix_states - 1; + const std::size_t elems_per_segment = + (n_values + n_segments - 1) / n_segments; + + const std::size_t no_op_flag_id = n_offsets - 1; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + sycl::event reorder_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + cgh.use_kernel_bundle(kb); + + using StorageT = typename PeerHelper::TempStorageT; + + StorageT peer_temp(1, cgh); + + sycl::range<1> lRange{sg_size}; + sycl::range<1> gRange{n_iters * n_segments * sg_size}; + + sycl::nd_range<1> ndRange{gRange, lRange}; + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = group_id - iter_id * n_segments; + + auto b_offset_ptr = offset_ptr + iter_id * n_offsets; + auto b_input_ptr = input_ptr + iter_id * n_values; + auto b_output_ptr = output_ptr + iter_id * n_values; + + const std::uint32_t lid = ndit.get_local_id(0); + + auto &no_op_flag = b_offset_ptr[no_op_flag_id]; + if (no_op_flag) { + // no reordering necessary, simply copy + copy_func_for_radix_sort( + n_segments, elems_per_segment, sg_size, lid, segment_id, + b_input_ptr, n_values, b_output_ptr); + return; + } + + // create a private array for storing offset values + // and add total offset and offset for compute unit + // for a certain radix state + std::array offset_arr{}; + const std::size_t scan_size = n_segments + 1; + + OffsetT scanned_bin = 0; + + /* find cumulative offset */ + constexpr std::uint32_t zero_radix_state_id = 0; + offset_arr[zero_radix_state_id] = b_offset_ptr[segment_id]; + + for (std::uint32_t radix_state_id = 1; + radix_state_id < radix_states; ++radix_state_id) + { + const std::uint32_t local_offset_id = + segment_id + scan_size * radix_state_id; + + // scan bins serially + const std::size_t last_segment_bucket_id = + radix_state_id * scan_size - 1; + scanned_bin += b_offset_ptr[last_segment_bucket_id]; + + offset_arr[radix_state_id] = + scanned_bin + b_offset_ptr[local_offset_id]; + } + + const std::size_t seg_start = elems_per_segment * segment_id; + std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n_values); + // ensure that each work item in a subgroup does the same number of + // loop iterations + const std::uint32_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + PeerHelper peer_prefix_hlp(ndit, peer_temp); + + // find offsets for the same values within a segment and fill the + // resulting buffer + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) + { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) + { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + if (tail_size > 0) { + ValueT in_val; + + // default: is greater than any actual radix state + std::uint32_t bucket_id = radix_states; + if (lid < tail_size) { + in_val = std::move(b_input_ptr[seg_end + lid]); + + const auto proj_val = proj_op(in_val); + const auto mapped_val = + (is_ascending) + ? order_preserving_cast( + proj_val) + : order_preserving_cast( + proj_val); + bucket_id = + get_bucket_id(mapped_val, radix_offset); + } + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + new_offset_id, offset_arr[radix_state_id], + is_current_bucket); + + offset_arr[radix_state_id] += sg_total_offset; + } + + if (lid < tail_size) { + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + }); + }); + + return reorder_ev; +} + +template +sizeT _slm_adjusted_work_group_size(sycl::queue &exec_q, + sizeT required_slm_bytes_per_wg, + sizeT wg_size) +{ + const auto &dev = exec_q.get_device(); + + if (wg_size == 0) + wg_size = + dev.template get_info(); + + const auto local_mem_sz = + dev.template get_info(); + + return sycl::min(local_mem_sz / required_slm_bytes_per_wg, wg_size); +} + +//----------------------------------------------------------------------- +// radix sort: one iteration +//----------------------------------------------------------------------- + +template +struct parallel_radix_sort_iteration_step +{ + template + using count_phase = radix_sort_count_kernel; + template + using local_scan_phase = radix_sort_scan_kernel; + template + using reorder_peer_phase = + radix_sort_reorder_peer_kernel; + template + using reorder_phase = radix_sort_reorder_kernel; + + template + static sycl::event submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_iter, + std::size_t n_values, + const InputT *in_ptr, + OutputT *out_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &dependency_events) + { + using _RadixCountKernel = count_phase; + using _RadixLocalScanKernel = + local_scan_phase; + using _RadixReorderPeerKernel = + reorder_peer_phase; + using _RadixReorderKernel = + reorder_phase; + + const auto &supported_sub_group_sizes = + exec_q.get_device() + .template get_info(); + const std::size_t max_sg_size = + (supported_sub_group_sizes.empty() + ? 0 + : supported_sub_group_sizes.back()); + const std::size_t reorder_sg_size = max_sg_size; + const std::size_t scan_wg_size = + exec_q.get_device() + .template get_info(); + + constexpr std::size_t two_mils = (std::size_t(1) << 21); + std::size_t count_wg_size = + ((max_sg_size > 0) && (n_values > two_mils) ? 128 : max_sg_size); + + constexpr std::uint32_t radix_states = std::uint32_t(1) << radix_bits; + + // correct count_wg_size according to local memory limit in count phase + const auto max_count_wg_size = _slm_adjusted_work_group_size( + exec_q, sizeof(CountT) * radix_states, count_wg_size); + count_wg_size = + static_cast<::std::size_t>((max_count_wg_size / radix_states)) * + radix_states; + + // work-group size must be a power of 2 and not less than the number of + // states, for scanning to work correctly + + const std::size_t rounded_down_count_wg_size = + std::size_t{1} << (number_of_bits_in_type() - + sycl::clz(count_wg_size) - 1); + count_wg_size = + sycl::max(rounded_down_count_wg_size, std::size_t(radix_states)); + + // Compute the radix position for the given iteration + std::uint32_t radix_offset = radix_iter * radix_bits; + + // 1. Count Phase + sycl::event count_ev = + radix_sort_count_submit<_RadixCountKernel, radix_bits>( + exec_q, n_iters, n_segments, count_wg_size, radix_offset, + n_values, in_ptr, n_counts, counts_ptr, proj_op, is_ascending, + dependency_events); + + // 2. Scan Phase + sycl::event scan_ev = + radix_sort_scan_submit<_RadixLocalScanKernel, radix_bits>( + exec_q, n_iters, n_segments, scan_wg_size, n_values, n_counts, + counts_ptr, {count_ev}); + + // 3. Reorder Phase + sycl::event reorder_ev{}; + if (reorder_sg_size == 8 || reorder_sg_size == 16 || + reorder_sg_size == 32) + { + constexpr auto peer_algorithm = peer_prefix_algo::subgroup_ballot; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderPeerKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + else { + constexpr auto peer_algorithm = + peer_prefix_algo::scan_then_broadcast; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + + return reorder_ev; + } +}; // struct parallel_radix_sort_iteration + +template +class radix_sort_one_wg_krn; + +template +struct subgroup_radix_sort +{ +private: + class use_slm_tag + { + }; + class use_global_mem_tag + { + }; + +public: + template + sycl::event operator()(sycl::queue &exec_q, + size_t n_iters, + size_t n_to_sort, + ValueT *input_ptr, + OutputT *output_ptr, + ProjT proj_op, + const bool is_ascending, + const std::vector &depends) + { + static_assert(std::is_same_v, OutputT>); + + using _SortKernelLoc = + radix_sort_one_wg_krn; + using _SortKernelPartGlob = + radix_sort_one_wg_krn; + using _SortKernelGlob = + radix_sort_one_wg_krn; + + constexpr std::size_t max_concurrent_work_groups = 128U; + + // Choose this to occupy the entire accelerator + const std::size_t n_work_groups = + std::min(n_iters, max_concurrent_work_groups); + + // determine which temporary allocation can be accommodated in SLM + const auto &SLM_availability = + check_slm_size(exec_q, n_to_sort); + + const std::size_t n_batch_size = n_work_groups; + + switch (SLM_availability) { + case temp_allocations::both_in_slm: + { + constexpr auto storage_for_values = use_slm_tag{}; + constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelLoc>()( + exec_q, n_iters, n_iters, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + case temp_allocations::counters_in_slm: + { + constexpr auto storage_for_values = use_global_mem_tag{}; + constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelPartGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + default: + { + constexpr auto storage_for_values = use_global_mem_tag{}; + constexpr auto storage_for_counters = use_global_mem_tag{}; + + return one_group_submitter<_SortKernelGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + } + } + +private: + template class TempBuf; + + template class TempBuf + { + const std::size_t buf_size; + + public: + TempBuf(std::size_t, std::size_t n) : buf_size(n) {} + auto get_acc(sycl::handler &cgh) + { + return sycl::local_accessor(buf_size, cgh); + } + + std::size_t get_iter_stride() const { return std::size_t{0}; } + }; + + template class TempBuf + { + sycl::buffer buf; + const std::size_t iter_stride; + + public: + TempBuf(std::size_t n_iters, std::size_t n) + : buf(n_iters * n), iter_stride(n) + { + } + auto get_acc(sycl::handler &cgh) + { + return sycl::accessor(buf, cgh, sycl::read_write, sycl::no_init); + } + std::size_t get_iter_stride() const { return iter_stride; } + }; + + static_assert(wg_size <= 1024); + static constexpr uint16_t bin_count = (1 << radix); + static constexpr uint16_t counter_buf_sz = wg_size * bin_count + 1; + + enum class temp_allocations + { + both_in_slm, + counters_in_slm, + both_in_global_mem + }; + + template + temp_allocations check_slm_size(const sycl::queue &exec_q, SizeT n) + { + // the kernel is designed for data size <= 64K + assert(n <= (SizeT(1) << 16)); + + constexpr auto req_slm_size_counters = + counter_buf_sz * sizeof(uint32_t); + + const auto &dev = exec_q.get_device(); + + // Pessimistically only use half of the memory to take into account + // a SYCL group algorithm might use a portion of SLM + const std::size_t max_slm_size = + dev.template get_info() / 2; + + const auto n_uniform = 1 << (std::uint32_t(std::log2(n - 1)) + 1); + const auto req_slm_size_val = sizeof(T) * n_uniform; + + return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size) + ? + // the values and the counters are placed in SLM + temp_allocations::both_in_slm + : (req_slm_size_counters <= max_slm_size) + ? + // the counters are placed in SLM, the values - in the + // global memory + temp_allocations::counters_in_slm + : + // the values and the counters are placed in the global + // memory + temp_allocations::both_in_global_mem; + } + + template struct one_group_submitter + { + template + sycl::event operator()(sycl::queue &exec_q, + size_t n_iters, + size_t n_batch_size, + size_t n_values, + InputT *input_arr, + OutputT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + SLM_value_tag, + SLM_counter_tag, + const std::vector &depends) + { + assert(!(n_values >> 16)); + + assert(n_values <= static_cast(block_size) * + static_cast(wg_size)); + + uint16_t n = static_cast(n_values); + static_assert(std::is_same_v, OutputT>); + + using ValueT = OutputT; + + using KeyT = std::invoke_result_t; + + TempBuf buf_val( + n_batch_size, static_cast(block_size * wg_size)); + TempBuf buf_count( + n_batch_size, static_cast(counter_buf_sz)); + + sycl::range<1> lRange{wg_size}; + + sycl::event sort_ev; + std::vector deps = depends; + + std::size_t n_batches = (n_iters + n_batch_size - 1) / n_batch_size; + + for (size_t batch_id = 0; batch_id < n_batches; ++batch_id) { + + const std::size_t block_start = batch_id * n_batch_size; + + // input_arr/output_arr each has shape (n_iters, n) + InputT *this_input_arr = input_arr + block_start * n_values; + OutputT *this_output_arr = output_arr + block_start * n_values; + + const std::size_t block_end = + std::min(block_start + n_batch_size, n_iters); + + sycl::range<1> gRange{(block_end - block_start) * wg_size}; + sycl::nd_range ndRange{gRange, lRange}; + + sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + + // allocation to use for value exchanges + auto exchange_acc = buf_val.get_acc(cgh); + const std::size_t exchange_acc_iter_stride = + buf_val.get_iter_stride(); + + // allocation for counters + auto counter_acc = buf_count.get_acc(cgh); + const std::size_t counter_acc_iter_stride = + buf_count.get_iter_stride(); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> + ndit) { + ValueT values[block_size]; + + const std::size_t iter_id = ndit.get_group(0); + const std::size_t iter_val_offset = + iter_id * static_cast(n); + const std::size_t iter_counter_offset = + iter_id * counter_acc_iter_stride; + const std::size_t iter_exchange_offset = + iter_id * exchange_acc_iter_stride; + + uint16_t wi = ndit.get_local_linear_id(); + uint16_t begin_bit = 0; + + constexpr uint16_t end_bit = + number_of_bits_in_type(); + +// copy from input array into values +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + if (id < n) + values[i] = std::move( + this_input_arr[iter_val_offset + + static_cast( + id)]); + } + + while (true) { + // indices for indirect access in the "re-order" + // phase + uint16_t indices[block_size]; + { + // pointers to bucket's counters + uint32_t *counters[block_size]; + + // counting phase + auto pcounter = + get_accessor_pointer(counter_acc) + + static_cast(wi) + + iter_counter_offset; + +// initialize counters +#pragma unroll + for (uint16_t i = 0; i < bin_count; ++i) + pcounter[i * wg_size] = std::uint32_t{0}; + + sycl::group_barrier(ndit.get_group()); + + if (is_ascending) { +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + true>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } + } + else { +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + constexpr uint16_t bin_mask = + bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + false>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + } + } + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan phase + { + + // scan contiguous numbers + uint16_t bin_sum[bin_count]; + bin_sum[0] = + counter_acc[iter_counter_offset + + static_cast( + wi * bin_count)]; + +#pragma unroll + for (uint16_t i = 1; i < bin_count; ++i) + bin_sum[i] = + bin_sum[i - 1] + + counter_acc + [iter_counter_offset + + static_cast( + wi * bin_count + i)]; + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan local sum + uint16_t sum_scan = + sycl::exclusive_scan_over_group( + ndit.get_group(), + bin_sum[bin_count - 1], + sycl::plus()); + +// add to local sum, generate exclusive scan result +#pragma unroll + for (uint16_t i = 0; i < bin_count; ++i) + counter_acc[iter_counter_offset + + static_cast( + wi * bin_count + i + + 1)] = + sum_scan + bin_sum[i]; + + if (wi == 0) + counter_acc[iter_counter_offset + 0] = + std::uint32_t{0}; + + sycl::group_barrier(ndit.get_group()); + } + +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + // a global index is a local offset plus a + // global base index + indices[i] += *counters[i]; + } + } + + begin_bit += radix; + + // "re-order" phase + sycl::group_barrier(ndit.get_group()); + if (begin_bit >= end_bit) { +// the last iteration - writing out the result +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t r = indices[i]; + if (r < n) { + // move the values to source range and + // destroy the values + this_output_arr + [iter_val_offset + + static_cast(r)] = + std::move(values[i]); + } + } + + return; + } + +// data exchange +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t r = indices[i]; + if (r < n) + exchange_acc[iter_exchange_offset + + static_cast(r)] = + std::move(values[i]); + } + + sycl::group_barrier(ndit.get_group()); + +#pragma unroll + for (uint16_t i = 0; i < block_size; ++i) { + const uint16_t id = wi * block_size + i; + if (id < n) + values[i] = std::move( + exchange_acc[iter_exchange_offset + + static_cast( + id)]); + } + + sycl::group_barrier(ndit.get_group()); + } + }); + }); + + deps = {sort_ev}; + } + + return sort_ev; + } + }; +}; + +template struct OneWorkGroupRadixSortKernel; + +//----------------------------------------------------------------------- +// radix sort: main function +//----------------------------------------------------------------------- +template +sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_to_sort, + const ValueT *input_arr, + ValueT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &depends) +{ + assert(n_to_sort > 1); + + using KeyT = std::remove_cv_t< + std::remove_reference_t>>; + + // radix bits represent number of processed bits in each value during one + // iteration + constexpr std::uint32_t radix_bits = 4; + + sycl::event sort_ev{}; + + const auto &dev = exec_q.get_device(); + const auto max_wg_size = + dev.template get_info(); + + constexpr std::uint16_t ref_wg_size = 64; + if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { + using _RadixSortKernel = OneWorkGroupRadixSortKernel; + + if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { + // wg_size * block_size == 64 * 1 * 1 == 64 + constexpr std::uint16_t wg_size = ref_wg_size; + constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 128 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 1 == 128 + constexpr std::uint16_t wg_size = ref_wg_size * 2; + constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 256 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 2 == 256 + constexpr std::uint16_t wg_size = ref_wg_size * 2; + constexpr std::uint16_t block_size = 2; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 512 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 4 == 512 + constexpr std::uint16_t wg_size = ref_wg_size * 2; + constexpr std::uint16_t block_size = 4; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 1024 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 8 == 1024 + constexpr std::uint16_t wg_size = ref_wg_size * 2; + constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 2048 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 8 == 2048 + constexpr std::uint16_t wg_size = ref_wg_size * 4; + constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 4096 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 16 == 4096 + constexpr std::uint16_t wg_size = ref_wg_size * 4; + constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 8192 && ref_wg_size * 8 <= max_wg_size) { + // wg_size * block_size == 64 * 8 * 16 == 8192 + constexpr std::uint16_t wg_size = ref_wg_size * 8; + constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else { + // wg_size * block_size == 64 * 8 * 32 == 16384 + constexpr std::uint16_t wg_size = ref_wg_size * 8; + constexpr std::uint16_t block_size = 32; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + } + else { + constexpr std::uint32_t radix_iters = + number_of_buckets_in_type(radix_bits); + constexpr std::uint32_t radix_states = std::uint32_t(1) << radix_bits; + + constexpr std::size_t bound_512k = (std::size_t(1) << 19); + constexpr std::size_t bound_2m = (std::size_t(1) << 21); + + const auto wg_sz_k = (n_to_sort < bound_512k) ? 8 + : (n_to_sort <= bound_2m) ? 4 + : 1; + const std::size_t wg_size = max_wg_size / wg_sz_k; + + const std::size_t n_segments = (n_to_sort + wg_size - 1) / wg_size; + + // Additional radix_states elements are used for getting local offsets + // from count values + no_op flag; 'No operation' flag specifies whether + // to skip re-order phase if the all keys are the same (lie in one bin) + const std::size_t n_counts = + (n_segments + 1) * radix_states + 1 /*no_op flag*/; + + using CountT = std::uint32_t; + + // memory for storing count and offset values + CountT *count_ptr = + sycl::malloc_device(n_iters * n_counts, exec_q); + if (nullptr == count_ptr) { + throw std::runtime_error("Could not allocate USM-device memory"); + } + + constexpr std::uint32_t zero_radix_iter{0}; + + if constexpr (std::is_same_v) { + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, output_arr, + n_counts, count_ptr, proj_op, + is_ascending, depends); + + sort_ev = exec_q.submit([=](sycl::handler &cgh) { + cgh.depends_on(sort_ev); + const sycl::context &ctx = exec_q.get_context(); + + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task( + [ctx, count_ptr]() { sycl_free_noexcept(count_ptr, ctx); }); + }); + + return sort_ev; + } + + ValueT *tmp_arr = + sycl::malloc_device(n_iters * n_to_sort, exec_q); + if (nullptr == tmp_arr) { + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + sycl_free_noexcept(count_ptr, exec_q); + throw std::runtime_error("Could not allocate USM-device memory"); + } + + // iterations per each bucket + assert("Number of iterations must be even" && radix_iters % 2 == 0); + assert(radix_iters > 0); + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, tmp_arr, n_counts, + count_ptr, proj_op, is_ascending, + depends); + + for (std::uint32_t radix_iter = 1; radix_iter < radix_iters; + ++radix_iter) + { + if (radix_iter % 2 == 0) { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/true>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, output_arr, + tmp_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + else { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/false>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, tmp_arr, + output_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + } + + sort_ev = exec_q.submit([=](sycl::handler &cgh) { + cgh.depends_on(sort_ev); + + const sycl::context &ctx = exec_q.get_context(); + + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task([ctx, count_ptr, tmp_arr]() { + sycl_free_noexcept(tmp_arr, ctx); + sycl_free_noexcept(count_ptr, ctx); + }); + }); + } + + return sort_ev; +} + +struct IdentityProj +{ + constexpr IdentityProj() {} + + template constexpr T operator()(T val) const { return val; } +}; + +template struct ValueProj +{ + constexpr ValueProj() {} + + constexpr ValueT operator()(const std::pair &pair) const + { + return pair.first; + } +}; + +template struct IndexedProj +{ + IndexedProj(const ValueT *arg_ptr, const ProjT &proj_op) + : ptr(arg_ptr), value_projector(proj_op) + { + } + + auto operator()(IndexT i) const { return value_projector(ptr[i]); } + +private: + const ValueT *ptr; + ProjT value_projector; +}; + +} // end of namespace radix_sort_details + +template +sycl::event +radix_sort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of rows in a + // matrix when sorting over rows) + size_t iter_nelems, + // size of each array to sort (length of rows, + // i.e. number of columns) + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + argTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + using Proj = radix_sort_details::IdentityProj; + constexpr Proj proj_op{}; + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, + sort_ascending, depends); + + return radix_sort_ev; +} + +template +class populate_indexed_data_for_radix_sort_krn; + +template +class index_write_out_for_radix_sort_krn; + +template +sycl::event +radix_argsort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of rows in + // a matrix when sorting over rows) + size_t iter_nelems, + // size of each array to sort (length of rows, + // i.e. number of columns) + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + using ValueIndexT = std::pair; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + ValueIndexT *workspace = sycl::malloc_device( + padded_total_nelems + total_nelems, exec_q); + + if (nullptr == workspace) { + throw std::runtime_error("Could not allocate workspace on device"); + } + + ValueIndexT *indexed_data_tp = workspace; + ValueIndexT *temp_tp = workspace + padded_total_nelems; + + using Proj = radix_sort_details::ValueProj; + constexpr Proj proj_op{}; + + sycl::event populate_indexed_data_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + populate_indexed_data_for_radix_sort_krn; + + cgh.parallel_for( + sycl::range<1>(total_nelems), [=](sycl::id<1> id) { + size_t i = id[0]; + IndexTy sort_id = static_cast(i % sort_nelems); + indexed_data_tp[i] = std::make_pair(arg_tp[i], sort_id); + }); + }); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, indexed_data_tp, temp_tp, proj_op, + sort_ascending, {populate_indexed_data_ev}); + + sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(radix_sort_ev); + + using KernelName = index_write_out_for_radix_sort_krn; + + cgh.parallel_for( + sycl::range<1>(total_nelems), + [=](sycl::id<1> id) { res_tp[id] = std::get<1>(temp_tp[id]); }); + }); + + sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(write_out_ev); + + const sycl::context &ctx = exec_q.get_context(); + + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); + }); + + return cleanup_ev; +} + +template class iota_for_radix_sort_krn; + +template +sycl::event +radix_argsort_axis1_contig_alt_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of + // rows in a matrix when sorting over rows) + size_t iter_nelems, + // size of each array to sort (length of + // rows, i.e. number of columns) + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + IndexTy *workspace = sycl::malloc_device( + padded_total_nelems + total_nelems, exec_q); + + if (nullptr == workspace) { + throw std::runtime_error("Could not allocate workspace on device"); + } + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp, IdentityProjT{}}; + + sycl::event iota_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = iota_for_radix_sort_krn; + + cgh.parallel_for( + sycl::range<1>(total_nelems), [=](sycl::id<1> id) { + size_t i = id[0]; + IndexTy sort_id = static_cast(i); + workspace[i] = sort_id; + }); + }); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, + sort_ascending, {iota_ev}); + + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(radix_sort_ev); + + using KernelName = index_write_out_for_radix_sort_krn; + + cgh.parallel_for( + sycl::range<1>(total_nelems), [=](sycl::id<1> id) { + IndexTy linear_index = res_tp[id]; + res_tp[id] = (linear_index % sort_nelems); + }); + }); + + sycl::event cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(map_back_ev); + + const sycl::context &ctx = exec_q.get_context(); + + using dpctl::tensor::alloc_utils::sycl_free_noexcept; + cgh.host_task([ctx, workspace] { sycl_free_noexcept(workspace, ctx); }); + }); + + return cleanup_ev; +} + +} // end of namespace kernels +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_detail.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp similarity index 98% rename from dpctl/tensor/libtensor/include/kernels/sorting/sort_detail.hpp rename to dpctl/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp index b286f04dfe..8d8b080ce5 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/sort_detail.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp @@ -35,7 +35,7 @@ namespace tensor namespace kernels { -namespace sort_detail +namespace search_sorted_detail { template T quotient_ceil(T n, T m) { return (n + m - 1) / m; } @@ -111,7 +111,7 @@ std::size_t upper_bound_indexed_impl(const Acc acc, acc_indexer); } -} // namespace sort_detail +} // namespace search_sorted_detail } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/searchsorted.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/searchsorted.hpp index 4c1f5c5c93..494d5d4f10 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting/searchsorted.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting/searchsorted.hpp @@ -31,7 +31,7 @@ #include #include "kernels/dpctl_tensor_types.hpp" -#include "kernels/sorting/sort_detail.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" #include "utils/offset_utils.hpp" namespace dpctl @@ -91,7 +91,7 @@ struct SearchSortedFunctor // lower_bound returns the first pos such that bool(hay[pos] < // needle_v) is false, i.e. needle_v <= hay[pos] - pos = sort_detail::lower_bound_indexed_impl( + pos = search_sorted_detail::lower_bound_indexed_impl( hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); } else { @@ -100,7 +100,7 @@ struct SearchSortedFunctor // upper_bound returns the first pos such that bool(needle_v < // hay[pos]) is true, i.e. needle_v < hay[pos] - pos = sort_detail::upper_bound_indexed_impl( + pos = search_sorted_detail::upper_bound_indexed_impl( hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); } diff --git a/dpctl/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp b/dpctl/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp new file mode 100644 index 0000000000..c9868093c5 --- /dev/null +++ b/dpctl/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp @@ -0,0 +1,50 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace dpctl +{ +namespace tensor +{ +namespace kernels +{ + +typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &, + size_t, + size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +} +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/merge_argsort.cpp b/dpctl/tensor/libtensor/source/sorting/merge_argsort.cpp new file mode 100644 index 0000000000..15b59f0368 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/merge_argsort.cpp @@ -0,0 +1,150 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "utils/math_utils.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" +#include "rich_comparisons.hpp" + +#include "merge_argsort.hpp" +#include "py_argsort_common.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct AscendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using Comp = typename DescendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +void init_merge_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder< + sort_contig_fn_ptr_t, DescendingArgSortContigFactory, td_ns::num_types> + dtb2; + dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table); +} + +void init_merge_argsort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_merge_argsort_dispatch_tables(); + + auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_argsort_contig_dispatch_table); + }; + m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_argsort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_argsort_contig_dispatch_table); + }; + m.def("_argsort_descending", py_argsort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // end of namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/merge_argsort.hpp b/dpctl/tensor/libtensor/source/sorting/merge_argsort.hpp new file mode 100644 index 0000000000..d85cabcd85 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/merge_argsort.hpp @@ -0,0 +1,42 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_merge_argsort_functions(py::module_); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/merge_sort.cpp b/dpctl/tensor/libtensor/source/sorting/merge_sort.cpp new file mode 100644 index 0000000000..e3773510e9 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/merge_sort.cpp @@ -0,0 +1,131 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include + +#include "dpctl4pybind11.hpp" +#include +#include + +#include "utils/math_utils.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "merge_sort.hpp" +#include "py_sort_common.hpp" +#include "rich_comparisons.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_sort_contig_dispatch_vector[td_ns::num_types]; + +template struct AscendingSortContigFactory +{ + fnT get() + { + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +template struct DescendingSortContigFactory +{ + fnT get() + { + using Comp = typename DescendingSorter::type; + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +void init_merge_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder + dtv1; + dtv1.populate_dispatch_vector(ascending_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_sort_contig_dispatch_vector); +} + +void init_merge_sort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_merge_sort_dispatch_vectors(); + + auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal::ascending_sort_contig_dispatch_vector); + }; + m.def("_sort_ascending", py_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal::descending_sort_contig_dispatch_vector); + }; + m.def("_sort_descending", py_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // end of namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/sort.hpp b/dpctl/tensor/libtensor/source/sorting/merge_sort.hpp similarity index 95% rename from dpctl/tensor/libtensor/source/sorting/sort.hpp rename to dpctl/tensor/libtensor/source/sorting/merge_sort.hpp index 2d25116dc6..2c5f43aa78 100644 --- a/dpctl/tensor/libtensor/source/sorting/sort.hpp +++ b/dpctl/tensor/libtensor/source/sorting/merge_sort.hpp @@ -35,7 +35,7 @@ namespace tensor namespace py_internal { -extern void init_sort_functions(py::module_); +extern void init_merge_sort_functions(py::module_); } // namespace py_internal } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/sorting/argsort.cpp b/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp similarity index 55% rename from dpctl/tensor/libtensor/source/sorting/argsort.cpp rename to dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp index b5a052ef94..cae18aed25 100644 --- a/dpctl/tensor/libtensor/source/sorting/argsort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/py_argsort_common.hpp @@ -32,10 +32,6 @@ #include "utils/output_validation.hpp" #include "utils/type_dispatch.hpp" -#include "argsort.hpp" -#include "kernels/sorting/sort.hpp" -#include "rich_comparisons.hpp" - namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl @@ -52,7 +48,7 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, const std::vector &depends, - const sorting_contig_impl_fnT &stable_sort_contig_fns) + const sorting_contig_impl_fnT &sort_contig_fns) { int src_nd = src.get_ndim(); int dst_nd = dst.get_ndim(); @@ -131,10 +127,10 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, if (is_src_c_contig && is_dst_c_contig) { static constexpr py::ssize_t zero_offset = py::ssize_t(0); - auto fn = stable_sort_contig_fns[src_typeid][dst_typeid]; + auto fn = sort_contig_fns[src_typeid][dst_typeid]; if (fn == nullptr) { - throw py::value_error("Not implemented for given index type"); + throw py::value_error("Not implemented for dtypes of input arrays"); } sycl::event comp_ev = @@ -151,103 +147,6 @@ py_argsort(const dpctl::tensor::usm_ndarray &src, "Both source and destination arrays must be C-contiguous"); } -using dpctl::tensor::kernels::sort_contig_fn_ptr_t; -static sort_contig_fn_ptr_t - ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; -static sort_contig_fn_ptr_t - descending_argsort_contig_dispatch_table[td_ns::num_types] - [td_ns::num_types]; - -template -struct AscendingArgSortContigFactory -{ - fnT get() - { - if constexpr (std::is_same_v || - std::is_same_v) - { - using Comp = typename AscendingSorter::type; - - using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; - return stable_argsort_axis1_contig_impl; - } - else { - return nullptr; - } - } -}; - -template -struct DescendingArgSortContigFactory -{ - fnT get() - { - if constexpr (std::is_same_v || - std::is_same_v) - { - using Comp = typename DescendingSorter::type; - - using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; - return stable_argsort_axis1_contig_impl; - } - else { - return nullptr; - } - } -}; - -void init_argsort_dispatch_tables(void) -{ - using dpctl::tensor::kernels::sort_contig_fn_ptr_t; - - td_ns::DispatchTableBuilder - dtb1; - dtb1.populate_dispatch_table(ascending_argsort_contig_dispatch_table); - - td_ns::DispatchTableBuilder< - sort_contig_fn_ptr_t, DescendingArgSortContigFactory, td_ns::num_types> - dtb2; - dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table); -} - -void init_argsort_functions(py::module_ m) -{ - dpctl::tensor::py_internal::init_argsort_dispatch_tables(); - - auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src, - const int trailing_dims_to_sort, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends) - -> std::pair { - return dpctl::tensor::py_internal::py_argsort( - src, trailing_dims_to_sort, dst, exec_q, depends, - dpctl::tensor::py_internal:: - ascending_argsort_contig_dispatch_table); - }; - m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"), - py::arg("trailing_dims_to_sort"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto py_argsort_descending = [](const dpctl::tensor::usm_ndarray &src, - const int trailing_dims_to_sort, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends) - -> std::pair { - return dpctl::tensor::py_internal::py_argsort( - src, trailing_dims_to_sort, dst, exec_q, depends, - dpctl::tensor::py_internal:: - descending_argsort_contig_dispatch_table); - }; - m.def("_argsort_descending", py_argsort_descending, py::arg("src"), - py::arg("trailing_dims_to_sort"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - return; -} - } // end of namespace py_internal } // end of namespace tensor } // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/sort.cpp b/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp similarity index 59% rename from dpctl/tensor/libtensor/source/sorting/sort.cpp rename to dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp index d79555d602..d261adb352 100644 --- a/dpctl/tensor/libtensor/source/sorting/sort.cpp +++ b/dpctl/tensor/libtensor/source/sorting/py_sort_common.hpp @@ -22,6 +22,8 @@ /// extension. //===--------------------------------------------------------------------===// +#pragma once + #include #include "dpctl4pybind11.hpp" @@ -33,10 +35,6 @@ #include "utils/output_validation.hpp" #include "utils/type_dispatch.hpp" -#include "kernels/sorting/sort.hpp" -#include "rich_comparisons.hpp" -#include "sort.hpp" - namespace td_ns = dpctl::tensor::type_dispatch; namespace dpctl @@ -53,7 +51,7 @@ py_sort(const dpctl::tensor::usm_ndarray &src, const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, const std::vector &depends, - const sorting_contig_impl_fnT &stable_sort_contig_fns) + const sorting_contig_impl_fnT &sort_contig_fns) { int src_nd = src.get_ndim(); int dst_nd = dst.get_ndim(); @@ -130,7 +128,12 @@ py_sort(const dpctl::tensor::usm_ndarray &src, if (is_src_c_contig && is_dst_c_contig) { constexpr py::ssize_t zero_offset = py::ssize_t(0); - auto fn = stable_sort_contig_fns[src_typeid]; + auto fn = sort_contig_fns[src_typeid]; + + if (nullptr == fn) { + throw py::value_error( + "Not implemented for the dtype of input arrays"); + } sycl::event comp_ev = fn(exec_q, iter_nelems, sort_nelems, src.get_data(), dst.get_data(), @@ -146,83 +149,6 @@ py_sort(const dpctl::tensor::usm_ndarray &src, "Both source and destination arrays must be C-contiguous"); } -using dpctl::tensor::kernels::sort_contig_fn_ptr_t; -static sort_contig_fn_ptr_t - ascending_sort_contig_dispatch_vector[td_ns::num_types]; -static sort_contig_fn_ptr_t - descending_sort_contig_dispatch_vector[td_ns::num_types]; - -template struct AscendingSortContigFactory -{ - fnT get() - { - using Comp = typename AscendingSorter::type; - - using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; - return stable_sort_axis1_contig_impl; - } -}; - -template struct DescendingSortContigFactory -{ - fnT get() - { - using Comp = typename DescendingSorter::type; - using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; - return stable_sort_axis1_contig_impl; - } -}; - -void init_sort_dispatch_vectors(void) -{ - using dpctl::tensor::kernels::sort_contig_fn_ptr_t; - - td_ns::DispatchVectorBuilder - dtv1; - dtv1.populate_dispatch_vector(ascending_sort_contig_dispatch_vector); - - td_ns::DispatchVectorBuilder - dtv2; - dtv2.populate_dispatch_vector(descending_sort_contig_dispatch_vector); -} - -void init_sort_functions(py::module_ m) -{ - dpctl::tensor::py_internal::init_sort_dispatch_vectors(); - - auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, - const int trailing_dims_to_sort, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends) - -> std::pair { - return dpctl::tensor::py_internal::py_sort( - src, trailing_dims_to_sort, dst, exec_q, depends, - dpctl::tensor::py_internal::ascending_sort_contig_dispatch_vector); - }; - m.def("_sort_ascending", py_sort_ascending, py::arg("src"), - py::arg("trailing_dims_to_sort"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - auto py_sort_descending = [](const dpctl::tensor::usm_ndarray &src, - const int trailing_dims_to_sort, - const dpctl::tensor::usm_ndarray &dst, - sycl::queue &exec_q, - const std::vector &depends) - -> std::pair { - return dpctl::tensor::py_internal::py_sort( - src, trailing_dims_to_sort, dst, exec_q, depends, - dpctl::tensor::py_internal::descending_sort_contig_dispatch_vector); - }; - m.def("_sort_descending", py_sort_descending, py::arg("src"), - py::arg("trailing_dims_to_sort"), py::arg("dst"), - py::arg("sycl_queue"), py::arg("depends") = py::list()); - - return; -} - } // end of namespace py_internal } // end of namespace tensor } // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp new file mode 100644 index 0000000000..a98e5677b2 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/radix_argsort.cpp @@ -0,0 +1,186 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpctl4pybind11.hpp" +#include +#include + +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_argsort_common.hpp" +#include "radix_argsort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + +static sort_contig_fn_ptr_t + ascending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +namespace +{ + +template +sycl::event argsort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_argsort_axis1_contig_alt_impl; + + return radix_argsort_axis1_contig_alt_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template +struct AscendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ true, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ false, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +void init_radix_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_radix_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + descending_radix_argsort_contig_dispatch_table); +} + +void init_radix_argsort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_radix_argsort_dispatch_tables(); + + auto py_radix_argsort_ascending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_ascending", py_radix_argsort_ascending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_argsort_descending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_descending", py_radix_argsort_descending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/radix_argsort.hpp b/dpctl/tensor/libtensor/source/sorting/radix_argsort.hpp new file mode 100644 index 0000000000..131c5ea048 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/radix_argsort.hpp @@ -0,0 +1,42 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +extern void init_radix_argsort_functions(py::module_); + +} // namespace py_internal +} // namespace tensor +} // namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp new file mode 100644 index 0000000000..09eb75d1f1 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/radix_sort.cpp @@ -0,0 +1,186 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpctl4pybind11.hpp" +#include +#include + +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_sort_common.hpp" +#include "radix_sort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_radix_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_sort_contig_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +sycl::event sort_axis1_contig_caller(sycl::queue &q, + size_t iter_nelems, + size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; + + return radix_sort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template struct AscendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +template struct DescendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +void init_radix_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder< + sort_contig_fn_ptr_t, AscendingRadixSortContigFactory, td_ns::num_types> + dtv1; + dtv1.populate_dispatch_vector(ascending_radix_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_radix_sort_contig_dispatch_vector); +} + +bool py_radix_sort_defined(int typenum) +{ + const auto &array_types = td_ns::usm_ndarray_types(); + + try { + int type_id = array_types.typenum_to_lookup_id(typenum); + return (nullptr != + ascending_radix_sort_contig_dispatch_vector[type_id]); + } catch (const std::exception &e) { + return false; + } +} + +void init_radix_sort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_radix_sort_dispatch_vectors(); + + auto py_radix_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_ascending", py_radix_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_descending", py_radix_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_radix_sort_dtype_supported", py_radix_sort_defined); + + return; +} + +} // namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/sorting/argsort.hpp b/dpctl/tensor/libtensor/source/sorting/radix_sort.hpp similarity index 95% rename from dpctl/tensor/libtensor/source/sorting/argsort.hpp rename to dpctl/tensor/libtensor/source/sorting/radix_sort.hpp index 6802ccc311..3f535f40fe 100644 --- a/dpctl/tensor/libtensor/source/sorting/argsort.hpp +++ b/dpctl/tensor/libtensor/source/sorting/radix_sort.hpp @@ -35,7 +35,7 @@ namespace tensor namespace py_internal { -extern void init_argsort_functions(py::module_); +extern void init_radix_sort_functions(py::module_); } // namespace py_internal } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/sorting/radix_sort_support.hpp b/dpctl/tensor/libtensor/source/sorting/radix_sort_support.hpp new file mode 100644 index 0000000000..9e42669b96 --- /dev/null +++ b/dpctl/tensor/libtensor/source/sorting/radix_sort_support.hpp @@ -0,0 +1,71 @@ +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===--------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +#include + +namespace dpctl +{ +namespace tensor +{ +namespace py_internal +{ + +template +struct TypeDefinedEntry : std::bool_constant> +{ + static constexpr bool is_defined = true; +}; + +struct NotDefinedEntry : std::true_type +{ + static constexpr bool is_defined = false; +}; + +template struct RadixSortSupportVector +{ + using resolver_t = + typename std::disjunction, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + NotDefinedEntry>; + + static constexpr bool is_defined = resolver_t::is_defined; +}; + +} // end of namespace py_internal +} // end of namespace tensor +} // end of namespace dpctl diff --git a/dpctl/tensor/libtensor/source/tensor_sorting.cpp b/dpctl/tensor/libtensor/source/tensor_sorting.cpp index 6f2f965285..52d3ab67b4 100644 --- a/dpctl/tensor/libtensor/source/tensor_sorting.cpp +++ b/dpctl/tensor/libtensor/source/tensor_sorting.cpp @@ -25,15 +25,15 @@ #include -#include "sorting/argsort.hpp" +#include "sorting/merge_argsort.hpp" +#include "sorting/merge_sort.hpp" #include "sorting/searchsorted.hpp" -#include "sorting/sort.hpp" namespace py = pybind11; PYBIND11_MODULE(_tensor_sorting_impl, m) { - dpctl::tensor::py_internal::init_sort_functions(m); - dpctl::tensor::py_internal::init_argsort_functions(m); + dpctl::tensor::py_internal::init_merge_sort_functions(m); + dpctl::tensor::py_internal::init_merge_argsort_functions(m); dpctl::tensor::py_internal::init_searchsorted_functions(m); } diff --git a/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp new file mode 100644 index 0000000000..b5ef49e0ac --- /dev/null +++ b/dpctl/tensor/libtensor/source/tensor_sorting_radix.cpp @@ -0,0 +1,37 @@ +//===-- tensor_sorting.cpp - -----*-C++-*-/===// +// Implementation of _tensor_reductions_impl module +// +// Data Parallel Control (dpctl) +// +// Copyright 2020-2024 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include + +#include "sorting/radix_argsort.hpp" +#include "sorting/radix_sort.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_sorting_radix_impl, m) +{ + dpctl::tensor::py_internal::init_radix_sort_functions(m); + dpctl::tensor::py_internal::init_radix_argsort_functions(m); +} diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index 088780d103..c738a80fef 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -153,81 +153,117 @@ def test_sort_validation(): dpt.sort(dict()) +def test_sort_validation_kind(): + get_queue_or_skip() + + x = dpt.ones(128, dtype="u1") + + with pytest.raises(ValueError): + dpt.sort(x, kind=Ellipsis) + + with pytest.raises(ValueError): + dpt.sort(x, kind="invalid") + + def test_argsort_validation(): with pytest.raises(TypeError): dpt.argsort(dict()) -def test_sort_axis0(): +def test_argsort_validation_kind(): + get_queue_or_skip() + + x = dpt.arange(127, stop=0, step=-1, dtype="i1") + + with pytest.raises(ValueError): + dpt.argsort(x, kind=Ellipsis) + + with pytest.raises(ValueError): + dpt.argsort(x, kind="invalid") + + +_all_kinds = ["stable", "mergesort", "radixsort"] + + +@pytest.mark.parametrize("kind", _all_kinds) +def test_sort_axis0(kind): get_queue_or_skip() n, m = 200, 30 xf = dpt.arange(n * m, 0, step=-1, dtype="i4") x = dpt.reshape(xf, (n, m)) - s = dpt.sort(x, axis=0) + s = dpt.sort(x, axis=0, kind=kind) assert dpt.all(s[:-1, :] <= s[1:, :]) -def test_argsort_axis0(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_argsort_axis0(kind): get_queue_or_skip() n, m = 200, 30 xf = dpt.arange(n * m, 0, step=-1, dtype="i4") x = dpt.reshape(xf, (n, m)) - idx = dpt.argsort(x, axis=0) + idx = dpt.argsort(x, axis=0, kind=kind) s = dpt.take_along_axis(x, idx, axis=0) assert dpt.all(s[:-1, :] <= s[1:, :]) -def test_argsort_axis1(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_argsort_axis1(kind): get_queue_or_skip() n, m = 200, 30 xf = dpt.arange(n * m, 0, step=-1, dtype="i4") x = dpt.reshape(xf, (n, m)) - idx = dpt.argsort(x, axis=1) + idx = dpt.argsort(x, axis=1, kind=kind) s = dpt.take_along_axis(x, idx, axis=1) assert dpt.all(s[:, :-1] <= s[:, 1:]) -def test_sort_strided(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_sort_strided(kind): get_queue_or_skip() x_orig = dpt.arange(100, dtype="i4") x_flipped = dpt.flip(x_orig, axis=0) - s = dpt.sort(x_flipped) + s = dpt.sort(x_flipped, kind=kind) assert dpt.all(s == x_orig) -def test_argsort_strided(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_argsort_strided(kind): get_queue_or_skip() x_orig = dpt.arange(100, dtype="i4") x_flipped = dpt.flip(x_orig, axis=0) - idx = dpt.argsort(x_flipped) + idx = dpt.argsort(x_flipped, kind=kind) s = dpt.take_along_axis(x_flipped, idx, axis=0) assert dpt.all(s == x_orig) -def test_sort_0d_array(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_sort_0d_array(kind): get_queue_or_skip() x = dpt.asarray(1, dtype="i4") - assert dpt.sort(x) == 1 + expected = dpt.asarray(1, dtype="i4") + assert dpt.sort(x, kind=kind) == expected -def test_argsort_0d_array(): +@pytest.mark.parametrize("kind", _all_kinds) +def test_argsort_0d_array(kind): get_queue_or_skip() x = dpt.asarray(1, dtype="i4") - assert dpt.argsort(x) == 0 + expected = dpt.asarray(0, dtype="i4") + assert dpt.argsort(x, kind=kind) == expected @pytest.mark.parametrize( @@ -238,14 +274,15 @@ def test_argsort_0d_array(): "f8", ], ) -def test_sort_real_fp_nan(dtype): +@pytest.mark.parametrize("kind", _all_kinds) +def test_sort_real_fp_nan(dtype, kind): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) x = dpt.asarray( [-0.0, 0.1, dpt.nan, 0.0, -0.1, dpt.nan, 0.2, -0.3], dtype=dtype ) - s = dpt.sort(x) + s = dpt.sort(x, kind=kind) expected = dpt.asarray( [-0.3, -0.1, -0.0, 0.0, 0.1, 0.2, dpt.nan, dpt.nan], dtype=dtype @@ -253,7 +290,7 @@ def test_sort_real_fp_nan(dtype): assert dpt.allclose(s, expected, equal_nan=True) - s = dpt.sort(x, descending=True) + s = dpt.sort(x, descending=True, kind=kind) expected = dpt.asarray( [dpt.nan, dpt.nan, 0.2, 0.1, -0.0, 0.0, -0.1, -0.3], dtype=dtype