Skip to content

Commit

Permalink
Improvements to onedpl_test_sort_by_key (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
danhoeflinger authored Aug 26, 2024
1 parent c35096d commit e052e64
Showing 1 changed file with 18 additions and 35 deletions.
53 changes: 18 additions & 35 deletions help_function/src/onedpl_test_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ int main() {
// #14 SORT BY KEY TEST //

{
sycl::queue myQueue;
const int N = 6;
sycl::buffer<int, 1> keys_buf{ sycl::range<1>(N) };
sycl::buffer<int, 1> values_buf{ sycl::range<1>(N) };
Expand All @@ -116,21 +117,21 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();

keys[0] = 1; keys[1] = 4; keys[2] = 2; keys[3] = 8; keys[4] = 5; keys[5] = 7;
values[0] = 'a'; values[1] = 'b'; values[2] = 'c'; values[3] = 'd'; values[4] = 'e';values[5] = 'f';
}

// call algorithm:
dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it + N, values_it);
dpct::sort(oneapi::dpl::execution::make_device_policy<class kernel1>(myQueue), keys_it, keys_it + N, values_it);

// keys is now { 1, 2, 4, 5, 7, 8}
// values is now {'a', 'c', 'b', 'e', 'f', 'd'}
{
test_name = "Regular call to sort";
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().get_host_access();
num_failing += ASSERT_EQUAL(test_name, values[0], 'a');
num_failing += ASSERT_EQUAL(test_name, values[1], 'c');
num_failing += ASSERT_EQUAL(test_name, values[2], 'b');
Expand Down Expand Up @@ -170,7 +171,7 @@ int main() {
auto keys_end = dpct::device_pointer<int>(keysArray + 10);
auto values_begin = dpct::device_pointer<int>(valuesArray);
// call algorithm
dpct::sort(oneapi::dpl::execution::make_device_policy<>(myQueue), keys_begin, keys_end, values_begin);
dpct::sort(oneapi::dpl::execution::make_device_policy<class kernel2>(myQueue), keys_begin, keys_end, values_begin);
}

// copy back
Expand Down Expand Up @@ -201,40 +202,22 @@ int main() {

{
// Test Two, test calls to dpct::sort using device vectors
dpct::device_vector<int> keys_vec(10);
dpct::device_vector<int> values_vec(10);

std::vector<int> keys_data{4, 8, 5, 3, 0, 9, 7, 2, 1, 6};
std::vector<int> values_data{13, 16, 17, 11, 19, 14, 12, 18, 10, 15};

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_vec.data(), keys_data.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_vec.data(), values_data.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();
dpct::device_vector<int> keys_vec(keys_data);
dpct::device_vector<int> values_vec(values_data);

auto keys_it = keys_vec.begin();
auto keys_it_end = keys_vec.end();
auto values_it = values_vec.begin();
{
// call algorithm
dpct::sort(oneapi::dpl::execution::make_device_policy<>(dpct::get_default_queue()), keys_it, keys_it_end, values_it);
dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it_end, values_it);
// keys is now = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
// values is now = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14}
}

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_data.data(), keys_vec.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_data.data(), values_vec.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();

{
int check_keys[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int check_values[10] = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14};
Expand All @@ -243,8 +226,8 @@ int main() {
// check that values and keys are correct

for (int i = 0; i != 10; ++i) {
num_failing += ASSERT_EQUAL(test_name, values_data[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_data[i], check_keys[i]);
num_failing += ASSERT_EQUAL(test_name, values_vec[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_vec[i], check_keys[i]);
}

failed_tests += test_passed(num_failing, test_name);
Expand All @@ -264,8 +247,8 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
keys[0] = 1; keys[1] = 4; keys[2] = 2; keys[3] = 8; keys[4] = 5; keys[5] = 7;
values[0] = 'a'; values[1] = 'b'; values[2] = 'c'; values[3] = 'd'; values[4] = 'e';values[5] = 'f';
}
Expand All @@ -277,7 +260,7 @@ int main() {
// values is now {'a', 'c', 'b', 'e', 'f', 'd'}
{
test_name = "Regular call to stable_sort";
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().get_host_access();

num_failing += ASSERT_EQUAL(test_name, values[0], 'a');
num_failing += ASSERT_EQUAL(test_name, values[1], 'c');
Expand Down Expand Up @@ -357,8 +340,8 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
// keys = {8, 3, 0, 2, 6, 5, 1, 8, 9, 10, 7, 4, 5, 2, 2, 10}
keys[0] = 8; keys[1] = 3; keys[2] = 0; keys[3] = 2; keys[4] = 6; keys[5] = 5;
keys[6] = 1; keys[7] = 8; keys[8] = 9; keys[9] = 10; keys[10] = 7; keys[11] = 4;
Expand All @@ -375,8 +358,8 @@ int main() {
// keys is now = {0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 10}
// values is now = {'k', 'n', 'g', 'j', 'l', 'm', 'p', 'c', 'o', 'd', 'h', 'b', 'f', 'i', 'e', 'a'}
{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
int check_values[16] = {'k', 'n', 'g', 'j', 'l', 'm', 'p', 'c', 'o', 'd', 'h', 'b', 'f', 'i', 'e', 'a'};
int check_keys[16] = {0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 10};
// check that values and keys are correct
Expand Down

0 comments on commit e052e64

Please sign in to comment.