Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding option strip_prompt to generate() #1913

Merged
merged 9 commits into from
Oct 16, 2024
38 changes: 37 additions & 1 deletion keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def generate(
inputs,
max_length=None,
stop_token_ids="auto",
strip_prompt=False,
):
"""Generate text given prompt `inputs`.

Expand Down Expand Up @@ -309,6 +310,9 @@ def generate(
specify a list of token id's the model should stop on. Note that
sequences of tokens will each be interpreted as a stop token,
multi-token stop sequences are not supported.
strip_prompt: Optional. By default, generate() returns the full prompt
followed by its completion generated by the model. If this option
is set to True, only the newly generated text is returned.
"""
# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
Expand Down Expand Up @@ -339,6 +343,33 @@ def preprocess(x):
def generate(x):
return generate_function(x, stop_token_ids=stop_token_ids)

def strip_prompt_function(x, prompt):
# This function removes the prompt from the generated
# response, in a batch-friendly fashion.
martin-gorner marked this conversation as resolved.
Show resolved Hide resolved
y = {}
prompt_mask = prompt["padding_mask"]
seq_len = prompt_mask.shape[1]

# We need to shift every output sequence by the size of the prompt.
shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len
ix = ops.arange(seq_len, dtype="int")
ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1)

# This produces the desired shift (in fact a rollover).
def roll_sequence(seq):
return ops.take_along_axis(seq, ix, axis=1)

# The shifting rolls the content over so the prompt is at the end of
# the sequence and the generated text is at the beginning. We mask
# it to retain the generated text only.
y["padding_mask"] = ops.logical_xor(
roll_sequence(prompt_mask), roll_sequence(x["padding_mask"])
)
# we assume the mask is enough and there is no need to zero-out the values
y["token_ids"] = roll_sequence(x["token_ids"])

return y

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

Expand All @@ -347,7 +378,12 @@ def postprocess(x):

if self.preprocessor is not None:
inputs = [preprocess(x) for x in inputs]
outputs = [generate(x) for x in inputs]

if strip_prompt:
outputs = [strip_prompt_function(generate(x), x) for x in inputs]
else:
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

Expand Down
6 changes: 6 additions & 0 deletions keras_hub/src/models/llama3/llama3_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def test_generate(self):
prompt_ids["padding_mask"][:, :5],
)

def test_generate_strip_prompt(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
prompt = " airplane at airport"
output = causal_lm.generate(prompt, strip_prompt=True)
self.assertFalse(output.startswith(prompt))

def test_early_stopping(self):
causal_lm = Llama3CausalLM(**self.init_kwargs)
call_with_cache = causal_lm.call_with_cache
Expand Down
20 changes: 18 additions & 2 deletions keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
martin-gorner marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import keras
import pytest
import tensorflow as tf

from keras_hub.src.tests.test_case import TestCase
Expand All @@ -15,7 +14,6 @@
)


@pytest.mark.large
class BytePairTokenizerTest(TestCase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -111,6 +109,24 @@ def test_whitespace_split(self):
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29])

# This is important for Llama3 which uses the \n\n sequence in chat
# templates: \n\n must be tokenized as a single token
input_data = "Hello\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 31414])

input_data = "Hello\n\n\n\nHello"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140, 31414])

input_data = "Hello\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140])

input_data = "Hello\n\n\n\n"
encoded = self.tokenizer(input_data)
self.assertAllEqual(encoded, [31414, 50140, 50140])

def test_special_whitespace(self):
input_data = "\xa0 \xa0 \x3000 s"
encoded = self.tokenizer(input_data)
Expand Down
Loading