Skip to content

Commit

Permalink
Implement prompt token alignment in FSMLogitsProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Jul 19, 2024
1 parent 91c7b3d commit b3f415e
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 15 deletions.
125 changes: 117 additions & 8 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import datetime
from dataclasses import dataclass
from typing import TYPE_CHECKING, Iterator, List, Optional, Union
from typing import Iterator, List, Optional, Sequence, Union

import torch

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch

FormattedOutput = Union[
str, int, float, bool, datetime.date, datetime.time, datetime.datetime
]
TotalCompletionsType = Optional[Union[list[str], str]]


class SequenceGenerator:
Expand Down Expand Up @@ -461,6 +461,47 @@ def prepare_generation_parameters(

return generation_params

def strip_completions(
self,
completions,
prompts: Union[str, List[str]],
aligned_prompts: Union[str, List[str]],
):
"""Remove characters generated through token alignment from the completions.
As token alignment makes the model re-generate some of the characters at
the end of the prompt, we want to remove those from the beginning of the
completions to only return the characters after the end of the user prompts.
Parameters
----------
completions
Text generated by the model
prompts
The original prompts provided by the user
aligned_prompts
The prompts of the user after token alignment (what's given to the model)
Returns
-------
The stripped completions
"""
if isinstance(prompts, str):
if isinstance(completions, str):
return completions[len(prompts) - len(aligned_prompts) :]

return [
self.strip_completions(completion, prompts, aligned_prompts)
for completion in completions
]

return [
self.strip_completions(completion, prompt, aligned_prompt)
for completion, prompt, aligned_prompt in zip(
completions, prompts, aligned_prompts
)
]

def format_sequence(self, sequence: str) -> FormattedOutput:
"""Translate the generated sequence to another type.
Expand Down Expand Up @@ -500,15 +541,24 @@ def format(sequences):
max_tokens, stop_at, seed
)

aligned_prompts = self.logits_processor.align_prompts(prompts)

completions = self.model.generate(
prompts,
aligned_prompts,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)

return format(completions)
print(completions, prompts, aligned_prompts)
stripped_completions = self.strip_completions(
completions, prompts, aligned_prompts
)

print(stripped_completions)

return format(stripped_completions)

def stream(
self,
Expand All @@ -519,13 +569,72 @@ def stream(
**model_specific_params,
):
"""Return a text generator from a prompt or a list of prompts."""

def add_chunks_to_completions(
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
total_completions: Optional[
Union[str, List[str], List[List[str]], Sequence[str]]
],
):
"""Append each of the text chunks at the end of the corresponding completions"""
if isinstance(text_chunks, str):
if isinstance(total_completions, str):
return total_completions + text_chunks
return text_chunks

if total_completions:
return [
add_chunks_to_completions(text_chunk, total_completion)
for text_chunk, total_completion in zip(
text_chunks, total_completions
)
]

return [
add_chunks_to_completions(text_chunk, None)
for text_chunk in text_chunks
]

def strip_text_chunks(
text_chunks: Union[str, List[str], List[List[str]], Sequence[str]],
stripped_completions: Union[str, List[str], List[List[str]], Sequence[str]],
):
"""Get the stripped text_chunks from the stripped_completions."""
if isinstance(text_chunks, str):
return (
stripped_completions[-len(text_chunks) :]
if len(text_chunks) > 0
else ""
)

return [
strip_text_chunks(text_chunk, stripped_completion)
for text_chunk, stripped_completion in zip(
text_chunks, stripped_completions
)
]

generation_params = self.prepare_generation_parameters(
max_tokens, stop_at, seed
)
return self.model.stream(

aligned_prompts = self.logits_processor.align_prompts(prompts)

total_completions: TotalCompletionsType = None

for text_chunks in self.model.stream(
prompts,
generation_params,
self.logits_processor,
self.sampling_params,
**model_specific_params,
)
):
total_completions = add_chunks_to_completions(
text_chunks, total_completions
)

stripped_completions = self.strip_completions(
total_completions, prompts, aligned_prompts
)

yield strip_text_chunks(text_chunks, stripped_completions)
49 changes: 42 additions & 7 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_states: Dict[int, int] = {}
self._fsm_states: List[Dict[int, int]] = []
self.fsm: Guide = fsm
self._seq_fsms: List[Guide] = []
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

Expand All @@ -83,33 +84,67 @@ def process_logits(
torch.Tensor
The biased logits.
"""
samples = int(len(input_ids) / len(self._seq_fsms))
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
self._fsm_states = [
{hash(tuple([])): 0} for _ in range(len(self._seq_fsms))
]
sequence_states = [0] * len(input_ids)

else:
for seq_ids in input_ids:
for i, seq_ids in enumerate(input_ids):
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]
prev_state = self._fsm_states[i // samples][prev_state_key]

curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])
curr_state = self._seq_fsms[i // samples].get_next_state(
prev_state, seq_ids[-1]
)

self._fsm_states[curr_state_key] = curr_state
self._fsm_states[i // samples][curr_state_key] = curr_state
sequence_states.append(curr_state)

mask = torch.full_like(logits, -math.inf)
for i, fsm_state in enumerate(sequence_states):
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
allowed_tokens = (
self._seq_fsms[i // samples].get_next_instruction(fsm_state).tokens
)
mask[i, allowed_tokens] = logits[i, allowed_tokens]

return mask

def align_prompts(self, prompts: Union[str, List[str]]) -> Union[str, List[str]]:
"""Create a distinct fsm for each prompt. Apply prompt alignment to each of them.
If applicable, prompt alignment shortens the user prompt and updates the fsm accordingly.
Parameters
----------
prompts
The text prompts previded by the user
Returns
-------
The initial text prompts after application of prompt alignment
"""
is_input_str = isinstance(prompts, str)
if isinstance(prompts, str):
prompts = [prompts]

self._seq_fsms = [self.fsm.copy() for _ in range(len(prompts))]
aligned_prompts = [
fsm.align_prompt_tokens(prompt, self.tokenizer)
for fsm, prompt in zip(self._seq_fsms, prompts)
]

if is_input_str:
return aligned_prompts[0]
return aligned_prompts

def copy(self) -> "FSMLogitsProcessor":
"""Return a copy of the logits processor."""
return FSMLogitsProcessor(tokenizer=self.tokenizer, fsm=self.fsm.copy())
Expand Down

0 comments on commit b3f415e

Please sign in to comment.