Skip to content

Commit

Permalink
Merge pull request #5 from EleutherAI/release-cleanup-2
Browse files Browse the repository at this point in the history
Clean up names, fn privacy, docstrings; add unit tests CI
  • Loading branch information
luciaquirke authored Apr 5, 2024
2 parents 60a17e3 + 025a039 commit f80c224
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 150 deletions.
27 changes: 0 additions & 27 deletions .github/workflows/python-package-conda.yml

This file was deleted.

29 changes: 29 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Python Package using Conda

on: [push]

jobs:
build-linux:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v3
- name: Set up Conda
uses: conda-incubator/setup-miniconda@v3
with:
python-version: "3.10"
miniforge-version: latest
use-mamba: true
mamba-version: "*"
- name: Test Python
env:
PYTHONPATH: /home/runner/work/tokengrams/tokengrams
shell: bash -l {0}
run: |
mamba install -c conda-forge numpy pytest hypothesis maturin
maturin develop
maturin build
python -m pip install --user ./target/wheels/tokengrams*.whl
pytest
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tokengrams"
version = "0.2.0"
version = "0.3.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
58 changes: 49 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,78 @@
# Tokengrams
This library allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$.
Tokengrams allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$.

Our code also allows you to turn your suffix array index into an efficient $n$-gram language model, which can be used to generate text or compute the perplexity of a given text.

The backend is written in Rust, and the Python bindings are generated using [PyO3](https://github.com/PyO3/pyo3).

# Installation
Currently you need to build and install from source using `maturin`. We plan to release wheels on PyPI soon.

```bash
pip install tokengrams
```

# Development

```bash
pip install maturin
maturin develop
```

# Usage

## Building an index
```python
from tokengrams import MemmapIndex

# Create a new index from an on-disk corpus called `document.bin` and save it to
# `pile.idx`
# `pile.idx`.
index = MemmapIndex.build(
"/mnt/ssd-1/pile_preshuffled/standard/document.bin",
"/mnt/ssd-1/nora/pile.idx",
"/data/document.bin",
"/pile.idx",
)

# Get the count of "hello world" in the corpus
# Verify index correctness
print(index.is_sorted())

# Get the count of "hello world" in the corpus.
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
print(index.count(tokenizer.encode("hello world")))

# You can now load the index from disk later using __init__
index = MemmapIndex(
"/mnt/ssd-1/pile_preshuffled/standard/document.bin",
"/mnt/ssd-1/nora/pile.idx"
"/data/document.bin",
"/pile.idx"
)
```
```

## Using an index

```python
# Count how often each token in the corpus succeeds "hello world".
print(index.count_next(tokenizer.encode("hello world")))

# Parallelise over queries
print(index.batch_count_next(
[tokenizer.encode("hello world"), tokenizer.encode("hello universe")]
))

# Autoregressively sample 10 tokens using 5-gram language statistics. Initial
# gram statistics are derived from the query, with lower order gram statistics used
# until the sequence contains at least 5 tokens.
print(index.sample(tokenizer.encode("hello world"), n=5, k=10))

# Parallelize over sequence generations
print(index.batch_sample(tokenizer.encode("hello world"), n=5, k=10, num_samples=20))

# Query whether the corpus contains "hello world"
print(index.contains(tokenizer.encode("hello world")))

# Get all n-grams beginning with "hello world" in the corpus
print(index.positions(tokenizer.encode("hello world")))
```

# Support

The best way to get support is to open an issue on this repo or post in #inductive-biases in the [EleutherAI Discord server](https://discord.gg/eleutherai). If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
10 changes: 10 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: test
channels:
- conda-forge
- defaults
dependencies:
- python=3.10
- numpy
- pytest
- hypothesis
- maturin
47 changes: 32 additions & 15 deletions src/in_memory_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::io::Read;
use crate::table::SuffixTable;
use crate::util::transmute_slice;

/// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory.
#[pyclass]
pub struct InMemoryIndex {
table: SuffixTable,
Expand All @@ -15,21 +16,25 @@ pub struct InMemoryIndex {
#[pymethods]
impl InMemoryIndex {
#[new]
fn new(_py: Python, tokens: Vec<u16>, verbose: bool) -> Self {
pub fn new(_py: Python, tokens: Vec<u16>, verbose: bool) -> Self {
InMemoryIndex {
table: SuffixTable::new(tokens, verbose),
}
}

#[staticmethod]
fn from_pretrained(path: String) -> PyResult<Self> {
pub fn from_pretrained(path: String) -> PyResult<Self> {
// TODO: handle errors here
let table: SuffixTable = deserialize(&std::fs::read(path)?).unwrap();
Ok(InMemoryIndex { table })
}

#[staticmethod]
fn from_token_file(path: String, verbose: bool, token_limit: Option<usize>) -> PyResult<Self> {
pub fn from_token_file(
path: String,
verbose: bool,
token_limit: Option<usize>,
) -> PyResult<Self> {
let mut buffer = Vec::new();
let mut file = File::open(&path)?;

Expand All @@ -46,37 +51,49 @@ impl InMemoryIndex {
})
}

fn contains(&self, query: Vec<u16>) -> bool {
pub fn contains(&self, query: Vec<u16>) -> bool {
self.table.contains(&query)
}

fn count(&self, query: Vec<u16>) -> usize {
pub fn count(&self, query: Vec<u16>) -> usize {
self.table.positions(&query).len()
}

pub fn positions(&self, query: Vec<u16>) -> Vec<u64> {
self.table.positions(&query).to_vec()
}

fn batch_next_token_counts(&self, queries: Vec<Vec<u16>>, vocab: Option<u16>) -> Vec<Vec<usize>> {
self.table.batch_next_token_counts(&queries, vocab)
pub fn count_next(&self, query: Vec<u16>, vocab: Option<u16>) -> Vec<usize> {
self.table.count_next(&query, vocab)
}

fn sample(&self, query: Vec<u16>, n: usize, k: usize) -> Result<Vec<u16>, PyErr> {
self.table.sample(&query, n, k)
.map_err(|error| PyValueError::new_err(error.to_string()))
pub fn batch_count_next(&self, queries: Vec<Vec<u16>>, vocab: Option<u16>) -> Vec<Vec<usize>> {
self.table.batch_count_next(&queries, vocab)
}

fn batch_sample(&self, query: Vec<u16>, n: usize, k: usize, num_samples: usize) -> Result<Vec<Vec<u16>>, PyErr> {
self.table.batch_sample(&query, n, k, num_samples)
.map_err(|error| PyValueError::new_err(error.to_string()))
pub fn sample(&self, query: Vec<u16>, n: usize, k: usize) -> Result<Vec<u16>, PyErr> {
self.table
.sample(&query, n, k)
.map_err(|error| PyValueError::new_err(error.to_string()))
}

fn is_sorted(&self) -> bool {
pub fn batch_sample(
&self,
query: Vec<u16>,
n: usize,
k: usize,
num_samples: usize,
) -> Result<Vec<Vec<u16>>, PyErr> {
self.table
.batch_sample(&query, n, k, num_samples)
.map_err(|error| PyValueError::new_err(error.to_string()))
}

pub fn is_sorted(&self) -> bool {
self.table.is_sorted()
}

fn save(&self, path: String) -> PyResult<()> {
pub fn save(&self, path: String) -> PyResult<()> {
// TODO: handle errors here
let bytes = serialize(&self.table).unwrap();
std::fs::write(&path, bytes)?;
Expand Down
10 changes: 4 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
pub mod mmap_slice;
pub mod table;
pub mod util;

pub use mmap_slice::MmapSlice;
pub use in_memory_index::InMemoryIndex;
pub use memmap_index::MemmapIndex;
pub use table::SuffixTable;

/// Python bindings
Expand All @@ -11,8 +9,8 @@ use pyo3::prelude::*;
mod in_memory_index;
mod memmap_index;
mod par_quicksort;
use in_memory_index::InMemoryIndex;
use memmap_index::MemmapIndex;
mod table;
mod util;

#[pymodule]
fn tokengrams(_py: Python, m: &PyModule) -> PyResult<()> {
Expand Down
Loading

0 comments on commit f80c224

Please sign in to comment.