-
Notifications
You must be signed in to change notification settings - Fork 3
/
random_agent.py
60 lines (47 loc) · 2.15 KB
/
random_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import constants as c
from random import randint
# noinspection PyAttributeOutsideInit
class RandomAgent(object):
def __init__(self, sess, args, n_actions):
self.sess = sess
self.args = args
self.n_actions = n_actions
self.define_graph()
def define_graph(self):
self.input = tf.placeholder(tf.float32, (None, c.IN_HEIGHT, c.IN_WIDTH, c.IN_CHANNELS))
self.w = tf.Variable(tf.truncated_normal(
(c.IN_HEIGHT * c.IN_WIDTH * c.IN_CHANNELS, self.n_actions), stddev=0.01))
self.b = tf.Variable(tf.constant(0.1), (self.n_actions,))
self.preds = self.get_preds(self.input)
# Ignore this. Doesn't really mean anything:
self.global_step = tf.Variable(0, trainable=False)
self.loss = self.preds * tf.random_uniform([self.n_actions])
optimizer = tf.train.AdamOptimizer(learning_rate=1)
self.train_op = optimizer.minimize(self.loss, global_step=self.global_step)
def get_preds(self, inputs):
preds = tf.contrib.layers.flatten(inputs)
preds = tf.matmul(preds, self.w) + self.b
return preds
#
# API for train loop:
#
def get_action(self, state):
"""
:param state: A numpy array with a single state (shape: (1, 84, 84, 4))
:return: The action from the policy (an int in [0, self.n_actions])
"""
return randint(0, self.n_actions - 1)
def train_step(self, states, actions, rewards, terminal):
"""
:param states: A numpy array with a batch of states (shape: (batch, 84, 84, 4))
:param actions: A numpy array with a batch of actions (shape: (batch))
:param rewards: A numpy array with a batch of rewards (shape: (batch))
:param terminal: A numpy array with a batch of terminal (shape: (batch))
:return: The action from the policy (an int in [0, self.n_actions])
"""
_, global_step = self.sess.run([self.train_op, self.global_step],
feed_dict={self.input: states})
if global_step % 100 == 0:
print 'Step: ', global_step
return global_step