Skip to content

Commit

Permalink
Merge pull request #23 from SzilBalazs/nnue
Browse files Browse the repository at this point in the history
NNUE

NNUE vs master at STC:
ELO   | 197.81 +- 40.85 (95%)
SPRT  | 10.0+0.10s Threads=1 Hash=16MB
LLR   | 2.99 (-2.94, 2.94) [0.00, 5.00]
GAMES | N: 336 W: 228 L: 55 D: 53

NNUE vs master at LTC:
ELO   | 210.21 +- 43.13 (95%)
SPRT  | 60.0+0.60s Threads=1 Hash=256MB
LLR   | 2.98 (-2.94, 2.94) [0.00, 5.00]
GAMES | N: 320 W: 223 L: 50 D: 47

Bench: 3623609
  • Loading branch information
SzilBalazs authored Jul 18, 2023
2 parents de132cf + ec0985c commit e31f3c9
Show file tree
Hide file tree
Showing 21 changed files with 351 additions and 459 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ endif

$(OUTPUT_BINARY): $(HEADERS) $(SOURCES) $(INCBIN_TOOL)
ifeq ($(uname_S), Windows)
@./$(INCBIN_TOOL) src/main.cpp -o src/corenet.cpp
@./$(INCBIN_TOOL) src/network/nnue.h -o src/corenet.cpp
endif
@echo Compiling $(NAME)
@$(CXX) $(TARGET_FLAGS) $(CXXFLAGS) -o $@ src/*.cpp
Expand Down
Binary file modified corenet.bin
Binary file not shown.
60 changes: 38 additions & 22 deletions src/core/board.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
#include "bitboard.h"
#include "board_state.h"
#include "move.h"
#include "../network/nnue.h"

#include <sstream>
#include <algorithm>
#include <sstream>
#include <vector>

#define state states.back()
Expand Down Expand Up @@ -133,7 +134,7 @@ namespace core {
states.pop_back();
}

inline void make_move(Move move) {
inline void make_move(Move move, nn::NNUE *nnue = nullptr) {
const Square from = move.get_from();
const Square to = move.get_to();
Piece piece_moved = piece_at(from);
Expand Down Expand Up @@ -163,7 +164,7 @@ namespace core {

if (move.eq_flag(EP_CAPTURE)) {
state.piece_captured = Piece(PAWN, xstm);
square_clear(to + DOWN);
square_clear(to + DOWN, nnue);
} else {
state.piece_captured = piece_at(to);
}
Expand Down Expand Up @@ -191,19 +192,19 @@ namespace core {
}
}

move_piece(piece_moved, from, to);
move_piece(piece_moved, from, to, nnue);

if (move.eq_flag(KING_CASTLE)) {
if (stm == WHITE) {
move_piece(Piece(ROOK, WHITE), H1, F1);
move_piece(Piece(ROOK, WHITE), H1, F1, nnue);
} else {
move_piece(Piece(ROOK, BLACK), H8, F8);
move_piece(Piece(ROOK, BLACK), H8, F8, nnue);
}
} else if (move.eq_flag(QUEEN_CASTLE)) {
if (stm == WHITE) {
move_piece(Piece(ROOK, WHITE), A1, D1);
move_piece(Piece(ROOK, WHITE), A1, D1, nnue);
} else {
move_piece(Piece(ROOK, BLACK), A8, D8);
move_piece(Piece(ROOK, BLACK), A8, D8, nnue);
}
}

Expand All @@ -223,7 +224,7 @@ namespace core {
state.hash.xor_castle(state.rights);
}

inline void undo_move(Move move) {
inline void undo_move(Move move, nn::NNUE *nnue = nullptr) {
const Square from = move.get_from();
const Square to = move.get_to();
Piece piece_moved = piece_at(to);
Expand All @@ -240,24 +241,24 @@ namespace core {

if (move.eq_flag(KING_CASTLE)) {
if (stm == WHITE) {
move_piece(Piece(ROOK, WHITE), F1, H1);
move_piece(Piece(ROOK, WHITE), F1, H1, nnue);
} else {
move_piece(Piece(ROOK, BLACK), F8, H8);
move_piece(Piece(ROOK, BLACK), F8, H8, nnue);
}
} else if (move.eq_flag(QUEEN_CASTLE)) {
if (stm == WHITE) {
move_piece(Piece(ROOK, WHITE), D1, A1);
move_piece(Piece(ROOK, WHITE), D1, A1, nnue);
} else {
move_piece(Piece(ROOK, BLACK), D8, A8);
move_piece(Piece(ROOK, BLACK), D8, A8, nnue);
}
}

move_piece(piece_moved, to, from);
move_piece(piece_moved, to, from, nnue);

if (move.eq_flag(EP_CAPTURE)) {
square_set(to + DOWN, state.piece_captured);
square_set(to + DOWN, state.piece_captured, nnue);
} else if (move.is_capture()) {
square_set(to, state.piece_captured);
square_set(to, state.piece_captured, nnue);
}

states.pop_back();
Expand Down Expand Up @@ -378,13 +379,24 @@ namespace core {
<< std::endl;
}

inline std::vector<unsigned int> to_features() const {
std::vector<unsigned int> result;
Bitboard bb = occupied();
while (bb) {
Square sq = bb.pop_lsb();
Piece piece = piece_at(sq);
result.emplace_back(nn::NNUE::get_feature_index(piece, sq));
}
return result;
}

private:
Piece mailbox[64];
Bitboard bb_pieces[6], bb_colors[2];

std::vector<BoardState> states;

inline void square_clear(Square square) {
inline void square_clear(Square square, nn::NNUE *nnue = nullptr) {
const Piece piece = piece_at(square);
if (piece.is_null()) return;

Expand All @@ -393,24 +405,28 @@ namespace core {
mailbox[square] = NULL_PIECE;

state.hash.xor_piece(square, piece);

if (nnue) nnue->deactivate(piece, square);
}

inline void square_set(Square square, Piece piece) {
inline void square_set(Square square, Piece piece, nn::NNUE *nnue = nullptr) {
assert(piece.is_ok());
square_clear(square);
square_clear(square, nnue);

bb_colors[piece.color].set(square);
bb_pieces[piece.type].set(square);
mailbox[square] = piece;

state.hash.xor_piece(square, piece);

if (nnue) nnue->activate(piece, square);
}

inline void move_piece(Piece piece, Square from, Square to) {
inline void move_piece(Piece piece, Square from, Square to, nn::NNUE *nnue = nullptr) {
assert(piece.is_ok());

square_clear(from);
square_set(to, piece);
square_clear(from, nnue);
square_set(to, piece, nnue);
}

inline void board_clear() {
Expand Down
23 changes: 1 addition & 22 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
#include "tests/tests.h"
#include "uci/uci.h"
#include "utils/bench.h"
#include "external/incbin/incbin.h"

namespace nn {
QNetwork net;
INCBIN(DefaultNetwork, "corenet.bin");
}

namespace core {
// Declarations
Expand Down Expand Up @@ -61,28 +55,13 @@ int main(int argc, char *argv[]) {
mode = std::string(argv[1]);
}

nn::net = nn::QNetwork(nn::gDefaultNetworkData);
init_all();

if (mode == "test") {
test::run();
} else if (mode == "bench") {
run_bench();
} else if (mode == "viz") {
nn::net = nn::QNetwork("corenet.bin");
for (Color color : {WHITE, BLACK}) {
for (PieceType pt : {KING, PAWN, BISHOP, KNIGHT, ROOK, QUEEN}) {
for (bool eg : {true, false}) {
std::cout << char_from_piece(Piece(pt, color)) << std::endl;
for (unsigned int i = 0; i < 64; i++) {
unsigned int feature = nn::Network::get_feature_index(Piece(pt, color), i);
std::cout << round(400 * nn::net.pst.weights[feature * 2 + eg]) << " ";
if (i % 8 == 7) std::cout << std::endl;
}
}
}
}
} else {
} else {
uci::UCI protocol;
protocol.start();
}
Expand Down
11 changes: 6 additions & 5 deletions src/network/activations/crelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
#include <algorithm>

namespace nn::activations {
template<typename T, int INT_UPPER_BOUND>
struct crelu {

static constexpr float UPPER_BOUND = 1.0f;
static constexpr T UPPER_BOUND = static_cast<T>(INT_UPPER_BOUND);

static float forward(float value) {
return std::clamp(value, 0.0f, UPPER_BOUND);
static T forward(T value) {
return std::clamp(value, static_cast<T>(0), UPPER_BOUND);
}

static constexpr float backward(float value) {
return 0.0f < value && value < UPPER_BOUND;
static constexpr T backward(T value) {
return static_cast<T>(0) < value && value < UPPER_BOUND;
}
};
} // namespace nn::activations
7 changes: 4 additions & 3 deletions src/network/activations/none.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
#pragma once

namespace nn::activations {
template<typename T>
struct none {
static constexpr float forward(float value) {
static constexpr T forward(T value) {
return value;
}

static constexpr float backward(float value) {
return 1.0f;
static constexpr T backward(T value) {
return static_cast<T>(1);
}
};
} // namespace nn::activations
9 changes: 5 additions & 4 deletions src/network/activations/relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
#include <algorithm>

namespace nn::activations {
template<typename T>
struct relu {
static float forward(float value) {
return std::max(0.0f, value);
static T forward(T value) {
return std::max(static_cast<T>(0), value);
}

static constexpr float backward(float value) {
return 0.0f < value;
static constexpr T backward(T value) {
return static_cast<T>(0) < value;
}
};
} // namespace nn::activations
4 changes: 4 additions & 0 deletions src/network/adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ namespace nn {
update(network.l1.biases, m_gradient.l1.biases, v_gradient.l1.biases, total.l1.biases);
}

void reduce_learning_rate(float rate) {
LR *= rate;
}

private:
static constexpr float BETA1 = 0.9f;
static constexpr float BETA2 = 0.999f;
Expand Down
2 changes: 2 additions & 0 deletions src/network/data_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// along with this program. If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include "../core/constants.h"
#include "../utils/utilities.h"
#include "activations/sigmoid.h"
Expand Down
Loading

0 comments on commit e31f3c9

Please sign in to comment.