From 0d68474cc0b4ce266988e41c6b06efa2ceb1b07f Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Sat, 21 Sep 2024 02:15:30 -0400 Subject: [PATCH] Implement AlignmentGuide --- outlines/fsm/guide.py | 242 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 242 insertions(+) diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index b7b121fe6..0d8992b06 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -90,6 +90,37 @@ def is_final_state(self, state: Any) -> bool: def copy(self) -> "Guide": ... + def accepts(self, token_ids: List[int], state=None) -> bool: + """ + Determine whether the sequence, `token_ids`, is accepted by the Guide. + `token_ids` doesn't need to complete the guide to be accepted. + """ + derived = self.derive(token_ids, state) + return derived is not None + + def derive(self, token_ids: List[int], state=None) -> Union["Guide", None]: + if state is None: + state = self.initial_state + for token_id in token_ids: + instruction = self.get_next_instruction(state) + + # determine if token_id allowed by instruction + if isinstance(instruction, Write): + raise NotImplementedError("TODO") + elif isinstance(instruction, Generate): + if ( + instruction.tokens is not None + and token_id not in instruction.tokens + ): + return None + else: + raise TypeError(f"Expected instruction, got {instruction}") + + # advance state + state = self.get_next_state(state, token_id) + + return state + class StopAtEOSGuide(Guide): """Guide to generate tokens until the EOS token has been generated.""" @@ -487,3 +518,214 @@ def must_terminate_state(self, state: CFGState) -> bool: def copy(self) -> "CFGGuide": """Create a copy of the Guide.""" return CFGGuide(self.cfg_string, self.tokenizer) + + +@cache() +def build_vocab_prefix_map(tokenizer: "Tokenizer") -> Dict[str, Set[Tuple[str, Tuple]]]: + """Build a map from token prefix to Set[Tuple[suffix, aligment_token_id, suffix_token_ids]]""" + + # precompute the token ids of all vocab suffixes + suffixes = list( + {tok[i:] for tok in tokenizer.vocabulary for i in range(1, len(tok))} + ) + encoded_suffixes, _ = tokenizer.encode(suffixes) + encoded_suffixes = [ + [tok for tok in seq_ids if tok != tokenizer.pad_token_id] + for seq_ids in encoded_suffixes.tolist() + ] + suffix_map = dict(zip(suffixes, map(tuple, encoded_suffixes))) + suffix_map[""] = tuple() + + # compute prefix-suffix map for all tokens, s.t. prefix + suffix = token + prefix_map = collections.defaultdict(set) + for token, token_id in tokenizer.vocabulary.items(): + for i in range(1, len(token) + 1): + prefix_map[token[:i]].add((token[i:], suffix_map[token[i:]])) + return prefix_map + + +AlignmentGuideState = collections.namedtuple( + "AlignmentGuideState", ["legal_path_map", "child_guide_state"] +) + + +class AlignmentGuide(Guide): + def __init__( + self, prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None + ): + """ + Initialize the AlignmentGuide with a prompt, tokenizer, and an optional child guide. + + Parameters + ---------- + prompt : str + The prompt text to be aligned with the generated tokens. + tokenizer : Tokenizer + Tokenizer used to align the prompt. + child_guide : Guide, optional + A guide to take control after alignment is complete. None -> Unconstrained after alignment + """ + self.prompt = prompt + self.tokenizer = tokenizer + self.child_guide = child_guide + + alignment_seqs, child_guide_ids = self._get_alignment_sequences( + prompt, tokenizer, child_guide + ) + alignment_prompt_ids, common_prompt_len = self._get_longest_common_prompt_ids( + alignment_seqs + ) + + self.alignment_prompt = self.tokenizer.decode( + [alignment_seqs[0, :common_prompt_len]] + )[0] + + # calculate map of alignment_prompt continuation tokens -> child_guide advancement tokens + legal_paths = [ + tuple([t for t in seq if t != tokenizer.pad_token_id]) + for seq in alignment_seqs[:, common_prompt_len:].tolist() + ] + legal_path_map = dict(zip(legal_paths, child_guide_ids)) + + self.initial_state = AlignmentGuideState( + legal_path_map=legal_path_map, child_guide_state=None + ) + + @staticmethod + def _get_alignment_sequences( + prompt: str, tokenizer: "Tokenizer", child_guide: Optional[Guide] = None + ): + """ + Calculate all possible sequences which are valid with a prompt + child_guide + E.g. prompt="hello wo", child guide accepts "rld" -> tokenization ["hello", "world"] is valid + + Returns tuple of (alignment_seqs, child_guide_ids) of same length + - alignment_seqs: + All token sequences which can represent `prompt` + start of generation. The last token + must represent the end of the prompt can extend beyond the prompt to start generation. + Sequences are only included if the start of generation portion is legal with child guide. + - child_guide_ids: + Token to send to the child guide to simulate the start of generation. In the example above + "world" is the last alignment seq token, therefore we must advance the state of the child + guide with the tokenization of "rld" in order to continue generation with the child guide. + """ + guide_accepts: Dict[ + Tuple[int], bool + ] = {} # cache of suffix acceptance for child_guide.accepts() + + # prompts with alignment tokens at end + aligned_prompt_completions: List[str] = [] + # tokens to feed child guide once alignment completes + child_guide_ids: List[Tuple] = [] + + # compute alignment seqs which are valid with prompt and child guide + for prefix, alignment_details in build_vocab_prefix_map(tokenizer).items(): + if prompt.endswith(prefix): + for suffix, suffix_ids in alignment_details: + if child_guide is None: + aligned_prompt_completions.append(prompt + suffix) + child_guide_ids.append(tuple()) + elif guide_accepts.setdefault( + suffix_ids, child_guide.accepts(suffix_ids) + ): + aligned_prompt_completions.append(prompt + suffix) + child_guide_ids.append(suffix_ids) + + alignment_seqs, _ = tokenizer.encode(aligned_prompt_completions) + return alignment_seqs, child_guide_ids + + @staticmethod + def _get_longest_common_prompt_ids(alignment_seqs): + """ + Among all candidate prompt alignment seqs, get the longest shared prefix and their length + """ + # get longest common prefix among alignment sequences, which will form our alignment prompt + common = ( + (alignment_seqs.unsqueeze(1) == alignment_seqs.unsqueeze(0)) + .all(0) + .cumprod(1) + ) + common_len = common.sum(1).max().item() + return alignment_seqs[0, :common_len], common_len + + def get_next_instruction(self, state: AlignmentGuideState) -> Instruction: + """ + Return the next set of valid tokens for generation based on the current state. + + If alignment hasn't completed: + tokens which continue one of the candidate alignment paths are legal + If alignment has completed: + get instruction from the child guide + """ + if state.legal_path_map is not None: + return Generate( + sorted({token_ids[0] for token_ids in state.legal_path_map.keys()}) + ) + elif self.child_guide is None: + return Generate(None) + else: + return self.child_guide.get_next_instruction(state.child_guide_state) + + def get_next_state( + self, state: AlignmentGuideState, token_id: int + ) -> AlignmentGuideState: + """ + Get AlignmentGuideState advanced by token ID. + + If alignment has completed: + get instruction from the child guide + If alignment hasn't completed: + Filter out alignment paths which don't start with token_id + Remove First token from remaining paths + If advancing the state completes alignment: + Advance the child_guide state + """ + if state.legal_path_map is None: + if self.child_guide is not None: + return AlignmentGuideState( + legal_path_map=None, + child_guide_state=self.child_guide.get_next_state( + state.child_guide_state, token_id + ), + ) + else: + return AlignmentGuideState(None, None) + else: + next_state_legal_path_map = { + key[1:]: value + for key, value in state.legal_path_map.items() + if key[0] == token_id + } + # if none remaining, advance the child guide + if not any(next_state_legal_path_map): + if self.child_guide is not None: + child_guide_advancement_ids = next( + iter(next_state_legal_path_map.values()) + ) + return AlignmentGuideState( + legal_path_map=None, + child_guide_state=self.child_guide.derive( + child_guide_advancement_ids, state.child_guide_state + ), + ) + else: + return AlignmentGuideState(None, None) + + # if paths remaining, return advanced legal_path_map + else: + return AlignmentGuideState( + legal_path_map=next_state_legal_path_map, + child_guide_state=state.child_guide_state, + ) + + def is_final_state(self, state: AlignmentGuideState) -> bool: + if state.legal_path_map is not None: + return False + elif self.child_guide is None: + return True + else: + return self.child_guide.is_final_state(state.child_guide_state) + + def copy(self): + """AlignmentGuide isn't mutated""" + return self