-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
76 lines (66 loc) · 1.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# train the nueral network
from brain import Brain
from dqn import Dqn
from env import reset_env, step
import variables
from variables import *
import numpy as np
import matplotlib.pyplot as plt
game_display = reset_env()
# criet mode
jump = int(SIZE_RECT)
table_size = int(SCREEN_SIZE[0] / jump)
Ai_brain = Brain((table_size, table_size, 1), learningRate)
Ai_model = Ai_brain.model
#Ai_model = Ai_brain.loadModel("./model_snack.h5")
# criet dqn
Ai_memory = Dqn(variables.memSize, variables.gamma)
# Starting the main loop
epoch = 0
scores = list()
steps_ = list()
maxNCollected = 0
totNCollected = 0
totSteps = 0
while 1:
game_display, current_state = reset_env()
epoch += 1
gameOver = False
i = 0
while not gameOver:
if np.random.rand() < variables.epsilon:
action = np.random.randint(0, 4)
else:
q_value = Ai_model.predict(current_state)
action = np.argmax(q_value)
# Updating the environment
state, reward, gameOver = step(game_display, action)
Ai_memory.remember([current_state, action, reward, state], gameOver)
inputs, targets = Ai_memory.get_batch(Ai_model, batchSize)
current_state = state
loss = Ai_model.train_on_batch(inputs, targets)
i += 1
print(variables.SCORE)
if variables.SCORE > maxNCollected:
maxNCollected = variables.SCORE
if variables.SCORE > 2 :
Ai_model.save(variables.filepathToSave)
totNCollected += variables.SCORE
totSteps += i
# Showing the results each 100 games
if epoch % 10 == 0 and epoch != 0:
scores.append(totNCollected / 10)
steps_.append(totSteps / 10)
totNCollected = 0
totSteps = 0
plt.plot(scores)
plt.plot(steps_)
plt.xlabel('Epoch / 10')
plt.ylabel('Average')
plt.savefig('stats.png')
plt.close()
# Lowering the epsilon
if variables.epsilon > variables.minEpsilon:
variables.epsilon -= variables.epsilonDecayRate
# Showing the results each game
print('Epoch: ' + str(epoch) + ' Current Best: ' + str(maxNCollected) + ' Epsilon: {:.5f}'.format(variables.epsilon))