Skip to content

Commit

Permalink
Better selfplay scaling
Browse files Browse the repository at this point in the history
Bench: 2266395
  • Loading branch information
SzilBalazs committed Aug 9, 2023
1 parent b97c4ae commit 8c28180
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
31 changes: 19 additions & 12 deletions src/selfplay/selfplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace selfplay {
constexpr unsigned int PROGRESS_BAR_WIDTH = 25;
constexpr unsigned int BLOCK_SIZE = 100000;

std::atomic<uint64_t> game_count, position_count;
std::vector<uint64_t> game_count_vec, position_count_vec;

std::optional<GameResult> get_game_result(const core::Board &board) {

Expand All @@ -52,7 +52,7 @@ namespace selfplay {
return std::nullopt;
}

void run_game(Engine &engine, const search::Limits &limits, const std::string &starting_fen, std::vector<DataEntry> &entries, unsigned int hash_size = DEFAULT_HASH_SIZE, const unsigned int thread_count = DEFAULT_THREAD_COUNT) {
void run_game(Engine &engine, const search::Limits &limits, const std::string &starting_fen, std::vector<DataEntry> &entries, size_t thread_id, unsigned int hash_size = DEFAULT_HASH_SIZE, const unsigned int thread_count = DEFAULT_THREAD_COUNT) {
engine.init(hash_size, thread_count);
core::Board board;
board.load(starting_fen);
Expand All @@ -65,7 +65,7 @@ namespace selfplay {

if (!board.is_check() && move.is_quiet() && std::abs(eval) < WORST_MATE) {
tmp.emplace_back(board.get_fen(), ply, move, eval, std::nullopt);
position_count++;
position_count_vec[thread_id]++;
}

board.make_move(move);
Expand All @@ -80,7 +80,7 @@ namespace selfplay {
}
}

void gen_games(const search::Limits &limits, const std::vector<std::string> &starting_fens, const std::string &output_path) {
void gen_games(const search::Limits &limits, const std::vector<std::string> &starting_fens, const std::string &output_path, size_t thread_id) {

Engine engine;

Expand All @@ -90,8 +90,8 @@ namespace selfplay {

std::vector<DataEntry> entries;
for (const std::string &fen : starting_fens) {
run_game(engine, limits, fen, entries);
game_count++;
run_game(engine, limits, fen, entries, thread_id);
game_count_vec[thread_id]++;

if (entries.size() >= BLOCK_SIZE) {
std::shuffle(entries.begin(), entries.end(), g);
Expand Down Expand Up @@ -144,6 +144,9 @@ namespace selfplay {
}

std::string get_run_name(const search::Limits &limits, const std::string &id) {

uint64_t position_count = std::reduce(position_count_vec.begin(), position_count_vec.end());

std::stringstream ss;
ss << id << "_" << limits.to_string() << "_" << (position_count / 1000) << "k";

Expand All @@ -165,11 +168,15 @@ namespace selfplay {
void print_progress(const uint64_t games_to_play) {

const int64_t start_time = now();
uint64_t game_count = 0, position_count = 0;

while (game_count != games_to_play) {

std::this_thread::sleep_for(std::chrono::seconds(1));

game_count = std::reduce(game_count_vec.begin(), game_count_vec.end());
position_count = std::reduce(position_count_vec.begin(), position_count_vec.end());

const int64_t current_time = now();
const int64_t elapsed_time = current_time - start_time + 1;

Expand Down Expand Up @@ -199,10 +206,10 @@ namespace selfplay {
std::cout << std::endl;
}

void start_generation(const search::Limits &limits, const uint64_t games_to_play, const unsigned int thread_count) {
void start_generation(const search::Limits &limits, uint64_t games_to_play, size_t thread_count) {

game_count = 0;
position_count = 0;
game_count_vec.assign(thread_count, 0);
position_count_vec.assign(thread_count, 0);

const std::string run_id = rng::gen_id();
const std::string directory_path = "selfplay/" + run_id;
Expand All @@ -216,12 +223,12 @@ namespace selfplay {

populate_starting_fens(games_to_play, starting_fens);

for (unsigned int id = 0; id < thread_count; id++) {
for (size_t id = 0; id < thread_count; id++) {
std::vector<std::string> workload;
for (unsigned int i = id; i < starting_fens.size(); i += thread_count) {
for (size_t i = id; i < starting_fens.size(); i += thread_count) {
workload.emplace_back(starting_fens[i]);
}
workers.emplace_back(gen_games, limits, workload, directory_path + "/" + std::to_string(id) + ".plain");
workers.emplace_back(gen_games, limits, workload, directory_path + "/" + std::to_string(id) + ".plain", id);
}

print_progress(games_to_play);
Expand Down
2 changes: 1 addition & 1 deletion src/uci/uci.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ namespace uci {
search::Limits limits;
limits.max_nodes = find_element<int64_t>(tokens, "nodes");
limits.depth = find_element<int64_t>(tokens, "depth");
std::optional<int> thread_count = find_element<int>(tokens, "threads");
std::optional<size_t> thread_count = find_element<size_t>(tokens, "threads");
std::optional<int> games_to_play = find_element<int>(tokens, "games");
selfplay::start_generation(limits, games_to_play.value_or(100'000), thread_count.value_or(1));
});
Expand Down

0 comments on commit 8c28180

Please sign in to comment.