From 4e4e27b5ba8004ea2f505289452ff03ed6c4c760 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Sat, 8 Apr 2023 23:19:00 +1000 Subject: [PATCH] update ndarray to jaxtyping equivalent --- .DS_Store | Bin 0 -> 6148 bytes maze_transformer/evaluation/plot_maze.py | 21 +++++++++++---------- maze_transformer/generation/constants.py | 10 +++++----- maze_transformer/generation/utils.py | 8 ++++---- maze_transformer/training/tokenizer.py | 5 ++--- 5 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5c1416ebae2b23859b7c2ad07c86e0de7332f688 GIT binary patch literal 6148 zcmeHK&5qMB82sGr(v*r-4pU>*3k9pK+?l>+i8Bd_%PeTcU9FBR$~9GcKaFlRKU2<5cI z|F*@SfjtHD4TY>I^2PNMzlokb`UL(6vq?~nDTP*|JVR|#(tj!G>)=Z$C*%VqMlQ`z z{QSGZU#4^O4g2oTqcF`nogbpISKYrP~&sJY22rc8^+a_kL&Ha#u&) zPRqS>``&t8cdp&I`QT(Ujb^d@jJH4td~;P?F?a^-%MLF-43jvL@msXAvVm2U@QR&2 zr8d?_(wL@2k1Nd1?|}oTPLUF@-$OlyzKtZYLQzpmynnvt_D`Ko=oqyFdID@i5tJ0X zgYt?R)_;t#LX4PgttRiEJlg@A-)8}Ko99{VU9c~C{yw>1o> zL=8e{DNsv=xnc-&a}ZjN_MYKyG-^2sH8T1!BMWmw5$56{Ttz3*(`a++fOVkgz@C0= z^7()K=lg$AWS^`9)`9=Z0Z|b=`|7u^EMil1erOGRe_J None: self.plot(dpi=dpi, title=title) plt.show() - def _rowcol_to_coord(self, point: Coord) -> NDArray: - """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" + def _rowcol_to_coord(self, point: Coord) -> Float[Array, "2"]: + """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x + is the horizontal axis.""" point = np.array([point[1], point[0]]) return self.unit_length * (point + 0.5) @@ -286,7 +286,7 @@ def _plot_maze(self) -> None: self.ax.imshow(img, cmap=cmap, vmin=-1, vmax=1) - def _lattice_maze_to_img(self) -> NDArray["row col", bool]: + def _lattice_maze_to_img(self) -> Bool[Array, "row col"]: """ Build an image to visualise the maze. Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. @@ -317,7 +317,7 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]: connection_values = scaled_node_values # Create background image (all pixels set to -1, walls everywhere) - img: NDArray["row col", float] = -np.ones( + img: Float[Array, "row col"] = -np.ones( ( self.maze.grid_shape[0] * self.unit_length + 1, self.maze.grid_shape[1] * self.unit_length + 1, @@ -351,12 +351,13 @@ def _lattice_maze_to_img(self) -> NDArray["row col", bool]: return img def _plot_path(self, path_format: PathFormat) -> None: - p_transformed = np.array( + p_transformed: Float[Array, "coord 2"] = np.array( [self._rowcol_to_coord(coord) for coord in path_format.path] ) + if path_format.quiver_kwargs is not None: - x: NDArray = p_transformed[:, 0] - y: NDArray = p_transformed[:, 1] + x: Float[Array, "x"] = p_transformed[:, 0] + y: Float[Array, "y"] = p_transformed[:, 1] self.ax.quiver( x[:-1], y[:-1], @@ -392,7 +393,7 @@ def _plot_path(self, path_format: PathFormat) -> None: ms=10, ) - def as_ascii(self, start=None, end=None): + def as_ascii(self, start=None, end=None) -> str: """ Returns an ASCII visualization of the maze. Courtesy of ChatGPT diff --git a/maze_transformer/generation/constants.py b/maze_transformer/generation/constants.py index cd5f197c..36f2e9ab 100644 --- a/maze_transformer/generation/constants.py +++ b/maze_transformer/generation/constants.py @@ -1,9 +1,9 @@ +from jaxtyping import Int, Int8, Array import numpy as np -from muutils.tensor_utils import NDArray -Coord = NDArray["x y", np.int8] +Coord = Int8[Array, "x y"] CoordTup = tuple[int, int] -CoordArray = NDArray["coords", np.int8] +CoordArray = Int8[Array, "coords"] SPECIAL_TOKENS: dict[str, str] = dict( adj_list_start="", @@ -19,7 +19,7 @@ padding="", ) -DIRECTIONS_MAP: NDArray["direction axes", int] = np.array( +DIRECTIONS_MAP: Int[Array, "direction axes"] = np.array( [ [0, 1], # down [0, -1], # up @@ -29,7 +29,7 @@ ) -NEIGHBORS_MASK: NDArray["coord point", int] = np.array( +NEIGHBORS_MASK: Int[Array, "coord point"] = np.array( [ [0, 1], # down [0, -1], # up diff --git a/maze_transformer/generation/utils.py b/maze_transformer/generation/utils.py index c87cb3e5..8d8c0747 100644 --- a/maze_transformer/generation/utils.py +++ b/maze_transformer/generation/utils.py @@ -1,12 +1,12 @@ import math import numpy as np -from muutils.tensor_utils import NDArray +from jaxtyping import Bool, Array def bool_array_from_string( string: str, shape: list[int], true_symbol: str = "T" -) -> NDArray: +) -> Bool[Array, "..."]: """Transform a string into an ndarray of bools. Parameters @@ -20,8 +20,8 @@ def bool_array_from_string( Returns ------- - NDArray - A ndarray with dtype bool. + Bool[Array, "..."] + An ndarray array with dtype bool and an unknown shape. Examples -------- diff --git a/maze_transformer/training/tokenizer.py b/maze_transformer/training/tokenizer.py index 9f08f61e..9f721b30 100644 --- a/maze_transformer/training/tokenizer.py +++ b/maze_transformer/training/tokenizer.py @@ -1,10 +1,9 @@ from itertools import chain - # Avoid circular import from training/config.py from typing import TYPE_CHECKING, Union # need Union as "a" | "b" doesn't work import torch -from muutils.tensor_utils import ATensor, NDArray +from muutils.tensor_utils import ATensor from transformers import PreTrainedTokenizer from transformers.tokenization_utils import BatchEncoding @@ -140,7 +139,7 @@ def batch_decode( def to_ascii( self, sequence: list[int | str] | ATensor, start=None, end=None - ) -> NDArray: + ) -> str: # Sequence should be a single maze (not batch) if isinstance(sequence, list) and isinstance(sequence[0], str): str_sequence = sequence # already decoded