Skip to content

Commit

Permalink
update ndarray to jaxtyping equivalent
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Apr 8, 2023
1 parent 16dbb1f commit 4e4e27b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
Binary file added .DS_Store
Binary file not shown.
21 changes: 11 additions & 10 deletions maze_transformer/evaluation/plot_maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Float
from jaxtyping import Float, Array, Bool
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from muutils.tensor_utils import NDArray

from maze_transformer.generation.lattice_maze import Coord, CoordArray, LatticeMaze

Expand Down Expand Up @@ -225,8 +224,9 @@ def show(self, dpi: int = 100, title: str = "") -> None:
self.plot(dpi=dpi, title=title)
plt.show()

def _rowcol_to_coord(self, point: Coord) -> NDArray:
"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
def _rowcol_to_coord(self, point: Coord) -> Float[Array, "2"]:
"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x
is the horizontal axis."""
point = np.array([point[1], point[0]])
return self.unit_length * (point + 0.5)

Expand Down Expand Up @@ -286,7 +286,7 @@ def _plot_maze(self) -> None:

self.ax.imshow(img, cmap=cmap, vmin=-1, vmax=1)

def _lattice_maze_to_img(self) -> NDArray["row col", bool]:
def _lattice_maze_to_img(self) -> Bool[Array, "row col"]:
"""
Build an image to visualise the maze.
Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
Expand Down Expand Up @@ -317,7 +317,7 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]:
connection_values = scaled_node_values

# Create background image (all pixels set to -1, walls everywhere)
img: NDArray["row col", float] = -np.ones(
img: Float[Array, "row col"] = -np.ones(
(
self.maze.grid_shape[0] * self.unit_length + 1,
self.maze.grid_shape[1] * self.unit_length + 1,
Expand Down Expand Up @@ -351,12 +351,13 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]:
return img

def _plot_path(self, path_format: PathFormat) -> None:
p_transformed = np.array(
p_transformed: Float[Array, "coord 2"] = np.array(
[self._rowcol_to_coord(coord) for coord in path_format.path]
)

if path_format.quiver_kwargs is not None:
x: NDArray = p_transformed[:, 0]
y: NDArray = p_transformed[:, 1]
x: Float[Array, "x"] = p_transformed[:, 0]
y: Float[Array, "y"] = p_transformed[:, 1]
self.ax.quiver(
x[:-1],
y[:-1],
Expand Down Expand Up @@ -392,7 +393,7 @@ def _plot_path(self, path_format: PathFormat) -> None:
ms=10,
)

def as_ascii(self, start=None, end=None):
def as_ascii(self, start=None, end=None) -> str:
"""
Returns an ASCII visualization of the maze.
Courtesy of ChatGPT
Expand Down
10 changes: 5 additions & 5 deletions maze_transformer/generation/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from jaxtyping import Int, Int8, Array
import numpy as np
from muutils.tensor_utils import NDArray

Coord = NDArray["x y", np.int8]
Coord = Int8[Array, "x y"]
CoordTup = tuple[int, int]
CoordArray = NDArray["coords", np.int8]
CoordArray = Int8[Array, "coords"]

SPECIAL_TOKENS: dict[str, str] = dict(
adj_list_start="<ADJLIST_START>",
Expand All @@ -19,7 +19,7 @@
padding="<PADDING>",
)

DIRECTIONS_MAP: NDArray["direction axes", int] = np.array(
DIRECTIONS_MAP: Int[Array, "direction axes"] = np.array(
[
[0, 1], # down
[0, -1], # up
Expand All @@ -29,7 +29,7 @@
)


NEIGHBORS_MASK: NDArray["coord point", int] = np.array(
NEIGHBORS_MASK: Int[Array, "coord point"] = np.array(
[
[0, 1], # down
[0, -1], # up
Expand Down
8 changes: 4 additions & 4 deletions maze_transformer/generation/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math

import numpy as np
from muutils.tensor_utils import NDArray
from jaxtyping import Bool, Array


def bool_array_from_string(
string: str, shape: list[int], true_symbol: str = "T"
) -> NDArray:
) -> Bool[Array, "..."]:
"""Transform a string into an ndarray of bools.
Parameters
Expand All @@ -20,8 +20,8 @@ def bool_array_from_string(
Returns
-------
NDArray
A ndarray with dtype bool.
Bool[Array, "..."]
An ndarray array with dtype bool and an unknown shape.
Examples
--------
Expand Down
5 changes: 2 additions & 3 deletions maze_transformer/training/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from itertools import chain

# Avoid circular import from training/config.py
from typing import TYPE_CHECKING, Union # need Union as "a" | "b" doesn't work

import torch
from muutils.tensor_utils import ATensor, NDArray
from muutils.tensor_utils import ATensor
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils import BatchEncoding

Expand Down Expand Up @@ -140,7 +139,7 @@ def batch_decode(

def to_ascii(
self, sequence: list[int | str] | ATensor, start=None, end=None
) -> NDArray:
) -> str:
# Sequence should be a single maze (not batch)
if isinstance(sequence, list) and isinstance(sequence[0], str):
str_sequence = sequence # already decoded
Expand Down

0 comments on commit 4e4e27b

Please sign in to comment.