From b3f415e5b7ca6e1f2523e3bdd17ca3cb932ef047 Mon Sep 17 00:00:00 2001 From: Robin Picard Date: Fri, 19 Jul 2024 15:24:50 +0200 Subject: [PATCH] Implement prompt token alignment in FSMLogitsProcessor --- outlines/generate/api.py | 125 ++++++++++++++++++++++++++++-- outlines/processors/structured.py | 49 ++++++++++-- 2 files changed, 159 insertions(+), 15 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4104e3080..9fda34f0d 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,16 +1,16 @@ import datetime from dataclasses import dataclass -from typing import TYPE_CHECKING, Iterator, List, Optional, Union +from typing import Iterator, List, Optional, Sequence, Union + +import torch from outlines.generate.generator import sequence_generator from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler -if TYPE_CHECKING: - import torch - FormattedOutput = Union[ str, int, float, bool, datetime.date, datetime.time, datetime.datetime ] +TotalCompletionsType = Optional[Union[list[str], str]] class SequenceGenerator: @@ -461,6 +461,47 @@ def prepare_generation_parameters( return generation_params + def strip_completions( + self, + completions, + prompts: Union[str, List[str]], + aligned_prompts: Union[str, List[str]], + ): + """Remove characters generated through token alignment from the completions. + + As token alignment makes the model re-generate some of the characters at + the end of the prompt, we want to remove those from the beginning of the + completions to only return the characters after the end of the user prompts. + + Parameters + ---------- + completions + Text generated by the model + prompts + The original prompts provided by the user + aligned_prompts + The prompts of the user after token alignment (what's given to the model) + + Returns + ------- + The stripped completions + """ + if isinstance(prompts, str): + if isinstance(completions, str): + return completions[len(prompts) - len(aligned_prompts) :] + + return [ + self.strip_completions(completion, prompts, aligned_prompts) + for completion in completions + ] + + return [ + self.strip_completions(completion, prompt, aligned_prompt) + for completion, prompt, aligned_prompt in zip( + completions, prompts, aligned_prompts + ) + ] + def format_sequence(self, sequence: str) -> FormattedOutput: """Translate the generated sequence to another type. @@ -500,15 +541,24 @@ def format(sequences): max_tokens, stop_at, seed ) + aligned_prompts = self.logits_processor.align_prompts(prompts) + completions = self.model.generate( - prompts, + aligned_prompts, generation_params, self.logits_processor, self.sampling_params, **model_specific_params, ) - return format(completions) + print(completions, prompts, aligned_prompts) + stripped_completions = self.strip_completions( + completions, prompts, aligned_prompts + ) + + print(stripped_completions) + + return format(stripped_completions) def stream( self, @@ -519,13 +569,72 @@ def stream( **model_specific_params, ): """Return a text generator from a prompt or a list of prompts.""" + + def add_chunks_to_completions( + text_chunks: Union[str, List[str], List[List[str]], Sequence[str]], + total_completions: Optional[ + Union[str, List[str], List[List[str]], Sequence[str]] + ], + ): + """Append each of the text chunks at the end of the corresponding completions""" + if isinstance(text_chunks, str): + if isinstance(total_completions, str): + return total_completions + text_chunks + return text_chunks + + if total_completions: + return [ + add_chunks_to_completions(text_chunk, total_completion) + for text_chunk, total_completion in zip( + text_chunks, total_completions + ) + ] + + return [ + add_chunks_to_completions(text_chunk, None) + for text_chunk in text_chunks + ] + + def strip_text_chunks( + text_chunks: Union[str, List[str], List[List[str]], Sequence[str]], + stripped_completions: Union[str, List[str], List[List[str]], Sequence[str]], + ): + """Get the stripped text_chunks from the stripped_completions.""" + if isinstance(text_chunks, str): + return ( + stripped_completions[-len(text_chunks) :] + if len(text_chunks) > 0 + else "" + ) + + return [ + strip_text_chunks(text_chunk, stripped_completion) + for text_chunk, stripped_completion in zip( + text_chunks, stripped_completions + ) + ] + generation_params = self.prepare_generation_parameters( max_tokens, stop_at, seed ) - return self.model.stream( + + aligned_prompts = self.logits_processor.align_prompts(prompts) + + total_completions: TotalCompletionsType = None + + for text_chunks in self.model.stream( prompts, generation_params, self.logits_processor, self.sampling_params, **model_specific_params, - ) + ): + total_completions = add_chunks_to_completions( + text_chunks, total_completions + ) + + stripped_completions = self.strip_completions( + total_completions, prompts, aligned_prompts + ) + + yield strip_text_chunks(text_chunks, stripped_completions) diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index bf50d8813..1b5208c27 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -61,8 +61,9 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide): The finite state machine which is used to bias the logits. """ self.tokenizer = tokenizer - self._fsm_states: Dict[int, int] = {} + self._fsm_states: List[Dict[int, int]] = [] self.fsm: Guide = fsm + self._seq_fsms: List[Guide] = [] self._is_first_token = True self._seq_start_idx: Optional[int] = None @@ -83,33 +84,67 @@ def process_logits( torch.Tensor The biased logits. """ + samples = int(len(input_ids) / len(self._seq_fsms)) sequence_states: List[int] = [] # vector of states corresponding to `input_ids` if self._is_first_token: self._is_first_token = False self._seq_start_idx = len(input_ids[0]) - self._fsm_states = {hash(tuple([])): 0} + self._fsm_states = [ + {hash(tuple([])): 0} for _ in range(len(self._seq_fsms)) + ] sequence_states = [0] * len(input_ids) else: - for seq_ids in input_ids: + for i, seq_ids in enumerate(input_ids): prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1])) - prev_state = self._fsm_states[prev_state_key] + prev_state = self._fsm_states[i // samples][prev_state_key] curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :])) - curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1]) + curr_state = self._seq_fsms[i // samples].get_next_state( + prev_state, seq_ids[-1] + ) - self._fsm_states[curr_state_key] = curr_state + self._fsm_states[i // samples][curr_state_key] = curr_state sequence_states.append(curr_state) mask = torch.full_like(logits, -math.inf) for i, fsm_state in enumerate(sequence_states): - allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens + allowed_tokens = ( + self._seq_fsms[i // samples].get_next_instruction(fsm_state).tokens + ) mask[i, allowed_tokens] = logits[i, allowed_tokens] return mask + def align_prompts(self, prompts: Union[str, List[str]]) -> Union[str, List[str]]: + """Create a distinct fsm for each prompt. Apply prompt alignment to each of them. + If applicable, prompt alignment shortens the user prompt and updates the fsm accordingly. + + Parameters + ---------- + prompts + The text prompts previded by the user + + Returns + ------- + The initial text prompts after application of prompt alignment + """ + is_input_str = isinstance(prompts, str) + if isinstance(prompts, str): + prompts = [prompts] + + self._seq_fsms = [self.fsm.copy() for _ in range(len(prompts))] + aligned_prompts = [ + fsm.align_prompt_tokens(prompt, self.tokenizer) + for fsm, prompt in zip(self._seq_fsms, prompts) + ] + + if is_input_str: + return aligned_prompts[0] + return aligned_prompts + def copy(self) -> "FSMLogitsProcessor": """Return a copy of the logits processor.""" return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())