From 7ffc337f54e90ab3ffae8e543281689b0105cc92 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 4 Apr 2024 00:33:58 +0000 Subject: [PATCH 1/6] Add unbatched next_token_counts to API; rename to count_next_tokens; add code examples to README.md --- README.md | 52 +++++++++++++++++++++++++++++++++------ src/in_memory_index.rs | 27 +++++++++++--------- src/lib.rs | 10 +++----- src/memmap_index.rs | 23 ++++++++++------- src/table.rs | 32 ++++++++++++------------ tokengrams/tokengrams.pyi | 10 ++++++-- 6 files changed, 102 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 0c6e1a8..4b0c3ff 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ # Tokengrams -This library allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$. +Tokengrams allows you to efficiently compute $n$-gram statistics for pre-tokenized text corpora used to train large language models. It does this not by explicitly pre-computing the $n$-gram counts for fixed $n$, but by creating a [suffix array](https://en.wikipedia.org/wiki/Suffix_array) index which allows you to efficiently compute the count of an $n$-gram on the fly for any $n$. Our code also allows you to turn your suffix array index into an efficient $n$-gram language model, which can be used to generate text or compute the perplexity of a given text. The backend is written in Rust, and the Python bindings are generated using [PyO3](https://github.com/PyO3/pyo3). # Installation -Currently you need to build and install from source using `maturin`. We plan to release wheels on PyPI soon. + +```bash +pip install tokengrams +``` + +# Development ```bash pip install maturin @@ -14,17 +19,22 @@ maturin develop ``` # Usage + +## Building an index ```python from tokengrams import MemmapIndex # Create a new index from an on-disk corpus called `document.bin` and save it to -# `pile.idx` +# `pile.idx`. index = MemmapIndex.build( - "/mnt/ssd-1/pile_preshuffled/standard/document.bin", - "/mnt/ssd-1/nora/pile.idx", + "/data/document.bin", + "/pile.idx", ) -# Get the count of "hello world" in the corpus +# Verify index correctness +print(index.is_sorted()) + +# Get the count of "hello world" in the corpus. from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m") @@ -32,7 +42,33 @@ print(index.count(tokenizer.encode("hello world"))) # You can now load the index from disk later using __init__ index = MemmapIndex( - "/mnt/ssd-1/pile_preshuffled/standard/document.bin", - "/mnt/ssd-1/nora/pile.idx" + "/data/document.bin", + "/pile.idx" ) +``` + +## Using an index + +```python +# Count how often each token in the corpus succeeds "hello world". +print(index.count_next(tokenizer.encode("hello world"))) + +# Parallelise over queries +print(index.batch_count_next( + [tokenizer.encode("hello world"), tokenizer.encode("hello universe")] +)) + +# Autoregressively sample 10 tokens using 5-gram language statistics. Initial +# gram statistics are derived from the query, with lower order gram statistics used +# until the sequence contains at least 5 tokens. +print(index.sample(tokenizer.encode("hello world"), n=5, k=10)) + +# Parallelize over sequence generations +print(index.batch_sample(tokenizer.encode("hello world"), n=5, k=10, num_samples=20)) + +# Query whether the corpus contains "hello world" +print(index.contains(tokenizer.encode("hello world"))) + +# Get all n-grams beginning with "hello world" in the corpus +print(index.positions(tokenizer.encode("hello world"))) ``` \ No newline at end of file diff --git a/src/in_memory_index.rs b/src/in_memory_index.rs index ee94aec..c4d6282 100644 --- a/src/in_memory_index.rs +++ b/src/in_memory_index.rs @@ -7,6 +7,7 @@ use std::io::Read; use crate::table::SuffixTable; use crate::util::transmute_slice; +/// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory. #[pyclass] pub struct InMemoryIndex { table: SuffixTable, @@ -15,21 +16,21 @@ pub struct InMemoryIndex { #[pymethods] impl InMemoryIndex { #[new] - fn new(_py: Python, tokens: Vec, verbose: bool) -> Self { + pub fn new(_py: Python, tokens: Vec, verbose: bool) -> Self { InMemoryIndex { table: SuffixTable::new(tokens, verbose), } } #[staticmethod] - fn from_pretrained(path: String) -> PyResult { + pub fn from_pretrained(path: String) -> PyResult { // TODO: handle errors here let table: SuffixTable = deserialize(&std::fs::read(path)?).unwrap(); Ok(InMemoryIndex { table }) } #[staticmethod] - fn from_token_file(path: String, verbose: bool, token_limit: Option) -> PyResult { + pub fn from_token_file(path: String, verbose: bool, token_limit: Option) -> PyResult { let mut buffer = Vec::new(); let mut file = File::open(&path)?; @@ -46,11 +47,11 @@ impl InMemoryIndex { }) } - fn contains(&self, query: Vec) -> bool { + pub fn contains(&self, query: Vec) -> bool { self.table.contains(&query) } - fn count(&self, query: Vec) -> usize { + pub fn count(&self, query: Vec) -> usize { self.table.positions(&query).len() } @@ -58,25 +59,29 @@ impl InMemoryIndex { self.table.positions(&query).to_vec() } - fn batch_next_token_counts(&self, queries: Vec>, vocab: Option) -> Vec> { - self.table.batch_next_token_counts(&queries, vocab) + pub fn count_next(&self, query: Vec, vocab: Option) -> Vec { + self.table.count_next(&query, vocab) } - fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { + pub fn batch_count_next(&self, queries: Vec>, vocab: Option) -> Vec> { + self.table.batch_count_next(&queries, vocab) + } + + pub fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { self.table.sample(&query, n, k) .map_err(|error| PyValueError::new_err(error.to_string())) } - fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { + pub fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { self.table.batch_sample(&query, n, k, num_samples) .map_err(|error| PyValueError::new_err(error.to_string())) } - fn is_sorted(&self) -> bool { + pub fn is_sorted(&self) -> bool { self.table.is_sorted() } - fn save(&self, path: String) -> PyResult<()> { + pub fn save(&self, path: String) -> PyResult<()> { // TODO: handle errors here let bytes = serialize(&self.table).unwrap(); std::fs::write(&path, bytes)?; diff --git a/src/lib.rs b/src/lib.rs index d8c97e7..6a34c40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,18 +1,16 @@ pub mod mmap_slice; -pub mod table; -pub mod util; - -pub use mmap_slice::MmapSlice; +pub use in_memory_index::InMemoryIndex; +pub use memmap_index::MemmapIndex; pub use table::SuffixTable; /// Python bindings use pyo3::prelude::*; +mod table; +mod util; mod in_memory_index; mod memmap_index; mod par_quicksort; -use in_memory_index::InMemoryIndex; -use memmap_index::MemmapIndex; #[pymodule] fn tokengrams(_py: Python, m: &PyModule) -> PyResult<()> { diff --git a/src/memmap_index.rs b/src/memmap_index.rs index ab0a011..500aaf5 100644 --- a/src/memmap_index.rs +++ b/src/memmap_index.rs @@ -7,6 +7,7 @@ use crate::mmap_slice::{MmapSlice, MmapSliceMut}; use crate::table::SuffixTable; use crate::par_quicksort::par_sort_unstable_by_key; +/// A memmap index exposes suffix table functionality over text corpora too large to fit in memory. #[pyclass] pub struct MemmapIndex { table: SuffixTable, MmapSlice>, @@ -15,7 +16,7 @@ pub struct MemmapIndex { #[pymethods] impl MemmapIndex { #[new] - fn new(_py: Python, text_path: String, table_path: String) -> PyResult { + pub fn new(_py: Python, text_path: String, table_path: String) -> PyResult { let text_file = File::open(&text_path)?; let table_file = File::open(&table_path)?; @@ -28,7 +29,7 @@ impl MemmapIndex { } #[staticmethod] - fn build(text_path: String, table_path: String, verbose: bool) -> PyResult { + pub fn build(text_path: String, table_path: String, verbose: bool) -> PyResult { // Memory map the text as read-only let text_mmap = MmapSlice::new(&File::open(&text_path)?)?; @@ -76,11 +77,11 @@ impl MemmapIndex { }) } - fn contains(&self, query: Vec) -> bool { + pub fn contains(&self, query: Vec) -> bool { self.table.contains(&query) } - fn count(&self, query: Vec) -> usize { + pub fn count(&self, query: Vec) -> usize { self.table.positions(&query).len() } @@ -88,21 +89,25 @@ impl MemmapIndex { self.table.positions(&query).to_vec() } - fn batch_next_token_counts(&self, queries: Vec>, vocab: Option) -> Vec> { - self.table.batch_next_token_counts(&queries, vocab) + pub fn count_next(&self, query: Vec, vocab: Option) -> Vec { + self.table.count_next(&query, vocab) } - fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { + pub fn batch_count_next(&self, queries: Vec>, vocab: Option) -> Vec> { + self.table.batch_count_next(&queries, vocab) + } + + pub fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { self.table.sample(&query, n, k) .map_err(|error| PyValueError::new_err(error.to_string())) } - fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { + pub fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { self.table.batch_sample(&query, n, k, num_samples) .map_err(|error| PyValueError::new_err(error.to_string())) } - fn is_sorted(&self) -> bool { + pub fn is_sorted(&self) -> bool { self.table.is_sorted() } } diff --git a/src/table.rs b/src/table.rs index c560033..e759315 100644 --- a/src/table.rs +++ b/src/table.rs @@ -8,7 +8,8 @@ use rand::thread_rng; use anyhow::Result; use crate::par_quicksort::par_sort_unstable_by_key; -/// A suffix table is a sequence of lexicographically sorted suffixes. +/// A suffix table is a sequence of lexicographically sorted suffixes. +/// The table supports n-gram statistics computation and language modeling over text corpora. #[derive(Clone, Deserialize, Eq, PartialEq, Serialize)] pub struct SuffixTable, U = Box<[u64]>> { text: T, @@ -213,9 +214,8 @@ where } } - /// Returns an unordered list of counts of token values that succeed `query`. - /// Counts all tokens if query is empty. - fn next_token_counts(&self, query: &[u16], vocab: Option) -> Vec { + /// Returns an unordered list of token counts succeeding `query`. Counts all tokens if query is empty. + pub fn count_next(&self, query: &[u16], vocab: Option) -> Vec { let mut counts: Vec = vec![0usize; vocab.unwrap_or(u16::MAX) as usize + 1]; let mut suffixed_query = query.to_vec(); let (range_start, range_end) = self.boundaries(query); @@ -231,10 +231,10 @@ where } /// Count the occurrences of each token that directly follows each query sequence. - pub fn batch_next_token_counts(&self, queries: &[Vec], vocab: Option) -> Vec> { + pub fn batch_count_next(&self, queries: &[Vec], vocab: Option) -> Vec> { queries.into_par_iter() .map(|query| { - self.next_token_counts(query, vocab) + self.count_next(query, vocab) }) .collect() } @@ -250,7 +250,7 @@ where let start = sequence.len().saturating_sub(n as usize - 1); let prev = &sequence[start..]; - let counts: Vec = self.next_token_counts(prev, None); + let counts: Vec = self.count_next(prev, None); let dist = WeightedIndex::new(&counts)?; let sampled_index = dist.sample(&mut rng); @@ -321,38 +321,38 @@ mod tests { } #[test] - fn next_token_counts_exists() { + fn count_next_exists() { let sa = sais("aaab"); let query = utf16!("a"); let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; - assert_eq!(2, sa.next_token_counts(query, Option::None)[a_index]); - assert_eq!(1, sa.next_token_counts(query, Option::None)[b_index]); + assert_eq!(2, sa.count_next(query, Option::None)[a_index]); + assert_eq!(1, sa.count_next(query, Option::None)[b_index]); } #[test] - fn next_token_counts_empty_query() { + fn count_next_empty_query() { let sa = sais("aaab"); let query = utf16!(""); let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; - assert_eq!(3, sa.next_token_counts(query, Option::None)[a_index]); - assert_eq!(1, sa.next_token_counts(query, Option::None)[b_index]); + assert_eq!(3, sa.count_next(query, Option::None)[a_index]); + assert_eq!(1, sa.count_next(query, Option::None)[b_index]); } #[test] - fn batch_next_token_counts_exists() { + fn batch_count_next_exists() { let sa = sais("aaab"); let queries: Vec> = vec![vec![utf16!("a")[0]; 1]; 10_000]; let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; - assert_eq!(2, sa.batch_next_token_counts(&queries, Option::None)[0][a_index]); - assert_eq!(1, sa.batch_next_token_counts(&queries, Option::None)[0][b_index]); + assert_eq!(2, sa.batch_count_next(&queries, Option::None)[0][a_index]); + assert_eq!(1, sa.batch_count_next(&queries, Option::None)[0][b_index]); } } \ No newline at end of file diff --git a/tokengrams/tokengrams.pyi b/tokengrams/tokengrams.pyi index 6258d10..49d2e01 100644 --- a/tokengrams/tokengrams.pyi +++ b/tokengrams/tokengrams.pyi @@ -17,7 +17,10 @@ class InMemoryIndex: def positions(self, query: list[int]) -> list[int]: """Returns an unordered list of positions where `query` starts in `text`.""" - def batch_next_token_counts(self, queries: list[list[int]], vocab: int | None) -> list[list[int]]: + def count_next(self, query: list[int], vocab: int | None) -> list[int]: + """Count the occurrences of each token directly following `query`.""" + + def batch_count_next(self, queries: list[list[int]], vocab: int | None) -> list[list[int]]: """Count the occurrences of each token that directly follows each sequence in `queries`.""" def sample(self, query: list[int], n: int, k: int) -> list[int]: @@ -53,7 +56,10 @@ class MemmapIndex: def positions(self, query: list[int]) -> list[int]: """Returns an unordered list of positions where `query` starts in `text`.""" - def batch_next_token_counts(self, queries: list[list[int]], vocab: int | None) -> list[list[int]]: + def count_next(self, query: list[int], vocab: int | None) -> list[int]: + """Count the occurrences of each token directly following `query`.""" + + def batch_count_next(self, queries: list[list[int]], vocab: int | None) -> list[list[int]]: """Count the occurrences of each token that directly follows each sequence in `queries`.""" def sample(self, query: list[int], n: int, k: int) -> list[int]: From c6a9fa95efb3ead0d169ddfc14356b34d7237a6f Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 4 Apr 2024 02:57:55 +0000 Subject: [PATCH 2/6] run clippy --- src/par_quicksort.rs | 8 ++++---- src/table.rs | 25 ++++++++++++------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/par_quicksort.rs b/src/par_quicksort.rs index 3337522..51ade13 100644 --- a/src/par_quicksort.rs +++ b/src/par_quicksort.rs @@ -9,7 +9,6 @@ use std::marker::PhantomData; use std::mem::{self, MaybeUninit}; use std::ptr; use indicatif::{ProgressBar, ProgressStyle}; -use rayon_core; pub fn par_sort_unstable_by_key(data: &mut [T], f: F, verbose: bool) where @@ -870,12 +869,13 @@ where // Limit the number of imbalanced partitions to `floor(log2(len)) + 1`. let limit = usize::BITS - v.len().leading_zeros(); let pbar = if verbose { - ProgressBar::new((v.len() as f64 / 2000.0).ceil() as u64) + let p = ProgressBar::new((v.len() as f64 / 2000.0).ceil() as u64); + p.set_style(ProgressStyle::with_template("{elapsed} elapsed (estimated duration {duration}) {bar:80}") + .unwrap()); + p } else { ProgressBar::hidden() }; - pbar.set_style(ProgressStyle::with_template("{elapsed} elapsed (estimated duration {duration}) {bar:80}") - .unwrap()); recurse(v, &is_less, None, limit, &pbar); pbar.finish(); } diff --git a/src/table.rs b/src/table.rs index e759315..5e104df 100644 --- a/src/table.rs +++ b/src/table.rs @@ -102,7 +102,7 @@ where /// ``` #[allow(dead_code)] pub fn contains(&self, query: &[u16]) -> bool { - query.len() > 0 + !query.is_empty() && self .table .binary_search_by(|&sufi| { @@ -141,8 +141,8 @@ where pub fn positions(&self, query: &[u16]) -> &[u64] { // We can quickly decide whether the query won't match at all if // it's outside the range of suffixes. - if self.text.len() == 0 - || query.len() == 0 + if self.text.is_empty() + || query.is_empty() || (query < self.suffix(0) && !self.suffix(0).starts_with(query)) || query > self.suffix(self.len() - 1) { @@ -173,8 +173,8 @@ where /// Returns the start and end indices of query matches in text fn boundaries(&self, query: &[u16]) -> (usize, usize) { - if self.text.len() == 0 - || query.len() == 0 + if self.text.is_empty() + || query.is_empty() || (query < self.suffix(0) && !self.suffix(0).starts_with(query)) || query > self.suffix(self.len() - 1) { @@ -193,9 +193,9 @@ where /// Returns an unordered list of positions where `query` starts in `text`, limiting the search to a /// specified range of the suffix table. fn range_positions(&self, query: &[u16], range_start: usize, range_end: usize) -> &[u64] { - if self.text.len() == 0 - || query.len() == 0 - || (query < self.suffix(0 + range_start) && !self.suffix(0 + range_start).starts_with(query)) + if self.text.is_empty() + || query.is_empty() + || (query < self.suffix(range_start) && !self.suffix(range_start).starts_with(query)) || query > self.suffix(std::cmp::max(0, range_end - 1)) { return &[]; @@ -220,11 +220,10 @@ where let mut suffixed_query = query.to_vec(); let (range_start, range_end) = self.boundaries(query); - for i in 0..counts.len() { + for (i, count) in counts.iter_mut().enumerate() { suffixed_query.push(i as u16); - let positions = self.range_positions(&suffixed_query, range_start, range_end); - counts[i] = positions.len(); + *count = positions.len(); suffixed_query.pop(); } counts @@ -247,7 +246,7 @@ where for _ in 0..k { // look at the previous (n - 1) characters to predict the n-gram completion - let start = sequence.len().saturating_sub(n as usize - 1); + let start = sequence.len().saturating_sub(n - 1); let prev = &sequence[start..]; let counts: Vec = self.count_next(prev, None); @@ -273,7 +272,7 @@ where /// Checks if the suffix table is lexicographically sorted. This is always true for valid suffix tables. pub fn is_sorted(&self) -> bool { self.table.windows(2).all(|pair| { - &self.text[pair[0] as usize..] <= &self.text[pair[1] as usize..] + self.text[pair[0] as usize..] <= self.text[pair[1] as usize..] }) } } From 54d0a5cb78767190f783c5988ba6756dcf825eef Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 4 Apr 2024 02:58:55 +0000 Subject: [PATCH 3/6] rustfmt lib --- src/in_memory_index.rs | 24 ++++++++++++----- src/lib.rs | 4 +-- src/memmap_index.rs | 53 +++++++++++++++++++++++++------------- src/par_quicksort.rs | 19 ++++++++++---- src/table.rs | 58 +++++++++++++++++++++++------------------- 5 files changed, 102 insertions(+), 56 deletions(-) diff --git a/src/in_memory_index.rs b/src/in_memory_index.rs index c4d6282..819c4c3 100644 --- a/src/in_memory_index.rs +++ b/src/in_memory_index.rs @@ -30,7 +30,11 @@ impl InMemoryIndex { } #[staticmethod] - pub fn from_token_file(path: String, verbose: bool, token_limit: Option) -> PyResult { + pub fn from_token_file( + path: String, + verbose: bool, + token_limit: Option, + ) -> PyResult { let mut buffer = Vec::new(); let mut file = File::open(&path)?; @@ -68,13 +72,21 @@ impl InMemoryIndex { } pub fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { - self.table.sample(&query, n, k) - .map_err(|error| PyValueError::new_err(error.to_string())) + self.table + .sample(&query, n, k) + .map_err(|error| PyValueError::new_err(error.to_string())) } - pub fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { - self.table.batch_sample(&query, n, k, num_samples) - .map_err(|error| PyValueError::new_err(error.to_string())) + pub fn batch_sample( + &self, + query: Vec, + n: usize, + k: usize, + num_samples: usize, + ) -> Result>, PyErr> { + self.table + .batch_sample(&query, n, k, num_samples) + .map_err(|error| PyValueError::new_err(error.to_string())) } pub fn is_sorted(&self) -> bool { diff --git a/src/lib.rs b/src/lib.rs index 6a34c40..9a49b47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,11 +6,11 @@ pub use table::SuffixTable; /// Python bindings use pyo3::prelude::*; -mod table; -mod util; mod in_memory_index; mod memmap_index; mod par_quicksort; +mod table; +mod util; #[pymodule] fn tokengrams(_py: Python, m: &PyModule) -> PyResult<()> { diff --git a/src/memmap_index.rs b/src/memmap_index.rs index 500aaf5..d83bbfa 100644 --- a/src/memmap_index.rs +++ b/src/memmap_index.rs @@ -4,8 +4,8 @@ use std::fs::{File, OpenOptions}; use std::time::Instant; use crate::mmap_slice::{MmapSlice, MmapSliceMut}; -use crate::table::SuffixTable; use crate::par_quicksort::par_sort_unstable_by_key; +use crate::table::SuffixTable; /// A memmap index exposes suffix table functionality over text corpora too large to fit in memory. #[pyclass] @@ -48,7 +48,10 @@ impl MemmapIndex { let start = Instant::now(); let mut table_mmap = MmapSliceMut::::new(&table_file)?; - table_mmap.iter_mut().enumerate().for_each(|(i, x)| *x = i as u64); + table_mmap + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = i as u64); assert_eq!(table_mmap.len(), text_mmap.len()); println!("Time elapsed: {:?}", start.elapsed()); @@ -58,16 +61,24 @@ impl MemmapIndex { // available as well. These magic numbers were tuned on a server with 48 physical cores. // Empirically we start getting stack overflows between 5B and 10B tokens when using the // default stack size of 2MB. We scale the stack size as log2(n) * 8MB to avoid this. - let scale = (text_mmap.len() as f64) / 5e9; // 5B tokens - let stack_size = scale.log2().max(1.0) * 8e6; // 8MB - - rayon::ThreadPoolBuilder::new().stack_size(stack_size as usize).build().unwrap().install(|| { - // Sort the indices by the suffixes they point to. - // The unstable algorithm is critical for avoiding out-of-memory errors, since it does - // not allocate any more memory than the input and output slices. - println!("Sorting indices..."); - par_sort_unstable_by_key(table_mmap.as_slice_mut(), |&i| &text_mmap[i as usize..], verbose); - }); + let scale = (text_mmap.len() as f64) / 5e9; // 5B tokens + let stack_size = scale.log2().max(1.0) * 8e6; // 8MB + + rayon::ThreadPoolBuilder::new() + .stack_size(stack_size as usize) + .build() + .unwrap() + .install(|| { + // Sort the indices by the suffixes they point to. + // The unstable algorithm is critical for avoiding out-of-memory errors, since it does + // not allocate any more memory than the input and output slices. + println!("Sorting indices..."); + par_sort_unstable_by_key( + table_mmap.as_slice_mut(), + |&i| &text_mmap[i as usize..], + verbose, + ); + }); println!("Time elapsed: {:?}", start.elapsed()); // Re-open the table as read-only @@ -98,13 +109,21 @@ impl MemmapIndex { } pub fn sample(&self, query: Vec, n: usize, k: usize) -> Result, PyErr> { - self.table.sample(&query, n, k) - .map_err(|error| PyValueError::new_err(error.to_string())) + self.table + .sample(&query, n, k) + .map_err(|error| PyValueError::new_err(error.to_string())) } - pub fn batch_sample(&self, query: Vec, n: usize, k: usize, num_samples: usize) -> Result>, PyErr> { - self.table.batch_sample(&query, n, k, num_samples) - .map_err(|error| PyValueError::new_err(error.to_string())) + pub fn batch_sample( + &self, + query: Vec, + n: usize, + k: usize, + num_samples: usize, + ) -> Result>, PyErr> { + self.table + .batch_sample(&query, n, k, num_samples) + .map_err(|error| PyValueError::new_err(error.to_string())) } pub fn is_sorted(&self) -> bool { diff --git a/src/par_quicksort.rs b/src/par_quicksort.rs index 51ade13..5b6e12d 100644 --- a/src/par_quicksort.rs +++ b/src/par_quicksort.rs @@ -4,11 +4,11 @@ //! The only difference from the original is that calls to `recurse` are executed in parallel using //! `rayon_core::join`. +use indicatif::{ProgressBar, ProgressStyle}; use std::cmp; use std::marker::PhantomData; use std::mem::{self, MaybeUninit}; use std::ptr; -use indicatif::{ProgressBar, ProgressStyle}; pub fn par_sort_unstable_by_key(data: &mut [T], f: F, verbose: bool) where @@ -754,8 +754,13 @@ where /// /// `limit` is the number of allowed imbalanced partitions before switching to `heapsort`. If zero, /// this function will immediately switch to heapsort. -fn recurse<'a, T, F>(mut v: &'a mut [T], is_less: &F, mut pred: Option<&'a mut T>, mut limit: u32, pbar: &ProgressBar) -where +fn recurse<'a, T, F>( + mut v: &'a mut [T], + is_less: &F, + mut pred: Option<&'a mut T>, + mut limit: u32, + pbar: &ProgressBar, +) where T: Send, F: Fn(&T, &T) -> bool + Sync, { @@ -870,8 +875,12 @@ where let limit = usize::BITS - v.len().leading_zeros(); let pbar = if verbose { let p = ProgressBar::new((v.len() as f64 / 2000.0).ceil() as u64); - p.set_style(ProgressStyle::with_template("{elapsed} elapsed (estimated duration {duration}) {bar:80}") - .unwrap()); + p.set_style( + ProgressStyle::with_template( + "{elapsed} elapsed (estimated duration {duration}) {bar:80}", + ) + .unwrap(), + ); p } else { ProgressBar::hidden() diff --git a/src/table.rs b/src/table.rs index 5e104df..4ccb2ed 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,14 +1,14 @@ extern crate utf16_literal; +use crate::par_quicksort::par_sort_unstable_by_key; +use anyhow::Result; +use rand::distributions::{Distribution, WeightedIndex}; +use rand::thread_rng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::{fmt, ops::Deref, u64}; -use rand::distributions::{Distribution, WeightedIndex}; -use rand::thread_rng; -use anyhow::Result; -use crate::par_quicksort::par_sort_unstable_by_key; -/// A suffix table is a sequence of lexicographically sorted suffixes. +/// A suffix table is a sequence of lexicographically sorted suffixes. /// The table supports n-gram statistics computation and language modeling over text corpora. #[derive(Clone, Deserialize, Eq, PartialEq, Serialize)] pub struct SuffixTable, U = Box<[u64]>> { @@ -201,12 +201,14 @@ where return &[]; } - let start = binary_search(&self.table[range_start..range_end], |&sufi| query <= &self.text[sufi as usize..]); + let start = binary_search(&self.table[range_start..range_end], |&sufi| { + query <= &self.text[sufi as usize..] + }); let end = start + binary_search(&self.table[range_start + start..range_end], |&sufi| { !self.text[sufi as usize..].starts_with(query) }); - + if start > end { &[] } else { @@ -231,14 +233,13 @@ where /// Count the occurrences of each token that directly follows each query sequence. pub fn batch_count_next(&self, queries: &[Vec], vocab: Option) -> Vec> { - queries.into_par_iter() - .map(|query| { - self.count_next(query, vocab) - }) + queries + .into_par_iter() + .map(|query| self.count_next(query, vocab)) .collect() } - /// Autoregressively sample k characters from a conditional distribution based + /// Autoregressively sample k characters from a conditional distribution based /// on the previous (n - 1) characters (n-gram prefix) in the sequence. pub fn sample(&self, query: &[u16], n: usize, k: usize) -> Result> { let mut rng = thread_rng(); @@ -248,7 +249,7 @@ where // look at the previous (n - 1) characters to predict the n-gram completion let start = sequence.len().saturating_sub(n - 1); let prev = &sequence[start..]; - + let counts: Vec = self.count_next(prev, None); let dist = WeightedIndex::new(&counts)?; let sampled_index = dist.sample(&mut rng); @@ -259,21 +260,26 @@ where Ok(sequence) } - /// Autoregressively samples num_samples of k characters each from conditional distributions based + /// Autoregressively samples num_samples of k characters each from conditional distributions based /// on the previous (n - 1) characters (n-gram prefix) in the sequence.""" - pub fn batch_sample(&self, query: &[u16], n: usize, k: usize, num_samples: usize) -> Result>> { - (0..num_samples).into_par_iter() - .map(|_| { - self.sample(query, n, k) - }) + pub fn batch_sample( + &self, + query: &[u16], + n: usize, + k: usize, + num_samples: usize, + ) -> Result>> { + (0..num_samples) + .into_par_iter() + .map(|_| self.sample(query, n, k)) .collect() } /// Checks if the suffix table is lexicographically sorted. This is always true for valid suffix tables. pub fn is_sorted(&self) -> bool { - self.table.windows(2).all(|pair| { - self.text[pair[0] as usize..] <= self.text[pair[1] as usize..] - }) + self.table + .windows(2) + .all(|pair| self.text[pair[0] as usize..] <= self.text[pair[1] as usize..]) } } @@ -322,7 +328,7 @@ mod tests { #[test] fn count_next_exists() { let sa = sais("aaab"); - + let query = utf16!("a"); let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; @@ -334,7 +340,7 @@ mod tests { #[test] fn count_next_empty_query() { let sa = sais("aaab"); - + let query = utf16!(""); let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; @@ -346,7 +352,7 @@ mod tests { #[test] fn batch_count_next_exists() { let sa = sais("aaab"); - + let queries: Vec> = vec![vec![utf16!("a")[0]; 1]; 10_000]; let a_index = utf16!("a")[0] as usize; let b_index = utf16!("b")[0] as usize; @@ -354,4 +360,4 @@ mod tests { assert_eq!(2, sa.batch_count_next(&queries, Option::None)[0][a_index]); assert_eq!(1, sa.batch_count_next(&queries, Option::None)[0][b_index]); } -} \ No newline at end of file +} From 05b20dd708168581a708d7ea982c6d53a7286189 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 4 Apr 2024 03:04:38 +0000 Subject: [PATCH 4/6] bump minor version --- Cargo.lock | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 419109c..52f1339 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -474,7 +474,7 @@ checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tokengrams" -version = "0.2.0" +version = "0.3.0" dependencies = [ "anyhow", "bincode", diff --git a/Cargo.toml b/Cargo.toml index 82e56fb..a0dfc39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokengrams" -version = "0.2.0" +version = "0.3.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html From 676dc3f1f6f82bf2293e42090d094efdcacb4319 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Thu, 4 Apr 2024 03:16:45 +0000 Subject: [PATCH 5/6] Add python tests to CI --- .github/workflows/python-package-conda.yml | 27 -------------------- .github/workflows/python.yml | 29 ++++++++++++++++++++++ environment.yml | 10 ++++++++ tokengrams/__init__.py | 3 +++ tokengrams/tests/test_gram_index.py | 8 +++--- tokengrams/tokengrams.pyi | 4 +-- 6 files changed, 48 insertions(+), 33 deletions(-) delete mode 100644 .github/workflows/python-package-conda.yml create mode 100644 .github/workflows/python.yml create mode 100644 environment.yml diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml deleted file mode 100644 index 708b24e..0000000 --- a/.github/workflows/python-package-conda.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Python Package using Conda - -on: [push] - -jobs: - build-linux: - runs-on: ubuntu-latest - strategy: - max-parallel: 5 - - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: '3.10' - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Install dependencies - run: | - conda install pytest hypothesis numpy maturin - - name: Test with pytest - run: | - maturin build - pytest diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml new file mode 100644 index 0000000..7427842 --- /dev/null +++ b/.github/workflows/python.yml @@ -0,0 +1,29 @@ +name: Python Package using Conda + +on: [push] + +jobs: + build-linux: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + + steps: + - uses: actions/checkout@v3 + - name: Set up Conda + uses: conda-incubator/setup-miniconda@v3 + with: + python-version: "3.10" + miniforge-version: latest + use-mamba: true + mamba-version: "*" + - name: Test Python + env: + PYTHONPATH: /home/runner/work/tokengrams/tokengrams + shell: bash -l {0} + run: | + mamba install -c conda-forge numpy pytest hypothesis maturin + maturin develop + maturin build + python -m pip install --user ./target/wheels/tokengrams*.whl + pytest \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..9f62c6a --- /dev/null +++ b/environment.yml @@ -0,0 +1,10 @@ +name: test +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - numpy + - pytest + - hypothesis + - maturin \ No newline at end of file diff --git a/tokengrams/__init__.py b/tokengrams/__init__.py index 74ee74b..56dcc3e 100644 --- a/tokengrams/__init__.py +++ b/tokengrams/__init__.py @@ -1,3 +1,6 @@ +import os +print(os.getcwd()) + from .tokengrams import ( InMemoryIndex, MemmapIndex, diff --git a/tokengrams/tests/test_gram_index.py b/tokengrams/tests/test_gram_index.py index ba4d38d..68e675a 100644 --- a/tokengrams/tests/test_gram_index.py +++ b/tokengrams/tests/test_gram_index.py @@ -28,7 +28,7 @@ def check_gram_index(index: InMemoryIndex | MemmapIndex, tokens: list[int]): ) def test_gram_index(tokens: list[int]): # Construct index - index = InMemoryIndex(tokens) + index = InMemoryIndex(tokens, False) check_gram_index(index, tokens) # Save to disk and check that we can load it back @@ -36,11 +36,11 @@ def test_gram_index(tokens: list[int]): memmap = np.memmap(f, dtype=np.uint16, mode="w+", shape=(len(tokens),)) memmap[:] = tokens - index = InMemoryIndex.from_token_file(f.name, None) + index = InMemoryIndex.from_token_file(f.name, False, None) check_gram_index(index, tokens) with NamedTemporaryFile() as idx: - index = MemmapIndex.build(f.name, idx.name) + index = MemmapIndex.build(f.name, idx.name, False) check_gram_index(index, tokens) index = MemmapIndex(f.name, idx.name) @@ -48,5 +48,5 @@ def test_gram_index(tokens: list[int]): # Now check limited token loading for limit in range(1, len(tokens) + 1): - index = InMemoryIndex.from_token_file(f.name, limit) + index = InMemoryIndex.from_token_file(f.name, False, limit) check_gram_index(index, tokens[:limit]) diff --git a/tokengrams/tokengrams.pyi b/tokengrams/tokengrams.pyi index 49d2e01..74c0f29 100644 --- a/tokengrams/tokengrams.pyi +++ b/tokengrams/tokengrams.pyi @@ -1,11 +1,11 @@ class InMemoryIndex: """An n-gram index.""" - def __init__(self, tokens: list[int]) -> None: + def __init__(self, tokens: list[int], verbose: bool) -> None: ... @staticmethod - def from_token_file(path: str, token_limit: int | None, verbose: bool) -> "InMemoryIndex": + def from_token_file(path: str, verbose: bool, token_limit: int | None) -> "InMemoryIndex": """Construct a `InMemoryIndex` from a file containing raw little-endian tokens.""" def contains(self, query: list[int]) -> bool: From 025a03919db3d0cc85d82482b4ca89343b2b67ba Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Fri, 5 Apr 2024 14:16:58 +1100 Subject: [PATCH 6/6] add support to readme --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4b0c3ff..ac04da7 100644 --- a/README.md +++ b/README.md @@ -71,4 +71,8 @@ print(index.contains(tokenizer.encode("hello world"))) # Get all n-grams beginning with "hello world" in the corpus print(index.positions(tokenizer.encode("hello world"))) -``` \ No newline at end of file +``` + +# Support + +The best way to get support is to open an issue on this repo or post in #inductive-biases in the [EleutherAI Discord server](https://discord.gg/eleutherai). If you've used the library and have had a positive (or negative) experience, we'd love to hear from you! \ No newline at end of file