Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Index Building #226

Closed
clayscode opened this issue Aug 15, 2023 · 7 comments · Fixed by #272
Closed

Slow Index Building #226

clayscode opened this issue Aug 15, 2023 · 7 comments · Fixed by #272
Assignees

Comments

@clayscode
Copy link

from typing import List
from enum import Enum
from pydantic import BaseModel, constr

import outlines.models as models
import outlines.text.generate as generate
import torch

class Weapon(str, Enum):
    sword = "sword"
    axe = "axe"
    mace = "mace"
    spear = "spear"
    bow = "bow"
    crossbow = "crossbow"


class Armor(str, Enum):
    leather = "leather"
    chainmail = "chainmail"
    plate = "plate"


class Character(BaseModel):
    name: constr(max_length=10)
    age: int
    armor: Armor
    weapon: Weapon
    strength: int


model = models.transformers("gpt2", device=torch.device('cuda'))
sequence = generate.json(model, Character)("Give me a character description")
print(sequence)
# {
#   "name": "ranbelt",
#   "age": 26,
#   "armor": "chainmail",
#   "weapon": "bow",
#   "strength": 5
# }

parsed = Character.model_validate_json(sequence)
print(parsed)

The above code snippet runs very slowly on my machine. I have a 7900XTX and before I added device=torch.device('cuda') it was defaulting to CPU inference only. It seems to running on my GPU now, but the inference is still very slow (takes ~30 seconds to run the above). This could just be a ROCM thing, I'm not entirely sure. I've installed both Torch and Tensorflow with ROCM support. Any idea what might be going on?

@sumitmamoria
Copy link

@clayscode : I am attempting to do the same with an RTX3090, but I am getting errors like: "RuntimeError: CUDA error: device-side assert triggered". Did you have to do anything special to make it work with your GPU ?

@rlouf
Copy link
Member

rlouf commented Aug 15, 2023

Two possibilities:

1 - The FSM takes a long time to build; We can address that more or less easily;
2 - That's a KV cache thing, see #190. We've been laser-focused on guided generation for now. This only requires a small interface change.

It would be amazing if you could run line_profiler on the code and come back to me with the results.

@rlouf rlouf self-assigned this Aug 15, 2023
@sumitmamoria
Copy link

@rlouf: Using "model = models.transformers("gpt2", device=torch.device("cuda"))", I am able to run the second example given in the readme. But the first example throws the error: "RuntimeError: CUDA error: device-side assert triggered"

@clayscode
Copy link
Author

Two possibilities:

1 - The FSM takes a long time to build; We can address that more or less easily; 2 - That's a KV cache thing, see #190. We've been laser-focused on guided generation for now. This only requires a small interface change.

It would be amazing if you could run line_profiler on the code and come back to me with the results.

Timer unit: 1e-06 s

Total time: 35.5477 s
File: test.py
Function: main at line 31

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    31                                           @profile
    32                                           def main():
    33         1    2222021.2 2222021.2      6.3      model = models.transformers("gpt2", device=torch.device('cuda'))
    34         1   33325550.6 33325550.6     93.7      sequence = generate.json(model, Character)("Give me a character description")
    35         1         22.4     22.4      0.0      print(sequence)
    36                                               # {
    37                                               #   "name": "ranbelt",
    38                                               #   "age": 26,
    39                                               #   "armor": "chainmail",
    40                                               #   "weapon": "bow",
    41                                               #   "strength": 5
    42                                               # }
    43                                           
    44         1         52.3     52.3      0.0      parsed = Character.model_validate_json(sequence)
    45         1         36.8     36.8      0.0      print(parsed)

@rlouf
Copy link
Member

rlouf commented Aug 16, 2023

Thank you! Could you add outlines.text.generate.Sequence.step and outlines.text.parsing. map_partial_states_to_vocab to the list of functions tracked by line_profiler? That will probably give us enough information to conclude.

@rlouf
Copy link
Member

rlouf commented Aug 16, 2023

I ran the profiling with the following code:

from enum import Enum
from pydantic import BaseModel, constr

import outlines
import outlines.models as models
import outlines.text.generate as generate

import line_profiler


class Weapon(str, Enum):
    sword = "sword"
    axe = "axe"
    mace = "mace"
    spear = "spear"
    bow = "bow"
    crossbow = "crossbow"


class Armor(str, Enum):
    leather = "leather"
    chainmail = "chainmail"
    plate = "plate"


class Character(BaseModel):
    name: constr(max_length=10)
    age: int
    armor: Armor
    weapon: Weapon
    strength: int

def fn():
    model = models.transformers("gpt2")
    sequence = generate.json(model, Character)("Give me a character description")
    print(sequence)


profile = line_profiler.LineProfiler()
profile.add_function(outlines.text.generate.sequence.Sequence.step)
profile.add_function(outlines.text.parsing.map_partial_states_to_vocab)
profile.add_function(outlines.text.parsing.find_partial_matches)
profile(fn)()
profile.print_stats()

And here are the results:

Profiling stats { "name": "ranbelt", "age": 26, "armor": "chainmail", "weapon": "bow", "strength": 5 } Timer unit: 1e-09 s

Total time: 38.6415 s
File: /home/remil/org/test.py
Function: fn at line 35

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    35                                           def fn():
    36         1 7305418488.0 7305418488.0     18.9      model = models.transformers("gpt2")
    37         1 31336078953.0 31336078953.0     81.1      sequence = generate.json(model, Character)("Give me a character description")
    38         1      27796.0  27796.0      0.0      print(sequence)

Total time: 2.83256 s
File: /home/remil/projects/normal/outlines/outlines/text/generate/sequence.py
Function: step at line 43

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    43                                               def step(
    44                                                   self,
    45                                                   rng: torch.Generator,
    46                                                   num_prompt_tokens: int,
    47                                                   token_ids: torch.LongTensor,
    48                                                   attention_mask: torch.LongTensor,
    49                                                   samples: int = 1,
    50                                               ) -> Tuple[torch.LongTensor, torch.FloatTensor]:
    51                                                   """Generate one or several tokens that complete the input sequence.
    52                                           
    53                                                   The sampling step consists in using a model to generate next-token
    54                                                   logits and then sample `samples`-many new tokens from a categorical
    55                                                   distribution parametrized by these logits.
    56                                           
    57                                                   Parameters
    58                                                   ----------
    59                                                   rng
    60                                                       NumPy random number Generator instance
    61                                                   num_prompt_tokens
    62                                                       The number of tokens in the prompt.
    63                                                   token_ids
    64                                                       The token ids passed as an input to the model, of shape `batch_shape
    65                                                       + (num_tokens,)`, where `num_tokens` is the sequences' length.
    66                                                   samples
    67                                                       The number of continuations to sample from the next-token probability
    68                                                       distribution.
    69                                           
    70                                                   Returns
    71                                                   -------
    72                                                   A tuple with an array of shape `new_batch_shape + (num_tokens+1,)`that
    73                                                   contains the completed sequences (input token ids and generated token
    74                                                   ids) and an array of shape `new_batch_shape + (vocab_size,)` that
    75                                                   contains the next token probabilities.
    76                                                   `new_batch_shape` is computed by removing dimensions of size one in
    77                                                   `(samples,) + batch_shape`.
    78                                           
    79                                                   """
    80        63      85042.0   1349.9      0.0          num_input_dims = token_ids.ndim
    81        63 2768696295.0 43947560.2     97.7          probs = self.model(token_ids, attention_mask)
    82        63   43774295.0 694830.1      1.5          probs = self.create_proposal(token_ids[:, num_prompt_tokens:], probs)
    83        63    4682866.0  74331.2      0.2          probs = torch.nn.functional.softmax(probs, dim=-1)
    84                                           
    85                                                   # Sample `samples`-many new tokens
    86        63   12802130.0 203208.4      0.5          next_token_ids = vectorized_random_choice(rng, probs, samples)
    87                                           
    88                                                   # Add the missing `num_tokens` and `num_sample` dimensions
    89        63     130407.0   2070.0      0.0          next_token_ids = torch.unsqueeze(next_token_ids, -1)
    90        63      80579.0   1279.0      0.0          token_ids = torch.unsqueeze(token_ids, 0)
    91                                           
    92                                                   # Expand the input `token_ids` array to be able to concatenate several
    93                                                   # samples.
    94        63      22136.0    351.4      0.0          if samples > 1:
    95                                                       repetitions = (samples,) + (1,) * num_input_dims
    96                                                       token_ids = torch.tile(token_ids, repetitions)
    97                                                       probs = torch.tile(probs, repetitions)
    98                                           
    99        63     820600.0  13025.4      0.0          token_ids = torch.concatenate([token_ids, next_token_ids], axis=-1)
   100                                           
   101                                                   # Merge sample and batch dimensions by removing dimensions of length
   102                                                   # 1. The shape of the resulting arrays is `new_batch_shape + (num_tokens,)`
   103                                                   # and `new_batch_shape + (vocab_size,)` respectively.
   104        63    1176611.0  18676.4      0.0          token_ids = torch.atleast_2d(token_ids.squeeze())
   105        63     277306.0   4401.7      0.0          probs = torch.atleast_2d(probs.squeeze())
   106                                           
   107        63      16207.0    257.3      0.0          return token_ids, probs

Total time: 23.5358 s
File: /home/remil/projects/normal/outlines/outlines/text/parsing.py
Function: find_partial_matches at line 272

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   272                                           def find_partial_matches(
   273                                               fsm: FSM, input_string: str, start_state: Optional[int] = None
   274                                           ) -> Set[Tuple[Optional[int], Tuple[int, ...]]]:
   275                                               """Find the states in the finite state machine `fsm` that accept `input_string`.
   276                                           
   277                                               This will consider all possible states in the finite state machine (FSM)
   278                                               that accept the beginning of `input_string` as starting points, unless a
   279                                               specific `start_state` is provided.
   280                                           
   281                                               Parameters
   282                                               ----------
   283                                               fsm
   284                                                   The finite state machine.
   285                                               input_string
   286                                                   The string for which we generate partial matches.
   287                                               start_state
   288                                                   A single fixed starting state to consider.  For example, if this value
   289                                                   is set to `fsm.initial`, it attempt to read `input_string` from the
   290                                                   beginning of the FSM/regular expression.
   291                                           
   292                                               Returns
   293                                               -------
   294                                               A set of tuples corresponding to each valid starting state in the FSM.  The
   295                                               first element of each tuple contains either ``None`` or an integer
   296                                               indicating the position in `input_string` at which the FSM terminated.  The
   297                                               second element is the tuple of states visited during execution of the FSM
   298                                               plus the next, unvisited transition state.
   299                                           
   300                                               """
   301     43345   23052608.0    531.8      0.1      if len(input_string) == 0 or input_string[0] not in fsm.alphabet:
   302      6974    1059851.0    152.0      0.0          return set()
   303                                           
   304     43345   17966036.0    414.5      0.1      trans_key = fsm.alphabet[input_string[0]]
   305                                           
   306                                               # TODO: We could probably reuse parts of the computed paths when computing
   307                                               # results for multiple starting points.
   308     43345   10812104.0    249.4      0.0      def _partial_match(
   309     43345   26852179.0    619.5      0.1          trans: Dict[int, int]
   310     43345  114723519.0   2646.8      0.5      ) -> Tuple[Optional[int], Optional[Tuple[int, ...]]]:
   311                                                   fsm_map = ChainMap({fsm.initial: trans}, fsm.map)
   312                                                   state = fsm.initial
   313                                                   accepted_states: Tuple[int, ...] = ()
   314                                           
   315                                                   for i, symbol in enumerate(input_string):
   316                                                       if anything_else in fsm.alphabet and symbol not in fsm.alphabet:
   317                                                           symbol = anything_else
   318                                           
   319                                                       trans_key = fsm.alphabet[symbol]
   320                                           
   321                                                       if not (state in fsm_map and trans_key in fsm_map[state]):
   322                                                           if state in fsm.finals:
   323                                                               i -= 1
   324                                                               break
   325                                                           return None, None
   326                                           
   327                                                       state = fsm_map[state][trans_key]
   328                                           
   329                                                       accepted_states += (state,)
   330                                           
   331                                                   terminated = state in fsm.finals
   332                                                   if not terminated and state == fsm.initial:
   333                                                       return None, None
   334                                           
   335                                                   return None if not terminated else i, accepted_states
   336                                           
   337     43345    6550933.0    151.1      0.0      res = set()
   338     43345    4028705.0     92.9      0.0      transition_maps = (
   339     43345    8301108.0    191.5      0.0          fsm.map if start_state is None else {start_state: fsm.map[start_state]}
   340                                               )
   341   6319380  788763081.0    124.8      3.4      for state, trans in transition_maps.items():
   342   4601713  609322267.0    132.4      2.6          if trans_key in trans:
   343   1717667 21599535414.0  12574.9     91.8              n_matched, path = _partial_match(trans)
   344   1255110  176131545.0    140.3      0.7              if path is not None:
   345    462557  144425285.0    312.2      0.6                  res.add((n_matched, (state,) + path))
   346                                           
   347     43345    4310811.0     99.5      0.0      return res

Total time: 23.5642 s
File: /home/remil/projects/normal/outlines/outlines/text/parsing.py
Function: map_partial_states_to_vocab at line 367

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   367                                           def map_partial_states_to_vocab(
   368                                               vocabulary: Iterable[str],
   369                                               terminals_to_fsms_map: Dict[str, FSM],
   370                                               partial_match_filter: Callable[
   371                                                   [str, Optional[int], Tuple[int, ...]], bool
   372                                               ] = lambda *args: True,
   373                                               final_state_string: Optional[str] = None,
   374                                           ) -> Tuple[
   375                                               DefaultDict[PartialParseState, Set[int]], Dict[str, DefaultDict[int, Set[int]]]
   376                                           ]:
   377                                               """Construct a map from partial parse states to subsets of `vocabulary`.
   378                                           
   379                                               The subsets of `vocabulary` consist of elements that are accepted by--or
   380                                               transition to--the corresponding partial parse states.
   381                                           
   382                                               Parameters
   383                                               ----------
   384                                               vocabulary
   385                                                   The vocabulary composed of strings.
   386                                               terminals_to_fsms_map
   387                                                   Terminal symbol names mapped to FSMs, as provided by `terminals_to_fsms`.
   388                                               partial_match_filter
   389                                                   A callable that determines which partial matches to keep.  The first
   390                                                   argument is the string being match, the rest are the unpacked partial
   391                                                   match return values of `find_partial_matches`.
   392                                               final_state_string
   393                                                   A string from `vocabulary` that is to be added to all the final states
   394                                                   in the FSM.
   395                                               """
   396                                           
   397         1       1143.0   1143.0      0.0      final_state_string_idx = None
   398                                           
   399                                               # Partial parse states to the subsets of the vocabulary that accept them
   400         1       2771.0   2771.0      0.0      pstate_to_vocab = defaultdict(set)
   401         1        185.0    185.0      0.0      possible_paths = {}
   402         1       1291.0   1291.0      0.0      for symbol_name, fsm in terminals_to_fsms_map.items():
   403         1        225.0    225.0      0.0          terminal_possible_paths = defaultdict(set)
   404     50257    7373056.0    146.7      0.0          for i, vocab_string in enumerate(vocabulary):
   405     50256    5918091.0    117.8      0.0              if vocab_string == final_state_string:
   406         1        123.0    123.0      0.0                  final_state_string_idx = i
   407                                           
   408    462495 23171425688.0  50100.9     98.3              for end_idx, state_seq in find_partial_matches(fsm, vocab_string):
   409    462482  150993591.0    326.5      0.6                  if partial_match_filter(vocab_string, end_idx, state_seq):
   410    462482  100647597.0    217.6      0.4                      terminal_possible_paths[state_seq[0]].add(state_seq[-1])
   411    462482  127835494.0    276.4      0.5                      pstate_to_vocab[(symbol_name, state_seq[0])].add(i)
   412                                           
   413         1        801.0    801.0      0.0          possible_paths[symbol_name] = terminal_possible_paths
   414                                           
   415         1        140.0    140.0      0.0      if final_state_string_idx is not None:
   416                                                   # Allow transitions to EOS from all terminals FSM states
   417         1        354.0    354.0      0.0          for symbol_name, fsm in terminals_to_fsms_map.items():
   418         1        426.0    426.0      0.0              for state in fsm.finals:
   419         1       1370.0   1370.0      0.0                  pstate_to_vocab[(symbol_name, state)].add(final_state_string_idx)
   420                                           
   421         1        151.0    151.0      0.0      return pstate_to_vocab, possible_paths

Unsurprisingly the bottleneck is the index building. For reference, the corresponding regex:

regex = outlines.text.json_schema.build_regex_from_schema(json.dumps(Character.model_json_schema()))
print(regex)
# \{
#  "name": ".{,10}",
#  "age": (0|[1-9][0-9]*),
#  "armor": ("leather"|"chainmail"|"plate"),
#  "weapon": ("sword"|"axe"|"mace"|"spear"|"bow"|"crossbow"),
#  "strength": (0|[1-9][0-9]*)
# \}

We can also look at the corresponding FSM, which contains ~160 states:

fsm = interegular.parse_pattern(regex).to_fsm()
print(fsm)
FSM name final? \n " , 0 1 2 3 4 5 6 7 8 9 : a b c d e g h i l m n o p r s t w x { } anything_else --------------------------------------------------------------------------------------------------------------------------------------------------------------------- * 0 False 1 1 False 2 2 False 3 3 False 4 4 False 5 5 False 6 6 False 7 7 False 8 8 False 9 9 False 10 10 False 11 11 False 12 12 False 13 13 False 14 15 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 False 16 17 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 15 False 16 17 18 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 False 19 20 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 17 False 19 20 21 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 18 False 22 19 20 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 False 23 24 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 20 False 23 24 25 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 21 False 22 23 24 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 22 False 26 23 False 27 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 24 False 27 28 29 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 25 False 22 27 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 26 False 30 27 False 31 32 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 28 False 31 32 33 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 29 False 22 31 32 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 30 False 34 31 False 35 36 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 32 False 35 36 37 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 33 False 22 35 36 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 34 False 38 35 False 39 40 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 36 False 39 40 41 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 37 False 22 39 40 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 38 False 42 39 False 43 44 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 40 False 43 44 45 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 41 False 22 43 44 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 42 False 46 43 False 47 48 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 44 False 47 48 49 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 45 False 22 47 48 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 46 False 50 47 False 51 48 False 51 52 49 False 22 51 50 False 53 51 False 52 52 False 22 53 False 54 54 False 56 55 55 55 55 55 55 55 55 55 55 False 58 57 57 57 57 57 57 57 57 57 57 56 False 58 57 False 58 57 57 57 57 57 57 57 57 57 57 58 False 59 59 False 60 60 False 61 61 False 62 62 False 63 63 False 64 64 False 65 65 False 66 66 False 67 67 False 68 68 False 69 69 False 70 70 False 71 71 False 74 73 72 72 False 75 73 False 76 74 False 77 75 False 78 76 False 79 77 False 80 78 False 81 79 False 82 80 False 83 81 False 84 82 False 85 83 False 86 84 False 87 85 False 88 86 False 89 87 False 90 88 False 91 89 False 92 90 False 93 91 False 94 92 False 95 93 False 96 94 False 90 95 False 97 96 False 98 97 False 99 98 False 100 99 False 90 100 False 101 101 False 102 102 False 103 103 False 104 104 False 105 105 False 106 106 False 107 107 False 108 108 False 109 109 False 110 110 False 113 114 115 112 111 111 False 116 117 112 False 118 113 False 119 114 False 120 115 False 121 116 False 122 117 False 123 118 False 124 119 False 125 120 False 126 121 False 127 122 False 128 123 False 129 124 False 130 125 False 131 126 False 132 127 False 133 128 False 134 129 False 135 130 False 136 131 False 137 132 False 137 133 False 138 134 False 139 135 False 140 136 False 137 137 False 141 138 False 142 139 False 137 140 False 137 141 False 143 142 False 144 143 False 145 144 False 146 145 False 147 146 False 148 147 False 149 148 False 137 149 False 150 150 False 151 151 False 152 152 False 153 153 False 154 154 False 155 155 False 156 156 False 157 157 False 158 158 False 159 159 False 161 160 160 160 160 160 160 160 160 160 160 False 163 162 162 162 162 162 162 162 162 162 162 161 False 163 162 False 163 162 162 162 162 162 162 162 162 162 162 163 False 164 164 True

There is a very simple optimization for this, which is to hard-code the tokens that correspond to the JSON structure and field names, and memoize map_partial_state_to_vocab so the index for a given regex is not computed more than once. We could do it right now, however there is a chance the results will be biased if we don't fix #161 first.

@brandonwillard brandonwillard changed the title Slow Text Generation on AMD GPU Slow Index Building Aug 17, 2023
@brandonwillard
Copy link
Member

brandonwillard commented Aug 17, 2023

We should also clarify that the index construction only needs to occur once for a given regular expression and vocabulary. As @rlouf said, we haven't set up automatic caching for this, but we can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants