-
Notifications
You must be signed in to change notification settings - Fork 1
/
graph.py
78 lines (65 loc) · 2.5 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from collections import Counter
import random
"""
A graph module for Markov chains.
"""
class Graph():
def __init__(self, state=None):
"""
Initialize a graph dict of key=state, val=neighbors and a current
state.
Any given state except None will be considered valid.
"""
if state is None:
self._state_key = None
else:
self._vertices = {state: Counter()}
self._state_key = state
def add_edge(self, next_state):
"""
Constructs an edge between the most recently updated vertex and a
vertex with state 'next_state'.
"""
assert(hasattr(next_state, '__iter__'))
if self._state_key is None:
# First token added to the graph.
self.__init__(next_state)
else:
# Add the new state as a vertex in the graph if it's not already.
if next_state not in self._vertices:
self._vertices[next_state] = Counter()
# Construct the edge between the previous token's vertex and the
# new.
self._vertices[self._state_key][next_state] += 1
# Update the new state
self._state_key = next_state
def get_random_token(self):
"""
Returns the token from a yielded neighbor of the current state, if
possible, or sets and returns a token from a new state at random.
"""
if self._state_key is None:
raise ValueError("You must add some data to the graph!")
try:
self._state_key = self.yield_neighbor()
except ValueError:
# The state had no neighbors, so choose a new state at random.
self._state_key = random.choice(list(self._vertices))
return self._state_key[-1]
def yield_neighbor(self):
"""
From the current state, yields a random neighbor.
"""
neighbor_count = sum(self._vertices[self._state_key].values())
if neighbor_count is 0:
raise ValueError("The current state has no neighbors.")
selection = random.randint(1, neighbor_count)
for neighbor, weight in self._vertices[self._state_key].items():
# yield a random selection
selection -= weight
if selection <= 0:
return neighbor
assert(False) # never reach here
def __str__(self):
return str(["State: '{}', Neighbors: '{}'".format(k, v) for k, v
in self._vertices.items()])