Skip to content

Commit

Permalink
[SYCLomatic] Add run length encode nontrivial runs API (#1083)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 authored Jul 24, 2023
1 parent d078ea1 commit 4d5fb54
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions clang/runtime/dpct-rt/include/dpl_extras/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,96 @@ segmented_reduce_argmax(Policy &&policy, Iter1 keys_in, Iter2 keys_out,
policy.queue().wait();
}

template <typename ExecutionPolicy, typename InputIterator,
typename OutputIterator1, typename OutputIterator2,
typename OutputIterator3>
void nontrivial_run_length_encode(ExecutionPolicy &&policy,
InputIterator input_beg,
OutputIterator1 offsets_out,
OutputIterator2 lengths_out,
OutputIterator3 num_runs,
::std::int64_t num_items) {
using oneapi::dpl::make_transform_iterator;
using oneapi::dpl::make_zip_iterator;
using offsets_t =
typename ::std::iterator_traits<OutputIterator1>::value_type;
using lengths_t =
typename ::std::iterator_traits<OutputIterator2>::value_type;

auto input_end = input_beg + num_items;
// First element must be nontrivial run (start of first segment)
auto first_adj_it = oneapi::dpl::adjacent_find(policy, input_beg, input_end);
auto first_adj_idx = ::std::distance(input_beg, first_adj_it);
if (first_adj_it == input_end) {
::std::fill(policy, num_runs, num_runs + 1, 0);
return;
}
auto get_prev_idx_element = [first_adj_idx](const auto &idx) {
auto out_idx = idx + first_adj_idx;
return (out_idx == 0) ? 0 : out_idx - 1;
};
auto get_next_idx_element = [first_adj_idx, num_items](const auto &idx) {
auto out_idx = idx + first_adj_idx;
return (out_idx == num_items - 1) ? num_items - 1 : out_idx + 1;
};
// TODO: Use shifted view to pad range once oneDPL ranges is non-experimental
auto left_shifted_input_beg =
oneapi::dpl::make_permutation_iterator(input_beg, get_prev_idx_element);
auto right_shifted_input_beg =
oneapi::dpl::make_permutation_iterator(input_beg, get_next_idx_element);
// Segment type for ith idx consists of zip of iterators at (i-1, i, i+1)
// padded at the ends
auto zipped_keys_beg = make_zip_iterator(
left_shifted_input_beg, input_beg, right_shifted_input_beg,
oneapi::dpl::counting_iterator<offsets_t>(0));
// Set flag at the beginning of new nontrivial run (ex: (2, 3, 3) -> 1)
auto key_flags_beg =
make_transform_iterator(zipped_keys_beg, [num_items](const auto &zipped) {
using ::std::get;
bool last_idx_mask = get<3>(zipped) != num_items - 1;
return (get<0>(zipped) != get<1>(zipped) &&
get<1>(zipped) == get<2>(zipped)) &&
last_idx_mask;
});
auto count_beg = oneapi::dpl::counting_iterator<offsets_t>(0);
auto const_it = dpct::make_constant_iterator(lengths_t(1));
// Check for presence of nontrivial element at current index
auto tr_nontrivial_flags = make_transform_iterator(
make_zip_iterator(left_shifted_input_beg, input_beg),
[](const auto &zip) {
using ::std::get;
return get<0>(zip) == get<1>(zip);
});
auto zipped_vals_beg =
make_zip_iterator(tr_nontrivial_flags, count_beg, const_it);
auto pred = [](bool lhs, bool rhs) { return !rhs; };
auto op = [](auto lhs, const auto &rhs) {
using ::std::get;

// Update length count of run.
// The first call to this op will use the first element of the input as lhs
// and second element as rhs. get<0>(first_element) is ignored in favor of a
// constant `1` in get<2>, avoiding the need for special casing the first
// element. The constant `1` utilizes the knowledge that each segment begins
// with a nontrivial run.
get<2>(lhs) += get<0>(rhs);

// A run's starting index is stored in get<1>(lhs) as the initial value in
// the segment and is preserved throughout the segment's reduction as the
// nontrivial run's offset.

return ::std::move(lhs);
};
auto zipped_out_beg = make_zip_iterator(oneapi::dpl::discard_iterator(),
offsets_out, lengths_out);
auto [_, zipped_out_vals_end] = oneapi::dpl::reduce_by_segment(
policy, key_flags_beg + first_adj_idx, key_flags_beg + num_items,
zipped_vals_beg + first_adj_idx, oneapi::dpl::discard_iterator(),
zipped_out_beg, pred, op);
auto ret_dist = ::std::distance(zipped_out_beg, zipped_out_vals_end);
::std::fill(policy, num_runs, num_runs + 1, ret_dist);
}

} // end namespace dpct

#endif

0 comments on commit 4d5fb54

Please sign in to comment.