From 649e084f08cb018b3eaeb2e47cf099305f155109 Mon Sep 17 00:00:00 2001 From: SzilBalazs Date: Wed, 16 Aug 2023 10:03:39 +0200 Subject: [PATCH] User sanitizer Bench: 2134743 --- src/chess/board.h | 22 +++++++++++++++++++- src/network/data_parser.h | 2 +- src/network/network.h | 10 ++++----- src/network/nnue.h | 4 ++-- src/network/train.h | 6 +++--- src/selfplay/selfplay.h | 8 ++++---- src/uci/uci.h | 27 +++++++++++++++--------- src/utils/bench.h | 4 ++-- src/utils/logger.h | 43 --------------------------------------- src/utils/utilities.h | 15 +++++++++++++- 10 files changed, 69 insertions(+), 72 deletions(-) delete mode 100644 src/utils/logger.h diff --git a/src/chess/board.h b/src/chess/board.h index 6cdcf5f..fce3405 100644 --- a/src/chess/board.h +++ b/src/chess/board.h @@ -24,6 +24,7 @@ #include "board_state.h" #include "move.h" +#include #include #include #include @@ -292,7 +293,13 @@ namespace chess { states.pop_back(); } - void load(const std::string &fen) { + void load(const std::string &fen, bool validate_fen = false) { + + if (validate_fen && !is_valid_fen(fen)) { + print("info", "error", "Invalid fen:", fen); + return; + } + board_clear(); std::stringstream ss(fen); @@ -322,6 +329,7 @@ namespace chess { state.rights = CastlingRights(rights); state.ep = square_from_string(ep); + if (!move50.empty() && std::all_of(move50.begin(), move50.end(), ::isdigit)) { state.move50 = std::stoi(move50); } else { @@ -473,6 +481,18 @@ namespace chess { states.clear(); states.emplace_back(); } + + static bool is_valid_fen(const std::string &fen) { + const static std::regex fen_regex("^" + "([rnbqkpRNBQKP1-8]+\\/){7}" + "([rnbqkpRNBQKP1-8]+)" + " [bw]" + " ([-KQkq]+|)" + " (([a-h][36])|-)" + " \\d+" + ".*"); + return std::regex_match(fen, fen_regex); + } }; diff --git a/src/network/data_parser.h b/src/network/data_parser.h index a977447..67c9b88 100644 --- a/src/network/data_parser.h +++ b/src/network/data_parser.h @@ -42,7 +42,7 @@ namespace nn { file.open(path, std::ios::in); if (!file.is_open()) { - Logger("Unable to open:", path); + print("Unable to open:", path); throw std::runtime_error("Unable to open: " + path); } } diff --git a/src/network/network.h b/src/network/network.h index 766c397..66dcdfc 100644 --- a/src/network/network.h +++ b/src/network/network.h @@ -52,7 +52,7 @@ namespace nn { Network(const std::string &network_path) { std::ifstream file(network_path, std::ios::in | std::ios::binary); if (!file.is_open()) { - Logger("Unable to open: ", network_path); + print("Unable to open: ", network_path); randomize(); return; } @@ -61,14 +61,14 @@ namespace nn { file.read(reinterpret_cast(&magic), sizeof(magic)); if (magic != MAGIC) { - Logger("Invalid network file: ", network_path, " with magic: ", magic); + print("Invalid network file: ", network_path, " with magic: ", magic); throw std::invalid_argument("Invalid network file with magic: " + std::to_string(magic)); } l0.load_from_file(file); l1.load_from_file(file); - Logger("Loaded network file: ", network_path); + print("Loaded network file: ", network_path); } Network() { @@ -90,7 +90,7 @@ namespace nn { void write_to_file(const std::string &output_path) { std::ofstream file(output_path, std::ios::out | std::ios::binary); if (!file.is_open()) { - Logger("Unable to open:", output_path); + print("Unable to open:", output_path); throw std::runtime_error("Unable to open: " + output_path); } @@ -107,7 +107,7 @@ namespace nn { void quantize(const std::string &output_path) { std::ofstream file(output_path, std::ios::out | std::ios::binary); if (!file.is_open()) { - Logger("Unable to open:", output_path); + print("Unable to open:", output_path); throw std::runtime_error("Unable to open: " + output_path); } diff --git a/src/network/nnue.h b/src/network/nnue.h index 7a56b22..0a950a6 100644 --- a/src/network/nnue.h +++ b/src/network/nnue.h @@ -18,8 +18,8 @@ #pragma once #include "../chess/constants.h" +#include "../utils/utilities.h" #include "../external/incbin/incbin.h" -#include "../utils/logger.h" #include "activations/crelu.h" #include "layers/accumulator.h" #include "layers/dense_layer.h" @@ -43,7 +43,7 @@ namespace nn { int offset = sizeof(int); if (magic != MAGIC) { - Logger("Invalid default network file with magic", magic); + print("Invalid default network file with magic", magic); throw std::invalid_argument("Invalid default network file with magic" + std::to_string(magic)); } diff --git a/src/network/train.h b/src/network/train.h index e65ae7b..85bbf40 100644 --- a/src/network/train.h +++ b/src/network/train.h @@ -17,7 +17,7 @@ #pragma once -#include "../utils/logger.h" +#include "../utils/utilities.h" #include "adam.h" #include "data_parser.h" @@ -217,7 +217,7 @@ namespace nn { } void index_training_data(const std::string &training_data) { - Logger("Indexing training data..."); + print("Indexing training data..."); std::string tmp; std::ifstream file(training_data, std::ios::in); @@ -226,7 +226,7 @@ namespace nn { } file.close(); - Logger("Found", entry_count, "positions"); + print("Found", entry_count, "positions"); } }; } // namespace nn \ No newline at end of file diff --git a/src/selfplay/selfplay.h b/src/selfplay/selfplay.h index 0d72c6e..64b28ef 100644 --- a/src/selfplay/selfplay.h +++ b/src/selfplay/selfplay.h @@ -114,7 +114,7 @@ namespace selfplay { void combine_data(const std::string &path, const std::string &output_file) { - Logger("Combining files..."); + print("Combining files..."); std::ofstream file(output_file, std::ios::app | std::ios::out); for (const auto &entry : std::filesystem::directory_iterator(path)) { @@ -127,7 +127,7 @@ namespace selfplay { } file.close(); - Logger("Finished combining"); + print("Finished combining"); } void compress_data(const std::string &input_path, const std::string &output_file) { @@ -136,11 +136,11 @@ namespace selfplay { ss << "zstd " << input_path << " -o " << output_file << " --rm #19"; const std::string cmd = ss.str(); - Logger(">", cmd); + print(">", cmd); system(cmd.c_str()); - Logger("Finished compressing"); + print("Finished compressing"); } std::string get_run_name(const search::Limits &limits, const std::string &id) { diff --git a/src/uci/uci.h b/src/uci/uci.h index ec9f5d8..312b3ac 100644 --- a/src/uci/uci.h +++ b/src/uci/uci.h @@ -22,7 +22,6 @@ #include "../search/search_manager.h" #include "../selfplay/selfplay.h" #include "../tests/perft.h" -#include "../utils/logger.h" #include "../utils/split.h" #include "../utils/utilities.h" #include "command.h" @@ -84,7 +83,7 @@ namespace uci { search::report::set_pretty_output(true); }); commands.emplace_back("isready", [&](context tokens) { - Logger("readyok"); + print("readyok"); }); commands.emplace_back("position", [&](context tokens) { parse_position(tokens); @@ -95,7 +94,7 @@ namespace uci { commands.emplace_back("eval", [&](context tokens) { nn::NNUE network{}; network.refresh(board.to_features()); - Logger("Eval:", eval::evaluate(board, network)); + print("Eval:", eval::evaluate(board, network)); }); commands.emplace_back("gen", [&](context tokens) { search::Limits limits; @@ -132,7 +131,7 @@ namespace uci { commands.emplace_back("perft", [&](context tokens) { int depth = find_element(tokens, "perft").value_or(5); uint64_t node_count = test::perft(board, depth); - Logger("Total node count: ", node_count); + print("Total node count: ", node_count); }); commands.emplace_back("go", [&](context tokens) { search::Limits limits = parse_limits(tokens); @@ -171,7 +170,7 @@ namespace uci { "Threads", "1", "spin", [&]() { sm.allocate_threads(get_option("Threads")); }, - 1, 128); + 1, 256); options.emplace_back( "MoveOverhead", "30", "spin", [&]() { @@ -204,23 +203,30 @@ namespace uci { break; } + std::vector tokens = convert_to_tokens(line); + bool found_match = false; for (const Command &cmd : commands) { if (cmd.is_match(tokens)) { cmd.func(tokens); + found_match = true; } } + + if (!found_match && !tokens.empty() && !tokens[0].empty()) { + print("info", "error", "Invalid uci command:", tokens[0]); + } } } void UCI::greetings() { - Logger("id", "name", "WhiteCore", VERSION); - Logger("id author Balazs Szilagyi"); + print("id", "name", "WhiteCore", VERSION); + print("id author Balazs Szilagyi"); for (const Option &opt : options) { - Logger(opt.to_string()); + print(opt.to_string()); } - Logger("uciok"); + print("uciok"); } search::Limits UCI::parse_limits(UCI::context tokens) { @@ -250,12 +256,13 @@ namespace uci { for (; idx < tokens.size() && tokens[idx] != "moves"; idx++) { fen += tokens[idx] + " "; } - board.load(fen); + board.load(fen, true); } if (idx < tokens.size() && tokens[idx] == "moves") idx++; for (; idx < tokens.size(); idx++) { chess::Move move = move_from_string(board, tokens[idx]); if (move == chess::NULL_MOVE) { + print("info", "error", "Invalid uci move:", tokens[idx]); break; } else { board.make_move(move); diff --git a/src/utils/bench.h b/src/utils/bench.h index 8d25368..086f2dd 100644 --- a/src/utils/bench.h +++ b/src/utils/bench.h @@ -69,7 +69,7 @@ void run_bench() { for (const std::string &fen : fens) { sm.tt_clear(); - board.load(fen); + board.load(fen, true); sm.set_limits(limits); sm.search(board); nodes += sm.get_node_count(); @@ -78,5 +78,5 @@ void run_bench() { int64_t end_time = now(); int64_t elapsed_time = end_time - start_time + 1; int64_t nps = calculate_nps(elapsed_time, nodes); - Logger(nodes, "nodes", nps, "nps"); + print(nodes, "nodes", nps, "nps"); } diff --git a/src/utils/logger.h b/src/utils/logger.h deleted file mode 100644 index c504fe7..0000000 --- a/src/utils/logger.h +++ /dev/null @@ -1,43 +0,0 @@ -// WhiteCore is a C++ chess engine -// Copyright (c) 2023 Balázs Szilágyi -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see . -// - -#pragma once - -#include -#include - -class Logger { -public: - template - explicit Logger(Args... args) { - print(args...); - std::cout << ss.str() << std::flush; - } - - template - void print(T a, Args... args) { - ss << a << " "; - print(args...); - } - - void print() { - ss << "\n"; - } - -private: - std::stringstream ss; -}; diff --git a/src/utils/utilities.h b/src/utils/utilities.h index 8df52c6..720de08 100644 --- a/src/utils/utilities.h +++ b/src/utils/utilities.h @@ -18,12 +18,25 @@ #pragma once #include "../chess/constants.h" -#include "logger.h" +#include +#include #include void init_all(); +template +void print(std::stringstream& ss, Args... args) { + ((ss << args << ' '), ...); +} + +template +void print(Args... args) { + std::stringstream ss; + print(ss, args...); + std::cout << ss.str() << std::endl; +} + template constexpr Color color_enemy() { if constexpr (color == WHITE)