Skip to content

Commit

Permalink
more docstring improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Aug 28, 2024
1 parent 8ce0c82 commit f4acb07
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 5 deletions.
19 changes: 19 additions & 0 deletions maze_dataset/dataset/rasterized.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze
this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset
see their paper:
```bibtex
@misc{schwarzschild2021learn,
title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
year={2021},
eprint={2106.04537},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
"""

import typing

import numpy as np
Expand Down
6 changes: 6 additions & 0 deletions maze_dataset/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`
`DEFAULT_GENERATORS` is a list of generator name, generator kwargs pairs used in tests and demos
"""

from maze_dataset.generation.generators import (
GENERATORS_MAP,
LatticeMazeGenerators,
Expand Down
16 changes: 16 additions & 0 deletions maze_dataset/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
"""utilities for plotting mazes and printing tokens
- any `LatticeMaze` or `SolvedMaze` comes with a `as_pixels()` method that returns
a 2D numpy array of pixel values, but this is somewhat limited
- `MazePlot` is a class that can be used to plot mazes and paths in a more customizable way
- `print_tokens` contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights)
"""

from maze_dataset.plotting.plot_dataset import plot_dataset_mazes, print_dataset_mazes
from maze_dataset.plotting.plot_maze import DEFAULT_FORMATS, MazePlot, PathFormat
from maze_dataset.plotting.print_tokens import (
color_maze_tokens_AOTP,
color_tokens_cmap,
color_tokens_rgb,
)

__all__ = [
# submodules
Expand All @@ -13,4 +26,7 @@
"DEFAULT_FORMATS",
"MazePlot",
"PathFormat",
"color_tokens_cmap",
"color_maze_tokens_AOTP",
"color_tokens_rgb",
]
5 changes: 5 additions & 0 deletions maze_dataset/plotting/plot_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""`plot_dataset_mazes` will plot several mazes using `as_pixels`
`print_dataset_mazes` will use `as_ascii` to print several mazes
"""

import matplotlib.pyplot as plt # type: ignore[import]

from maze_dataset.dataset.maze_dataset import MazeDataset
Expand Down
2 changes: 2 additions & 0 deletions maze_dataset/plotting/plot_maze.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""

from __future__ import annotations # for type hinting self as return value

import warnings
Expand Down
3 changes: 3 additions & 0 deletions maze_dataset/plotting/plot_tokens.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds"

from typing import Any, Sequence

import matplotlib.pyplot as plt
Expand All @@ -21,6 +23,7 @@ def plot_colored_text(
fig_width_scale: float = 0.25,
char_min: int = 4,
):
"hacky function to plot tokens on a matplotlib axis with colored backgrounds" ""
assert len(tokens) == len(
weights
), f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}"
Expand Down
27 changes: 26 additions & 1 deletion maze_dataset/plotting/print_tokens.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""Functions to print tokens with colors in different formats
you can color the tokens by their:
- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP`
- custom weights (i.e. attention weights) using `color_tokens_cmap`
- entirely custom colors using `color_tokens_rgb`
and the output can be in different formats, specified by `FormatType` (html, latex, terminal)
"""

import html
import textwrap
from typing import Literal, Sequence
Expand All @@ -12,26 +24,31 @@
from maze_dataset.token_utils import tokens_between

RGBArray = UInt8[np.ndarray, "n 3"]
"1D array of RGB values"

FormatType = Literal["html", "latex", "terminal", None]
"output format for the tokens"

TEMPLATES: dict[FormatType, str] = {
"html": '<span style="color: black; background-color: rgb({clr})">&nbsp{tok}&nbsp</span>',
"latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}",
"terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m",
}
"templates of printing tokens in different formats"

_COLOR_JOIN: dict[FormatType, str] = {
"html": ",",
"latex": ",",
"terminal": ";",
}
"joiner for colors in different formats"


def _escape_tok(
tok: str,
fmt: FormatType,
) -> str:
"escape token based on format"
if fmt == "html":
return html.escape(tok)
elif fmt == "latex":
Expand All @@ -50,7 +67,8 @@ def color_tokens_rgb(
clr_join: str | None = None,
max_length: int | None = None,
) -> str:
"""
"""color tokens from a list with an RGB color array
tokens will not be escaped if `fmt` is None
# Parameters:
Expand Down Expand Up @@ -104,6 +122,7 @@ def color_tokens_cmap(
template: str | None = None,
labels: bool = False,
):
"color tokens given a list of weights and a colormap"
assert len(tokens) == len(weights), f"{len(tokens)} != {len(weights)}"
weights = np.array(weights)
# normalize weights to [0, 1]
Expand Down Expand Up @@ -150,11 +169,17 @@ def color_tokens_cmap(
(SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (234, 209, 220), # red
(SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (207, 226, 243), # blue
}
"default colors for maze tokens, roughly matches the format of `as_pixels`"


def color_maze_tokens_AOTP(
tokens: list[str], fmt: FormatType = "html", template: str | None = None, **kwargs
) -> str:
"""color tokens assuming AOTP format
i.e: adjaceny list, origin, target, path
"""
output: list[str] = [
" ".join(
tokens_between(
Expand Down
12 changes: 10 additions & 2 deletions maze_dataset/tokenization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""turning a maze into text
- `MazeTokenizerModular` is the new recommended way to do this as of 1.0.0
- legacy `TokenizationMode` enum and `MazeTokenizer` class for supporting existing code
- a whole lot of helper classes and functions
"""

from maze_dataset.tokenization.maze_tokenizer import (
AdjListTokenizers,
CoordTokenizers,
Expand All @@ -21,8 +29,6 @@
"all_tokenizers",
"maze_tokenizer",
"save_hashes",
# imports
"MazeTokenizer",
# modular maze tokenization components
"TokenizationMode",
"_TokenizerElement",
Expand All @@ -40,4 +46,6 @@
# helpers
"coord_str_to_tuple",
"get_tokens_up_to_path_start",
# old tokenizer
"MazeTokenizer",
]
8 changes: 6 additions & 2 deletions maze_dataset/tokenization/maze_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""MazeTokenizerModular and the legacy TokenizationMode enum, MazeTokenizer class"""
"""turning a maze into text: `MazeTokenizerModular` and the legacy `TokenizationMode` enum and `MazeTokenizer` class"""

import abc
import hashlib
Expand Down Expand Up @@ -138,7 +138,11 @@ def get_tokens_up_to_path_start(
properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True
)
class MazeTokenizer(SerializableDataclass):
"""LEGACY: Tokenizer for mazes
"""LEGACY Tokenizer for mazes
> [!CAUTION]
> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
> for use, but will remain for compatibility with existing code.
# Parameters:
- `tokenization_mode: TokenizationMode`
Expand Down

0 comments on commit f4acb07

Please sign in to comment.