diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index 66b4388d0..40fd68c21 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -1,7 +1,10 @@ +from collections import defaultdict +from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Protocol, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Protocol, Tuple, Union import interegular +import torch from lark import Lark from outlines import grammars @@ -62,11 +65,16 @@ def get_next_state(self, state: int, token_id: int) -> int: def is_final_state(self, state: int) -> bool: ... + def align_prompt_tokens( + self, token_ids: torch.Tensor, attention_masks: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + ... + class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" - final_state = 1 + final_state = -1 start_state = 0 def __init__(self, tokenizer: "Tokenizer"): @@ -77,24 +85,52 @@ def __init__(self, tokenizer: "Tokenizer"): """ self.eos_token_id = tokenizer.eos_token_id - self.vocabulary = tokenizer.vocabulary.values() + self.vocabulary = tokenizer.vocabulary + self.tokenizer = tokenizer + self.states_to_token_maps = self.create_states_to_tokens_map() + + def create_states_to_tokens_map(self) -> Dict[int, Dict[int, int]]: + """Create the states_to_tokens_map. All tokens from the starting state lead + to itself, except for the eos_token that leads to the final state.""" + return { + self.start_state: { + token_id: self.start_state + if token_id != self.eos_token_id + else self.final_state + for token_id in self.vocabulary.values() + } + } + + def align_prompt_tokens( + self, token_ids: torch.Tensor, attention_masks: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update the states_to_token_maps and return the aligned prompt tokens and attention masks""" + ( + token_ids, + attention_masks, + self.states_to_token_maps, + ) = align_tokens_states_to_token_maps( + token_ids, attention_masks, self.vocabulary, self.states_to_token_maps + ) + return token_ids, attention_masks def get_next_instruction(self, state: int) -> Instruction: if self.is_final_state(state): return Write([self.eos_token_id]) - return Generate(list(self.vocabulary)) + + return Generate(list(self.states_to_token_maps[state].keys())) def get_next_state(self, state: int, token_id: int) -> int: - if token_id == self.eos_token_id or state == self.final_state: + if self.is_final_state(state): return self.final_state - return self.start_state + return self.states_to_token_maps[state][token_id] def is_final_state(self, state: int): return state == self.final_state def copy(self): - return self + return deepcopy(self) class RegexGuide(Guide): @@ -136,10 +172,23 @@ def create_states_mapping( ) = create_states_mapping( regex_string, tuple(sorted(tokenizer.vocabulary.items())) ) - self.vocabulary = list(tokenizer.vocabulary.values()) + self.vocabulary = tokenizer.vocabulary self.eos_token_id = tokenizer.eos_token_id self.final_states = fsm_finals | {-1} + def align_prompt_tokens( + self, token_ids: torch.Tensor, attention_masks: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update the states_to_token_maps and return the aligned prompt tokens and attention masks""" + ( + token_ids, + attention_masks, + self.states_to_token_maps, + ) = align_tokens_states_to_token_maps( + token_ids, attention_masks, self.vocabulary, self.states_to_token_maps + ) + return token_ids, attention_masks + def get_next_instruction(self, state: int) -> Instruction: """Return the next instruction for guided generation. @@ -244,7 +293,7 @@ def is_final_state(self, state: int) -> bool: return state in self.final_states def copy(self): - return self + return deepcopy(self) class CFGGuide(Guide): @@ -281,6 +330,12 @@ def __init__(self, cfg_string: str, tokenizer): self.start_state = 0 self.final_state = -1 + def align_prompt_tokens( + self, token_ids: torch.Tensor, attention_masks: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Not applicable to this type of FSM""" + return token_ids, attention_masks + def get_next_instruction(self, state: int) -> Instruction: """Generate an instruction for the next step. @@ -416,3 +471,163 @@ def is_final_state(self, state: int) -> bool: def copy(self) -> "CFGGuide": """Create a copy of the FSM.""" return CFGGuide(self.cfg_string, self.tokenizer) + + +def align_tokens_states_to_token_maps( + token_ids: torch.Tensor, + attention_masks: torch.Tensor, + vocabulary: Dict[str, int], + states_to_token_maps: Dict[int, Dict[int, int]], +) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, Dict[int, int]]]: + """Apply token alignment to the provided prompt tokens and attention masks given the + states_to_token_maps of a FSM. Return the updated tokens/maps as well as the updated + states_to_token_maps""" + prompt_token_ids = token_ids.tolist() + crossing_tokens = find_crossing_tokens(prompt_token_ids, vocabulary) + valid_crossing_tokens = get_crossing_tokens_target_states( + states_to_token_maps, crossing_tokens, prompt_token_ids, vocabulary + ) + if not valid_crossing_tokens: + return token_ids, attention_masks, states_to_token_maps + ( + states_to_token_maps, + number_cropped_tokens, + ) = add_crossing_tokens_states_to_tokens_map( + states_to_token_maps, prompt_token_ids, valid_crossing_tokens + ) + return ( + token_ids[:-number_cropped_tokens], + attention_masks[:-number_cropped_tokens], + states_to_token_maps, + ) + + +def find_crossing_tokens( + token_ids: List[int], vocabulary: Dict[str, int] +) -> Dict[int, List[int]]: + """Find the tokens that could replace one or more tokens at the end of token_ids + while conserving the same intial text (and extending it by at least one character). + Return a dictionary with, for the indexes in the token_ids with matches, the associated crossing tokens. + """ + reversed_vocabulary = {value: key for key, value in vocabulary.items()} + len_token_ids = len(token_ids) + max_length_token_text = max(len(item) for item in vocabulary.keys()) + characters_considered = "" + crossing_tokens_map = {} + + for index, token_id in enumerate(reversed(token_ids)): + characters_considered = reversed_vocabulary[token_id] + characters_considered + if len(characters_considered) >= max_length_token_text: + break + crossing_token_ids = [ + token_id + for text, token_id in vocabulary.items() + if text.startswith(characters_considered) + and len(text) > len(characters_considered) + ] + if crossing_token_ids: + crossing_tokens_map[len_token_ids - index - 1] = crossing_token_ids + + return crossing_tokens_map + + +def get_crossing_tokens_target_states( + states_to_tokens_map: Dict[int, Dict[int, int]], + crossing_tokens: Dict[int, List[int]], + prompt_token_ids: List[int], + vocabulary: Dict[str, int], +) -> Dict[int, Dict[int, int]]: + """For each crossing token associated to an index, check that the characters after the boundary + match the states_to_tokens_map and find the state it would lead to. Return a dict with, for each + provided indexes, the associated valid tokens with the state they would lead to. + """ + reversed_vocabulary = {value: key for key, value in vocabulary.items()} + prompt_token_texts = [ + reversed_vocabulary[token_id] for token_id in prompt_token_ids + ] + + valid_crossing_tokens: Dict[int, Dict[int, int]] = defaultdict(dict) + for pos, tokens in crossing_tokens.items(): + for token in tokens: + is_valid = True + characters = reversed_vocabulary[token] + characters_before_border = "".join(prompt_token_texts[pos:]) + characters_after_border = characters[len(characters_before_border) :] + state = 0 + for char in characters_after_border: + char_token = vocabulary.get(char) + try: + state = states_to_tokens_map[state][char_token] # type: ignore + except KeyError: + is_valid = False + break + if is_valid: + valid_crossing_tokens[pos][token] = state + + return valid_crossing_tokens + + +def add_crossing_tokens_states_to_tokens_map( + states_to_tokens_map: Dict[int, Dict[int, int]], + prompt_token_ids: List[int], + crossing_tokens_map: Dict[int, Dict[int, int]], +) -> Tuple[Dict[int, Dict[int, int]], int]: + """Modify the states_to_tokens_map to account for the crossing tokens. This operation modifies + the starting state of the fsm as we would include some characters at the end of the prompt in + the states_to_tokens_map. + Attention! the starting state of the states_to_tokens_map provided must be 0. + Return the updated states_to_tokens_map and the number of cropped tokens/additional states + """ + if not crossing_tokens_map: + return states_to_tokens_map, 0 + first_crossing_token_pos = min( + [key for key, value in crossing_tokens_map.items() if value] + ) + number_additional_states = len(prompt_token_ids) - first_crossing_token_pos + highest_state = max( + max(states_to_tokens_map.keys()), + max(max(items.values()) for items in states_to_tokens_map.values()), + ) + + for i in range(number_additional_states): + # add the tokens that was originally part of the prompt + if i == number_additional_states - 1: + states_to_tokens_map[highest_state + 1 + i] = { + prompt_token_ids[first_crossing_token_pos + i]: 0 + } + else: + states_to_tokens_map[highest_state + 1 + i] = { + prompt_token_ids[first_crossing_token_pos + i]: highest_state + 2 + i + } + # add the crossing tokens + crossing_tokens = crossing_tokens_map.get(first_crossing_token_pos + i) + if crossing_tokens: + for token, target_state in crossing_tokens.items(): + states_to_tokens_map[highest_state + 1 + i][token] = target_state + + # set the id of our new initial state to 0 + states_to_tokens_map = swap_state_ids_states_to_tokens_map( + states_to_tokens_map, highest_state + 1, 0 + ) + return states_to_tokens_map, number_additional_states + + +def swap_state_ids_states_to_tokens_map( + states_to_tokens_map: Dict[int, Dict[int, int]], + first_state_id: int, + second_state_id: int, +) -> Dict[int, Dict[int, int]]: + """Swap the id of two states of the states_to_tokens_map while conserving all transitions""" + first_state_transitions = states_to_tokens_map.pop(first_state_id) + second_state_transitions = states_to_tokens_map.pop(second_state_id) + states_to_tokens_map[first_state_id] = second_state_transitions + states_to_tokens_map[second_state_id] = first_state_transitions + + for transitions in states_to_tokens_map.values(): + for token, target_state_id in list(transitions.items()): + if target_state_id == first_state_id: + transitions[token] = second_state_id + elif target_state_id == second_state_id: + transitions[token] = first_state_id + + return states_to_tokens_map diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 97b9a981b..f8e463429 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,7 +1,8 @@ -from typing import Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch +from outlines.fsm.guide import Guide from outlines.generate.generator import sequence_generator @@ -20,6 +21,53 @@ def __init__( self.device = device self.num_samples = sampler.samples + def align_prompt_tokens( + self, + prompt_token_ids: torch.Tensor, + attention_masks: torch.Tensor, + fsms: List[Guide], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Implement token alignment for each fsm. Return the updated tokens_ids and attention_masks""" + aligned_prompts, aligned_masks = zip( + *[ + fsm.align_prompt_tokens(prompt, mask) + for prompt, mask, fsm in zip(prompt_token_ids, attention_masks, fsms) + ] + ) + # We have to pad some of the prompts if they are not all of the same length after this operation + max_length_aligned_prompt = max(prompt.shape[0] for prompt in aligned_prompts) + padded_aligned_prompts = [ + torch.cat( + [ + torch.full( + (max_length_aligned_prompt - prompt.shape[0],), + 0, + device=prompt_token_ids.device, + dtype=prompt.dtype, + ), + prompt, + ] + ) + for prompt in aligned_prompts + ] + padded_aligned_masks = [ + torch.cat( + [ + torch.full( + (max_length_aligned_prompt - mask.shape[0],), + 0, + device=prompt_token_ids.device, + dtype=mask.dtype, + ), + mask, + ] + ) + for mask in aligned_masks + ] + aligned_prompt_token_ids = torch.stack(padded_aligned_prompts) + aligned_attention_masks = torch.stack(padded_aligned_masks) + return aligned_prompt_token_ids, aligned_attention_masks + def get_generated_token_ids( self, prompt_token_ids: torch.Tensor, @@ -47,6 +95,19 @@ def get_generated_token_ids( return token_ids + def get_generated_sequences( + self, + prompt_token_ids: List[torch.Tensor], + token_ids: List[torch.Tensor], + ) -> List[str]: + """Give the text sequences generated""" + sequences = self.tokenizer.decode(token_ids) + prompt_sequences = self.tokenizer.decode(prompt_token_ids) + return [ + seq[len(prompt_seq) :] + for seq, prompt_seq in zip(sequences, prompt_sequences) + ] + def is_stop_sequence_found( self, generated_sequences: List[str], stop_sequences: List[str] ) -> bool: @@ -175,10 +236,16 @@ def __call__( num_samples = self.num_samples batch_size = len(prompts) - prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) - attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) fsm_states = [0 for _ in range(batch_size * num_samples)] fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)] + + prompt_token_ids = torch.repeat_interleave(prompt_token_ids, num_samples, dim=0) + attention_masks = torch.repeat_interleave(attention_masks, num_samples, dim=0) + + aligned_prompt_token_ids, aligned_attention_masks = self.align_prompt_tokens( + prompt_token_ids, attention_masks, fsms + ) + weights = torch.zeros( (batch_size * num_samples), dtype=torch.float, device=self.device ) @@ -187,9 +254,9 @@ def __call__( self.model, self.sampler, fsms, - prompt_token_ids, + aligned_prompt_token_ids, weights, - attention_masks, + aligned_attention_masks, fsm_states, rng=rng, ) @@ -204,17 +271,20 @@ def __call__( ) if max_tokens and len(generated_token_ids[0]) >= max_tokens: break - if stop_sequences and self.is_stop_sequence_found( - self.tokenizer.decode(generated_token_ids), stop_sequences - ): - break + if stop_sequences: + generated_sequences = self.get_generated_sequences( + prompt_token_ids, token_ids + ) + if self.is_stop_sequence_found( + generated_sequences, stop_sequences + ): + break except StopIteration: break token_ids = last_state.token_ids - generated_token_ids = self.get_generated_token_ids(prompt_token_ids, token_ids) - generated = self.tokenizer.decode(generated_token_ids) + generated = self.get_generated_sequences(prompt_token_ids, token_ids) stripped = [ self.strip_stop_sequences(sequence, stop_sequences) for sequence in generated diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 4be5259d9..be8e4e161 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -1,6 +1,18 @@ import pytest - -from outlines.fsm.guide import CFGGuide, Generate, RegexGuide, StopAtEOSGuide, Write +import torch + +from outlines.fsm.guide import ( + CFGGuide, + Generate, + RegexGuide, + StopAtEOSGuide, + Write, + add_crossing_tokens_states_to_tokens_map, + align_tokens_states_to_token_maps, + find_crossing_tokens, + get_crossing_tokens_target_states, + swap_state_ids_states_to_tokens_map, +) def test_stop_at_eos(): @@ -24,6 +36,21 @@ class MockTokenizer: assert fsm.is_final_state(fsm.final_state) is True +def test_stop_at_eos_align_prompt_tokens(): + class MockTokenizer: + vocabulary = {"a": 1, "ab": 2, "b": 3, "eos": 4} + eos_token_id = 4 + + fsm = StopAtEOSGuide(MockTokenizer()) + + token_ids, attention_masks = fsm.align_prompt_tokens( + torch.tensor([1]), torch.tensor([1]) + ) + assert torch.equal(token_ids, torch.tensor([])) + assert torch.equal(attention_masks, torch.tensor([])) + assert fsm.states_to_token_maps == {0: {1: 1, 2: 1}, 1: {1: 1, 2: 1, 3: 1, 4: -1}} + + def test_regex_vocabulary_error(): class MockTokenizer: vocabulary = {"a": 1} @@ -67,6 +94,27 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(state) is True +def test_regex_align_prompt_tokens(): + class MockTokenizer: + vocabulary = {"1": 1, "2": 2, "12": 3, "eos": 4} + special_tokens = {"eos"} + eos_token_id = 4 + + def convert_token_to_string(self, token): + return token + + regex_str = "[1-9]" + tokenizer = MockTokenizer() + fsm = RegexGuide(regex_str, tokenizer) + + token_ids, attention_masks = fsm.align_prompt_tokens( + torch.tensor([1, 1]), torch.tensor([1, 1]) + ) + assert torch.equal(token_ids, torch.tensor([1])) + assert torch.equal(attention_masks, torch.tensor([1])) + assert fsm.states_to_token_maps == {0: {1: 2, 3: 1}, 2: {1: 1, 2: 1}} + + def test_regex_final_state(): """Make sure that the FSM stays in the final state as we keep generating""" @@ -388,3 +436,214 @@ def decode(self, token_ids): state = fsm.get_next_state(state=state, token_id=4) assert fsm.generation == "(aa)" assert fsm.is_final_state(state) + + +@pytest.mark.parametrize( + "token_ids,vocabulary,expected_output", + [ + # Several possible crossing tokens for the last prompt token + ([1, 2], {"a": 1, "ab": 2, "abc": 3, "abcd": 4}, {1: [3, 4]}), + # Several possible crossing tokens for the one before last prompt token + ([1, 2, 3], {"a": 1, "b": 2, "c": 3, "bcd": 4, "bcde": 5}, {1: [4, 5]}), + # Several possible crossing tokens for several different tokens of the prompt + ( + [1, 2, 3], + {"a": 1, "b": 2, "c": 3, "cd": 4, "cde": 5, "bcd": 6, "bcde": 7}, + {1: [6, 7], 2: [4, 5]}, + ), + # No crossing token found + ([1, 2], {"a": 1, "b": 2, "c": 3, "cd": 4}, {}), + ], +) +def test_find_crossing_tokens(token_ids, vocabulary, expected_output): + assert find_crossing_tokens(token_ids, vocabulary) == expected_output + + +@pytest.mark.parametrize( + "states_to_tokens_map,crossing_tokens,prompt_token_ids,vocabulary,expected_output", + [ + # Only some of the crossing tokens are valid, several different target states + ( + { + 0: {8: 1, 10: 1, 11: -1}, + 1: {10: -1}, + }, + {1: [6, 7], 2: [4, 5]}, + [1, 2, 3], + { + "a": 1, + "b": 2, + "c": 3, + "cd": 4, + "cde": 5, + "bcd": 6, + "bcdf": 7, + "d": 8, + "e": 9, + "f": 10, + "df": 11, + }, + {1: {6: 1, 7: -1}, 2: {4: 1}}, + ), + # No valid crossing tokens + ( + { + 0: {9: 1}, + 1: {8: 2, 11: -1}, + 2: {10: -1}, + }, + {1: [6, 7], 2: [4, 5]}, + [1, 2, 3], + { + "a": 1, + "b": 2, + "c": 3, + "cd": 4, + "cde": 5, + "bcd": 6, + "bcdf": 7, + "d": 8, + "e": 9, + "f": 10, + "df": 11, + }, + {}, + ), + ], +) +def test_get_crossing_tokens_target_states( + states_to_tokens_map, crossing_tokens, prompt_token_ids, vocabulary, expected_output +): + assert ( + get_crossing_tokens_target_states( + states_to_tokens_map, crossing_tokens, prompt_token_ids, vocabulary + ) + == expected_output + ) + + +@pytest.mark.parametrize( + "states_to_tokens_map,first_state_id,second_state_id,expected_output", + [ + ( + { + 0: {10: 1, 11: 2, 12: -1}, + 1: {12: 2, 14: -1}, + 2: {15: 2, 16: 0, 17: -1}, + 3: {18: 0, 19: 1, 20: 2}, + }, + 0, + 3, + { + 3: {10: 1, 11: 2, 12: -1}, + 1: {12: 2, 14: -1}, + 2: {15: 2, 16: 3, 17: -1}, + 0: {18: 3, 19: 1, 20: 2}, + }, + ) + ], +) +def test_swap_state_ids_states_to_tokens_map( + states_to_tokens_map, first_state_id, second_state_id, expected_output +): + assert ( + swap_state_ids_states_to_tokens_map( + states_to_tokens_map, first_state_id, second_state_id + ) + == expected_output + ) + + +def test_swap_state_ids_states_to_tokens_map_key_error(): + with pytest.raises(KeyError): + swap_state_ids_states_to_tokens_map({0: {1: 1}, 1: {2: -1}}, 0, 2) + + +@pytest.mark.parametrize( + "states_to_tokens_map,prompt_token_ids,crossing_tokens_map,expected_output", + [ + # Add several new states to states_to_tokens_map + ( + { + 0: {10: 1, 11: 2, 12: -1}, + 1: {12: 2, 14: -1}, + 2: {15: 2, 16: 0, 17: -1}, + 3: {18: 0, 19: 1, 20: 2}, + }, + [6, 7, 8], + { + 1: {20: 1, 21: 2}, + 2: {22: 1, 23: 3}, + }, + ( + { + 4: {10: 1, 11: 2, 12: -1}, + 1: {12: 2, 14: -1}, + 2: {15: 2, 16: 4, 17: -1}, + 3: {18: 4, 19: 1, 20: 2}, + 0: {7: 5, 20: 1, 21: 2}, + 5: {8: 4, 22: 1, 23: 3}, + }, + 2, + ), + ), + # No crossing tokens, unchanged states_to_tokens_map + ({0: {1: -1, 2: -1}}, [5, 6, 7, 8], {}, ({0: {1: -1, 2: -1}}, 0)), + ], +) +def test_add_crossing_tokens_states_to_tokens_map( + states_to_tokens_map, prompt_token_ids, crossing_tokens_map, expected_output +): + assert ( + add_crossing_tokens_states_to_tokens_map( + states_to_tokens_map, prompt_token_ids, crossing_tokens_map + ) + == expected_output + ) + + +@pytest.mark.parametrize( + "token_ids,attention_masks,vocabulary,states_to_token_maps,expected_output", + [ + ( + torch.tensor([1, 2, 3]), + torch.tensor([1, 1, 1]), + { + "a": 1, + "b": 2, + "c": 3, + "cd": 4, + "cde": 5, + "bcd": 6, + "bcdf": 7, + "d": 8, + "e": 9, + "f": 10, + "df": 11, + }, + { + 0: {8: 1, 10: 1, 11: -1}, + 1: {10: -1}, + }, + ( + torch.tensor([1]), + torch.tensor([1]), + { + 2: {8: 1, 10: 1, 11: -1}, + 1: {10: -1}, + 0: {2: 3, 6: 1, 7: -1}, + 3: {3: 2, 4: 1}, + }, + ), + ) + ], +) +def test_align_tokens_states_to_token_maps( + token_ids, attention_masks, vocabulary, states_to_token_maps, expected_output +): + assert ( + align_tokens_states_to_token_maps( + token_ids, attention_masks, vocabulary, states_to_token_maps + ) + == expected_output + ) diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 5a2edf8dc..7ea104d40 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -21,6 +21,9 @@ def test_sequence_generator_class(): class MockFSM: first_state = 0 + def align_prompt_tokens(self, token_ids, attention_masks): + return token_ids, attention_masks + def get_next_state(self, state, next_token_ids): return 4 @@ -39,7 +42,7 @@ def encode(self, _): return torch.tensor([[0, 1, 2, 3]]), torch.tensor([[1, 1, 1, 1]]) def decode(self, tokens): - return ["testx"[i] for i in tokens] + return ["".join(["testx"[int(i)] for i in tokens[0]])] class MockModel: def __init__(self): @@ -77,6 +80,9 @@ def __call__(self, biased_logits, *_): def test_sequence_generator_1d_single_iteration(): class MockFSM: + def align_prompt_tokens(self, token_ids, attention_masks): + return token_ids, attention_masks + def get_next_state(self, state, next_token_ids): return 0 @@ -132,6 +138,9 @@ def sampler(biased_logits, *_): def test_sequence_generator_1d_several_iterations(): class MockFSM: + def align_prompt_tokens(self, token_ids, attention_masks): + return token_ids, attention_masks + def get_next_state(self, state, next_token_ids): return state + 1 @@ -194,6 +203,9 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_single_iteration(): class MockFSM: + def align_prompt_tokens(self, token_ids, attention_masks): + return token_ids, attention_masks + def get_next_state(self, state, next_token_ids): return 0 @@ -260,6 +272,9 @@ def sampler(biased_logits, *_): def test_sequence_generator_2d_several_iterations(): class MockFSM: + def align_prompt_tokens(self, token_ids, attention_masks): + return token_ids, attention_masks + def get_next_state(self, state, next_token_ids): return state + 1