From 9b0557d587dc23ccfd73189dda363cb5411a7e79 Mon Sep 17 00:00:00 2001 From: Saman Ghannadzadeh Date: Thu, 15 Aug 2024 14:37:34 +0100 Subject: [PATCH] feat: add `clear_tasks()` (#69) * feat: add `clear_tasks()` * test: add tests for `thread_safe_queue::clear()` * test: add tests for `thread_pool::clear_tasks()` --- include/thread_pool/thread_pool.h | 16 +++++ include/thread_pool/thread_safe_queue.h | 8 +++ test/source/thread_pool.cpp | 85 +++++++++++++++++++++++++ test/source/thread_safe_queue.cpp | 35 ++++++++++ 4 files changed, 144 insertions(+) diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 48b00e6..14f841e 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -241,6 +241,22 @@ namespace dp { } } + /** + * @brief Makes best-case attempt to clear all tasks from the thread_pool + * @details Note that this does not guarantee that all tasks will be cleared, as currently + * running tasks could add additional tasks. Also a thread could steal a task from another + * in the middle of this. + * @return number of tasks cleared + */ + size_t clear_tasks() { + size_t removed_task_count{0}; + for (auto &task_list : tasks_) removed_task_count += task_list.tasks.clear(); + in_flight_tasks_.fetch_sub(removed_task_count, std::memory_order_release); + unassigned_tasks_.fetch_sub(removed_task_count, std::memory_order_release); + + return removed_task_count; + } + private: template void enqueue_task(Function &&f) { diff --git a/include/thread_pool/thread_safe_queue.h b/include/thread_pool/thread_safe_queue.h index db02fa1..9057452 100644 --- a/include/thread_pool/thread_safe_queue.h +++ b/include/thread_pool/thread_safe_queue.h @@ -44,6 +44,14 @@ namespace dp { return data_.empty(); } + size_type clear() { + std::scoped_lock lock(mutex_); + auto size = data_.size(); + data_.clear(); + + return size; + } + [[nodiscard]] std::optional pop_front() { std::scoped_lock lock(mutex_); if (data_.empty()) return std::nullopt; diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index dc16faa..1cffabb 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -5,9 +5,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -560,3 +562,86 @@ TEST_CASE("Initialization function is called") { } CHECK_EQ(counter.load(), 4); } + +TEST_CASE("Check clear_tasks() can be called from a task") { + // Here: + // - we use a barrier to trigger tasks_clear() once all threads are busy; + // - to prevent race conditions (e.g. task_clear() getting called whilst we are still adding + // tasks), we use a mutex to prevent the tasks from running, until all tasks have been added + // to the pool. + + unsigned int thread_count = 0; + + SUBCASE("with single thread") { thread_count = 1; } + SUBCASE("with multiple threads") { thread_count = 4; } + + std::atomic counter = 0; + dp::thread_pool pool(thread_count); + std::shared_mutex mutex; + + { + /* Clear thread_pool when barrier is hit, this must not throw */ + auto clear_func = [&pool]() noexcept { + try { + pool.clear_tasks(); + } catch (...) { + } + }; + std::barrier sync_point(thread_count, clear_func); + + auto func = [&counter, &sync_point, &mutex]() { + std::shared_lock lock(mutex); + counter.fetch_add(1); + sync_point.arrive_and_wait(); + }; + + { + std::unique_lock lock(mutex); + for (int i = 0; i < 10; i++) pool.enqueue_detach(func); + } + + pool.wait_for_tasks(); + } + + CHECK_EQ(counter.load(), thread_count); +} + +TEST_CASE("Check clear_tasks() clears tasks") { + // Here we: + // - add twice as many tasks to the pool as can be run simultaniously + // - use a lock to prevent race conditions (e.g. clear_task() running whilst the another task is + // being added) + + unsigned int thread_count{4}; + size_t cleared_tasks{0}; + std::atomic counter{0}; + + SUBCASE("with no thread") { thread_count = 0; } + SUBCASE("with single thread") { thread_count = 1; } + SUBCASE("with multiple threads") { thread_count = 4; } + + { + std::mutex mutex; + dp::thread_pool pool(thread_count); + + std::function func; + func = [&counter, &mutex]() { + counter.fetch_add(1); + std::lock_guard lock(mutex); + }; + + { + /* fill the thread_pool twice over, and wait until all threads running and locked in a + * task */ + std::lock_guard lock(mutex); + for (unsigned int i = 0; i < 2 * thread_count; i++) pool.enqueue_detach(func); + + while (counter != thread_count) + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + cleared_tasks = pool.clear_tasks(); + } + } + CHECK_EQ(cleared_tasks, static_cast(thread_count)); + CHECK_EQ(thread_count, counter.load()); +} diff --git a/test/source/thread_safe_queue.cpp b/test/source/thread_safe_queue.cpp index a67fed2..07270b0 100644 --- a/test/source/thread_safe_queue.cpp +++ b/test/source/thread_safe_queue.cpp @@ -51,3 +51,38 @@ TEST_CASE("Ensure insert and pop works with thread contention") { CHECK_NE(res2, res3); CHECK_NE(res3, res1); } + +TEST_CASE("Ensure clear() works and returns correct count") { + // create a synchronization barrier to ensure our threads have started before executing code to + // clear the queue + + // here, we check that: + // - the queue is cleared + // - that clear() return the correct number + + std::barrier barrier(3); + std::atomic removed_count{0}; + + dp::thread_safe_queue queue; + { + std::jthread t1([&queue, &barrier, &removed_count] { + queue.push_front(1); + barrier.arrive_and_wait(); + removed_count = queue.clear(); + barrier.arrive_and_wait(); + }); + std::jthread t2([&queue, &barrier] { + queue.push_front(2); + barrier.arrive_and_wait(); + barrier.arrive_and_wait(); + }); + std::jthread t3([&queue, &barrier] { + queue.push_front(3); + barrier.arrive_and_wait(); + barrier.arrive_and_wait(); + }); + } + + CHECK(queue.empty()); + CHECK_EQ(removed_count, 3); +}