diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index dc16faa..f931eed 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -5,9 +5,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -560,3 +562,90 @@ TEST_CASE("Initialization function is called") { } CHECK_EQ(counter.load(), 4); } + +TEST_CASE("Check clear_tasks() can be called from a task") { + // Here: + // - we use a barrier to trigger tasks_clear() once all threads are busy; + // - to prevent race conditions (e.g. task_clear() getting called whilst we are still adding + // tasks), we use a mutex to prevent the tasks from running, until all tasks have been added + // to the pool. + + unsigned int thread_count = 0; + + SUBCASE("with single thread") { thread_count = 1; } + SUBCASE("with multiple threads") { thread_count = 4; } + + std::atomic counter = 0; + dp::thread_pool pool(thread_count); + std::shared_mutex mutex; + + { + /* Clear thread_pool when barrier is hit, this must not throw */ + auto clear_func = [&pool]() noexcept { + try { + pool.clear_tasks(); + } catch (...) { + } + }; + std::barrier sync_point(thread_count, clear_func); + + auto func = [&counter, &sync_point, &mutex]() { + std::shared_lock lock(mutex); + counter.fetch_add(1); + sync_point.arrive_and_wait(); + }; + + { + std::unique_lock lock(mutex); + for (int i = 0; i < 10; i++) pool.enqueue_detach(func); + } + + pool.wait_for_tasks(); + } + + CHECK_EQ(counter.load(), thread_count); +} + +TEST_CASE("Check clear_tasks() clears tasks") { + // Here we: + // - add twice as many tasks to the pool as can be run simultaniously + // - use a lock to prevent race conditions (e.g. clear_task() running whilst the another task is + // being added) + + unsigned int thread_count{4}; + size_t cleared_tasks {0}; + std::atomic counter{0}; + + SUBCASE("with no thread") { thread_count = 0; } + SUBCASE("with single thread") { thread_count = 1; } + SUBCASE("with multiple threads") { thread_count = 4; } + + { + std::mutex mutex; + dp::thread_pool pool(thread_count); + + std::function func; + func = [&counter, &mutex]() { + counter.fetch_add(1); + std::lock_guard lock(mutex); + }; + + { + /* fill the thread_pool twice over, and wait until all threads running and locked in a + * task */ + std::lock_guard lock(mutex); + for (unsigned int i = 0; i < 2 * thread_count; i++) pool.enqueue_detach(func); + + while (counter != thread_count) + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + cleared_tasks = pool.clear_tasks(); + } + } + CHECK_EQ(cleared_tasks, static_cast(thread_count)); + CHECK_EQ(thread_count, counter.load()); +} + + + +