diff --git a/include/oneapi/dpl/pstl/algorithm_impl.h b/include/oneapi/dpl/pstl/algorithm_impl.h index 7f9db008b45..b54b8796641 100644 --- a/include/oneapi/dpl/pstl/algorithm_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_impl.h @@ -3549,8 +3549,10 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _ __result, __comp, [](_DifferenceType __n, _DifferenceType __m) { return ::std::min(__n, __m); }, [](_RandomAccessIterator1 __first1, _RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2, _T* __result, _Compare __comp) { - return oneapi::dpl::__utils::__set_intersection_construct(__first1, __last1, __first2, __last2, - __result, __comp); + return oneapi::dpl::__utils::__set_intersection_construct( + __first1, __last1, __first2, __last2, __result, __comp, + oneapi::dpl::__internal::__op_uninitialized_copy<_ExecutionPolicy>{}, + /*CopyFromFirstSet = */ std::true_type{}); }); }); } @@ -3565,8 +3567,10 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _ __result, __comp, [](_DifferenceType __n, _DifferenceType __m) { return ::std::min(__n, __m); }, [](_RandomAccessIterator1 __first1, _RandomAccessIterator1 __last1, _RandomAccessIterator2 __first2, _RandomAccessIterator2 __last2, _T* __result, _Compare __comp) { - return oneapi::dpl::__utils::__set_intersection_construct(__first2, __last2, __first1, __last1, - __result, __comp); + return oneapi::dpl::__utils::__set_intersection_construct( + __first2, __last2, __first1, __last1, __result, __comp, + oneapi::dpl::__internal::__op_uninitialized_copy<_ExecutionPolicy>{}, + /*CopyFromFirstSet = */ std::false_type{}); }); return __result; }); diff --git a/include/oneapi/dpl/pstl/memory_impl.h b/include/oneapi/dpl/pstl/memory_impl.h index 8555c541399..3a01f298da0 100644 --- a/include/oneapi/dpl/pstl/memory_impl.h +++ b/include/oneapi/dpl/pstl/memory_impl.h @@ -118,9 +118,11 @@ struct __op_uninitialized_copy<_ExecutionPolicy> void operator()(_SourceT&& __source, _TargetT& __target) const { - using _TargetValueType = ::std::decay_t<_TargetT>; - - ::new (::std::addressof(__target)) _TargetValueType(::std::forward<_SourceT>(__source)); + using _TargetValueType = std::decay_t<_TargetT>; + if constexpr (std::is_trivial_v<_TargetValueType>) + __target = std::forward<_SourceT>(__source); + else + ::new (std::addressof(__target)) _TargetValueType(std::forward<_SourceT>(__source)); } }; diff --git a/include/oneapi/dpl/pstl/parallel_backend_utils.h b/include/oneapi/dpl/pstl/parallel_backend_utils.h index 502c5b92bdd..9a0e8fc0fef 100644 --- a/include/oneapi/dpl/pstl/parallel_backend_utils.h +++ b/include/oneapi/dpl/pstl/parallel_backend_utils.h @@ -20,6 +20,7 @@ #include #include #include "utils.h" +#include "memory_fwd.h" namespace oneapi { @@ -203,13 +204,14 @@ __set_union_construct(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _Fo return __cc_range(__first2, __last2, __result); } -template +template _OutputIterator __set_intersection_construct(_ForwardIterator1 __first1, _ForwardIterator1 __last1, _ForwardIterator2 __first2, - _ForwardIterator2 __last2, _OutputIterator __result, _Compare __comp) + _ForwardIterator2 __last2, _OutputIterator __result, _Compare __comp, _CopyFunc _copy, + _CopyFromFirstSet) { using _Tp = typename ::std::iterator_traits<_OutputIterator>::value_type; - for (; __first1 != __last1 && __first2 != __last2;) { if (__comp(*__first1, *__first2)) @@ -218,7 +220,15 @@ __set_intersection_construct(_ForwardIterator1 __first1, _ForwardIterator1 __las { if (!__comp(*__first2, *__first1)) { - ::new (::std::addressof(*__result)) _Tp(*__first1); + + if constexpr (_CopyFromFirstSet::value) + { + _copy(*__first1, *__result); + } + else + { + _copy(*__first2, *__result); + } ++__result; ++__first1; } diff --git a/test/parallel_api/algorithm/alg.sorting/alg.set.operations/set_common.h b/test/parallel_api/algorithm/alg.sorting/alg.set.operations/set_common.h index 35bea647e3b..fd4f9afc8be 100644 --- a/test/parallel_api/algorithm/alg.sorting/alg.set.operations/set_common.h +++ b/test/parallel_api/algorithm/alg.sorting/alg.set.operations/set_common.h @@ -119,6 +119,20 @@ struct test_set_union } }; +// Compare the first of a tuple using the supplied comparator +template +struct comp_select_first +{ + _Comp comp; + comp_select_first(_Comp __comp) : comp(__comp) {} + template + bool + operator()(_T1&& t1, _T2&& t2) const + { + return comp(std::get<0>(std::forward<_T1>(t1)), std::get<0>(std::forward<_T2>(t2))); + } +}; + template struct test_set_intersection { @@ -135,6 +149,36 @@ struct test_set_intersection EXPECT_TRUE(expect_res - expect.begin() == res - out.begin(), "wrong result for set_intersection"); EXPECT_EQ_N(expect.begin(), out.begin(), ::std::distance(out.begin(), res), "wrong set_intersection effect"); + + if constexpr (TestUtils::is_base_of_iterator_category::value && + TestUtils::is_base_of_iterator_category::value) + { + // Check that set_intersection always copies from the first list to result. + // Will fail to compile if the second range is used to copy to the output. + // Comparator is designed to only compare the first element of a zip iterator. + + auto zip_first1 = oneapi::dpl::make_zip_iterator(first1, oneapi::dpl::counting_iterator(0)); + auto zip_last1 = oneapi::dpl::make_zip_iterator( + last1, oneapi::dpl::counting_iterator(std::distance(first1, last1))); + + // Second value should be ignored and discarded in range 2 because the result should be copied from range 1 + auto zip_first2 = oneapi::dpl::make_zip_iterator(first2, oneapi::dpl::discard_iterator()); + auto zip_last2 = oneapi::dpl::make_zip_iterator(last2, oneapi::dpl::discard_iterator()); + + Sequence expect_ints(std::distance(first1, last1) + std::distance(first2, last2)); + Sequence out_ints(std::distance(first1, last1) + std::distance(first2, last2)); + + auto zip_expect = oneapi::dpl::make_zip_iterator(sequences.first.begin(), expect_ints.begin()); + auto zip_out = oneapi::dpl::make_zip_iterator(sequences.second.begin(), out_ints.begin()); + + auto zip_expect_res = std::set_intersection(zip_first1, zip_last1, zip_first2, zip_last2, zip_expect, + comp_select_first(comp)); + auto zip_res = std::set_intersection(exec, zip_first1, zip_last1, zip_first2, zip_last2, zip_out, + comp_select_first(comp)); + EXPECT_TRUE(zip_expect_res - zip_expect == zip_res - zip_out, "wrong result for zipped set_intersection"); + EXPECT_EQ_N(zip_expect, zip_out, std::distance(zip_out, zip_res), "wrong zipped set_intersection effect"); + } } template