-
Notifications
You must be signed in to change notification settings - Fork 191
/
game_ac_network.py
270 lines (211 loc) · 10.9 KB
/
game_ac_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
# Actor-Critic Network Base Class
# (Policy network and Value network)
class GameACNetwork(object):
def __init__(self,
action_size,
thread_index, # -1 for global
device="/cpu:0"):
self._action_size = action_size
self._thread_index = thread_index
self._device = device
def prepare_loss(self, entropy_beta):
with tf.device(self._device):
# taken action (input for policy)
self.a = tf.placeholder("float", [None, self._action_size])
# temporary difference (R-V) (input for policy)
self.td = tf.placeholder("float", [None])
# avoid NaN with clipping when value in pi becomes zero
log_pi = tf.log(tf.clip_by_value(self.pi, 1e-20, 1.0))
# policy entropy
entropy = -tf.reduce_sum(self.pi * log_pi, reduction_indices=1)
# policy loss (output) (Adding minus, because the original paper's objective function is for gradient ascent, but we use gradient descent optimizer.)
policy_loss = - tf.reduce_sum( tf.reduce_sum( tf.multiply( log_pi, self.a ), reduction_indices=1 ) * self.td + entropy * entropy_beta )
# R (input for value)
self.r = tf.placeholder("float", [None])
# value loss (output)
# (Learning rate for Critic is half of Actor's, so multiply by 0.5)
value_loss = 0.5 * tf.nn.l2_loss(self.r - self.v)
# gradienet of policy and value are summed up
self.total_loss = policy_loss + value_loss
def run_policy_and_value(self, sess, s_t):
raise NotImplementedError()
def run_policy(self, sess, s_t):
raise NotImplementedError()
def run_value(self, sess, s_t):
raise NotImplementedError()
def get_vars(self):
raise NotImplementedError()
def sync_from(self, src_netowrk, name=None):
src_vars = src_netowrk.get_vars()
dst_vars = self.get_vars()
sync_ops = []
with tf.device(self._device):
with tf.name_scope(name, "GameACNetwork", []) as name:
for(src_var, dst_var) in zip(src_vars, dst_vars):
sync_op = tf.assign(dst_var, src_var)
sync_ops.append(sync_op)
return tf.group(*sync_ops, name=name)
# weight initialization based on muupan's code
# https://github.com/muupan/async-rl/blob/master/a3c_ale.py
def _fc_variable(self, weight_shape):
input_channels = weight_shape[0]
output_channels = weight_shape[1]
d = 1.0 / np.sqrt(input_channels)
bias_shape = [output_channels]
weight = tf.Variable(tf.random_uniform(weight_shape, minval=-d, maxval=d))
bias = tf.Variable(tf.random_uniform(bias_shape, minval=-d, maxval=d))
return weight, bias
def _conv_variable(self, weight_shape):
w = weight_shape[0]
h = weight_shape[1]
input_channels = weight_shape[2]
output_channels = weight_shape[3]
d = 1.0 / np.sqrt(input_channels * w * h)
bias_shape = [output_channels]
weight = tf.Variable(tf.random_uniform(weight_shape, minval=-d, maxval=d))
bias = tf.Variable(tf.random_uniform(bias_shape, minval=-d, maxval=d))
return weight, bias
def _conv2d(self, x, W, stride):
return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "VALID")
# Actor-Critic FF Network
class GameACFFNetwork(GameACNetwork):
def __init__(self,
action_size,
thread_index, # -1 for global
device="/cpu:0"):
GameACNetwork.__init__(self, action_size, thread_index, device)
scope_name = "net_" + str(self._thread_index)
with tf.device(self._device), tf.variable_scope(scope_name) as scope:
self.W_conv1, self.b_conv1 = self._conv_variable([8, 8, 4, 16]) # stride=4
self.W_conv2, self.b_conv2 = self._conv_variable([4, 4, 16, 32]) # stride=2
self.W_fc1, self.b_fc1 = self._fc_variable([2592, 256])
# weight for policy output layer
self.W_fc2, self.b_fc2 = self._fc_variable([256, action_size])
# weight for value output layer
self.W_fc3, self.b_fc3 = self._fc_variable([256, 1])
# state (input)
self.s = tf.placeholder("float", [None, 84, 84, 4])
h_conv1 = tf.nn.relu(self._conv2d(self.s, self.W_conv1, 4) + self.b_conv1)
h_conv2 = tf.nn.relu(self._conv2d(h_conv1, self.W_conv2, 2) + self.b_conv2)
h_conv2_flat = tf.reshape(h_conv2, [-1, 2592])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, self.W_fc1) + self.b_fc1)
# policy (output)
self.pi = tf.nn.softmax(tf.matmul(h_fc1, self.W_fc2) + self.b_fc2)
# value (output)
v_ = tf.matmul(h_fc1, self.W_fc3) + self.b_fc3
self.v = tf.reshape( v_, [-1] )
def run_policy_and_value(self, sess, s_t):
pi_out, v_out = sess.run( [self.pi, self.v], feed_dict = {self.s : [s_t]} )
return (pi_out[0], v_out[0])
def run_policy(self, sess, s_t):
pi_out = sess.run( self.pi, feed_dict = {self.s : [s_t]} )
return pi_out[0]
def run_value(self, sess, s_t):
v_out = sess.run( self.v, feed_dict = {self.s : [s_t]} )
return v_out[0]
def get_vars(self):
return [self.W_conv1, self.b_conv1,
self.W_conv2, self.b_conv2,
self.W_fc1, self.b_fc1,
self.W_fc2, self.b_fc2,
self.W_fc3, self.b_fc3]
# Actor-Critic LSTM Network
class GameACLSTMNetwork(GameACNetwork):
def __init__(self,
action_size,
thread_index, # -1 for global
device="/cpu:0" ):
GameACNetwork.__init__(self, action_size, thread_index, device)
scope_name = "net_" + str(self._thread_index)
with tf.device(self._device), tf.variable_scope(scope_name) as scope:
self.W_conv1, self.b_conv1 = self._conv_variable([8, 8, 4, 16]) # stride=4
self.W_conv2, self.b_conv2 = self._conv_variable([4, 4, 16, 32]) # stride=2
self.W_fc1, self.b_fc1 = self._fc_variable([2592, 256])
# lstm
self.lstm = tf.contrib.rnn.BasicLSTMCell(256, state_is_tuple=True)
# weight for policy output layer
self.W_fc2, self.b_fc2 = self._fc_variable([256, action_size])
# weight for value output layer
self.W_fc3, self.b_fc3 = self._fc_variable([256, 1])
# state (input)
self.s = tf.placeholder("float", [None, 84, 84, 4])
h_conv1 = tf.nn.relu(self._conv2d(self.s, self.W_conv1, 4) + self.b_conv1)
h_conv2 = tf.nn.relu(self._conv2d(h_conv1, self.W_conv2, 2) + self.b_conv2)
h_conv2_flat = tf.reshape(h_conv2, [-1, 2592])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, self.W_fc1) + self.b_fc1)
# h_fc1 shape=(5,256)
h_fc1_reshaped = tf.reshape(h_fc1, [1,-1,256])
# h_fc_reshaped = (1,5,256)
# place holder for LSTM unrolling time step size.
self.step_size = tf.placeholder(tf.float32, [1])
self.initial_lstm_state0 = tf.placeholder(tf.float32, [1, 256])
self.initial_lstm_state1 = tf.placeholder(tf.float32, [1, 256])
self.initial_lstm_state = tf.contrib.rnn.LSTMStateTuple(self.initial_lstm_state0,
self.initial_lstm_state1)
# Unrolling LSTM up to LOCAL_T_MAX time steps. (= 5time steps.)
# When episode terminates unrolling time steps becomes less than LOCAL_TIME_STEP.
# Unrolling step size is applied via self.step_size placeholder.
# When forward propagating, step_size is 1.
# (time_major = False, so output shape is [batch_size, max_time, cell.output_size])
lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn(self.lstm,
h_fc1_reshaped,
initial_state = self.initial_lstm_state,
sequence_length = self.step_size,
time_major = False,
scope = scope)
# lstm_outputs: (1,5,256) for back prop, (1,1,256) for forward prop.
lstm_outputs = tf.reshape(lstm_outputs, [-1,256])
# policy (output)
self.pi = tf.nn.softmax(tf.matmul(lstm_outputs, self.W_fc2) + self.b_fc2)
# value (output)
v_ = tf.matmul(lstm_outputs, self.W_fc3) + self.b_fc3
self.v = tf.reshape( v_, [-1] )
scope.reuse_variables()
self.W_lstm = tf.get_variable("basic_lstm_cell/weights")
self.b_lstm = tf.get_variable("basic_lstm_cell/biases")
self.reset_state()
def reset_state(self):
self.lstm_state_out = tf.contrib.rnn.LSTMStateTuple(np.zeros([1, 256]),
np.zeros([1, 256]))
def run_policy_and_value(self, sess, s_t):
# This run_policy_and_value() is used when forward propagating.
# so the step size is 1.
pi_out, v_out, self.lstm_state_out = sess.run( [self.pi, self.v, self.lstm_state],
feed_dict = {self.s : [s_t],
self.initial_lstm_state0 : self.lstm_state_out[0],
self.initial_lstm_state1 : self.lstm_state_out[1],
self.step_size : [1]} )
# pi_out: (1,3), v_out: (1)
return (pi_out[0], v_out[0])
def run_policy(self, sess, s_t):
# This run_policy() is used for displaying the result with display tool.
pi_out, self.lstm_state_out = sess.run( [self.pi, self.lstm_state],
feed_dict = {self.s : [s_t],
self.initial_lstm_state0 : self.lstm_state_out[0],
self.initial_lstm_state1 : self.lstm_state_out[1],
self.step_size : [1]} )
return pi_out[0]
def run_value(self, sess, s_t):
# This run_value() is used for calculating V for bootstrapping at the
# end of LOCAL_T_MAX time step sequence.
# When next sequcen starts, V will be calculated again with the same state using updated network weights,
# so we don't update LSTM state here.
prev_lstm_state_out = self.lstm_state_out
v_out, _ = sess.run( [self.v, self.lstm_state],
feed_dict = {self.s : [s_t],
self.initial_lstm_state0 : self.lstm_state_out[0],
self.initial_lstm_state1 : self.lstm_state_out[1],
self.step_size : [1]} )
# roll back lstm state
self.lstm_state_out = prev_lstm_state_out
return v_out[0]
def get_vars(self):
return [self.W_conv1, self.b_conv1,
self.W_conv2, self.b_conv2,
self.W_fc1, self.b_fc1,
self.W_lstm, self.b_lstm,
self.W_fc2, self.b_fc2,
self.W_fc3, self.b_fc3]