diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 9562f95d14..87948439a8 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,11 +18,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -42,55 +39,17 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Return only the beam with the highest accumulated probability. ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.BeamSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzeeeeeee'] - ``` + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - Return all beams and their probabilities. - ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 8, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - - print(beams.shape) - # >>> (1, 5, 8) - print(probs.shape) - # >>> (1, 5) - print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()]) - # >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb'] + # Pass by name to compile. + causal_lm.compile(sampler="beam") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.BeamSampler(num_beams=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bac65bcfbe..8b3d52d9a5 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,11 +17,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): """Contrastive Sampler class. @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord("a")) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - hidden_size = 5 - index = 5 - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, hidden_size)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.ContrastiveSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=index, - hidden_states=np.ones([batch_size, index, hidden_size]), - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzeeeeeee" + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="contrastive") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.ContrastiveSampler(k=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 8e178b7468..7f93444bb6 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -15,11 +15,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -27,29 +24,18 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Call arguments: - {{call_args}} - Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="greedy") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.GreedySampler() + causal_lm.compile(sampler=keras_nlp.samplers.GreedySampler()) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index b922d29b2a..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.RandomSampler") class RandomSampler(Sampler): """Random Sampler class. @@ -37,24 +34,16 @@ class RandomSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, state, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, state + # Pass by name to compile. + causal_lm.compile(sampler="random") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.RandomSampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzcpnjqij'] + # Pass by object to compile. + sampler = keras_nlp.samplers.RandomSampler(temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e28fbe9d6e..2101c9277d 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,33 +17,8 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random -from keras_nlp.utils.python_utils import format_docstring - -call_args_docstring = """next: A function which takes in the - `prompt, cache, index` of the current generation loop, and outputs - a tuple `(logits, hidden_states, cache)` with `logits` being the - logits of next token, `hidden_states` being the representation of - the next token, and `cache` for next iteration. - prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This - tensor will be iteratively updated column by column with new sampled - values, starting at `index`. - cache: Optional. A tensor or nested structure of tensors that will be - updated by each call to `next`. This can be used to cache - computations from early iterations of the generative loop. - index: Optional. The first index of `prompt` to start sampling at. - Usually this is set as the length of the shortest non-padded - sequence in `prompt`. - mask: Optional. A 2D integer tensor with the same shape as `prompt`. - Locations which are `True` in the mask are never updated during - sampling. Usually used to mark all locations in the dense prompt - tensor which were present in a user input. - end_token_id: Optional. The token marking the end of the sequence. If - specified, sampling will stop as soon as all sequences in the prompt - produce a `end_token_id` in a location where `mask` is `False`. -""" - - -@format_docstring(call_args=call_args_docstring) + + @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. @@ -57,35 +32,32 @@ class Sampler: {{call_args}} This base class can be extended to implement different auto-regressive - sampling methods. Subclasses can either: - - - Override the `get_next_token()` method, which computes the next token - based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method needs additional information - beyond the next tokens probability distribution to sample a sequence. - - Please check available subclass samplers for examples. + sampling methods. To do so, override the `get_next_token()` method, which + computes the next token based on a probability distribution over all + possible vocab entries. Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - # return a uniform distribution over our alphabet. - logits = ops.ones((batch_size, vocab_size)) - return logits, None, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=ops.fill((batch_size, length,), char_lookup['z']), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Greedy search with some tokens forbidden. + class CustomSampler(keras_nlp.samplers.Sampler): + def __init__(self, forbidden_tokens, **kwargs): + super().__init__(**kwargs) + self.forbidden_tokens = forbidden_tokens + + def get_next_token(self, probs): + batch_size, vocab_size = keras.ops.shape(probs) + for id in self.forbidden_tokens: + update = keras.ops.zeros((batch_size, 1)) + probs = keras.ops.slice_update(probs, (0, id), update) + return keras.ops.argmax(probs, axis=-1) + + # 257 = "a" with a leading space, 262 = "the" with a leading space. + causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262])) + causal_lm.summary() + causal_lm.generate(["That's strange"]) ``` """ diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 3456694848..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -38,24 +35,16 @@ class TopKSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_k") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopKSampler(k=3)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzacbbcaa'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index a04b39aa2b..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -46,24 +43,16 @@ class TopPSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_p") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopPSampler(p=0.1)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzbabcccb'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """