diff --git a/outlines/text/fsm.py b/outlines/text/fsm.py new file mode 100644 index 000000000..62d353904 --- /dev/null +++ b/outlines/text/fsm.py @@ -0,0 +1,532 @@ +from itertools import chain +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Set, Tuple + +import numba +from interegular.fsm import FSM, Alphabet, anything_else +from joblib import Parallel, delayed +from numba.experimental import structref +from numba.typed.typedobjectutils import _nonoptional + +if TYPE_CHECKING: + from outlines.models.tokenizer import Tokenizer + + +class BetterAlphabet(Alphabet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.has_anything_else = anything_else in self._symbol_mapping + if self.has_anything_else: + self.anything_value = self._symbol_mapping[anything_else] + else: + self.anything_value = None + + def __getitem__(self, item): + return self._symbol_mapping.get(item, self.anything_value) + + def copy(self): + return BetterAlphabet(self._symbol_mapping.copy()) + + +class BetterFSM(FSM): + flat_transition_map: Dict[Tuple[int, int], int] + trans_key_to_states: Dict[int, List[int]] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(self.alphabet, BetterAlphabet): + self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping) + + flat_transition_map = {} + trans_key_to_states = {} + for from_state, trans_map in self.map.items(): + for trans_key, to_state in trans_map.items(): + flat_transition_map[(from_state, trans_key)] = to_state + trans_key_to_states.setdefault(trans_key, set()).add(from_state) + + self.__dict__["trans_key_to_states"] = trans_key_to_states + self.__dict__["flat_transition_map"] = flat_transition_map + self.__dict__["_fsm_info"] = None + + def copy(self): + return BetterFSM( + alphabet=self.alphabet.copy(), + states=self.states.copy(), + initial=self.initial, + finals=self.finals.copy(), + map=self.map.copy(), + __no_validation__=True, + ) + + @property + def fsm_info(self): + if self._fsm_info is None: + trans_key_to_states = numba.typed.Dict.empty( + numba.int64, numba.types.ListType(numba.int64) + ) + for trans_key, states in self.trans_key_to_states.items(): + new_states = numba.typed.List.empty_list(numba.int64) + for state in states: + new_states.append(numba.int64(state)) + trans_key_to_states[numba.int64(trans_key)] = new_states + + flat_transition_map = numba.typed.Dict.empty( + numba.types.UniTuple(numba.int64, 2), numba.int64 + ) + for trans_key, state in self.flat_transition_map.items(): + flat_transition_map[ + (numba.int64(trans_key[0]), numba.int64(trans_key[1])) + ] = numba.int64(state) + + alphabet_symbol_map = numba.typed.Dict.empty( + numba.types.string, numba.int64 + ) + for symbol, trans_key in self.alphabet._symbol_mapping.items(): + if symbol is not anything_else: + alphabet_symbol_map[symbol] = numba.int64(trans_key) + + initial = numba.int64(self.initial) + + finals = numba.typed.List.empty_list(numba.int64) + for final in self.finals: + finals.append(numba.int64(final)) + + anything_value = numba.int64(self.alphabet.anything_value) + + self.__dict__["_fsm_info"] = FSMInfo( + initial, + finals, + flat_transition_map, + trans_key_to_states, + anything_value, + alphabet_symbol_map, + ) + + return self._fsm_info + + +spec = [ + ("initial", numba.int64), + ("finals", numba.types.Set(numba.int64)), + ( + "transitions", + numba.types.DictType(numba.types.UniTuple(numba.int64, 2), numba.int64), + ), + ( + "trans_key_to_states", + numba.types.DictType(numba.int64, numba.types.ListType(numba.int64)), + ), + ("alphabet_anything_value", numba.optional(numba.int64)), + ("alphabet_symbol_mapping", numba.types.DictType(numba.types.string, numba.int64)), +] + + +@structref.register +class FSMInfoType(numba.types.StructRef): + def preprocess_fields(self, fields): + return tuple((name, numba.types.unliteral(typ)) for name, typ in fields) + + +class FSMInfo(structref.StructRefProxy): + def __new__( + cls, + initial, + finals, + transitions, + trans_key_to_states, + alphabet_anything_value, + alphabet_symbol_mapping, + ): + return structref.StructRefProxy.__new__( + cls, + initial, + finals, + transitions, + trans_key_to_states, + alphabet_anything_value, + alphabet_symbol_mapping, + ) + + @property + def initial(self): + return FSMInfo_get_initial(self) + + @property + def finals(self): + return FSMInfo_get_finals(self) + + @property + def transitions(self): + return FSMInfo_get_transitions(self) + + @property + def trans_key_to_states(self): + return FSMInfo_get_trans_key_to_states(self) + + @property + def alphabet_anything_value(self): + return FSMInfo_get_alphabet_anything_value(self) + + @property + def alphabet_symbol_mapping(self): + return FSMInfo_get_alphabet_symbol_mapping(self) + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_initial(self): + return self.initial + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_finals(self): + return self.finals + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_transitions(self): + return self.transitions + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_trans_key_to_states(self): + return self.trans_key_to_states + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_alphabet_anything_value(self): + return self.alphabet_anything_value + + +@numba.njit(nogil=True, inline="always") +def FSMInfo_get_alphabet_symbol_mapping(self): + return self.alphabet_symbol_mapping + + +structref.define_proxy(FSMInfo, FSMInfoType, [name for name, _ in spec]) +FSMInfo_type = FSMInfoType(fields=spec) + + +def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]: + """Construct an equivalent FSM with deterministic state labels.""" + old_to_new_trans_keys = { + trans_key: i + for i, (trans_key, _) in enumerate( + sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1])) + ) + } + + new_symbol_mapping = { + symbol: old_to_new_trans_keys[trans_key] + for symbol, trans_key in fsm.alphabet._symbol_mapping.items() + } + + new_alphabet = BetterAlphabet(new_symbol_mapping) + + new_map = { + from_state: { + old_to_new_trans_keys[trans_key]: to_state + for trans_key, to_state in trans_map.items() + } + for from_state, trans_map in fsm.map.items() + } + + old_to_new_states = {} + old_to_new_states[fsm.initial] = 0 + + i = 0 + seen = {fsm.initial} + old_state_queue = [fsm.initial] + while old_state_queue: + old_state = old_state_queue.pop(-1) + transitions = new_map[old_state] + sorted_transitions = sorted(transitions.items(), key=lambda v: v[0]) + for _, old_state in sorted_transitions: + if old_state not in seen: + old_state_queue.append(old_state) + seen.add(old_state) + if old_state not in old_to_new_states: + i += 1 + old_to_new_states[old_state] = i + + new_map = dict( + sorted( + ( + ( + old_to_new_states[from_state], + dict( + sorted( + ( + (trans_key, old_to_new_states[to_state]) + for trans_key, to_state in trans_map.items() + ), + key=lambda v: v[0], + ) + ), + ) + for from_state, trans_map in new_map.items() + ), + key=lambda v: v[0], + ) + ) + + new_initial = 0 + new_finals = frozenset( + sorted(old_to_new_states[old_state] for old_state in fsm.finals) + ) + new_states = frozenset(sorted(new_map.keys())) + + new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map) + + return new_fsm, old_to_new_states + + +@numba.njit(nogil=True, cache=True) +def walk_fsm( + fsm_info: BetterFSM, + input_string: str, + start_state: int, + full_match: bool = True, +) -> List[int]: + state = fsm_info.initial + accepted_states: List[int] = numba.typed.List.empty_list(numba.int64) + last_final_idx = -1 + + # Apparently `fsm.alphabet.get` is incredibly slow, so we need to reproduce + # it here with the following: + alphabet_symbol_mapping = fsm_info.alphabet_symbol_mapping + anything_value = fsm_info.alphabet_anything_value + + for i, symbol in enumerate(input_string): + # Again, this is the logic from `fsm.alphabet.get` + trans_key = alphabet_symbol_mapping.get(symbol, anything_value) + + if state == fsm_info.initial: + new_state = fsm_info.transitions.get((start_state, trans_key)) + else: + new_state = fsm_info.transitions.get((state, trans_key)) + + if new_state is None: + if full_match: + if state in fsm_info.finals: + break + elif last_final_idx > -1: + accepted_states = accepted_states[: last_final_idx + 1] + break + + return numba.typed.List.empty_list(numba.int64) + + state = new_state + + if state in fsm_info.finals: + last_final_idx = i + + accepted_states.append(_nonoptional(state)) + + terminated = state in fsm_info.finals + if not terminated and state == fsm_info.initial: + return numba.typed.List.empty_list(numba.int64) + + return accepted_states + + +# TODO FIXME: Can't cache this due to https://github.com/numba/numba/issues/9177 +@numba.njit(nogil=True) +def find_partial_matches( + fsm_info: FSMInfo, + input_string: str, + full_match: bool = True, +) -> Generator[Tuple[int, List[int]], None, None]: + """Find the states in the finite state machine `fsm_info` that accept `input_string`. + + This will consider all possible states in the finite state machine (FSM) + that accept the beginning of `input_string` as starting points, unless a + specific `start_state` is provided. + + Parameters + ---------- + fsm_info + The finite state machine. + input_string + The string for which we generate partial matches. + full_match + Matches must cover the entire string. + + Returns + ------- + A set of tuples corresponding to each valid starting state in the FSM. The + first element of each tuple contains an integer indicating the position in + `input_string` at which the FSM stopped. The second element is the tuple + of states visited during execution of the FSM plus the next, unvisited + transition state. + + """ + + if len(input_string) == 0: + return + + trans_key = fsm_info.alphabet_symbol_mapping.get( + input_string[0], fsm_info.alphabet_anything_value + ) + + for state in fsm_info.trans_key_to_states.get( + trans_key, numba.typed.List.empty_list(numba.int64) # type: ignore + ): + path = walk_fsm(fsm_info, input_string, state, full_match=full_match) + if path: + path.insert(0, state) + res = (len(path) - 2, path) + yield res + + +@numba.njit(nogil=True, cache=True) +def process_token_string( + fsm_info: FSMInfo, + token: str, + token_idx: int, + final_state_string: Optional[str] = None, +) -> Set[Tuple[int, int]]: + res = set() + vocab_string_len = len(token) + + for end_idx, state_seq in find_partial_matches(fsm_info, token): + if end_idx is not None and end_idx < vocab_string_len - 1: + continue + + res.add((state_seq[0], token_idx)) + + if token == final_state_string: + # Allow transitions to EOS from all terminals FSM states + for state in fsm_info.finals: + res.add((state, token_idx)) + + return res + + +def create_fsm_index( + fsm_info: FSMInfo, + vocabulary: Dict[str, int], + final_state_string: Optional[str] = None, + n_jobs=-1, +) -> Dict[int, Set[int]]: + """Construct a map from FSM states to subsets of `vocabulary`. + + The subsets of `vocabulary` consist of elements that are accepted by--or + transition to--the corresponding partial parse states. + + Parameters + ---------- + fsm + The finite-state machine. + vocabulary + The vocabulary composed of token strings mapped to token IDs. + final_state_string + A string from `vocabulary` that is to be added to all the final states + in the FSM (e.g. ``""``). + """ + + results = Parallel(backend="threading", n_jobs=n_jobs, return_as="generator")( + delayed(process_token_string)(fsm_info, token, token_idx, final_state_string) + for token, token_idx in vocabulary.items() + ) + + states_to_token_subsets: Dict[int, Set[int]] = {} + + for fsm_state, token_idx in chain.from_iterable(results): + states_to_token_subsets.setdefault(fsm_state, set()).add(token_idx) + + return states_to_token_subsets + + +@numba.njit(cache=True, nogil=True) +def state_scan_tokens( + fsm_info: FSMInfo, vocabulary: Dict[str, List[int]], start_state: int +) -> Tuple[Set[int], Set[int]]: + next_states = set() + all_token_ids = set() + + for token, token_ids in vocabulary.items(): + state_seq = walk_fsm(fsm_info, token, start_state) + + if state_seq is not None and len(state_seq) < len(token): + continue + + all_token_ids.update(token_ids) + next_states.add(state_seq[-1]) + + return all_token_ids, next_states + + +def create_fsm_index_end_to_end( + fsm_info: FSMInfo, + vocabulary: Dict[str, List[int]], +) -> Tuple[Dict[int, Set[int]], bool]: + """Create an FSM state-to-vocabulary map/index through end-to-end token parsing.""" + + # TODO: Consider using a `List` of `Set`s instead; that way we can JIT this + # code, too. + states_to_token_subsets: Dict[int, Set[int]] = {} + seen: Set[int] = set() + next_states = {fsm_info.initial} + + while next_states: + start_state = next_states.pop() + + token_ids, next_next_states = state_scan_tokens( + fsm_info, vocabulary, start_state + ) + + if token_ids: + states_to_token_subsets.setdefault(start_state, set()).update(token_ids) + + next_states.update(next_next_states - seen) + + seen.add(start_state) + + return states_to_token_subsets, any( + final_state in seen for final_state in fsm_info.finals + ) + + +# TODO: Cache these? +def reduced_vocabulary(tokenizer: "Tokenizer"): + """Create a map from decoded vocabulary tokens to lists of equivalent token ids.""" + vocabulary = numba.typed.Dict.empty( + numba.types.string, numba.types.ListType(numba.int64) + ) + for token, token_idx in tokenizer.vocabulary.items(): + vocabulary.setdefault( + tokenizer.convert_token_to_string(token), + numba.typed.List.empty_list(numba.int64), + ).append(numba.int64(token_idx)) + + return vocabulary + + +def create_fsm_index_tokenizer( + fsm: BetterFSM, + tokenizer: "Tokenizer", +) -> Tuple[Dict[int, Set[int]], bool]: + """Construct an FMS index from a tokenizer. + + This uses the end-to-end approach of `create_fsm_index_end_to_end`. + + .. warning:: + + `fsm` needs to be deterministically ordered so that the caching makes sense. + + """ + vocabulary = reduced_vocabulary(tokenizer) + + states_to_token_subsets, reaches_a_final = create_fsm_index_end_to_end( + fsm.fsm_info, vocabulary + ) + + # Allow transitions to EOS from all terminals FSM states that are + # reachable + # TODO: Do we really need this anymore? + for state in fsm.fsm_info.finals: + subset = states_to_token_subsets.get(state) + if subset is not None: + subset.add(tokenizer.eos_token_id) + + return states_to_token_subsets, reaches_a_final diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py index 76a0a4c38..7422947e9 100644 --- a/outlines/text/generate/regex.py +++ b/outlines/text/generate/regex.py @@ -1,4 +1,3 @@ -import collections import math from json import dumps from typing import List, Optional, Tuple, Union @@ -7,9 +6,13 @@ import torch from pydantic import BaseModel +from outlines.text.fsm import ( + create_fsm_index_tokenizer, + make_deterministic_fsm, + walk_fsm, +) from outlines.text.generate.continuation import Continuation from outlines.text.json_schema import build_regex_from_schema -from outlines.text.parsing import find_partial_matches, map_partial_states_to_vocab class Regex(Continuation): @@ -28,48 +31,27 @@ class Regex(Continuation): def __init__(self, model, regex_string: str, max_tokens: Optional[int]): super().__init__(model, max_tokens) - vocabulary = model.tokenizer.vocabulary - sorted_vocabulary = [ - model.tokenizer.convert_token_to_string(k) - for k, v in sorted(vocabulary.items(), key=lambda kv: kv[1]) - ] - regex_pattern = interegular.parse_pattern(regex_string) - self.regex_fsm = regex_pattern.to_fsm().reduce() - - def partial_match_filter(string, end_idx, state_seq): - if end_idx is not None and end_idx < len(string) - 1: - return False - return True - - pstate_to_vocab, paths = map_partial_states_to_vocab( - list(sorted_vocabulary), - {"REGEX": self.regex_fsm}, - partial_match_filter, - final_state_string=model.tokenizer.eos_token, + self.regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + states_to_token_ids, reaches_a_final = create_fsm_index_tokenizer( + self.regex_fsm, model.tokenizer ) + self.states_to_token_ids = { + state: list(token_ids) for state, token_ids in states_to_token_ids.items() + } # Check whether a terminal path (from the initial state of the FSM to # one of its terminal states) exists, raise an exception otherwise. - traversed_states = set() - queue = collections.deque([self.regex_fsm.initial]) - while queue: - symbol = queue.popleft() - for prev_state in paths["REGEX"][symbol]: - if prev_state not in traversed_states: - traversed_states.add(prev_state) - queue.append(prev_state) - - if traversed_states.intersection(self.regex_fsm.finals) == set(): + if not reaches_a_final: raise ValueError( "The vocabulary does not allow us to build a sequence that matches the input regex" ) - self.pstate_to_vocab = {k: list(v) for k, v in pstate_to_vocab.items()} - # These tuples are comprised of the FSM name, last FSM state, and + # These tuples are comprised of the last FSM state and the # number of processed tokens. # When an EOS is observed, the last FSM state becomes `-1`. - self.pstates: List[Tuple[str, int, int]] = [] + self.last_token_states: List[Tuple[int, int]] = [] def create_proposal( self, generated_token_ids: torch.LongTensor, logits: torch.DoubleTensor @@ -85,17 +67,16 @@ def create_proposal( """ - if len(self.pstates) == 0: - self.pstates = [ - ("REGEX", self.regex_fsm.initial, 0) - for _ in range(generated_token_ids.shape[0]) + if len(self.last_token_states) == 0: + self.last_token_states = [ + (self.regex_fsm.initial, 0) for _ in range(generated_token_ids.shape[0]) ] if generated_token_ids.shape[-1] > 0: - new_pstates = [] - for token_seq, (_, last_fsm_state, last_token_idx) in zip( + new_last_token_states = [] + for token_seq, (last_fsm_state, last_token_idx) in zip( generated_token_ids, - self.pstates, + self.last_token_states, ): # Get the tokens we haven't already processed, readable_tokens = token_seq[last_token_idx:] @@ -109,36 +90,39 @@ def create_proposal( # getting/sampling any more non-EOS tokens. assert last_fsm_state > -1 + # TODO: Let's not re-decode the same tokens every time + # around sequence = self.model.tokenizer.decode(readable_tokens) - ((_, state_seq),) = find_partial_matches( - self.regex_fsm, + # TODO: This is unnecessary; use the last state from the + # index (once we add those to the index) + state_seq = walk_fsm( + self.regex_fsm.fsm_info, "".join(sequence), start_state=last_fsm_state, ) - pstate = ( - "REGEX", + last_token_state = ( state_seq[-1], last_token_idx + len(sequence), ) else: - pstate = ("REGEX", -1, last_token_idx) + last_token_state = (-1, last_token_idx) - new_pstates.append(pstate) + new_last_token_states.append(last_token_state) - self.pstates = new_pstates + self.last_token_states = new_last_token_states masks = [] - for pstate in self.pstates: + for last_token_state in self.last_token_states: mask = torch.full( (len(self.model.tokenizer.vocabulary),), -math.inf, device=self.device ) - if pstate[1] > -1: - next_support = self.pstate_to_vocab[pstate[:2]] - else: - next_support = [self.model.tokenizer.eos_token_id] + next_support = self.states_to_token_ids.get( + last_token_state[0], self.model.tokenizer.eos_token_id + ) + # TODO: Cache these masks based on `last_token_state[0]` mask[next_support] = 0 masks.append(mask.unsqueeze(0)) diff --git a/pyproject.toml b/pyproject.toml index f0eb4b42c..f17964548 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "tenacity", "torch", "accelerate", + "numba", + "joblib", ] dynamic = ["version"] @@ -85,7 +87,7 @@ exclude=["examples"] module = [ "diffusers", "jinja2", - "joblib", + "joblib.*", "openai", "numpy.*", "perscache.*", @@ -101,6 +103,7 @@ module = [ "lark.*", "regex.*", "interegular.*", + "numba.*", ] ignore_missing_imports = true diff --git a/tests/text/generate/test_regex.py b/tests/text/generate/test_regex.py index 6bdac68a9..4c9c65c9c 100644 --- a/tests/text/generate/test_regex.py +++ b/tests/text/generate/test_regex.py @@ -14,6 +14,12 @@ class Tokenizer: vocabulary = {"": 0, "-": 1, "1": 2, "0.": 3, "431": 4, "a": 5, "A": 6} tokens = list(vocabulary.keys()) + def encode(self, tokens): + if not isinstance(tokens, (tuple, list)): + tokens = [tokens] + + return [self.vocabulary[token] for token in tokens] + def decode(self, token_ids): decoded = [] for i in range(token_ids.shape[0]): diff --git a/tests/text/test_fsm.py b/tests/text/test_fsm.py new file mode 100644 index 000000000..e206a8f5b --- /dev/null +++ b/tests/text/test_fsm.py @@ -0,0 +1,197 @@ +import interegular +import numba +import pytest + +from outlines.models.transformers import TransformersTokenizer +from outlines.text.fsm import ( + create_fsm_index, + create_fsm_index_end_to_end, + create_fsm_index_tokenizer, + find_partial_matches, + make_deterministic_fsm, + walk_fsm, +) + + +def test_partial_match(): + name_pattern = interegular.parse_pattern(r"[^\W\d]\w*") + name_fsm, _ = make_deterministic_fsm(name_pattern.to_fsm().reduce()) + assert name_fsm.initial == 0 + + name_fsm = name_fsm.fsm_info + + def_pattern = interegular.parse_pattern("def") + def_fsm, _ = make_deterministic_fsm(def_pattern.to_fsm().reduce()) + assert def_fsm.initial == 0 + + def_fsm = def_fsm.fsm_info + + def to_python(res): + return {(x, tuple(y)) for x, y in res} + + res = to_python(find_partial_matches(def_fsm, "def")) + assert res == {(2, (0, 1, 2, 3))} + res = to_python(find_partial_matches(def_fsm, "de")) + assert res == {(1, (0, 1, 2))} + res = to_python(find_partial_matches(def_fsm, "d")) + assert res == {(0, (0, 1))} + res = to_python(find_partial_matches(def_fsm, "")) + assert res == set() + res = to_python(find_partial_matches(def_fsm, "df")) + assert res == set() + res = to_python(find_partial_matches(def_fsm, "ef")) + assert res == {(1, (1, 2, 3))} + res = to_python(find_partial_matches(def_fsm, "e")) + assert res == {(0, (1, 2))} + res = to_python(find_partial_matches(def_fsm, "f")) + assert res == {(0, (2, 3))} + res = to_python(find_partial_matches(def_fsm, "ef foo")) + assert res == {(1, (1, 2, 3))} + + # This string has a `DEF` token in it, but should ultimately not lex one + res = to_python(find_partial_matches(def_fsm, "defb")) + assert res == {(2, (0, 1, 2, 3))} + + # `NAME` can have multiple start states for this input + res = to_python(find_partial_matches(name_fsm, "d")) + assert res == {(0, (0, 1)), (0, (1, 1))} + # Not this case + res = to_python(find_partial_matches(name_fsm, "1d")) + assert res == {(1, (1, 1, 1))} + + res = to_python(find_partial_matches(name_fsm, "blah")) + assert res == { + (3, (0, 1, 1, 1, 1)), + (3, (1, 1, 1, 1, 1)), + } + + float_pattern = interegular.parse_pattern( + r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))" + ) + float_fsm, _ = make_deterministic_fsm(float_pattern.to_fsm().reduce()) + assert 5 in float_fsm.finals + assert 2 not in float_fsm.finals + + float_fsm = float_fsm.fsm_info + + res = to_python(find_partial_matches(float_fsm, ".")) + assert res == {(0, (3, 5)), (0, (4, 5)), (0, (0, 2))} + + joins_fsm, _ = make_deterministic_fsm( + interegular.parse_pattern(r"(JOIN LEFT|JOIN)").to_fsm().reduce() + ) + + joins_fsm = joins_fsm.fsm_info + + res = to_python(find_partial_matches(joins_fsm, "JOIN BLAH", full_match=False)) + assert res == {(3, (0, 1, 2, 3, 4))} + + res = to_python(find_partial_matches(joins_fsm, "JOIN L", full_match=False)) + assert res == {(5, (0, 1, 2, 3, 4, 5, 6))} + + res = to_python(find_partial_matches(joins_fsm, "JOI", full_match=False)) + assert res == {(2, (0, 1, 2, 3))} + + regex_pattern = interegular.parse_pattern("0|[1-9][2-9]*") + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + # State `1` has no transitions + assert not regex_fsm.map[1] + # This should fail, because state `1` reads nothing + res = to_python(walk_fsm(regex_fsm.fsm_info, "0", 1)) + assert res == set() + + res = to_python(find_partial_matches(regex_fsm.fsm_info, "0", 1)) + assert res == {(0, (0, 1))} + + +def test_create_fsm_index(): + regex_str = "0|[1-9][0-9]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + vocabulary = {"blah": 0, "1a": 1, "2": 2, "0": 3, "": 4} + + res = create_fsm_index(regex_fsm.fsm_info, vocabulary) + + assert res == {0: {2, 3}, 2: {2, 3}} + + res = create_fsm_index(regex_fsm.fsm_info, vocabulary, "") + + assert res == {0: {2, 3}, 1: {4}, 2: {2, 3, 4}} + + +def test_create_fsm_index_end_to_end(): + regex_str = "0|[1-9][0-9]*" + + regex_pattern = interegular.parse_pattern(regex_str) + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce()) + + vocabulary = { + "blah": numba.typed.List([0]), + "1a": numba.typed.List([1]), + "2": numba.typed.List([2]), + "0": numba.typed.List([3]), + "": numba.typed.List([4]), + } + + vocabulary_nb = numba.typed.Dict.empty( + numba.types.string, numba.types.ListType(numba.int64) + ) + vocabulary_nb.update(vocabulary) + + res, reaches_a_final = create_fsm_index_end_to_end( + regex_fsm.fsm_info, vocabulary_nb + ) + + assert reaches_a_final + assert res == {0: {2, 3}, 2: {2, 3}} + + +def test_create_fsm_index_tokenizer(): + # The combined regular expressions of a lexer state in a Python grammar + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + tokenizer = TransformersTokenizer("gpt2") + + res, reaches_a_final = create_fsm_index_tokenizer(regex_fsm, tokenizer) + + assert reaches_a_final + assert len(res) / num_fsm_states > 0.94 + + +@pytest.mark.skip(reason="Only for local profiling") +def test_regex_index_performance(): + from line_profiler import LineProfiler # type: ignore [import] + + regex_str = "(?:(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|[0-9](?:(?:_)?[0-9])*)(?:J|j)|(?:[0-9](?:(?:_)?[0-9])*(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*|(?:[0-9](?:(?:_)?[0-9])*\\.(?:[0-9](?:(?:_)?[0-9])*)?|\\.[0-9](?:(?:_)?[0-9])*)(?:(?:e|E)(?:(?:\\+|\\-))?[0-9](?:(?:_)?[0-9])*)?)|0(?:x|X)(?:(?:_)?(?:[0-9]|[a-f]|[A-F]))+|0(?:b|B)(?:(?:_)?[0-1])+|0(?:o|O)(?:(?:_)?[0-7])+|(?:(?i:([ubf]?r?|r[ubf])('([^\\\\']|.)*?'))|(?i:([ubf]?r?|r[ubf])(\"([^\\\"]|.)*?\")))|(?:(?:\r?\n[\t ]*|#[^\n]*))+|[1-9](?:(?:_)?[0-9])*|\\\\[\t \x0c]*\r?\n|continue|nonlocal|assert|global|import|lambda|return|async|await|break|class|False|match|raise|while|yield|case|from|None|pass|True|with|def|del|for|not|try|if|[^\\W\\d]\\w*|#[^\n]*|[\t \x0c]+|\\.\\.\\.|@|\\{|\\(|\\[|\\-|\\+|\\*|\\~" + + regex_pattern = interegular.parse_pattern(regex_str) + # Not reduced, so that there are many states + regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm()) + + num_fsm_states = len(regex_fsm.states) + assert num_fsm_states == 220 + + tokenizer = TransformersTokenizer("gpt2") + + # Pre-compile Numba functions + res, _ = create_fsm_index_tokenizer(regex_fsm, tokenizer) + assert len(res) > 1 + + profiler = LineProfiler(create_fsm_index_end_to_end) + + profiler.runctx( + r"""create_fsm_index_tokenizer(regex_fsm, tokenizer)""", + globals(), + locals(), + ) + profiler.dump_stats("line-profiler-create_fsm_index.pkl") + profiler.print_stats(output_unit=1e-3, summarize=True)