Skip to content

Commit

Permalink
Bugfix for set_intersection copying from second iterator range (#983)
Browse files Browse the repository at this point in the history
The standard for set_intersection mentions that "elements will be copied from [first1, last1) to the output range, preserving order."  This PR fixes a bug which allowed the implementation to copy instead from the second set of elements in some cases. 

Signed-off-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
danhoeflinger authored Apr 12, 2024
1 parent 3b561ee commit ecb7cd5
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
12 changes: 8 additions & 4 deletions include/oneapi/dpl/pstl/algorithm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3549,8 +3549,10 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _
__result, __comp, [](_DifferenceType __n, _DifferenceType __m) { return ::std::min(__n, __m); },
[](_RandomAccessIterator1 __first1, _RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2,
_RandomAccessIterator2 __last2, _T* __result, _Compare __comp) {
return oneapi::dpl::__utils::__set_intersection_construct(__first1, __last1, __first2, __last2,
__result, __comp);
return oneapi::dpl::__utils::__set_intersection_construct(
__first1, __last1, __first2, __last2, __result, __comp,
oneapi::dpl::__internal::__op_uninitialized_copy<_ExecutionPolicy>{},
/*CopyFromFirstSet = */ std::true_type{});
});
});
}
Expand All @@ -3565,8 +3567,10 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _
__result, __comp, [](_DifferenceType __n, _DifferenceType __m) { return ::std::min(__n, __m); },
[](_RandomAccessIterator1 __first1, _RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2,
_RandomAccessIterator2 __last2, _T* __result, _Compare __comp) {
return oneapi::dpl::__utils::__set_intersection_construct(__first2, __last2, __first1, __last1,
__result, __comp);
return oneapi::dpl::__utils::__set_intersection_construct(
__first2, __last2, __first1, __last1, __result, __comp,
oneapi::dpl::__internal::__op_uninitialized_copy<_ExecutionPolicy>{},
/*CopyFromFirstSet = */ std::false_type{});
});
return __result;
});
Expand Down
8 changes: 5 additions & 3 deletions include/oneapi/dpl/pstl/memory_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ struct __op_uninitialized_copy<_ExecutionPolicy>
void
operator()(_SourceT&& __source, _TargetT& __target) const
{
using _TargetValueType = ::std::decay_t<_TargetT>;

::new (::std::addressof(__target)) _TargetValueType(::std::forward<_SourceT>(__source));
using _TargetValueType = std::decay_t<_TargetT>;
if constexpr (std::is_trivial_v<_TargetValueType>)
__target = std::forward<_SourceT>(__source);
else
::new (std::addressof(__target)) _TargetValueType(std::forward<_SourceT>(__source));
}
};

Expand Down
18 changes: 14 additions & 4 deletions include/oneapi/dpl/pstl/parallel_backend_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <utility>
#include <cassert>
#include "utils.h"
#include "memory_fwd.h"

namespace oneapi
{
Expand Down Expand Up @@ -203,13 +204,14 @@ __set_union_construct(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _Fo
return __cc_range(__first2, __last2, __result);
}

template <typename _ForwardIterator1, typename _ForwardIterator2, typename _OutputIterator, typename _Compare>
template <typename _ForwardIterator1, typename _ForwardIterator2, typename _OutputIterator, typename _Compare,
typename _CopyFunc, typename _CopyFromFirstSet>
_OutputIterator
__set_intersection_construct(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2,
_ForwardIterator2 __last2, _OutputIterator __result, _Compare __comp)
_ForwardIterator2 __last2, _OutputIterator __result, _Compare __comp, _CopyFunc _copy,
_CopyFromFirstSet)
{
using _Tp = typename ::std::iterator_traits<_OutputIterator>::value_type;

for (; __first1 != __last1 && __first2 != __last2;)
{
if (__comp(*__first1, *__first2))
Expand All @@ -218,7 +220,15 @@ __set_intersection_construct(_ForwardIterator1 __first1, _ForwardIterator1 __las
{
if (!__comp(*__first2, *__first1))
{
::new (::std::addressof(*__result)) _Tp(*__first1);

if constexpr (_CopyFromFirstSet::value)
{
_copy(*__first1, *__result);
}
else
{
_copy(*__first2, *__result);
}
++__result;
++__first1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ struct test_set_union
}
};

// Compare the first of a tuple using the supplied comparator
template <typename _Comp>
struct comp_select_first
{
_Comp comp;
comp_select_first(_Comp __comp) : comp(__comp) {}
template <typename _T1, typename _T2>
bool
operator()(_T1&& t1, _T2&& t2) const
{
return comp(std::get<0>(std::forward<_T1>(t1)), std::get<0>(std::forward<_T2>(t2)));
}
};

template <typename Type>
struct test_set_intersection
{
Expand All @@ -135,6 +149,36 @@ struct test_set_intersection

EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_intersection");
EXPECT_EQ_N(expect.begin(), out.begin(), ::std::distance(out.begin(), res), "wrong set_intersection effect");

if constexpr (TestUtils::is_base_of_iterator_category<std::random_access_iterator_tag,
InputIterator1>::value &&
TestUtils::is_base_of_iterator_category<std::random_access_iterator_tag, InputIterator2>::value)
{
// Check that set_intersection always copies from the first list to result.
// Will fail to compile if the second range is used to copy to the output.
// Comparator is designed to only compare the first element of a zip iterator.

auto zip_first1 = oneapi::dpl::make_zip_iterator(first1, oneapi::dpl::counting_iterator<int>(0));
auto zip_last1 = oneapi::dpl::make_zip_iterator(
last1, oneapi::dpl::counting_iterator<int>(std::distance(first1, last1)));

// Second value should be ignored and discarded in range 2 because the result should be copied from range 1
auto zip_first2 = oneapi::dpl::make_zip_iterator(first2, oneapi::dpl::discard_iterator());
auto zip_last2 = oneapi::dpl::make_zip_iterator(last2, oneapi::dpl::discard_iterator());

Sequence<int> expect_ints(std::distance(first1, last1) + std::distance(first2, last2));
Sequence<int> out_ints(std::distance(first1, last1) + std::distance(first2, last2));

auto zip_expect = oneapi::dpl::make_zip_iterator(sequences.first.begin(), expect_ints.begin());
auto zip_out = oneapi::dpl::make_zip_iterator(sequences.second.begin(), out_ints.begin());

auto zip_expect_res = std::set_intersection(zip_first1, zip_last1, zip_first2, zip_last2, zip_expect,
comp_select_first(comp));
auto zip_res = std::set_intersection(exec, zip_first1, zip_last1, zip_first2, zip_last2, zip_out,
comp_select_first(comp));
EXPECT_TRUE(zip_expect_res - zip_expect == zip_res - zip_out, "wrong result for zipped set_intersection");
EXPECT_EQ_N(zip_expect, zip_out, std::distance(zip_out, zip_res), "wrong zipped set_intersection effect");
}
}

template <typename Policy, typename InputIterator1, typename InputIterator2>
Expand Down

0 comments on commit ecb7cd5

Please sign in to comment.