-
Notifications
You must be signed in to change notification settings - Fork 28
/
GRU.hpp
50 lines (38 loc) · 946 Bytes
/
GRU.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#pragma once
#include "Matrix.hpp"
#include "Rand.hpp"
#include <fstream>
class GRU{
public:
GRU(){};
GRU(const int inputDim, const int hiddenDim);
class State;
class Grad;
MatD Wxr, Whr; VecD br;
MatD Wxz, Whz; VecD bz;
MatD Wxu, Whu; VecD bu;
void init(Rand& rnd, const Real scale = 1.0);
virtual void forward(const VecD& xt, const GRU::State* prev, GRU::State* cur);
virtual void backward(GRU::State* prev, GRU::State* cur, GRU::Grad& grad, const VecD& xt);
void sgd(const GRU::Grad& grad, const Real learningRate);
void save(std::ofstream& ofs);
void load(std::ifstream& ifs);
};
class GRU::State{
public:
VecD h, u, r, z;
VecD rh;
VecD delh, delx; //for backprop
void clear();
};
class GRU::Grad{
public:
Grad(){}
Grad(const GRU& gru);
MatD Wxr, Whr; VecD br;
MatD Wxz, Whz; VecD bz;
MatD Wxu, Whu; VecD bu;
void init();
Real norm();
void operator += (const GRU::Grad& grad);
};