Skip to content

Commit

Permalink
added options strip_prompt to generate()
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-gorner committed Oct 8, 2024
1 parent b9a2026 commit d0530f7
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,7 @@ def normalize(x):
return normalize([x for x in outputs])

def generate(
self,
inputs,
max_length=None,
stop_token_ids="auto",
self, inputs, max_length=None, stop_token_ids="auto", strip_prompt=False
):
"""Generate text given prompt `inputs`.
Expand Down Expand Up @@ -309,6 +306,9 @@ def generate(
specify a list of token id's the model should stop on. Note that
sequences of tokens will each be interpreted as a stop token,
multi-token stop sequences are not supported.
strip_prompt: Optional. By default, generate() returns the full prompt
followed by its completion generated by the model. If this option
is set to True, only the newle generated text is returned.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
Expand Down Expand Up @@ -339,6 +339,34 @@ def preprocess(x):
def generate(x):
return generate_function(x, stop_token_ids=stop_token_ids)

def strip_prompt_fn(x, prompt):
# This function removes the prompt from the generated
# response, in a batch-friendly fashion.

# We need to shift every output sequence by the size of the prompt.
mask = prompt["padding_mask"]
rows = ops.expand_dims(ops.arange(mask.shape[0]), 1)
cols = ops.expand_dims(ops.arange(mask.shape[1]), 0)

shifts = -ops.sum(mask.astype(int), axis=1)
shifts %= mask.shape[1]
cols = cols - ops.expand_dims(shifts, axis=1)

# Indexing by [rows, cols] produces the desired shift (in fact a rollover).

y = {}
# The shifting rolls the content over so the prompt is at the end of
# the sequence and the generated text is at the begining. We mask
# it to retain the generated text only.
rolled_in_mask = mask[rows, cols]
rolled_out_mask = x["padding_mask"][rows, cols]
y["padding_mask"] = ops.logical_xor(rolled_in_mask, rolled_out_mask)

# we assume the mask is enough and there is no need to zero-out the values
y["token_ids"] = x["token_ids"][rows, cols]

return y

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

Expand All @@ -347,7 +375,12 @@ def postprocess(x):

if self.preprocessor is not None:
inputs = [preprocess(x) for x in inputs]
outputs = [generate(x) for x in inputs]

if strip_prompt:
outputs = [strip_prompt_fn(generate(x), x) for x in inputs]
else:
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

Expand Down

0 comments on commit d0530f7

Please sign in to comment.