Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zanj integration: datasets & training #177

Merged
merged 89 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
25d745d
wip
mivanit Mar 28, 2023
da2e05f
Merge branch 'zanj-integration' into zanj-integration-2
mivanit Mar 28, 2023
7fdbdb0
wip
mivanit Mar 28, 2023
f7abcb0
bump muutils to 0.3.3, some zanj tests working with that
mivanit Mar 28, 2023
a31d4ba
misc
mivanit Mar 29, 2023
705e1f6
something with layernorm is causing the tensor elements not to match up
mivanit Mar 30, 2023
34a62fc
???
mivanit Mar 30, 2023
a6a5b32
exact loading of model works!
mivanit Apr 1, 2023
0181b02
ugh not quite, only working if layernorm folding disabled
mivanit Apr 1, 2023
9e2fe97
wip
mivanit Apr 1, 2023
07aa160
zanj save/load tests passing?
mivanit Apr 2, 2023
e1b28b4
fixed some unit tests, test_eval_model still fails >:(
mivanit Apr 2, 2023
84d3ae8
so confused, test only fails when model generated via training?
mivanit Apr 3, 2023
2019ed4
merge with main (and bump muutils to 0.3.6)
mivanit Apr 6, 2023
570c2b1
fixed folding issue
mivanit Apr 6, 2023
1db5c61
Merge branch 'add-notebook-testing' into zanj-integration-2
mivanit Apr 6, 2023
075ff2b
bump muutils to 0.3.7
mivanit Apr 6, 2023
808e333
updated poetry.lock
mivanit Apr 6, 2023
04b9d09
prelim to/from ascii and pixels methods, might need to be moved
mivanit Apr 6, 2023
9ab36f7
run notebook
mivanit Apr 6, 2023
4548296
merge with add-notebook-testing
mivanit Apr 9, 2023
377724a
wip
mivanit Apr 9, 2023
2406dea
wip
mivanit Apr 9, 2023
70e99f5
this was some of the most paintful debugging ive ever done
mivanit Apr 10, 2023
a8a52af
format
mivanit Apr 10, 2023
8ab6e79
bump muutils
mivanit Apr 10, 2023
6bf592b
merge with main
mivanit Apr 10, 2023
820f0b3
fixes?
mivanit Apr 10, 2023
ecb1872
format
mivanit Apr 10, 2023
b650af9
update poetry lock
mivanit Apr 10, 2023
525c719
fixes
mivanit Apr 10, 2023
93a31aa
format
mivanit Apr 10, 2023
94c675d
reworked mazeplot init
mivanit Apr 10, 2023
e612f09
wip
mivanit Apr 11, 2023
3cf9041
add unit length parameter to MazePlot
canrager Apr 11, 2023
40f4efd
misspelled folder??
mivanit Apr 11, 2023
ea7a66a
wip, but unit tests passing!
mivanit Apr 11, 2023
b09e707
wip
mivanit Apr 12, 2023
e1b774f
incomprehensible upstream issue in muutils
mivanit Apr 12, 2023
e2d3799
reworking training script
mivanit Apr 12, 2023
16b5665
wip
mivanit Apr 12, 2023
c3a9d69
test_train_model working!
mivanit Apr 12, 2023
a8f8934
wip
mivanit Apr 13, 2023
5d8bd00
test_eval_model passing
mivanit Apr 13, 2023
5238158
format
mivanit Apr 13, 2023
56ce56d
wip refactor
mivanit Apr 14, 2023
bb04c45
SolvedMaze now inherits from TargetedLatticeMaze
mivanit Apr 14, 2023
09876b1
Really dumb bug tracked down, path would overwrite endpoints in as_pi…
mivanit Apr 14, 2023
cdb9ea7
format
mivanit Apr 14, 2023
ea20e9a
Merge branch 'add-maze-from-ascii' of https://github.com/AISC-underst…
mivanit Apr 14, 2023
20436ab
remove MazePlot.show()
mivanit Apr 14, 2023
f65abbe
aaaaA
mivanit Apr 15, 2023
134e0ea
wip
mivanit Apr 15, 2023
fe4eae6
merge
mivanit Apr 15, 2023
f248e5a
wip
mivanit Apr 15, 2023
22518df
wip filtering
mivanit Apr 15, 2023
2c0728e
more filtering wip
mivanit Apr 15, 2023
360c940
wip filters
mivanit Apr 15, 2023
1ae7d6e
filters working!
mivanit Apr 15, 2023
1742ee4
filteringgit add maze_transformer/ notebooks/!
mivanit Apr 15, 2023
41223af
removed debug printing
mivanit Apr 15, 2023
52c2042
format
mivanit Apr 15, 2023
92eae14
simplified decorator, minor change to notebook
mivanit Apr 15, 2023
2180d19
filtering improvements
mivanit Apr 16, 2023
f1e304c
format
mivanit Apr 16, 2023
cee6204
bump muutils to v0.3.9
mivanit Apr 18, 2023
f491f32
Add tests for MazeDataset
valedan Apr 19, 2023
e64119e
Test custom filters
valedan Apr 19, 2023
a8fd1e5
test dataset filters
valedan Apr 19, 2023
a7148e9
fixed minor bugs in tests from zanj-integration-datasets, needs to be…
mivanit Apr 20, 2023
da56b52
initial version of maze complexity evals
mivanit Apr 20, 2023
2c13e51
fixed bug in cut_percentile_shortest and ran formatting
mivanit Apr 20, 2023
a510d41
merging in from main
mivanit Apr 20, 2023
990dbb0
format, resolved a forgotten merge conflict
mivanit Apr 20, 2023
2d91858
MazePath dissapeared again???
mivanit Apr 20, 2023
45e75dd
format (removed jaxtyping import)
mivanit Apr 20, 2023
135435a
added a TODO of something to implement for constrained dfs kwargs
mivanit Apr 25, 2023
88002f6
dumb bug that probably doesnt matter since we will remove TargetedLat…
mivanit Apr 26, 2023
e0cd326
Revert "dumb bug that probably doesnt matter since we will remove Tar…
mivanit Apr 26, 2023
15070b6
Zanj datasets getitem (#182)
valedan Apr 26, 2023
88402cd
format
mivanit Apr 28, 2023
e8b7196
format
mivanit Apr 28, 2023
04486da
Constrained dfs, dataset modifications (#184)
canrager Apr 28, 2023
6d942ef
Merge branch 'zanj-integration-datasets' of https://github.com/AISC-u…
mivanit Apr 28, 2023
54ff5a0
fixed maze dataset config hash usage, removed print from parallel wor…
mivanit Apr 28, 2023
c99f652
format
mivanit Apr 28, 2023
e30f3f0
fixed notebook test
mivanit Apr 28, 2023
e58c348
bumpy pytest to 7.3.1 to resolve missing 'mocker' fixture
mivanit Apr 28, 2023
e2f9039
fix biased baseline
valedan Apr 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ convert_notebooks:

.PHONY: test_notebooks
test_notebooks: convert_notebooks
@echo "run tests on converted notebooks in $(CONVERTED_NOTEBOOKS_TEMP_DIR) using $(HELPERS_DIR)/run_notebook_tests.py"
python $(HELPERS_DIR)/run_notebook_tests.py --notebooks-dir=$(NOTEBOOKS_DIR) --converted-notebooks-temp-dir=$(CONVERTED_NOTEBOOKS_TEMP_DIR)


Expand Down
4 changes: 2 additions & 2 deletions maze_transformer/evaluation/baseline_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _predict_next_step(
unvisited_neighbors = [coord for coord in neighbors if coord not in path]

# if the current path is already as long as the solution, there can be no correct next step
correct_step = solution[len(path)] if len(solution) > len(path) else None
correct_step = tuple(solution[len(path)]) if len(solution) > len(path) else None

if len(unvisited_neighbors) == 0:
return SPECIAL_TOKENS["path_end"]
Expand Down Expand Up @@ -89,7 +89,7 @@ def _generate_path(
maze = LatticeMaze.from_tokens(tokens)
origin_coord = self.config.dataset_cfg.token_node_map[get_origin_token(tokens)]
target_coord = self.config.dataset_cfg.token_node_map[get_target_token(tokens)]
solution = maze.find_shortest_path(origin_coord, target_coord)
solution = maze.find_shortest_path(origin_coord, target_coord).tolist()

existing_path = tokens_to_coords(
get_path_tokens(tokens), self.config.dataset_cfg
Expand Down
14 changes: 5 additions & 9 deletions maze_transformer/evaluation/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from maze_transformer.evaluation.path_evals import PathEvalFunction, PathEvals
from maze_transformer.generation.constants import SPECIAL_TOKENS
from maze_transformer.generation.lattice_maze import SolvedMaze
from maze_transformer.training.config import ConfigHolder
from maze_transformer.training.maze_dataset import MazeDataset, MazeDatasetConfig
from maze_transformer.training.training import TRAIN_SAVE_FILES
Expand Down Expand Up @@ -150,15 +149,12 @@ def evaluate_model(
name: StatCounter() for name in eval_functions.keys()
}

for batch in chunks(dataset.mazes_tokens, batch_size):
# TODO: This won't be needed after #124, then we can call mazes_objs instead
# https://github.com/orgs/AISC-understanding-search/projects/1/views/1?pane=issue&itemId=23879308
solved_mazes: SolvedMaze = [
SolvedMaze.from_tokens(tokens, dataset.cfg) for tokens in batch
for maze_batch in chunks(dataset, batch_size):
tokens_batch = [
maze.as_tokens(dataset.cfg.node_token_map) for maze in maze_batch
]

predictions = predict_maze_paths(
tokens_batch=batch,
tokens_batch=tokens_batch,
data_cfg=dataset.cfg,
model=model,
max_new_tokens=max_new_tokens,
Expand All @@ -173,7 +169,7 @@ def evaluate_model(
prediction=np.array(prediction),
model=model,
)
for sm, prediction in zip(solved_mazes, predictions)
for sm, prediction in zip(maze_batch, predictions)
)

return score_counters
12 changes: 12 additions & 0 deletions maze_transformer/evaluation/maze_complexity_evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import typing

from maze_transformer.generation.lattice_maze import SolvedMaze
from maze_transformer.utils.utils import register_method

MAZE_COMPLEXITY_EVALS: dict[str, typing.Callable[[SolvedMaze], float]] = dict()


class MazeComplexityEvals:
@register_method(MAZE_COMPLEXITY_EVALS)
def solution_length(maze: SolvedMaze) -> float:
return len(maze.solution)
15 changes: 7 additions & 8 deletions maze_transformer/evaluation/path_evals.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from typing import Iterable, Optional, Protocol, TypeAlias
import typing

import numpy as np
from jaxtyping import Int

from maze_transformer.generation.constants import Coord, CoordArray, CoordTup
from maze_transformer.generation.lattice_maze import LatticeMaze
from maze_transformer.utils.utils import register_method

# pylint: disable=unused-argument
MazePath: TypeAlias = Int[np.ndarray, "node x_y_pos"]
MazePath = CoordArray


class PathEvalFunction(Protocol):
class PathEvalFunction(typing.Protocol):
def __call__(
self,
maze: Optional[LatticeMaze] = None,
solution: Optional[CoordArray] = None,
prediction: Optional[CoordArray] = None,
maze: LatticeMaze | None = None,
solution: CoordArray | None = None,
prediction: CoordArray | None = None,
) -> float:
...


def path_as_segments_iter(path: CoordArray) -> Iterable[tuple]:
def path_as_segments_iter(path: CoordArray) -> typing.Iterable[tuple]:
"""
Iterate over the segments of a path (ie each consecutive pair).
"""
Expand Down
92 changes: 66 additions & 26 deletions maze_transformer/generation/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import random
import warnings
from typing import Any, Callable

import numpy as np

from maze_transformer.generation.constants import CoordArray
from maze_transformer.generation.lattice_maze import (
NEIGHBORS_MASK,
ConnectionList,
Coord,
CoordTup,
LatticeMaze,
SolvedMaze,
)
Expand All @@ -18,9 +19,11 @@ class LatticeMazeGenerators:

@staticmethod
def gen_dfs(
grid_shape: Coord | CoordTup,
start_coord: Coord | None = None,
grid_shape: Coord,
lattice_dim: int = 2,
n_accessible_cells: int | None = None,
max_tree_depth: int | None = None,
start_coord: Coord | None = None,
) -> LatticeMaze:
"""generate a lattice maze using depth first search, iterative

Expand All @@ -35,28 +38,39 @@ def gen_dfs(
4. Mark the chosen cell as visited and push it to the stack
"""

grid_shape = np.array(grid_shape)

# initialize the maze with no connections
connection_list: np.ndarray = np.zeros(
(lattice_dim, grid_shape[0], grid_shape[1]), dtype=bool
)

# Default values if no constraints have been passed
grid_shape: Coord = np.array(grid_shape)
n_total_cells: int = np.prod(grid_shape)
if n_accessible_cells is None:
n_accessible_cells = n_total_cells
if max_tree_depth is None:
max_tree_depth = (
2 * n_total_cells
) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
if start_coord is None:
start_coord: Coord = (
random.randint(0, grid_shape[0] - 1),
random.randint(0, grid_shape[1] - 1),
start_coord: Coord = np.random.randint(
0,
np.maximum(grid_shape - 1, 1),
size=2,
)
else:
start_coord = np.array(start_coord)

# print(f"{grid_shape = } {start_coord = }")
# initialize the maze with no connections
connection_list: ConnectionList = np.zeros(
(lattice_dim, grid_shape[0], grid_shape[1]), dtype=np.bool_
)

# initialize the stack with the target coord
visited_cells: set[tuple[int, int]] = set()
visited_cells.add(tuple(start_coord))
stack: list[Coord] = [start_coord]

# loop until the stack is empty
while stack:
# initialize tree_depth_counter
current_tree_depth: int = 1

# loop until the stack is empty or n_connected_cells is reached
while stack and (len(visited_cells) < n_accessible_cells):
# get the current coord from the stack
current_coord: Coord = stack.pop()

Expand All @@ -73,7 +87,10 @@ def gen_dfs(
)
]

if unvisited_neighbors_deltas:
# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
if unvisited_neighbors_deltas and (
current_tree_depth <= max_tree_depth / 2
):
stack.append(current_coord)

# choose one of the unvisited neighbors
Expand All @@ -92,22 +109,24 @@ def gen_dfs(
visited_cells.add(tuple(chosen_neighbor))
stack.append(chosen_neighbor)

# Update current tree depth
current_tree_depth += 1
else:
current_tree_depth -= 1

return LatticeMaze(
connection_list=connection_list,
generation_meta=dict(
func_name="gen_dfs",
grid_shape=grid_shape,
start_coord=start_coord,
visited_cells=visited_cells,
n_accessible_cells=n_accessible_cells,
max_tree_depth=max_tree_depth,
fully_connected=(len(visited_cells) == n_accessible_cells),
),
)

@classmethod
def gen_dfs_with_solution(cls, grid_shape: Coord) -> SolvedMaze:
maze: LatticeMaze = cls.gen_dfs(grid_shape)
solution: CoordArray = np.array(maze.generate_random_path())

return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

@staticmethod
def gen_wilson(
grid_shape: Coord,
Expand Down Expand Up @@ -137,9 +156,9 @@ def neighbor(current: Coord, direction: int) -> Coord:

# A connection list only contains two elements: one boolean matrix indicating all the
# downwards connections in the maze, and one boolean matrix indicating the rightwards connections.
connection_list: np.ndarray = np.zeros((2, rows, cols), dtype=bool)
connection_list: np.ndarray = np.zeros((2, rows, cols), dtype=np.bool_)

connected = np.zeros(grid_shape, dtype=bool)
connected = np.zeros(grid_shape, dtype=np.bool_)
direction_matrix = np.zeros(grid_shape, dtype=int)

# Mark a random cell as connected
Expand Down Expand Up @@ -198,12 +217,33 @@ def neighbor(current: Coord, direction: int) -> Coord:
generation_meta=dict(
func_name="gen_wilson",
grid_shape=grid_shape,
fully_connected=True,
),
)

@classmethod
def gen_dfs_with_solution(cls, grid_shape: Coord):
warnings.warn(
"gen_dfs_with_solution is deprecated, use get_maze_with_solution instead",
DeprecationWarning,
)
return get_maze_with_solution("gen_dfs", grid_shape)


# TODO: use the thing @valedan wrote for the evals function to make this automatic?
GENERATORS_MAP: dict[str, Callable[[Coord, Any], "LatticeMaze"]] = {
"gen_dfs": LatticeMazeGenerators.gen_dfs,
"gen_wilson": LatticeMazeGenerators.gen_wilson,
}


def get_maze_with_solution(
gen_name: str,
grid_shape: Coord,
maze_ctor_kwargs: dict | None = None,
) -> SolvedMaze:
if maze_ctor_kwargs is None:
maze_ctor_kwargs = dict()
maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)
solution: CoordArray = np.array(maze.generate_random_path())
return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)
Loading