Skip to content

Commit

Permalink
Added validation data
Browse files Browse the repository at this point in the history
Bench: 3126270
  • Loading branch information
SzilBalazs committed Jul 15, 2023
1 parent 068ea45 commit de132cf
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 16 deletions.
56 changes: 47 additions & 9 deletions src/network/train.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace nn {
class Trainer {
public:

Trainer(const std::string &training_data, const std::optional<std::string> &network_path, float learning_rate, int epochs, int batch_size, int thread_count) : adam(learning_rate), parser(training_data), batch_size(batch_size), thread_count(thread_count) {
Trainer(const std::string &training_data, const std::string &validation_data, const std::optional<std::string> &network_path, float learning_rate, int epochs, int batch_size, int thread_count) :
adam(learning_rate), training_parser(training_data), validation_parser(validation_data), batch_size(batch_size), thread_count(thread_count) {

if (!std::filesystem::exists("networks")) {
std::filesystem::create_directory("networks");
Expand All @@ -53,7 +54,7 @@ namespace nn {
entries_next = new TrainingEntry[batch_size];

bool _;
parser.read_batch(batch_size, entries_next, _);
training_parser.read_batch(batch_size, entries_next, _);

int64_t iter = 0;
for (int epoch = 1; epoch <= epochs; epoch++) {
Expand All @@ -76,11 +77,11 @@ namespace nn {
errors.assign(thread_count, 0.0f);
accuracy.assign(thread_count, 0);

std::thread th_loading = std::thread(&DataParser::read_batch, &parser, batch_size, entries_next, std::ref(is_new_epoch));
std::thread th_loading = std::thread(&DataParser::read_batch, &training_parser, batch_size, entries_next, std::ref(is_new_epoch));

std::vector<std::thread> ths;
for (int id = 0; id < thread_count; id++) {
ths.emplace_back(&Trainer::process_batch, this, id);
ths.emplace_back(&Trainer::process_batch<true>, this, id);
}

for (std::thread &th : ths) {
Expand All @@ -99,6 +100,7 @@ namespace nn {
if (iter % 20 == 0) {
float average_error = checkpoint_error / float(batch_size * checkpoint_iter);
float average_accuracy = float(checkpoint_accuracy) / float(batch_size * checkpoint_iter);
auto [val_loss, val_acc] = test_validation();

int64_t current_time = now();
int64_t elapsed_time = current_time - start_time;
Expand All @@ -124,7 +126,7 @@ namespace nn {
std::cout << "] - Epoch " << epoch << " - Iteration " << iter << " - Error " << average_error << " - ETA " << (eta / 1000) << "s - " << pos_per_s << " pos/s \r" << std::flush;


log_file << iter << " " << average_error << " " << pos_per_s << " " << average_accuracy << "\n";
log_file << iter << " " << average_error << " " << pos_per_s << " " << average_accuracy << " " << val_loss << " " << val_acc << "\n";
log_file.flush();
checkpoint_error = 0.0f;
checkpoint_accuracy = 0;
Expand All @@ -142,14 +144,48 @@ namespace nn {
private:
Network network;
Adam adam;
DataParser parser;
DataParser training_parser;
DataParser validation_parser;
unsigned int entry_count;
int batch_size, thread_count;
std::vector<Gradient> gradients;
std::vector<float> errors;
std::vector<int> accuracy;
TrainingEntry *entries, *entries_next;

std::pair<float, float> test_validation() {
bool _;
delete[] entries;
entries = new TrainingEntry[batch_size];
validation_parser.read_batch(batch_size, entries, _);

errors.assign(thread_count, 0.0f);
accuracy.assign(thread_count, 0);

std::vector<std::thread> ths;
for (int id = 0; id < thread_count; id++) {
ths.emplace_back(&Trainer::process_batch<false>, this, id);
}

for (std::thread &th : ths) {
if (th.joinable()) th.join();
}

float val_loss = 0.0f;
int correct = 0.0f;

for (int id = 0; id < thread_count; id++) {
val_loss += errors[id];
correct += accuracy[id];
}

val_loss /= batch_size;

float val_acc = float(correct) / float(batch_size);
return {val_loss, val_acc};
}

template<bool train>
void process_batch(int id) {
Gradient &g = gradients[id];

Expand All @@ -168,10 +204,12 @@ namespace nn {
errors[id] += error;
accuracy[id] += ((entry.wdl - 0.5f) * (prediction - 0.5f) > 0.0f) || std::abs(entry.wdl - prediction) < 0.05f;

std::array<float, 1> l1_loss = {(1 - EVAL_INFLUENCE) * 2.0f * (prediction - entry.wdl) + EVAL_INFLUENCE * 2.0f * (prediction - entry.eval)};
if constexpr (train) {
std::array<float, 1> l1_loss = {(1 - EVAL_INFLUENCE) * 2.0f * (prediction - entry.wdl) + EVAL_INFLUENCE * 2.0f * (prediction - entry.eval)};

network.l1.backward(l1_loss, l0_output, l1_output, l0_loss, g.l1);
network.l0.backward(l0_loss, entry.features, l0_output, g.l0);
network.l1.backward(l1_loss, l0_output, l1_output, l0_loss, g.l1);
network.l0.backward(l0_loss, entry.features, l0_output, g.l0);
}
}
}

Expand Down
17 changes: 13 additions & 4 deletions src/uci/uci.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "../selfplay/selfplay.h"
#include "../network/train.h"
#include "../tests/perft.h"
#include "../utils/split.h"
#include "../utils/logger.h"
#include "../utils/utilities.h"
#include "command.h"
Expand Down Expand Up @@ -99,17 +100,25 @@ namespace uci {
std::optional<int> dropout = find_element<int>(tokens, "dropout");
selfplay::start_generation(limits, book.value_or("book.epd"), output.value_or("data.plain"), thread_count.value_or(1), dropout.value_or(1));
});
commands.emplace_back("split", [&](context tokens){
std::optional<std::string> input_data = find_element<std::string>(tokens, "input");
std::optional<std::string> output_data1 = find_element<std::string>(tokens, "output1");
std::optional<std::string> output_data2 = find_element<std::string>(tokens, "output2");
std::optional<int> rate = find_element<int>(tokens, "rate");
split_data(input_data.value_or("data.plain"), output_data1.value_or("train.plain"), output_data2.value_or("validation.plain"), rate.value_or(10));
});
commands.emplace_back("train", [&](context tokens){
std::optional<std::string> network_path = find_element<std::string>(tokens, "network");
std::optional<std::string> training_data = find_element<std::string>(tokens, "in");
std::optional<std::string> training_data = find_element<std::string>(tokens, "training_data");
std::optional<std::string> validation_data = find_element<std::string>(tokens, "validation_data");
std::optional<float> learning_rate = find_element<float>(tokens, "lr");
std::optional<int> epochs = find_element<int>(tokens, "epochs");
std::optional<int> batch_size = find_element<int>(tokens, "batch");
std::optional<int> threads = find_element<int>(tokens, "threads");
nn::Trainer trainer(training_data.value_or("data.plain"), network_path, learning_rate.value_or(0.001f),
nn::Trainer trainer(training_data.value_or("train.plain"), validation_data.value_or("validation.plain"), network_path, learning_rate.value_or(0.001f),
epochs.value_or(10), batch_size.value_or(16384), threads.value_or(4));
});
commands.emplace_back("learn", [&](context tokens){
/*commands.emplace_back("learn", [&](context tokens){
std::optional<int> thread_count = find_element<int>(tokens, "threads");
std::optional<int> iterations = find_element<int>(tokens, "iter");
std::optional<int> nodes = find_element<int>(tokens, "nodes");
Expand All @@ -130,7 +139,7 @@ namespace uci {
std::this_thread::sleep_for(std::chrono::seconds(1));
nn::net = nn::QNetwork("corenet.bin");
}
});
});*/
commands.emplace_back("perft", [&](context tokens) {
int depth = find_element<int>(tokens, "perft").value_or(5);
U64 node_count = test::perft<true, false>(board, depth);
Expand Down
40 changes: 40 additions & 0 deletions src/utils/split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// WhiteCore is a C++ chess engine
// Copyright (c) 2023 Balázs Szilágyi
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include <random>

inline void split_data(const std::string &input, const std::string &output1, const std::string &output2, int rate) {
std::ifstream in(input, std::ios::in);
std::ofstream out1(output1, std::ios::out | std::ios::app);
std::ofstream out2(output2, std::ios::out | std::ios::app);

std::random_device rd;
std::mt19937 g(rd());
std::uniform_int_distribution<int> dist(0, rate);

std::string line;
while (std::getline(in, line)) {
if (dist(g) == 0) out2 << line << "\n";
else out1 << line << "\n";
}

in.close();
out1.close();
out2.close();
}
7 changes: 4 additions & 3 deletions train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
config={
"learning_rate": 0.001,
"architecture": 3,
"dataset": "data.plain",
"epochs": 10,
"epochs": 30,
"batch_size": 16384,
"thread_count": 4
}
Expand All @@ -38,7 +37,7 @@
current_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
wandb.run.name = f"exp-{commit_hash}-{current_time}"

command_string = f"train in {wandb.run.config['dataset']} lr {wandb.run.config['learning_rate']} " \
command_string = f"train lr {wandb.run.config['learning_rate']} " \
f"epochs {wandb.run.config['epochs']} batch {wandb.run.config['batch_size']} " \
f"threads {wandb.run.config['thread_count']}\n"

Expand All @@ -64,6 +63,8 @@
wandb.run.summary["iterations"] = iteration
wandb.log({"training loss": float(data[1]),
"training accuracy": float(data[3]),
"validation loss": float(data[4]),
"validation accuracy": float(data[5]),
"positions per second": int(data[2])})
f.close()
except Exception as e:
Expand Down

0 comments on commit de132cf

Please sign in to comment.