diff --git a/src/network/train.h b/src/network/train.h index f90c811..bacd5fc 100644 --- a/src/network/train.h +++ b/src/network/train.h @@ -33,7 +33,8 @@ namespace nn { class Trainer { public: - Trainer(const std::string &training_data, const std::optional &network_path, float learning_rate, int epochs, int batch_size, int thread_count) : adam(learning_rate), parser(training_data), batch_size(batch_size), thread_count(thread_count) { + Trainer(const std::string &training_data, const std::string &validation_data, const std::optional &network_path, float learning_rate, int epochs, int batch_size, int thread_count) : + adam(learning_rate), training_parser(training_data), validation_parser(validation_data), batch_size(batch_size), thread_count(thread_count) { if (!std::filesystem::exists("networks")) { std::filesystem::create_directory("networks"); @@ -53,7 +54,7 @@ namespace nn { entries_next = new TrainingEntry[batch_size]; bool _; - parser.read_batch(batch_size, entries_next, _); + training_parser.read_batch(batch_size, entries_next, _); int64_t iter = 0; for (int epoch = 1; epoch <= epochs; epoch++) { @@ -76,11 +77,11 @@ namespace nn { errors.assign(thread_count, 0.0f); accuracy.assign(thread_count, 0); - std::thread th_loading = std::thread(&DataParser::read_batch, &parser, batch_size, entries_next, std::ref(is_new_epoch)); + std::thread th_loading = std::thread(&DataParser::read_batch, &training_parser, batch_size, entries_next, std::ref(is_new_epoch)); std::vector ths; for (int id = 0; id < thread_count; id++) { - ths.emplace_back(&Trainer::process_batch, this, id); + ths.emplace_back(&Trainer::process_batch, this, id); } for (std::thread &th : ths) { @@ -99,6 +100,7 @@ namespace nn { if (iter % 20 == 0) { float average_error = checkpoint_error / float(batch_size * checkpoint_iter); float average_accuracy = float(checkpoint_accuracy) / float(batch_size * checkpoint_iter); + auto [val_loss, val_acc] = test_validation(); int64_t current_time = now(); int64_t elapsed_time = current_time - start_time; @@ -124,7 +126,7 @@ namespace nn { std::cout << "] - Epoch " << epoch << " - Iteration " << iter << " - Error " << average_error << " - ETA " << (eta / 1000) << "s - " << pos_per_s << " pos/s \r" << std::flush; - log_file << iter << " " << average_error << " " << pos_per_s << " " << average_accuracy << "\n"; + log_file << iter << " " << average_error << " " << pos_per_s << " " << average_accuracy << " " << val_loss << " " << val_acc << "\n"; log_file.flush(); checkpoint_error = 0.0f; checkpoint_accuracy = 0; @@ -142,7 +144,8 @@ namespace nn { private: Network network; Adam adam; - DataParser parser; + DataParser training_parser; + DataParser validation_parser; unsigned int entry_count; int batch_size, thread_count; std::vector gradients; @@ -150,6 +153,39 @@ namespace nn { std::vector accuracy; TrainingEntry *entries, *entries_next; + std::pair test_validation() { + bool _; + delete[] entries; + entries = new TrainingEntry[batch_size]; + validation_parser.read_batch(batch_size, entries, _); + + errors.assign(thread_count, 0.0f); + accuracy.assign(thread_count, 0); + + std::vector ths; + for (int id = 0; id < thread_count; id++) { + ths.emplace_back(&Trainer::process_batch, this, id); + } + + for (std::thread &th : ths) { + if (th.joinable()) th.join(); + } + + float val_loss = 0.0f; + int correct = 0.0f; + + for (int id = 0; id < thread_count; id++) { + val_loss += errors[id]; + correct += accuracy[id]; + } + + val_loss /= batch_size; + + float val_acc = float(correct) / float(batch_size); + return {val_loss, val_acc}; + } + + template void process_batch(int id) { Gradient &g = gradients[id]; @@ -168,10 +204,12 @@ namespace nn { errors[id] += error; accuracy[id] += ((entry.wdl - 0.5f) * (prediction - 0.5f) > 0.0f) || std::abs(entry.wdl - prediction) < 0.05f; - std::array l1_loss = {(1 - EVAL_INFLUENCE) * 2.0f * (prediction - entry.wdl) + EVAL_INFLUENCE * 2.0f * (prediction - entry.eval)}; + if constexpr (train) { + std::array l1_loss = {(1 - EVAL_INFLUENCE) * 2.0f * (prediction - entry.wdl) + EVAL_INFLUENCE * 2.0f * (prediction - entry.eval)}; - network.l1.backward(l1_loss, l0_output, l1_output, l0_loss, g.l1); - network.l0.backward(l0_loss, entry.features, l0_output, g.l0); + network.l1.backward(l1_loss, l0_output, l1_output, l0_loss, g.l1); + network.l0.backward(l0_loss, entry.features, l0_output, g.l0); + } } } diff --git a/src/uci/uci.h b/src/uci/uci.h index b7e9b18..4587a1b 100644 --- a/src/uci/uci.h +++ b/src/uci/uci.h @@ -22,6 +22,7 @@ #include "../selfplay/selfplay.h" #include "../network/train.h" #include "../tests/perft.h" +#include "../utils/split.h" #include "../utils/logger.h" #include "../utils/utilities.h" #include "command.h" @@ -99,17 +100,25 @@ namespace uci { std::optional dropout = find_element(tokens, "dropout"); selfplay::start_generation(limits, book.value_or("book.epd"), output.value_or("data.plain"), thread_count.value_or(1), dropout.value_or(1)); }); + commands.emplace_back("split", [&](context tokens){ + std::optional input_data = find_element(tokens, "input"); + std::optional output_data1 = find_element(tokens, "output1"); + std::optional output_data2 = find_element(tokens, "output2"); + std::optional rate = find_element(tokens, "rate"); + split_data(input_data.value_or("data.plain"), output_data1.value_or("train.plain"), output_data2.value_or("validation.plain"), rate.value_or(10)); + }); commands.emplace_back("train", [&](context tokens){ std::optional network_path = find_element(tokens, "network"); - std::optional training_data = find_element(tokens, "in"); + std::optional training_data = find_element(tokens, "training_data"); + std::optional validation_data = find_element(tokens, "validation_data"); std::optional learning_rate = find_element(tokens, "lr"); std::optional epochs = find_element(tokens, "epochs"); std::optional batch_size = find_element(tokens, "batch"); std::optional threads = find_element(tokens, "threads"); - nn::Trainer trainer(training_data.value_or("data.plain"), network_path, learning_rate.value_or(0.001f), + nn::Trainer trainer(training_data.value_or("train.plain"), validation_data.value_or("validation.plain"), network_path, learning_rate.value_or(0.001f), epochs.value_or(10), batch_size.value_or(16384), threads.value_or(4)); }); - commands.emplace_back("learn", [&](context tokens){ + /*commands.emplace_back("learn", [&](context tokens){ std::optional thread_count = find_element(tokens, "threads"); std::optional iterations = find_element(tokens, "iter"); std::optional nodes = find_element(tokens, "nodes"); @@ -130,7 +139,7 @@ namespace uci { std::this_thread::sleep_for(std::chrono::seconds(1)); nn::net = nn::QNetwork("corenet.bin"); } - }); + });*/ commands.emplace_back("perft", [&](context tokens) { int depth = find_element(tokens, "perft").value_or(5); U64 node_count = test::perft(board, depth); diff --git a/src/utils/split.h b/src/utils/split.h new file mode 100644 index 0000000..98e6055 --- /dev/null +++ b/src/utils/split.h @@ -0,0 +1,40 @@ +// WhiteCore is a C++ chess engine +// Copyright (c) 2023 Balázs Szilágyi +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . +// + +#pragma once + +#include + +inline void split_data(const std::string &input, const std::string &output1, const std::string &output2, int rate) { + std::ifstream in(input, std::ios::in); + std::ofstream out1(output1, std::ios::out | std::ios::app); + std::ofstream out2(output2, std::ios::out | std::ios::app); + + std::random_device rd; + std::mt19937 g(rd()); + std::uniform_int_distribution dist(0, rate); + + std::string line; + while (std::getline(in, line)) { + if (dist(g) == 0) out2 << line << "\n"; + else out1 << line << "\n"; + } + + in.close(); + out1.close(); + out2.close(); +} \ No newline at end of file diff --git a/train/trainer.py b/train/trainer.py index 14c2951..0d27bc8 100644 --- a/train/trainer.py +++ b/train/trainer.py @@ -27,8 +27,7 @@ config={ "learning_rate": 0.001, "architecture": 3, - "dataset": "data.plain", - "epochs": 10, + "epochs": 30, "batch_size": 16384, "thread_count": 4 } @@ -38,7 +37,7 @@ current_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") wandb.run.name = f"exp-{commit_hash}-{current_time}" -command_string = f"train in {wandb.run.config['dataset']} lr {wandb.run.config['learning_rate']} " \ +command_string = f"train lr {wandb.run.config['learning_rate']} " \ f"epochs {wandb.run.config['epochs']} batch {wandb.run.config['batch_size']} " \ f"threads {wandb.run.config['thread_count']}\n" @@ -64,6 +63,8 @@ wandb.run.summary["iterations"] = iteration wandb.log({"training loss": float(data[1]), "training accuracy": float(data[3]), + "validation loss": float(data[4]), + "validation accuracy": float(data[5]), "positions per second": int(data[2])}) f.close() except Exception as e: