Skip to content

Commit

Permalink
User sanitizer
Browse files Browse the repository at this point in the history
Bench: 2134743
  • Loading branch information
SzilBalazs committed Aug 16, 2023
1 parent 7e0f7b1 commit 649e084
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 72 deletions.
22 changes: 21 additions & 1 deletion src/chess/board.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "board_state.h"
#include "move.h"

#include <regex>
#include <algorithm>
#include <sstream>
#include <vector>
Expand Down Expand Up @@ -292,7 +293,13 @@ namespace chess {
states.pop_back();
}

void load(const std::string &fen) {
void load(const std::string &fen, bool validate_fen = false) {

if (validate_fen && !is_valid_fen(fen)) {
print("info", "error", "Invalid fen:", fen);
return;
}

board_clear();

std::stringstream ss(fen);
Expand Down Expand Up @@ -322,6 +329,7 @@ namespace chess {

state.rights = CastlingRights(rights);
state.ep = square_from_string(ep);

if (!move50.empty() && std::all_of(move50.begin(), move50.end(), ::isdigit)) {
state.move50 = std::stoi(move50);
} else {
Expand Down Expand Up @@ -473,6 +481,18 @@ namespace chess {
states.clear();
states.emplace_back();
}

static bool is_valid_fen(const std::string &fen) {
const static std::regex fen_regex("^"
"([rnbqkpRNBQKP1-8]+\\/){7}"
"([rnbqkpRNBQKP1-8]+)"
" [bw]"
" ([-KQkq]+|)"
" (([a-h][36])|-)"
" \\d+"
".*");
return std::regex_match(fen, fen_regex);
}
};


Expand Down
2 changes: 1 addition & 1 deletion src/network/data_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace nn {
file.open(path, std::ios::in);

if (!file.is_open()) {
Logger("Unable to open:", path);
print("Unable to open:", path);
throw std::runtime_error("Unable to open: " + path);
}
}
Expand Down
10 changes: 5 additions & 5 deletions src/network/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace nn {
Network(const std::string &network_path) {
std::ifstream file(network_path, std::ios::in | std::ios::binary);
if (!file.is_open()) {
Logger("Unable to open: ", network_path);
print("Unable to open: ", network_path);
randomize();
return;
}
Expand All @@ -61,14 +61,14 @@ namespace nn {
file.read(reinterpret_cast<char *>(&magic), sizeof(magic));

if (magic != MAGIC) {
Logger("Invalid network file: ", network_path, " with magic: ", magic);
print("Invalid network file: ", network_path, " with magic: ", magic);
throw std::invalid_argument("Invalid network file with magic: " + std::to_string(magic));
}

l0.load_from_file(file);
l1.load_from_file(file);

Logger("Loaded network file: ", network_path);
print("Loaded network file: ", network_path);
}

Network() {
Expand All @@ -90,7 +90,7 @@ namespace nn {
void write_to_file(const std::string &output_path) {
std::ofstream file(output_path, std::ios::out | std::ios::binary);
if (!file.is_open()) {
Logger("Unable to open:", output_path);
print("Unable to open:", output_path);
throw std::runtime_error("Unable to open: " + output_path);
}

Expand All @@ -107,7 +107,7 @@ namespace nn {
void quantize(const std::string &output_path) {
std::ofstream file(output_path, std::ios::out | std::ios::binary);
if (!file.is_open()) {
Logger("Unable to open:", output_path);
print("Unable to open:", output_path);
throw std::runtime_error("Unable to open: " + output_path);
}

Expand Down
4 changes: 2 additions & 2 deletions src/network/nnue.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#pragma once

#include "../chess/constants.h"
#include "../utils/utilities.h"
#include "../external/incbin/incbin.h"
#include "../utils/logger.h"
#include "activations/crelu.h"
#include "layers/accumulator.h"
#include "layers/dense_layer.h"
Expand All @@ -43,7 +43,7 @@ namespace nn {
int offset = sizeof(int);

if (magic != MAGIC) {
Logger("Invalid default network file with magic", magic);
print("Invalid default network file with magic", magic);
throw std::invalid_argument("Invalid default network file with magic" + std::to_string(magic));
}

Expand Down
6 changes: 3 additions & 3 deletions src/network/train.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#include "../utils/logger.h"
#include "../utils/utilities.h"
#include "adam.h"
#include "data_parser.h"

Expand Down Expand Up @@ -217,7 +217,7 @@ namespace nn {
}

void index_training_data(const std::string &training_data) {
Logger("Indexing training data...");
print("Indexing training data...");

std::string tmp;
std::ifstream file(training_data, std::ios::in);
Expand All @@ -226,7 +226,7 @@ namespace nn {
}
file.close();

Logger("Found", entry_count, "positions");
print("Found", entry_count, "positions");
}
};
} // namespace nn
8 changes: 4 additions & 4 deletions src/selfplay/selfplay.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ namespace selfplay {

void combine_data(const std::string &path, const std::string &output_file) {

Logger("Combining files...");
print("Combining files...");

std::ofstream file(output_file, std::ios::app | std::ios::out);
for (const auto &entry : std::filesystem::directory_iterator(path)) {
Expand All @@ -127,7 +127,7 @@ namespace selfplay {
}
file.close();

Logger("Finished combining");
print("Finished combining");
}

void compress_data(const std::string &input_path, const std::string &output_file) {
Expand All @@ -136,11 +136,11 @@ namespace selfplay {
ss << "zstd " << input_path << " -o " << output_file << " --rm #19";

const std::string cmd = ss.str();
Logger(">", cmd);
print(">", cmd);

system(cmd.c_str());

Logger("Finished compressing");
print("Finished compressing");
}

std::string get_run_name(const search::Limits &limits, const std::string &id) {
Expand Down
27 changes: 17 additions & 10 deletions src/uci/uci.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "../search/search_manager.h"
#include "../selfplay/selfplay.h"
#include "../tests/perft.h"
#include "../utils/logger.h"
#include "../utils/split.h"
#include "../utils/utilities.h"
#include "command.h"
Expand Down Expand Up @@ -84,7 +83,7 @@ namespace uci {
search::report::set_pretty_output(true);
});
commands.emplace_back("isready", [&](context tokens) {
Logger("readyok");
print("readyok");
});
commands.emplace_back("position", [&](context tokens) {
parse_position(tokens);
Expand All @@ -95,7 +94,7 @@ namespace uci {
commands.emplace_back("eval", [&](context tokens) {
nn::NNUE network{};
network.refresh(board.to_features());
Logger("Eval:", eval::evaluate(board, network));
print("Eval:", eval::evaluate(board, network));
});
commands.emplace_back("gen", [&](context tokens) {
search::Limits limits;
Expand Down Expand Up @@ -132,7 +131,7 @@ namespace uci {
commands.emplace_back("perft", [&](context tokens) {
int depth = find_element<int>(tokens, "perft").value_or(5);
uint64_t node_count = test::perft<true, false>(board, depth);
Logger("Total node count: ", node_count);
print("Total node count: ", node_count);
});
commands.emplace_back("go", [&](context tokens) {
search::Limits limits = parse_limits(tokens);
Expand Down Expand Up @@ -171,7 +170,7 @@ namespace uci {
"Threads", "1", "spin", [&]() {
sm.allocate_threads(get_option<int>("Threads"));
},
1, 128);
1, 256);

options.emplace_back(
"MoveOverhead", "30", "spin", [&]() {
Expand Down Expand Up @@ -204,23 +203,30 @@ namespace uci {
break;
}


std::vector<std::string> tokens = convert_to_tokens(line);

bool found_match = false;
for (const Command &cmd : commands) {
if (cmd.is_match(tokens)) {
cmd.func(tokens);
found_match = true;
}
}

if (!found_match && !tokens.empty() && !tokens[0].empty()) {
print("info", "error", "Invalid uci command:", tokens[0]);
}
}
}

void UCI::greetings() {
Logger("id", "name", "WhiteCore", VERSION);
Logger("id author Balazs Szilagyi");
print("id", "name", "WhiteCore", VERSION);
print("id author Balazs Szilagyi");
for (const Option &opt : options) {
Logger(opt.to_string());
print(opt.to_string());
}
Logger("uciok");
print("uciok");
}

search::Limits UCI::parse_limits(UCI::context tokens) {
Expand Down Expand Up @@ -250,12 +256,13 @@ namespace uci {
for (; idx < tokens.size() && tokens[idx] != "moves"; idx++) {
fen += tokens[idx] + " ";
}
board.load(fen);
board.load(fen, true);
}
if (idx < tokens.size() && tokens[idx] == "moves") idx++;
for (; idx < tokens.size(); idx++) {
chess::Move move = move_from_string(board, tokens[idx]);
if (move == chess::NULL_MOVE) {
print("info", "error", "Invalid uci move:", tokens[idx]);
break;
} else {
board.make_move(move);
Expand Down
4 changes: 2 additions & 2 deletions src/utils/bench.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void run_bench() {

for (const std::string &fen : fens) {
sm.tt_clear();
board.load(fen);
board.load(fen, true);
sm.set_limits(limits);
sm.search<true>(board);
nodes += sm.get_node_count();
Expand All @@ -78,5 +78,5 @@ void run_bench() {
int64_t end_time = now();
int64_t elapsed_time = end_time - start_time + 1;
int64_t nps = calculate_nps(elapsed_time, nodes);
Logger(nodes, "nodes", nps, "nps");
print(nodes, "nodes", nps, "nps");
}
43 changes: 0 additions & 43 deletions src/utils/logger.h

This file was deleted.

15 changes: 14 additions & 1 deletion src/utils/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,25 @@
#pragma once

#include "../chess/constants.h"
#include "logger.h"

#include <sstream>
#include <iostream>
#include <chrono>

void init_all();

template<typename... Args>
void print(std::stringstream& ss, Args... args) {
((ss << args << ' '), ...);
}

template<typename... Args>
void print(Args... args) {
std::stringstream ss;
print(ss, args...);
std::cout << ss.str() << std::endl;
}

template<Color color>
constexpr Color color_enemy() {
if constexpr (color == WHITE)
Expand Down

0 comments on commit 649e084

Please sign in to comment.