Skip to content

Commit

Permalink
changes to how tokenizer hashes are handled
Browse files Browse the repository at this point in the history
- directory in which we look for the hash file can be overriden with `set_tokenizer_hashes_path`
  this could be useful for situations where writing to inside the package installation dir is not possible.
  I tried making it a relative path to `data/MazeTokenizerModular_hashes.npz` but this broke so many things
  and is honestly a worse idea.
- `MazeTokenizerModular.__hash__` now calls `MazeTokenizerModular.hash_int()`,
  which is also used in `MazeTokenizerModular.hash_b64` which should be more concise for filenames.
  this is used in `tests/all_tokenizers/test_all_tokenizers.py`
- option for more informative assert mode in `is_tested_tokenizer` (I broke some tests, was useful for debugging)
- `demo_mazetokenizermodular.ipynb` now also asserts that length of loaded hashes
  is equal to length of created tokenizers
  • Loading branch information
mivanit committed Aug 29, 2024
1 parent c9c7635 commit 4fdfebb
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 43 deletions.
90 changes: 73 additions & 17 deletions maze_dataset/tokenization/maze_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""turning a maze into text: `MazeTokenizerModular` and the legacy `TokenizationMode` enum and `MazeTokenizer` class"""

import abc
import base64
import hashlib
import random
import warnings
Expand Down Expand Up @@ -1900,13 +1901,23 @@ class MazeTokenizerModular(SerializableDataclass):
loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
)

def __hash__(self):
"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
def hash_int(self) -> int:
return int.from_bytes(
hashlib.blake2b(self.name.encode("utf-8")).digest(),
byteorder="big",
)

def __hash__(self):
"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
return self.hash_int()

def hash_b64(self) -> str:
"filename-safe base64 encoding of the hash"
return base64.b64encode(
self.hash_int().to_bytes(16, byteorder="big"),
altchars=b"-_",
).decode()

# Information Querying Methods

@cached_property
Expand Down Expand Up @@ -2008,21 +2019,34 @@ def is_legacy_equivalent(self) -> bool:
]
)

def is_tested_tokenizer(self) -> bool:
"""
Returns if the tokenizer is returned by `all_tokenizers._get_all_tokenizers`, the set of tested and reliable tokenizers.
def is_tested_tokenizer(self, do_assert: bool = False) -> bool:
"""Returns if the tokenizer is returned by `all_tokenizers._get_all_tokenizers`, the set of tested and reliable tokenizers.
Since evaluating `all_tokenizers._get_all_tokenizers` is expensive,
instead checks for membership of `self`'s hash in `get_all_tokenizer_hashes()`.
if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
"""
all_tokenizer_hashes: Int64[np.ndarray, "n_tokenizers"] = (
get_all_tokenizer_hashes()
)
hash_index: int = np.searchsorted(all_tokenizer_hashes, hash(self))
return (
hash_index < len(all_tokenizer_hashes)
and all_tokenizer_hashes[hash_index] == hash(self)
and self.is_valid()
)

in_range: bool = hash_index < len(all_tokenizer_hashes)
hashes_match: bool = all_tokenizer_hashes[hash_index] == hash(self)
is_valid: bool = self.is_valid()

if do_assert:
assert (
in_range
), f"{hash_index = } is invalid, must be at most {len(all_tokenizer_hashes) - 1}"
assert (
hashes_match
), f"{all_tokenizer_hashes[hash_index] = } != {hash(self) = }"
assert is_valid, f"self.is_valid returns False"
return True
else:
return in_range and hashes_match and is_valid

def is_AOTP(self) -> bool:
return self.has_element(PromptSequencers.AOTP)
Expand Down Expand Up @@ -2145,22 +2169,54 @@ def decode(
return output


_ALL_TOKENIZER_HASHES: Int64[np.ndarray, "n_tokenizers"]
"private array of all tokenizer hashes"
_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`"


def set_tokenizer_hashes_path(path: Path):
"""set path to tokenizer hashes, and reload the hashes if needed
the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`,
which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory.
However, this might not always work, so we provide a way to change this.
"""
global _TOKENIZER_HASHES_PATH
global _ALL_TOKENIZER_HASHES

path = Path(path)
if path.is_dir():
path = path / "MazeTokenizerModular_hashes.npz"

if not path.is_file():
raise FileNotFoundError(f"could not find maze tokenizer hashes file at: {path}")

if _TOKENIZER_HASHES_PATH.absolute() != path.absolute():
# reload if they aren't equal
_TOKENIZER_HASHES_PATH = path
_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
else:
# always set to new path
_TOKENIZER_HASHES_PATH = path


def _load_tokenizer_hashes() -> Int64[np.ndarray, "n_tokenizers"]:
"""Loads the sorted list of `all_tokenizers.ALL_TOKENIZERS` hashes from disk."""
global _TOKENIZER_HASHES_PATH
try:
path: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
path: Path = _TOKENIZER_HASHES_PATH
return np.load(path)["hashes"]
except FileNotFoundError as e:
raise FileNotFoundError(
"Tokenizers hashes cannot be loaded. To fix this:",
"\n- install the package with the `tokenizers` extra: `pip install maze-dataset[tokenizers]` (recommended)",
"\n- run `python -m maze-dataset.tokenization.save_hashes` (not recommended, might break depending on how `maze-dataset` is installed)",
"Tokenizers hashes cannot be loaded. To fix this, run",
"\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to",
"\n`data/MazeTokenizerModular_hashes.npz`",
"relative to the current working directory -- this is where the code looks for them.",
) from e


_ALL_TOKENIZER_HASHES: Int64[np.ndarray, "n_tokenizers"]


def get_all_tokenizer_hashes() -> Int64[np.ndarray, "n_tokenizers"]:
global _ALL_TOKENIZER_HASHES
try:
Expand Down
Loading

0 comments on commit 4fdfebb

Please sign in to comment.