-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.py
64 lines (52 loc) · 1.65 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import random
"""
Simple replay memory class, storing transitions and sampling from them.
Non-standard is the option to alter memory after the fact to allow for
knowing the state after the opponent action
which is necessary for discounted rewards.
"""
class ReplayMemory:
def __init__(self, n_remember=10000):
self.memory = []
self.n_remember = n_remember
""" Add a sequence to memory """
def remember(
self,
state,
action,
reward,
next_state,
terminated,
boards_after_opponent_action,
enemy_reward_after_action,
):
self.memory.append(
[
state,
action,
reward,
next_state,
terminated,
boards_after_opponent_action,
enemy_reward_after_action,
]
)
if len(self.memory) > self.n_remember:
del self.memory[0]
""" Update the last-added memory item with boards_after_opponent_action """
def update_memory_after_opponent_action(
self, boards_after_opponent_action, enemy_reward_after_action
):
if len(self.memory) == 0:
return
self.memory[len(self.memory) - 1][5] = boards_after_opponent_action
self.memory[len(self.memory) - 1][6] = enemy_reward_after_action
""" Get (decorrelated) samples from memory """
def sample(self, n=None):
if n == None:
# Sample all
return self.memory
if n > len(self.memory):
return self.memory
# sample n memory items
return random.sample(self.memory, n)