-
Notifications
You must be signed in to change notification settings - Fork 0
/
sarsa.py
137 lines (120 loc) · 4.32 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Applies SARSA algorithm to grid-world environments.
"""
import sys
import numpy as np
import itertools
import copy
import random
import matplotlib.pyplot as plt
from collections import defaultdict
from grid_world import GridWorld
class SARSA:
"""Creates value and policy containers for SARSA evaluation of a
grid-world instance.
Attributes:
alpha (float): learning rate
c_map (TYPE): grid-world instance
discount_factor (float): discount factor
epsilon (float): policy choice probability
policy (TYPE): action policy
possible_actions (TYPE): possible actions for an agent to take
Q (TYPE): state action value storage
x_lim (TYPE): grid x-length
y_lim (TYPE): grid y-length
"""
def __init__(self, c_map, possible_actions, world, Q=None):
"""Initializes SARSA parameters
Args:
c_map (TYPE): grid-world instance
possible_actions (TYPE): possible actions for an agent to take
x_lim (TYPE): grid x-length
y_lim (TYPE): grid y-length
"""
self.alpha = 0.5
self.discount_factor = 0.9
self.c_map = c_map
self.possible_actions = possible_actions
self.world = world
self.x_lim = self.world
self.y_lim = self.world
if Q is None:
Q = dict()
self.Q = Q
self.policy = dict()
def epsilon_greedy_random_action(self, state, step=0, exploit=False):
"""Chooses action based on a greedy policy, but allows for
exploration by a 1 - epsilon probability.
Args:
state (TYPE): current state
Returns:
TYPE: action to be taken
"""
if exploit:
epsilon = 0.0
else:
epsilon = 1.0 / np.sqrt(step + 1)
p = np.random.random()
count = 0
if p < epsilon:
action = random.choice(self.possible_actions)
else:
q_all = [self.Q.get((state, a), 0.0)
for a in self.possible_actions]
max_a = [a for a in self.possible_actions if q_all[
a] == max(q_all)]
if len(max_a) > 1:
action = random.choice(max_a)
else:
action = max_a[0]
self.policy[state] = action
return action
def update_Q(self, state, action, new_state, new_action, reward):
"""Updates state value function
Args:
state (TYPE): current state
action (TYPE): action taken from current state
new_state (TYPE): new state
new_action (TYPE): action taken from new state
reward (TYPE): reward received upon taking an action from
the current state
"""
q = self.Q.get((state, action), 0.0)
self.Q[state, action] = q + self.alpha * \
(reward + self.discount_factor *
self.Q.get((new_state, new_action), 0.0) - q)
def take_step(self, state, action):
"""Agent performs action to transition to next state on grid-world.
Args:
state (TYPE): current state
action (TYPE): action taken from current state
Returns:
TYPE: new state and reward received from said state
"""
# from action [U, D, R, L] to state
x, y = state[0], state[1]
if self.c_map[state]['actions'][action]:
if action == 0:
new_state = (x, y + 1)
if action == 1:
new_state = (x, y - 1)
if action == 2:
new_state = (x + 1, y)
if action == 3:
new_state = (x - 1, y)
else:
new_state = state
return new_state, self.c_map[new_state]['reward']
def print_values(self):
"""Print state value function to terminal
"""
for y in range(self.y_lim - 1, -1, -1):
print('-' * self.x_lim * 9)
for x in range(self.x_lim):
value = np.mean([self.Q.get(((x, y), a), 0.0)
for a in self.possible_actions])
if value >= 0:
print(' {:.4f}'.format(value), "|", end="")
else:
print('{:.4f}'.format(value), "|", end="")
print("")
print('-' * self.x_lim * 9)