This repository has been archived by the owner on Mar 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.py
66 lines (47 loc) · 1.67 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import chess
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
def square_color(square: chess.Square):
return "gray" if square % 2 == 0 else "#C5C5C5"
def to_node(board: chess.Board):
return board.fen(), {"color": "blue"}
def to_edge(board: chess.Board, move: chess.Move):
piece = board.piece_at(move.from_square)
return (
board.fen(),
play(board, move).fen(),
{"label": f"{piece}: {move.uci()}"},
)
def turn_color(board: chess.Board):
return "gray" if board.turn == chess.BLACK else "#C5C5C5"
def play(board: chess.Board, move: chess.Move):
next_board = board.copy(stack=False)
next_board.push(move)
return next_board
def add_layer(graph, board: chess.Board, depth: int):
if depth == 0:
return
moves = list(board.legal_moves)
edges = [to_edge(board, move) for move in moves]
graph.add_edges_from(edges)
if depth == 3:
non_pruned_moves = board.legal_moves
else:
non_pruned_moves = np.random.choice(list(board.legal_moves), 5, replace=False)
for move in non_pruned_moves:
board.push(move)
add_layer(graph, board, depth - 1)
board.pop()
def plot(graph, save_as=None, prog="twopi"):
pos = graphviz_layout(graph, prog=prog) # prog="dot" for top-down tree
plt.figure(1, figsize=(20, 20))
# colors = [turn_color(chess.Board(fen)) for fen, _ in graph.nodes.data()]
print(graph.edges.data())
nx.draw(graph, pos, node_size=100, width=1) # , #node_color=colors)
plt.axis("off")
if save_as is not None:
plt.savefig(save_as)
else:
plt.show()