Skip to content

Commit

Permalink
fix for tensorflow: the compiled version of generate(strip_prompt=Tru…
Browse files Browse the repository at this point in the history
…e) now works + code refactoring to make it more understandable
  • Loading branch information
martin-gorner committed Oct 9, 2024
1 parent 6ed9b65 commit 16f6109
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions keras_hub/src/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,28 +342,27 @@ def generate(x):
def strip_prompt_fn(x, prompt):
# This function removes the prompt from the generated
# response, in a batch-friendly fashion.
y = {}
prompt_mask = prompt["padding_mask"]
seq_len = prompt_mask.shape[1]

# We need to shift every output sequence by the size of the prompt.
mask = prompt["padding_mask"]
rows = ops.expand_dims(ops.arange(mask.shape[0]), 1)
cols = ops.expand_dims(ops.arange(mask.shape[1]), 0)

shifts = -ops.sum(mask.astype(int), axis=1)
shifts %= mask.shape[1]
cols = cols - ops.expand_dims(shifts, axis=1)
shifts = -ops.sum(ops.cast(prompt_mask, "int"), axis=1) % seq_len
ix = ops.arange(seq_len, dtype="int")
ix = ops.expand_dims(ix, axis=0) - ops.expand_dims(shifts, axis=1)

# Indexing by [rows, cols] produces the desired shift (in fact a rollover).
# This produces the desired shift (in fact a rollover).
def roll_sequence(seq):
return ops.take_along_axis(seq, ix, axis=1)

y = {}
# The shifting rolls the content over so the prompt is at the end of
# the sequence and the generated text is at the begining. We mask
# it to retain the generated text only.
rolled_in_mask = mask[rows, cols]
rolled_out_mask = x["padding_mask"][rows, cols]
y["padding_mask"] = ops.logical_xor(rolled_in_mask, rolled_out_mask)

y["padding_mask"] = ops.logical_xor(
roll_sequence(prompt_mask), roll_sequence(x["padding_mask"])
)
# we assume the mask is enough and there is no need to zero-out the values
y["token_ids"] = x["token_ids"][rows, cols]
y["token_ids"] = roll_sequence(x["token_ids"])

return y

Expand Down

0 comments on commit 16f6109

Please sign in to comment.