Skip to content

Commit

Permalink
Policy backward
Browse files Browse the repository at this point in the history
Bench: 9895581
  • Loading branch information
SzilBalazs committed Sep 2, 2023
1 parent f1529e0 commit 5a298cc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
10 changes: 5 additions & 5 deletions src/network/activations/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ namespace nn::activations::softmax {
}
}

template<size_t N>
void backward(const std::array<float, N> &s, std::array<float, N> &z_grad, size_t target) {
std::fill(z_grad.begin(), z_grad.end(), 0.0f);
void backward(const std::vector<float> &s, std::vector<float> &z_grad, size_t target) {
const size_t N = s.size();
z_grad.assign(N, 0.0f);

std::array<float, N> s_grad;
std::vector<float> s_grad(N);
for (size_t i = 0; i < N; i++) {
s_grad[i] = s[i] - (i == target);
s_grad[i] = s[i] - float(i == target);
}

for (size_t i = 0; i < N; i++) {
Expand Down
37 changes: 31 additions & 6 deletions src/network/layers/policy_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@

namespace nn::layers {

template<size_t IN>
struct PolicyLayerGradient {
PolicyLayerGradient() {
weights = new float[IN * OUT];
}

~PolicyLayerGradient() {
delete[] weights;
}

float *weights;
static constexpr size_t OUT = 64 * 64;
};

template<size_t IN>
class PolicyLayer {
public:
Expand All @@ -32,12 +46,27 @@ namespace nn::layers {

void forward(const std::array<float, IN> &input, const std::vector<chess::Move> &moves, std::vector<float> &output) const {
const size_t N = moves.size();
output.assign(N, 0.0f);
std::vector<float> z(N);

for (size_t i = 0; i < N; i++) {
size_t move_index = get_move_index(moves[i]);
for (size_t j = 0; j < IN; j++) {
output[i] += input[j] * get_weight(j, move_index);
z[i] += input[j] * weights[move_index * IN + j];
}
}
activations::softmax::forward(z, output);
}

void backward(const std::array<float, IN> &input, const std::vector<float> &output, const std::vector<chess::Move> &moves, size_t best_move_index, PolicyLayerGradient<IN> &gradient) {
const size_t N = moves.size();

std::vector<float> z_grad;
activations::softmax::backward(output, z_grad, best_move_index);

for (size_t i = 0; i < N; i++) {
size_t move_index = get_move_index(moves[i]);
for (size_t j = 0; j < IN; j++) {
gradient.weights[move_index * IN + j] += input[j] * z_grad[i];
}
}
}
Expand All @@ -46,10 +75,6 @@ namespace nn::layers {
static constexpr size_t OUT = 64 * 64;
float *weights;

float &get_weight(size_t in, size_t out) {
return weights[out * IN + in];
}

size_t get_move_index(const chess::Move &move) {
return move.get_from() * 64 + move.get_to();
}
Expand Down
14 changes: 11 additions & 3 deletions src/network/policy_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ namespace nn {
class PolicyNetwork {
public:
void forward(const std::array<float, IN> &input, const std::vector<chess::Move> &moves, std::vector<float> &output) {
std::vector<float> z;
l.forward(input, moves, output);
}

l.forward(input, moves, z);
activations::softmax::forward(z, output);
void backward(const std::array<float, IN> &input, const std::vector<float> &output, const std::vector<chess::Move> &moves, const chess::Move &best_move, layers::PolicyLayerGradient<IN> &gradient) {
size_t best_move_index = 0;
for (size_t i = 0; i < moves.size(); i++) {
if (moves[i] == best_move) {
best_move_index = i;
break;
}
}
l.backward(input, output, moves, best_move_index, gradient);
}

private:
Expand Down

0 comments on commit 5a298cc

Please sign in to comment.