Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/radix sort #1867

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
830f428
Renamed some files in sorting folders, in preparation for introductio…
oleksandr-pavlyk Oct 16, 2024
a6d51af
Add header file implementing radix sort
oleksandr-pavlyk Oct 16, 2024
3517745
Add a comment next to radix_sort_contig_fn_ptr_t type definition
oleksandr-pavlyk Oct 17, 2024
dbe62d0
Added Python API to exercise radix sort functions
oleksandr-pavlyk Oct 17, 2024
6d33867
Add license headers
oleksandr-pavlyk Oct 17, 2024
e274c2d
Rename argument variable from stable_sort_fns to sort_fns
oleksandr-pavlyk Oct 17, 2024
c2f8486
Add Python API to check if radix sort is supported for given dtype
oleksandr-pavlyk Oct 17, 2024
a15e4aa
Add support for kind keyword in sort/argsort
oleksandr-pavlyk Oct 17, 2024
2684651
Parametrize sorting tests by kind
oleksandr-pavlyk Oct 17, 2024
662b46b
Use sycl_free_noexcept
oleksandr-pavlyk Oct 21, 2024
fef0fe4
Remove unused include statement
oleksandr-pavlyk Oct 21, 2024
3c05c1b
Add entry to changelog about radix sort algorithm
oleksandr-pavlyk Oct 22, 2024
bbe1019
Change to pass sorting direction as call argument, not template param…
oleksandr-pavlyk Oct 29, 2024
ec6a930
Moved radix sort Python API to dedicated module, _tensor_sorting_radi…
oleksandr-pavlyk Oct 29, 2024
d63dd70
Merge pull request #1883 from IntelPython/change-descending-from-temp…
oleksandr-pavlyk Oct 29, 2024
446ce05
Address PR feedback
oleksandr-pavlyk Nov 3, 2024
93db58a
Renamed n_values->n_to_sort for readability per review
oleksandr-pavlyk Nov 3, 2024
350738b
Merge remote-tracking branch 'origin/master' into feature/radix-sort
oleksandr-pavlyk Nov 3, 2024
09236c9
Use sycl_free_noexcept instead of sycl::free
oleksandr-pavlyk Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Improved performance of copying operation to C-/F-contig array, with optimization for batch of square matrices [gh-1850](https://github.com/IntelPython/dpctl/pull/1850)
* Improved performance of `tensor.argsort` function for all types [gh-1859](https://github.com/IntelPython/dpctl/pull/1859)
* Improved performance of `tensor.sort` and `tensor.argsort` for short arrays in the range [16, 64] elements [gh-1866](https://github.com/IntelPython/dpctl/pull/1866)
* Implement radix sort algorithm to be used in `dpt.sort` and `dpt.argsort` [gh-1867](https://github.com/IntelPython/dpctl/pull/1867)

### Fixed
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877)
Expand Down
26 changes: 20 additions & 6 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,14 @@ set(_reduction_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
)
set(_sorting_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
)
set(_sorting_radix_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
)
set(_static_lib_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
)
Expand Down Expand Up @@ -151,6 +155,10 @@ set(_tensor_sorting_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp
${_sorting_sources}
)
set(_tensor_sorting_radix_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting_radix.cpp
${_sorting_radix_sources}
)
set(_linalg_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp
Expand All @@ -160,10 +168,10 @@ set(_tensor_linalg_impl_sources
${_linalg_sources}
)
set(_accumulator_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
)
set(_tensor_accumulation_impl_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
Expand Down Expand Up @@ -205,6 +213,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
list(APPEND _py_trgts ${python_module_name})

set(python_module_name _tensor_sorting_radix_impl)
pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_radix_impl_sources})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_radix_impl_sources})
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
list(APPEND _py_trgts ${python_module_name})

set(python_module_name _tensor_linalg_impl)
pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources})
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources})
Expand Down
86 changes: 81 additions & 5 deletions dpctl/tensor/_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,26 @@
_sort_ascending,
_sort_descending,
)
from ._tensor_sorting_radix_impl import (
_radix_argsort_ascending,
_radix_argsort_descending,
_radix_sort_ascending,
_radix_sort_descending,
_radix_sort_dtype_supported,
)

__all__ = ["sort", "argsort"]


def sort(x, /, *, axis=-1, descending=False, stable=True):
def _get_mergesort_impl_fn(descending):
return _sort_descending if descending else _sort_ascending


def _get_radixsort_impl_fn(descending):
return _radix_sort_descending if descending else _radix_sort_ascending


def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
"""sort(x, axis=-1, descending=False, stable=True)

Returns a sorted copy of an input array `x`.
Expand All @@ -49,7 +64,10 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
relative order of `x` values which compare as equal. If `False`,
the returned array may or may not maintain the relative order of
`x` values which compare as equal. Default: `True`.

kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
Sorting algorithm. The default is `"stable"`, which uses parallel
merge-sort or parallel radix-sort algorithms depending on the
array data type.
Returns:
usm_ndarray:
a sorted array. The returned array has the same data type and
Expand All @@ -74,10 +92,33 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
axis,
]
arr = dpt.permute_dims(x, perm)
if kind is None:
kind = "stable"
if not isinstance(kind, str) or kind not in [
"stable",
"radixsort",
"mergesort",
]:
raise ValueError(
"Unsupported kind value. Expected 'stable', 'mergesort', "
f"or 'radixsort', but got '{kind}'"
)
ndgrigorian marked this conversation as resolved.
Show resolved Hide resolved
if kind == "mergesort":
impl_fn = _get_mergesort_impl_fn(descending)
elif kind == "radixsort":
if _radix_sort_dtype_supported(x.dtype.num):
impl_fn = _get_radixsort_impl_fn(descending)
else:
raise ValueError(f"Radix sort is not supported for {x.dtype}")
else:
dt = x.dtype
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
impl_fn = _get_radixsort_impl_fn(descending)
else:
impl_fn = _get_mergesort_impl_fn(descending)
exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
impl_fn = _sort_descending if descending else _sort_ascending
if arr.flags.c_contiguous:
res = dpt.empty_like(arr, order="C")
ht_ev, impl_ev = impl_fn(
Expand Down Expand Up @@ -109,7 +150,15 @@ def sort(x, /, *, axis=-1, descending=False, stable=True):
return res


def argsort(x, axis=-1, descending=False, stable=True):
def _get_mergeargsort_impl_fn(descending):
return _argsort_descending if descending else _argsort_ascending


def _get_radixargsort_impl_fn(descending):
return _radix_argsort_descending if descending else _radix_argsort_ascending


def argsort(x, axis=-1, descending=False, stable=True, kind=None):
"""argsort(x, axis=-1, descending=False, stable=True)

Returns the indices that sort an array `x` along a specified axis.
Expand All @@ -129,6 +178,10 @@ def argsort(x, axis=-1, descending=False, stable=True):
relative order of `x` values which compare as equal. If `False`,
the returned array may or may not maintain the relative order of
`x` values which compare as equal. Default: `True`.
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
Sorting algorithm. The default is `"stable"`, which uses parallel
merge-sort or parallel radix-sort algorithms depending on the
array data type.

Returns:
usm_ndarray:
Expand Down Expand Up @@ -157,10 +210,33 @@ def argsort(x, axis=-1, descending=False, stable=True):
axis,
]
arr = dpt.permute_dims(x, perm)
if kind is None:
kind = "stable"
if not isinstance(kind, str) or kind not in [
"stable",
"radixsort",
"mergesort",
]:
raise ValueError(
"Unsupported kind value. Expected 'stable', 'mergesort', "
f"or 'radixsort', but got '{kind}'"
)
if kind == "mergesort":
impl_fn = _get_mergeargsort_impl_fn(descending)
elif kind == "radixsort":
if _radix_sort_dtype_supported(x.dtype.num):
impl_fn = _get_radixargsort_impl_fn(descending)
else:
raise ValueError(f"Radix sort is not supported for {x.dtype}")
else:
dt = x.dtype
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
impl_fn = _get_radixargsort_impl_fn(descending)
else:
impl_fn = _get_mergeargsort_impl_fn(descending)
exec_q = x.sycl_queue
_manager = du.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events
impl_fn = _argsort_descending if descending else _argsort_ascending
index_dt = ti.default_device_index_type(exec_q)
if arr.flags.c_contiguous:
res = dpt.empty_like(arr, dtype=index_dt, order="C")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <vector>

#include "kernels/dpctl_tensor_types.hpp"
#include "kernels/sorting/sort_detail.hpp"
#include "kernels/sorting/search_sorted_detail.hpp"

namespace dpctl
{
Expand All @@ -41,9 +41,11 @@ namespace tensor
namespace kernels
{

namespace sort_detail
namespace merge_sort_detail
{

using namespace dpctl::tensor::kernels::search_sorted_detail;

/*! @brief Merge two contiguous sorted segments */
template <typename InAcc, typename OutAcc, typename Compare>
void merge_impl(const std::size_t offset,
Expand Down Expand Up @@ -699,18 +701,7 @@ merge_sorted_block_contig_impl(sycl::queue &q,
return dep_ev;
}

} // end of namespace sort_detail

typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &,
size_t,
size_t,
const char *,
char *,
ssize_t,
ssize_t,
ssize_t,
ssize_t,
const std::vector<sycl::event> &);
} // end of namespace merge_sort_detail

template <typename argTy, typename Comp = std::less<argTy>>
sycl::event stable_sort_axis1_contig_impl(
Expand Down Expand Up @@ -741,8 +732,8 @@ sycl::event stable_sort_axis1_contig_impl(
if (sort_nelems < sequential_sorting_threshold) {
// equal work-item sorts entire row
sycl::event sequential_sorting_ev =
sort_detail::sort_base_step_contig_impl<const argTy *, argTy *,
Comp>(
merge_sort_detail::sort_base_step_contig_impl<const argTy *,
argTy *, Comp>(
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
sort_nelems, depends);

Expand All @@ -753,16 +744,16 @@ sycl::event stable_sort_axis1_contig_impl(

// Sort segments of the array
sycl::event base_sort_ev =
sort_detail::sort_over_work_group_contig_impl<const argTy *,
argTy *, Comp>(
merge_sort_detail::sort_over_work_group_contig_impl<const argTy *,
argTy *, Comp>(
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
sorted_block_size, // modified in place with size of sorted
// block size
depends);

// Merge segments in parallel until all elements are sorted
sycl::event merges_ev =
sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
merge_sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
exec_q, iter_nelems, sort_nelems, res_tp, comp,
sorted_block_size, {base_sort_ev});

Expand Down Expand Up @@ -816,8 +807,7 @@ sycl::event stable_argsort_axis1_contig_impl(
const IndexComp<IndexTy, argTy, ValueComp> index_comp{arg_tp, ValueComp{}};

static constexpr size_t determine_automatically = 0;
size_t sorted_block_size =
(sort_nelems >= 512) ? 512 : determine_automatically;
size_t sorted_block_size = determine_automatically;

const size_t total_nelems = iter_nelems * sort_nelems;

Expand All @@ -837,21 +827,24 @@ sycl::event stable_argsort_axis1_contig_impl(
});

// Sort segments of the array
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl(
exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
sorted_block_size, // modified in place with size of sorted block size
{populate_indexed_data_ev});
sycl::event base_sort_ev =
merge_sort_detail::sort_over_work_group_contig_impl(
exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp,
sorted_block_size, // modified in place with size of sorted block
// size
{populate_indexed_data_ev});

// Merge segments in parallel until all elements are sorted
sycl::event merges_ev = sort_detail::merge_sorted_block_contig_impl(
sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl(
exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size,
{base_sort_ev});

sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(merges_ev);

auto temp_acc =
sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp, cgh);
merge_sort_detail::GetReadOnlyAccess<decltype(res_tp)>{}(res_tp,
cgh);

using KernelName = index_map_to_rows_krn<argTy, IndexTy, ValueComp>;

Expand Down
Loading
Loading