Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Add run length encode nontrivial runs API #1083

Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions clang/runtime/dpct-rt/include/dpl_extras/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,95 @@ segmented_reduce_argmax(Policy &&policy, Iter1 keys_in, Iter2 keys_out,
policy.queue().wait();
}

template <typename ExecutionPolicy, typename InputIterator,
typename OutputIterator1, typename OutputIterator2,
typename OutputIterator3>
::std::pair<OutputIterator1, OutputIterator2> nontrivial_run_length_encode(
ExecutionPolicy &&policy, InputIterator input_beg, InputIterator input_end,
OutputIterator1 offsets_out, OutputIterator2 lengths_out,
OutputIterator3 num_runs) {
using oneapi::dpl::make_transform_iterator;
using oneapi::dpl::make_zip_iterator;
using offsets_t =
typename ::std::iterator_traits<OutputIterator1>::value_type;
using lengths_t =
typename ::std::iterator_traits<OutputIterator2>::value_type;

auto n = ::std::distance(input_beg, input_end);
// First element must be nontrivial run (start of first segment)
auto first_adj_it = oneapi::dpl::adjacent_find(policy, input_beg, input_end);
auto first_adj_idx = ::std::distance(input_beg, first_adj_it);
if (first_adj_it == input_end) {
::std::fill(policy, num_runs, num_runs + 1, 0);
return ::std::make_pair(offsets_out, lengths_out);
}
auto get_prev_idx_element = [first_adj_idx](const auto &idx) {
auto out_idx = idx + first_adj_idx;
return (out_idx == 0) ? 0 : out_idx - 1;
};
auto get_next_idx_element = [first_adj_idx, n](const auto &idx) {
auto out_idx = idx + first_adj_idx;
return (out_idx == n - 1) ? n - 1 : out_idx + 1;
};
// TODO: Use shifted view to pad range once oneDPL ranges is non-experimental
auto left_shifted_input_beg =
oneapi::dpl::make_permutation_iterator(input_beg, get_prev_idx_element);
auto right_shifted_input_beg =
oneapi::dpl::make_permutation_iterator(input_beg, get_next_idx_element);
// Segment type for ith idx consists of zip of iterators at (i-1, i, i+1)
// padded at the ends
auto zipped_keys_beg = make_zip_iterator(
left_shifted_input_beg, input_beg, right_shifted_input_beg,
oneapi::dpl::counting_iterator<offsets_t>(0));
// Set flag at the beginning of new nontrivial run (ex: (2, 3, 3) -> 1)
auto key_flags_beg =
make_transform_iterator(zipped_keys_beg, [n](const auto &zipped) {
using ::std::get;
bool last_idx_mask = get<3>(zipped) != n - 1;
return (get<0>(zipped) != get<1>(zipped) &&
get<1>(zipped) == get<2>(zipped)) &&
last_idx_mask;
});
auto count_beg = oneapi::dpl::counting_iterator<offsets_t>(0);
auto const_it = dpct::make_constant_iterator(lengths_t(1));
// Check for presence of nontrivial element at current index
auto tr_nontrivial_flags = make_transform_iterator(
make_zip_iterator(left_shifted_input_beg, input_beg),
[](const auto &zip) {
using ::std::get;
return get<0>(zip) == get<1>(zip);
});
auto zipped_vals_beg =
make_zip_iterator(tr_nontrivial_flags, count_beg, const_it);
auto pred = [](bool lhs, bool rhs) { return !rhs; };
auto op = [](auto lhs, const auto &rhs) {
using ::std::get;

// Update length count of run.
// The first call to this op will use the first element of the input as lhs
// and second element as rhs. get<0>(first_element) is ignored in favor of a
// constant `1` in get<2>, avoiding the need for special casing the first
// element. The constant `1` utilizes the knowledge that each segment begins
// with a nontrivial run.
get<2>(lhs) += get<0>(rhs);
danhoeflinger marked this conversation as resolved.
Show resolved Hide resolved

// A run's starting index is stored in get<1>(lhs) as the initial value in
// the segment and is preserved throughout the segment's reduction as the
// nontrivial run's offset.

return ::std::move(lhs);
};
auto zipped_out_beg = make_zip_iterator(oneapi::dpl::discard_iterator(),
offsets_out, lengths_out);
auto [_, zipped_out_vals_end] = oneapi::dpl::reduce_by_segment(
policy, key_flags_beg + first_adj_idx, key_flags_beg + n,
zipped_vals_beg + first_adj_idx, oneapi::dpl::discard_iterator(),
zipped_out_beg, pred, op);
auto ret_dist = ::std::distance(zipped_out_beg, zipped_out_vals_end);
::std::fill(policy, num_runs, num_runs + 1, ret_dist);
return ::std::make_pair(offsets_out + ret_dist, lengths_out + ret_dist);
}

} // end namespace dpct

#endif