Skip to content

Commit

Permalink
add wilson with solution
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Apr 8, 2023
1 parent 16dbb1f commit 0d252fe
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 11 deletions.
11 changes: 9 additions & 2 deletions maze_transformer/generation/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def gen_dfs(
)

@classmethod
def gen_dfs_with_solution(cls, grid_shape: Coord):
def gen_dfs_with_solution(cls, grid_shape: Coord) -> SolvedMaze:
maze = cls.gen_dfs(grid_shape)
solution = np.array(maze.generate_random_path())
solution = maze.get_shortest_path_between_random_points()

return SolvedMaze(maze, solution)

Expand Down Expand Up @@ -199,6 +199,13 @@ def neighbor(current: Coord, direction: int) -> Coord:
),
)

@classmethod
def gen_wilson_with_solution(cls, grid_shape: Coord) -> SolvedMaze:
maze = cls.gen_wilson(grid_shape)
solution = maze.get_shortest_path_between_random_points()

return SolvedMaze(maze, solution)


GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = {
"gen_dfs": LatticeMazeGenerators.gen_dfs,
Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/generation/lattice_maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def find_shortest_path(
self,
c_start: CoordTup,
c_end: CoordTup,
) -> list[Coord]:
) -> list[CoordTup]:
"""find the shortest path between two coordinates, using A*"""

g_score: dict[
Expand Down Expand Up @@ -215,7 +215,7 @@ def get_nodes(self) -> list[Coord]:
for col in range(self.grid_shape[1])
]

def generate_random_path(self) -> list[Coord]:
def get_shortest_path_between_random_points(self) -> list[CoordTup]:
""" "return a path between randomly chosen start and end nodes"""

# we can't create a "path" in a single-node maze
Expand Down
8 changes: 4 additions & 4 deletions maze_transformer/training/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def maze_to_tokens(
maze: LatticeMaze,
solution: list[Coord],
solution: list[CoordTup],
node_token_map: dict[CoordTup, str],
) -> list[str]:
"""serialize maze and solution to tokens"""
Expand All @@ -29,9 +29,9 @@ def maze_to_tokens(
*chain.from_iterable(
[
[
node_token_map[tuple(c_s.tolist())],
node_token_map[tuple(list(c_s))],
SPECIAL_TOKENS["connector"],
node_token_map[tuple(c_e.tolist())],
node_token_map[tuple(list(c_e))],
SPECIAL_TOKENS["adjacency_endline"],
]
for c_s, c_e in maze.as_adj_list()
Expand All @@ -47,7 +47,7 @@ def maze_to_tokens(
node_token_map[tuple(solution[-1])],
SPECIAL_TOKENS["target_end"],
SPECIAL_TOKENS["path_start"],
*[node_token_map[tuple(c.tolist())] for c in solution],
*[node_token_map[c] for c in solution],
SPECIAL_TOKENS["path_end"],
]

Expand Down
2 changes: 1 addition & 1 deletion scripts/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def generate_solve_plot(
if start and end:
path = np.array(maze.find_shortest_path(start, end))
else:
path = np.array(maze.generate_random_path())
path = np.array(maze.get_shortest_path_between_random_points())

print(f"solving time: {time.time() - solution_start}")

Expand Down
6 changes: 6 additions & 0 deletions tests/unit/maze_transformer/generation/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ def test_gen_dfs_with_solution():
def test_wilson_generation():
maze = LatticeMazeGenerators.gen_wilson(np.array([2, 2]))
assert maze.connection_list.shape == (2, 2, 2)


def test_wilson_generation_with_solution():
maze, solution = LatticeMazeGenerators.gen_wilson_with_solution(np.array([2, 2]))
assert maze.connection_list.shape == (2, 2, 2)
assert len(solution[0]) == 2
4 changes: 2 additions & 2 deletions tests/unit/maze_transformer/generation/test_latticemaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_get_nodes():

def test_generate_random_path():
maze = LatticeMazeGenerators.gen_dfs((2, 2))
path = maze.generate_random_path()
path = maze.get_shortest_path_between_random_points()

# len > 1 ensures that we have unique start and end nodes
assert len(path) > 1
Expand All @@ -45,4 +45,4 @@ def test_generate_random_path():
def test_generate_random_path_size_1():
maze = LatticeMazeGenerators.gen_dfs((1, 1))
with pytest.raises(AssertionError):
maze.generate_random_path()
maze.get_shortest_path_between_random_points()

0 comments on commit 0d252fe

Please sign in to comment.