Skip to content

Commit

Permalink
feat: add clear_tasks() (#69)
Browse files Browse the repository at this point in the history
* feat: add `clear_tasks()`

* test: add tests for `thread_safe_queue::clear()`

* test: add tests for `thread_pool::clear_tasks()`
  • Loading branch information
samangh authored Aug 15, 2024
1 parent 8bd7661 commit 9b0557d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
16 changes: 16 additions & 0 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,22 @@ namespace dp {
}
}

/**
* @brief Makes best-case attempt to clear all tasks from the thread_pool
* @details Note that this does not guarantee that all tasks will be cleared, as currently
* running tasks could add additional tasks. Also a thread could steal a task from another
* in the middle of this.
* @return number of tasks cleared
*/
size_t clear_tasks() {
size_t removed_task_count{0};
for (auto &task_list : tasks_) removed_task_count += task_list.tasks.clear();
in_flight_tasks_.fetch_sub(removed_task_count, std::memory_order_release);
unassigned_tasks_.fetch_sub(removed_task_count, std::memory_order_release);

return removed_task_count;
}

private:
template <typename Function>
void enqueue_task(Function &&f) {
Expand Down
8 changes: 8 additions & 0 deletions include/thread_pool/thread_safe_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ namespace dp {
return data_.empty();
}

size_type clear() {
std::scoped_lock lock(mutex_);
auto size = data_.size();
data_.clear();

return size;
}

[[nodiscard]] std::optional<T> pop_front() {
std::scoped_lock lock(mutex_);
if (data_.empty()) return std::nullopt;
Expand Down
85 changes: 85 additions & 0 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

#include <algorithm>
#include <array>
#include <barrier>
#include <iostream>
#include <numeric>
#include <random>
#include <shared_mutex>
#include <string>
#include <thread>

Expand Down Expand Up @@ -560,3 +562,86 @@ TEST_CASE("Initialization function is called") {
}
CHECK_EQ(counter.load(), 4);
}

TEST_CASE("Check clear_tasks() can be called from a task") {
// Here:
// - we use a barrier to trigger tasks_clear() once all threads are busy;
// - to prevent race conditions (e.g. task_clear() getting called whilst we are still adding
// tasks), we use a mutex to prevent the tasks from running, until all tasks have been added
// to the pool.

unsigned int thread_count = 0;

SUBCASE("with single thread") { thread_count = 1; }
SUBCASE("with multiple threads") { thread_count = 4; }

std::atomic<unsigned int> counter = 0;
dp::thread_pool pool(thread_count);
std::shared_mutex mutex;

{
/* Clear thread_pool when barrier is hit, this must not throw */
auto clear_func = [&pool]() noexcept {
try {
pool.clear_tasks();
} catch (...) {
}
};
std::barrier sync_point(thread_count, clear_func);

auto func = [&counter, &sync_point, &mutex]() {
std::shared_lock lock(mutex);
counter.fetch_add(1);
sync_point.arrive_and_wait();
};

{
std::unique_lock lock(mutex);
for (int i = 0; i < 10; i++) pool.enqueue_detach(func);
}

pool.wait_for_tasks();
}

CHECK_EQ(counter.load(), thread_count);
}

TEST_CASE("Check clear_tasks() clears tasks") {
// Here we:
// - add twice as many tasks to the pool as can be run simultaniously
// - use a lock to prevent race conditions (e.g. clear_task() running whilst the another task is
// being added)

unsigned int thread_count{4};
size_t cleared_tasks{0};
std::atomic<unsigned int> counter{0};

SUBCASE("with no thread") { thread_count = 0; }
SUBCASE("with single thread") { thread_count = 1; }
SUBCASE("with multiple threads") { thread_count = 4; }

{
std::mutex mutex;
dp::thread_pool pool(thread_count);

std::function<void(void)> func;
func = [&counter, &mutex]() {
counter.fetch_add(1);
std::lock_guard lock(mutex);
};

{
/* fill the thread_pool twice over, and wait until all threads running and locked in a
* task */
std::lock_guard lock(mutex);
for (unsigned int i = 0; i < 2 * thread_count; i++) pool.enqueue_detach(func);

while (counter != thread_count)
std::this_thread::sleep_for(std::chrono::milliseconds(100));

cleared_tasks = pool.clear_tasks();
}
}
CHECK_EQ(cleared_tasks, static_cast<size_t>(thread_count));
CHECK_EQ(thread_count, counter.load());
}
35 changes: 35 additions & 0 deletions test/source/thread_safe_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,38 @@ TEST_CASE("Ensure insert and pop works with thread contention") {
CHECK_NE(res2, res3);
CHECK_NE(res3, res1);
}

TEST_CASE("Ensure clear() works and returns correct count") {
// create a synchronization barrier to ensure our threads have started before executing code to
// clear the queue

// here, we check that:
// - the queue is cleared
// - that clear() return the correct number

std::barrier barrier(3);
std::atomic<size_t> removed_count{0};

dp::thread_safe_queue<int> queue;
{
std::jthread t1([&queue, &barrier, &removed_count] {
queue.push_front(1);
barrier.arrive_and_wait();
removed_count = queue.clear();
barrier.arrive_and_wait();
});
std::jthread t2([&queue, &barrier] {
queue.push_front(2);
barrier.arrive_and_wait();
barrier.arrive_and_wait();
});
std::jthread t3([&queue, &barrier] {
queue.push_front(3);
barrier.arrive_and_wait();
barrier.arrive_and_wait();
});
}

CHECK(queue.empty());
CHECK_EQ(removed_count, 3);
}

0 comments on commit 9b0557d

Please sign in to comment.