-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataProcessor.py
140 lines (104 loc) · 5.06 KB
/
dataProcessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import chess
from matplotlib import pyplot as plt
import numpy as np
import torch
from random import shuffle
from tqdm import tqdm
import os
current_dir = os.path.dirname(__file__)
DATASET_LOCATION = os.path.join(current_dir, "games.txt")
class DataPovider:
def __init__(self, seqLength) -> None:
self.games = []
self.vocab = None
self.stoi = None
self.itos = None
self.seqLength = seqLength
def loadData(self, generateStats = False):
print("Reading file...")
with open(DATASET_LOCATION, "r") as f:
data = json.load(f)
print("Parsing data...")
moveFrequency = {}
allGamesList = [x for batch in data for x in batch]
paddedGames = []
for game in tqdm(allGamesList):
paddedGame = self.getStartInput()
for move in game:
if move not in moveFrequency:
moveFrequency[move] = 0
moveFrequency[move] += 1
paddedGame = paddedGame[1:] + [move]
paddedGames.append(paddedGame)
if generateStats:
print("Total number of games: " + str(len(allGamesList)))
gamesAsSingleString = [''.join(x) for x in allGamesList]
uniqueGames = set(gamesAsSingleString)
print("Number of unique games: " + str(len(uniqueGames)))
uniqueBoards = set()
for game in tqdm(allGamesList):
board = chess.Board("8/8/8/4k3/8/8/8/5BKN w - - 0 1")
for move in game:
board.push_san(move.capitalize())
uniqueBoards.add(board.fen())
print("Number of unique positions: " + str(len(uniqueBoards)))
print("Total number of moves: " + str(sum([len(x) for x in allGamesList])))
print("Longest game: " + str(max([len(x) for x in allGamesList])))
print("Shortest game: " + str(min([len(x) for x in allGamesList])))
print("Average game length: " + str(sum([len(x) for x in allGamesList]) / len(allGamesList)))
print("Number of unique moves: " + str(len(moveFrequency)))
sortedMoves = sorted(moveFrequency.items(), key=lambda x: x[1])
print("Most common moves: " + str(sortedMoves[-10:]))
print("Least common moves: " + str(sortedMoves[:10]))
self.plotMoveFrequencyBarChart(moveFrequency)
self.vocab = ['.'] + ["<BOS>"] + [x[0] for x in tqdm(moveFrequency.items())]
self.stoi = { ch:i for i,ch in enumerate(self.vocab) }
self.itos = { i:ch for i,ch in enumerate(self.vocab) }
for game in tqdm(paddedGames):
self.games.append([self.encode(move) for move in game])
shuffle(self.games)
def encode(self, move):
return self.stoi[move]
def decode(self, move):
return self.itos[move]
def getBatch(self, train, batchSize):
if train:
minIdx = 0
maxIdx = int(len(self.games) * 0.8)
else:
minIdx = int(len(self.games) * 0.8)
maxIdx = len(self.games)
idxs = np.random.randint(minIdx, maxIdx, size=batchSize)
gamesToReturn = [self.games[x] for x in idxs]
return torch.tensor(gamesToReturn)
def getStartInput(self):
return ["."] * (self.seqLength-1) + ["<BOS>"]
def getVocab(self):
return self.vocab
def loadVocab(self, vocab):
self.vocab = vocab
self.stoi = { ch:i for i,ch in enumerate(self.vocab) }
self.itos = { i:ch for i,ch in enumerate(self.vocab) }
def isGameInSet(self, gameToCheck):
for g in self.games:
if g == gameToCheck:
return True
return False
def plotMoveFrequencyBarChart(self, moveFrequency):
sortedMoves = sorted(moveFrequency.items(), key=lambda x: x[1], reverse=True)
sortedKeys = [x[0] for x in sortedMoves]
sortedValues = [x[1] for x in sortedMoves]
#print(sortedMoves[:10])
#print(x)
#print(y)
plt.bar(sortedKeys,sortedValues)
plt.show()
if __name__ == "__main__":
dp = DataPovider(100)
dp.loadData(generateStats=True)
testGame = ['.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '<BOS>', 'Kg2', 'kd4', 'Kh3', 'ke4', 'Bg2', 'kf4', 'Bb7', 'ke5', 'Kg3', 'kf5', 'Kf3', 'ke5', 'Ng3', 'kd6', 'Be4', 'kc5', 'Nf5', 'kc4', 'Ke3', 'kc5', 'Kd2', 'kc4', 'Bb7', 'kc5', 'Kc3', 'kb6', 'Bd5', 'ka5', 'Ng7', 'kb6', 'Ne6', 'ka5', 'Bc6', 'kb6', 'Be8', 'ka6', 'Kb4', 'kb7']
print(dp.isGameInSet(testGame))
games = dp.getBatch(True, 10)
decodedGame = [dp.decode(x.item()) for x in games[0]]
print(decodedGame)