Tokengrams allows you to efficiently compute
Our code also allows you to turn your suffix array index into an efficient
The backend is written in Rust, and the Python bindings are generated using PyO3.
pip install tokengrams
Use a dataset of u16 or u32 tokens, or prepare one from a HuggingFace dataset.
# Get pre-tokenized dataset
from huggingface_hub import HfApi, hf_hub_download
hf_hub_download(
repo_id="EleutherAI/pile-standard-pythia-preshuffled",
repo_type="dataset",
filename="document-00000-of-00020.bin",
local_dir="."
)
# Tokenize HF dataset
from tokengrams import tokenize_hf_dataset
from datasets import load_dataset
from transformers import AutoTokenizer
tokenize_hf_dataset(
dataset=load_dataset("EleutherAI/lambada_openai", "en"),
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"),
output_path="lambada.bin",
text_key="text",
append_eod=True,
workers=1,
)
from tokengrams import MemmapIndex
# Create a new index from an on-disk corpus of u16 tokens and save it to a .idx file.
# Set verbose to true to include a progress bar for the index sort.
index = MemmapIndex.build(
"document-00000-of-00020.bin",
"document-00000-of-00020.idx",
vocab=2**16,
verbose=True
)
# True for any valid index.
print(index.is_sorted())
# Get the count of "hello world" in the corpus.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
print(index.count(tokenizer.encode("hello world")))
# You can now load the index from disk later using __init__
index = MemmapIndex(
"document-00000-of-00020.bin",
"document-00000-of-00020.idx",
vocab=2**16
)
# Count how often each token in the corpus succeeds "hello world".
print(index.count_next(tokenizer.encode("hello world")))
print(index.batch_count_next(
[tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
))
# Get smoothed probabilities for query continuations
print(index.smoothed_probs(tokenizer.encode("hello world")))
print(index.batch_smoothed_probs(
[tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
))
# Autoregressively sample 10 tokens using 5-gram language statistics. Initial
# gram statistics are derived from the query, with lower order gram statistics used
# until the sequence contains at least 5 tokens.
print(index.sample_unsmoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
print(index.sample_smoothed(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))
# Query whether the corpus contains "hello world"
print(index.contains(tokenizer.encode("hello world")))
# Get all n-grams beginning with "hello world" in the corpus
print(index.positions(tokenizer.encode("hello world")))
Corpora small enough to fit in memory can use an InMemoryIndex:
from tokengrams import InMemoryIndex
tokens = [0, 1, 2, 3, 4]
index = InMemoryIndex(tokens, vocab=5)
Larger corpora must use a MemmapIndex.
Some systems struggle with memory mapping extremely large tables (e.g. 40 billion tokens), causing unexpected bus errors. To prevent this split the corpus into shards then use a ShardedMemmapIndex to sort and query the table shard by shard:
from tokengrams import ShardedMemmapIndex
from huggingface_hub import HfApi, hf_hub_download
files = [
file for file in HfApi().list_repo_files("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset")
if file.endswith('.bin')
]
index_paths = []
for file in files:
hf_hub_download("EleutherAI/pile-standard-pythia-preshuffled", repo_type="dataset", filename=file, local_dir=".")
index_paths.append((file, f'{file.rstrip(".bin")}.idx'))
index = ShardedMemmapIndex.build(index_paths, vocab=2**16, verbose=True)
Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 232. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory.
Corpora with vocabulary sizes smaller than 216 must use u16 tokens.
Index build times for in-memory corpora scale inversely with the number of available CPU threads, whereas if the index reads from or writes to a file it is likely to be IO bound.
The time complexities of count_next(query) and sample_unsmoothed(query) are O(n log n), where n is ~ the number of completions for the query. The time complexity of sample_smoothed(query) is O(m n log n) where m is the n-gram order.
cargo build
cargo test
Develop Python bindings:
pip install maturin
maturin develop
pytest
The best way to get support is to open an issue on this repo or post in #interp-across-time in the EleutherAI Discord server. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!