-
Notifications
You must be signed in to change notification settings - Fork 28
/
LSTM.hpp
84 lines (66 loc) · 2.49 KB
/
LSTM.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#pragma once
#include "Matrix.hpp"
#include "Rand.hpp"
#include <fstream>
class LSTM{
public:
LSTM(){};
LSTM(const int inputDim, const int hiddenDim);
LSTM(const int inputDim, const int additionalInputDim, const int hiddenDim);
class State;
class Grad;
Real dropoutRateX;
Real dropoutRateA;
Real dropoutRateH;
MatD Wxi, Whi; VecD bi; //for the input gate
MatD Wxf, Whf; VecD bf; //for the forget gate
MatD Wxo, Who; VecD bo; //for the output gate
MatD Wxu, Whu; VecD bu; //for the memory cell
void init(Rand& rnd, const Real scale = 1.0);
void activate(LSTM::State* cur);
void activate(const LSTM::State* prev, LSTM::State* cur);
virtual void forward(const VecD& xt, const LSTM::State* prev, LSTM::State* cur);
virtual void forward(const VecD& xt, LSTM::State* cur);
virtual void backward(LSTM::State* prev, LSTM::State* cur, LSTM::Grad& grad, const VecD& xt);
virtual void backward(LSTM::State* cur, LSTM::Grad& grad, const VecD& xt);
void sgd(const LSTM::Grad& grad, const Real learningRate);
void save(std::ofstream& ofs);
void load(std::ifstream& ifs);
MatD Wai, Waf, Wao, Wau; //for additional input
virtual void forward(const VecD& xt, const VecD& at, const LSTM::State* prev, LSTM::State* cur);
virtual void forward(const VecD& xt, const VecD& at, LSTM::State* cur);
virtual void backward(LSTM::State* prev, LSTM::State* cur, LSTM::Grad& grad, const VecD& xt, const VecD& at);
virtual void backward(LSTM::State* cur, LSTM::Grad& grad, const VecD& xt, const VecD& at);
void dropout(bool isTest);
void operator += (const LSTM& lstm);
void operator /= (const Real val);
};
class LSTM::State{
public:
virtual ~State() {this->clear();};
VecD h, c, u, i, f, o;
VecD cTanh;
VecD maskXt, maskAt, maskHt; //for dropout
VecD delh, delc, delx, dela; //for backprop
virtual void clear();
};
class LSTM::Grad{
public:
Grad(): gradHist(0) {}
Grad(const LSTM& lstm);
LSTM::Grad* gradHist;
MatD Wxi, Whi; VecD bi;
MatD Wxf, Whf; VecD bf;
MatD Wxo, Who; VecD bo;
MatD Wxu, Whu; VecD bu;
MatD Wai, Waf, Wao, Wau;
void init();
Real norm();
void l2reg(const Real lambda, const LSTM& lstm);
void l2reg(const Real lambda, const LSTM& lstm, const LSTM& target);
void sgd(const Real learningRate, LSTM& lstm);
void adagrad(const Real learningRate, LSTM& lstm, const Real initVal = 1.0);
void momentum(const Real learningRate, const Real m, LSTM& lstm);
void operator += (const LSTM::Grad& grad);
void operator /= (const Real val);
};