From 91ff01d2784da3271095817c9a31b3477d413a92 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:16:59 +0200 Subject: [PATCH 01/70] Fix BP special tokens tokenization --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 66 +++++++++++-------- .../tokenizers/byte_pair_tokenizer_test.py | 11 ++++ 2 files changed, 50 insertions(+), 27 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 902af812e9..6ca95963c7 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -63,17 +63,10 @@ SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" -def create_alts_for_unsplittable_tokens(unsplittable_tokens): - # Create alternates for all special tokens that will be not split during - # tokenization. - alts = [] - prefix = "Ĵ" - # Trim out splitters. - replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" - for token in unsplittable_tokens: - token = re.sub(replace_pattern, "", token) - alts.append(prefix + token) - return alts +def get_unsplittable_tokens_pattern(unsplittable_tokens): + if unsplittable_tokens is None or len(unsplittable_tokens) == 0: + return None + return r"|".join([re.escape(token) for token in unsplittable_tokens]) def bytes_to_unicode(): @@ -108,7 +101,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs, unsplittable_tokens=None): +def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -120,24 +113,38 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None): inputs = tf.strings.regex_replace( inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" ) - if unsplittable_tokens: - alts = create_alts_for_unsplittable_tokens(unsplittable_tokens) - for token, alt in zip(unsplittable_tokens, alts): - escaped_token = re.escape(token) - inputs = tf_text.regex_split(inputs, escaped_token, escaped_token) - inputs = tf.strings.regex_replace(inputs, escaped_token, alt) - raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + + if unsplittable_tokens_pattern is not None: + # First split the unsplittable tokens from the input. + raw_tokens = tf_text.regex_split( + inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern + ) + split_pattern_1_with_unsplittable_tokens = r"|".join( + [unsplittable_tokens_pattern, SPLIT_PATTERN_1] + ) + # Then split using both `unsplittable_tokens_pattern` and + # `SPLIT_PATTERN_1` to split inputs like original gpt2, while not + # affecting the unsplittable tokens. + # We split unsplittable tokens first then apply this split instead of + # applying this split directly, because otherwise we will not split + # unsplittable tokens from inputs properly, because of this pattern + # ` ?[^\s\p{L}\p{N}{special_spaces}]+`. + # e.g., [" "] will be [" "] instead of [" ", ""] + raw_tokens = tf_text.regex_split( + raw_tokens, + split_pattern_1_with_unsplittable_tokens, + split_pattern_1_with_unsplittable_tokens, + ) + raw_tokens = raw_tokens.merge_dims(-2, -1) + else: + raw_tokens = tf_text.regex_split( + inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1 + ) + # Second pass splits out the last whilespace char or "६". raw_tokens = tf_text.regex_split( raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 ) - if unsplittable_tokens: - # Replace special tokens alternate with originals. - for token, alt in zip(unsplittable_tokens, alts): - escaped_alt = re.escape(alt) - raw_tokens = tf.strings.regex_replace( - raw_tokens, escaped_alt, token - ) while raw_tokens.shape.rank > 2: raw_tokens = raw_tokens.merge_dims(1, 2) return remove_strings_from_inputs(raw_tokens, "६") @@ -298,6 +305,9 @@ def __init__( self.sequence_length = sequence_length self.add_prefix_space = add_prefix_space self.unsplittable_tokens = unsplittable_tokens + self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern( + unsplittable_tokens + ) # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -544,7 +554,9 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + raw_tokens = split_strings_for_bpe( + inputs, self._unsplittable_tokens_pattern + ) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 00f8f9b87f..9752966a17 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -80,6 +80,17 @@ def test_tokenize_with_special_tokens(self): output = tokenizer("sp") self.assertAllEqual(output, [0]) + vocab = {"": 0, "": 1, "a": 2, "Ġquick": 3, "Ġfox": 4} + merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + merges += ["Ġ f", "o x", "Ġf ox"] + tokenizer = BytePairTokenizer( + vocabulary=vocab, + merges=merges, + unsplittable_tokens=["", ""], + ) + output = tokenizer("a quick fox") + self.assertAllEqual(output, [0, 2, 3, 4, 1]) + def test_tokenize_prefix_space(self): input_data = ["brown.", "black."] tokenizer = BytePairTokenizer( From 6c26d848ba4ac364fa3c2d92bc6c1f20af024da1 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:29:40 +0200 Subject: [PATCH 02/70] Add test to Bart --- keras_nlp/models/bart/bart_tokenizer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py index 5a0015357b..afe3faced5 100644 --- a/keras_nlp/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -38,9 +38,9 @@ def test_tokenizer_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) From b4769ea0e5ecb5fd69bad3055a56001d03224b20 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:32:25 +0200 Subject: [PATCH 03/70] Add test to Bloom --- keras_nlp/models/bloom/bloom_tokenizer_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_tokenizer_test.py b/keras_nlp/models/bloom/bloom_tokenizer_test.py index 9ae9c0cc00..2fbebdf1ab 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer_test.py +++ b/keras_nlp/models/bloom/bloom_tokenizer_test.py @@ -28,8 +28,8 @@ def setUp(self): self.merges += ["Ġai r", "Ġa i", "pla ne"] self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} self.input_data = [ - "airplane at airport", - " airplane airport", + "airplane at airport", + " airplane airport", ] def test_tokenizer_basics(self): @@ -37,7 +37,7 @@ def test_tokenizer_basics(self): cls=BloomTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]], + expected_output=[[6, 1, 3, 4, 2, 5, 7, 8], [6, 2, 3, 2, 5, 7, 8]], ) def test_errors_missing_special_tokens(self): From 814836bbb3a67357e527bf0fe3cfc7e728c0f102 Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:32:50 +0200 Subject: [PATCH 04/70] Remove ToDo comment --- keras_nlp/models/bart/bart_tokenizer_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py index afe3faced5..b18629939c 100644 --- a/keras_nlp/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -37,7 +37,6 @@ def test_tokenizer_basics(self): cls=BartTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ " airplane at airport", From f451351cb96ec157c28e6bfd405e21d26935358f Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:41:41 +0200 Subject: [PATCH 05/70] Add tests for Roberta --- keras_nlp/models/roberta/roberta_tokenizer_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py index 3b2305608d..2c40290c97 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py @@ -28,7 +28,7 @@ def setUp(self): self.merges += ["Ġai r", "Ġa i", "pla ne"] self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} self.input_data = [ - " airplane at airport", + " airplane at airport", " airplane airport", ] @@ -38,9 +38,9 @@ def test_tokenizer_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 8, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) From 464555c50193319a51a319dc1b8a6c5ba989218f Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Tue, 20 Feb 2024 15:44:17 +0200 Subject: [PATCH 06/70] Remove roberta todo comment --- keras_nlp/models/roberta/roberta_tokenizer_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py index 2c40290c97..572bc03151 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py @@ -37,7 +37,6 @@ def test_tokenizer_basics(self): cls=RobertaTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as expected_output=[[0, 4, 5, 6, 4, 7, 8, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ " airplane at airport", From ab7b48a21f4bb89ea1e321509b279a2cae5a430d Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 21 Feb 2024 19:57:02 +0200 Subject: [PATCH 07/70] Fix split comment --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 6ca95963c7..2ac8832a76 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -129,7 +129,7 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None): # applying this split directly, because otherwise we will not split # unsplittable tokens from inputs properly, because of this pattern # ` ?[^\s\p{L}\p{N}{special_spaces}]+`. - # e.g., [" "] will be [" "] instead of [" ", ""] + # e.g., [" "] will be [" "] instead of [" ", ""] raw_tokens = tf_text.regex_split( raw_tokens, split_pattern_1_with_unsplittable_tokens, From 996fc4867239ae1303ff2a97e177319018844c47 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Feb 2024 15:56:13 -0800 Subject: [PATCH 08/70] Update our sampler documentation to reflect usage (#1444) We will update our samplers in the near future to push the backend specific compilation details out: https://github.com/keras-team/keras-nlp/pull/1425 Also in general, we want our documentation to reflect the main usage of our classes, which is using them with Seq2SeqLM and CausalLM classes. So with that in mind, this updates our sampler docs to show the practical usage of the sampling classes with our modeling classes. For the base class, we show the main use case of overriding the `get_next_token()` function. --- keras_nlp/samplers/beam_sampler.py | 59 +++--------------- keras_nlp/samplers/contrastive_sampler.py | 35 +++-------- keras_nlp/samplers/greedy_sampler.py | 34 +++------- keras_nlp/samplers/random_sampler.py | 27 +++----- keras_nlp/samplers/sampler.py | 76 +++++++---------------- keras_nlp/samplers/top_k_sampler.py | 27 +++----- keras_nlp/samplers/top_p_sampler.py | 27 +++----- 7 files changed, 77 insertions(+), 208 deletions(-) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 9562f95d14..87948439a8 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,11 +18,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -42,55 +39,17 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Return only the beam with the highest accumulated probability. ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.BeamSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzeeeeeee'] - ``` + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - Return all beams and their probabilities. - ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 8, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - - print(beams.shape) - # >>> (1, 5, 8) - print(probs.shape) - # >>> (1, 5) - print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()]) - # >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb'] + # Pass by name to compile. + causal_lm.compile(sampler="beam") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.BeamSampler(num_beams=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bac65bcfbe..8b3d52d9a5 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,11 +17,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): """Contrastive Sampler class. @@ -44,28 +41,16 @@ class ContrastiveSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord("a")) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - hidden_size = 5 - index = 5 - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, hidden_size)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.ContrastiveSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=index, - hidden_states=np.ones([batch_size, index, hidden_size]), - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzeeeeeee" + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="contrastive") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.ContrastiveSampler(k=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 8e178b7468..ee8a6ecc2d 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -15,11 +15,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -27,29 +24,18 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Call arguments: - {{call_args}} - Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="greedy") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.GreedySampler() + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index b922d29b2a..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.RandomSampler") class RandomSampler(Sampler): """Random Sampler class. @@ -37,24 +34,16 @@ class RandomSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, state, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, state + # Pass by name to compile. + causal_lm.compile(sampler="random") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.RandomSampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzcpnjqij'] + # Pass by object to compile. + sampler = keras_nlp.samplers.RandomSampler(temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e28fbe9d6e..2101c9277d 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,33 +17,8 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random -from keras_nlp.utils.python_utils import format_docstring - -call_args_docstring = """next: A function which takes in the - `prompt, cache, index` of the current generation loop, and outputs - a tuple `(logits, hidden_states, cache)` with `logits` being the - logits of next token, `hidden_states` being the representation of - the next token, and `cache` for next iteration. - prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This - tensor will be iteratively updated column by column with new sampled - values, starting at `index`. - cache: Optional. A tensor or nested structure of tensors that will be - updated by each call to `next`. This can be used to cache - computations from early iterations of the generative loop. - index: Optional. The first index of `prompt` to start sampling at. - Usually this is set as the length of the shortest non-padded - sequence in `prompt`. - mask: Optional. A 2D integer tensor with the same shape as `prompt`. - Locations which are `True` in the mask are never updated during - sampling. Usually used to mark all locations in the dense prompt - tensor which were present in a user input. - end_token_id: Optional. The token marking the end of the sequence. If - specified, sampling will stop as soon as all sequences in the prompt - produce a `end_token_id` in a location where `mask` is `False`. -""" - - -@format_docstring(call_args=call_args_docstring) + + @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. @@ -57,35 +32,32 @@ class Sampler: {{call_args}} This base class can be extended to implement different auto-regressive - sampling methods. Subclasses can either: - - - Override the `get_next_token()` method, which computes the next token - based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method needs additional information - beyond the next tokens probability distribution to sample a sequence. - - Please check available subclass samplers for examples. + sampling methods. To do so, override the `get_next_token()` method, which + computes the next token based on a probability distribution over all + possible vocab entries. Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - # return a uniform distribution over our alphabet. - logits = ops.ones((batch_size, vocab_size)) - return logits, None, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=ops.fill((batch_size, length,), char_lookup['z']), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Greedy search with some tokens forbidden. + class CustomSampler(keras_nlp.samplers.Sampler): + def __init__(self, forbidden_tokens, **kwargs): + super().__init__(**kwargs) + self.forbidden_tokens = forbidden_tokens + + def get_next_token(self, probs): + batch_size, vocab_size = keras.ops.shape(probs) + for id in self.forbidden_tokens: + update = keras.ops.zeros((batch_size, 1)) + probs = keras.ops.slice_update(probs, (0, id), update) + return keras.ops.argmax(probs, axis=-1) + + # 257 = "a" with a leading space, 262 = "the" with a leading space. + causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262])) + causal_lm.summary() + causal_lm.generate(["That's strange"]) ``` """ diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 3456694848..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -38,24 +35,16 @@ class TopKSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_k") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopKSampler(k=3)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzacbbcaa'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index a04b39aa2b..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -46,24 +43,16 @@ class TopPSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_p") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopPSampler(p=0.1)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzbabcccb'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ From 9da7400508467a71aefc0bf37da9a36790a53464 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 20 Feb 2024 20:04:32 -0800 Subject: [PATCH 09/70] Add Gemma model (#1448) The Keras implementation of the Gemma model was the effort of a number of contributors: - Initial architecture: Gabriel Rasskin, Francois Chollet, Matt Watson - Model parallelism: Qianli Scott Zhu - Model export for inference: Neel Kovelamudi - Lora implementation: Francois Chollet, Samaneh Saadat - Benchmarking: Haifeng Jin - Intepretability extensions: Ryan Mullins - Testing infrastructure: Ramesh Sampath Many more helped with documentaiton and Kaggle integration. Co-authored-by: Francois Chollet Co-authored-by: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Co-authored-by: Qianli Scott Zhu Co-authored-by: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Co-authored-by: Samaneh Saadat Co-authored-by: Haifeng Jin <5476582+haifeng-jin@users.noreply.github.com> Co-authored-by: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Co-authored-by: Ryan Mullins --- .../modeling/transformer_layer_utils.py | 9 +- keras_nlp/models/__init__.py | 7 + keras_nlp/models/backbone.py | 77 +++ keras_nlp/models/bart/bart_seq_2_seq_lm.py | 1 + keras_nlp/models/gemma/__init__.py | 13 + keras_nlp/models/gemma/gemma_attention.py | 197 ++++++++ keras_nlp/models/gemma/gemma_backbone.py | 267 +++++++++++ keras_nlp/models/gemma/gemma_backbone_test.py | 128 +++++ keras_nlp/models/gemma/gemma_causal_lm.py | 441 ++++++++++++++++++ .../gemma/gemma_causal_lm_preprocessor.py | 173 +++++++ .../gemma_causal_lm_preprocessor_test.py | 92 ++++ .../models/gemma/gemma_causal_lm_test.py | 245 ++++++++++ keras_nlp/models/gemma/gemma_decoder_block.py | 189 ++++++++ keras_nlp/models/gemma/gemma_lora_test.py | 102 ++++ keras_nlp/models/gemma/gemma_preprocessor.py | 199 ++++++++ .../models/gemma/gemma_preprocessor_test.py | 74 +++ keras_nlp/models/gemma/gemma_presets.py | 66 +++ keras_nlp/models/gemma/gemma_tokenizer.py | 108 +++++ .../models/gemma/gemma_tokenizer_test.py | 67 +++ keras_nlp/models/gemma/rms_normalization.py | 40 ++ keras_nlp/models/generative_task.py | 17 +- keras_nlp/models/gpt2/gpt2_causal_lm.py | 1 + .../models/gpt_neo_x/gpt_neo_x_causal_lm.py | 1 + keras_nlp/models/opt/opt_causal_lm.py | 1 + keras_nlp/models/t5/t5_transformer_layer.py | 3 +- keras_nlp/samplers/beam_sampler.py | 2 + keras_nlp/samplers/contrastive_sampler.py | 2 + keras_nlp/samplers/sampler.py | 58 ++- .../tests/test_data/gemma_test_vocab.spm | Bin 0 -> 237805 bytes .../tokenizers/sentence_piece_tokenizer.py | 2 + keras_nlp/utils/preset_utils.py | 8 + tools/gemma/export_gemma_to_hf.py | 328 +++++++++++++ tools/gemma/export_gemma_to_torch_xla.py | 322 +++++++++++++ tools/gemma/run_gemma_xla.py | 287 ++++++++++++ .../create_gemma_test_proto.py | 36 ++ 35 files changed, 3538 insertions(+), 25 deletions(-) create mode 100644 keras_nlp/models/gemma/__init__.py create mode 100644 keras_nlp/models/gemma/gemma_attention.py create mode 100644 keras_nlp/models/gemma/gemma_backbone.py create mode 100644 keras_nlp/models/gemma/gemma_backbone_test.py create mode 100644 keras_nlp/models/gemma/gemma_causal_lm.py create mode 100644 keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py create mode 100644 keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py create mode 100644 keras_nlp/models/gemma/gemma_causal_lm_test.py create mode 100644 keras_nlp/models/gemma/gemma_decoder_block.py create mode 100644 keras_nlp/models/gemma/gemma_lora_test.py create mode 100644 keras_nlp/models/gemma/gemma_preprocessor.py create mode 100644 keras_nlp/models/gemma/gemma_preprocessor_test.py create mode 100644 keras_nlp/models/gemma/gemma_presets.py create mode 100644 keras_nlp/models/gemma/gemma_tokenizer.py create mode 100644 keras_nlp/models/gemma/gemma_tokenizer_test.py create mode 100644 keras_nlp/models/gemma/rms_normalization.py create mode 100644 keras_nlp/tests/test_data/gemma_test_vocab.spm create mode 100644 tools/gemma/export_gemma_to_hf.py create mode 100644 tools/gemma/export_gemma_to_torch_xla.py create mode 100644 tools/gemma/run_gemma_xla.py create mode 100644 tools/sentencepiece_testing/create_gemma_test_proto.py diff --git a/keras_nlp/layers/modeling/transformer_layer_utils.py b/keras_nlp/layers/modeling/transformer_layer_utils.py index 863da59a36..f375bf1b9d 100644 --- a/keras_nlp/layers/modeling/transformer_layer_utils.py +++ b/keras_nlp/layers/modeling/transformer_layer_utils.py @@ -55,9 +55,12 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0): `(batch_size, output_length, input_length)` that can be passed to a attention layer. """ - i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index - j = ops.arange(input_length) - mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0) + i = ops.arange(output_length, dtype="float32") + i = i + ops.cast(cache_index, "float32") + i = ops.expand_dims(i, axis=1) + j = ops.arange(input_length, dtype="float32") + mask = ops.expand_dims(i >= j, axis=0) + return ops.broadcast_to(mask, (batch_size, output_length, input_length)) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 8fd6a70ac0..cdd50670f3 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -75,6 +75,13 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 6fccf6013a..9c8cdaa60e 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -152,3 +152,80 @@ def from_preset(calling_cls, *args, **kwargs): example_preset_name=next(iter(cls.presets), ""), preset_names='", "'.join(cls.presets), )(cls.from_preset.__func__) + + def enable_lora(self, rank): + """Enable Lora on the backbone. + + Calling this method will freeze all weights on the backbone, + while enabling Lora on the query & value `EinsumDense` layers + of the attention layers. + """ + target_names = ["query_dense", "value_dense", "query", "value"] + self.trainable = True + self._lora_enabled_layers = [] + self._lora_rank = rank + for layer in self._flatten_layers(include_self=False): + layer.trainable = False + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for i, layer in enumerate(all_layers): + for name in target_names: + if layer.name == name: + if hasattr(layer, "enable_lora"): + layer.trainable = True + layer.enable_lora(rank) + self._lora_enabled_layers.append(i) + + def save_lora_weights(self, filepath): + if not getattr(self, "_lora_enabled_layers", []): + raise ValueError( + "There are no lora-enabled layers in this model. " + "Make sure to call `.enable_lora(rank)` first." + ) + if not str(filepath).endswith(".lora.h5"): + raise ValueError( + "The filename must end in `.lora.h5`. " + f"Received: filepath={filepath}" + ) + + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w") + lora_store = store.make("lora") + lora_store["rank"] = self._lora_rank + # We cannot identify layers by name since names are non-unique, + # so we identify them by index in the topologically sorted list + # of layers that have weights. + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + # We only lora the einsumdense layers, + # so the factored weights are always named `kernel` + layer = all_layers[layer_index] + inner_store = store.make(f"lora/{layer_index}") + inner_store["lora_kernel_a"] = layer.lora_kernel_a + inner_store["lora_kernel_b"] = layer.lora_kernel_b + store.close() + + def load_lora_weights(self, filepath): + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r") + lora_store = store.get("lora") + rank = int(lora_store["rank"][()]) + + if not getattr(self, "_lora_enabled_layers", []): + self.enable_lora(rank) + else: + if self._lora_rank != rank: + raise ValueError( + f"The Lora rank expected by file '{filepath}' " + f"is rank={rank}, but the model was called with " + f"`.enable_lora(rank={self._lora_rank})`. " + "Both ranks must match." + ) + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + layer = all_layers[layer_index] + lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"] + lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"] + layer.lora_kernel_a.assign(lora_kernel_a) + layer.lora_kernel_b.assign(lora_kernel_b) + store.close() diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index c17eafdb02..c530555b3d 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -479,6 +479,7 @@ def repeat_tensor(x): mask=decoder_padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gemma/__init__.py b/keras_nlp/models/gemma/__init__.py new file mode 100644 index 0000000000..ba0c2545e4 --- /dev/null +++ b/keras_nlp/models/gemma/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py new file mode 100644 index 0000000000..80c2ac6a63 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -0,0 +1,197 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.utils.keras_utils import clone_initializer + + +class CachedGemmaAttention(keras.layers.Layer): + """A cached grouped query attention layer.""" + + def __init__( + self, + head_dim, + num_query_heads, + num_key_value_heads, + kernel_initializer="glorot_uniform", + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.num_key_value_groups = num_query_heads // num_key_value_heads + + def build(self, inputs_shape): + self.hidden_dim = inputs_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + "btd,ndh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + self.softmax = keras.layers.Softmax(dtype="float32") + self.built = True + + def _apply_rope(self, x, positions): + """Rope rotate q or k.""" + # TODO: refactor to use RotaryEmbedding layer? + max_wavelength = 10000 + x_shape = ops.shape(x) + freq_exponents = (2.0 / x_shape[-1]) * ops.cast( + ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype + ) + timescale = max_wavelength**freq_exponents + radians = positions[..., None] / timescale[None, None, :] + radians = radians[..., None, :] + sin, cos = ops.sin(radians), ops.cos(radians) + x1, x2 = ops.split(x, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) + return ops.reshape(output, x_shape) + + def _compute_attention( + self, + q, + k, + v, + attention_mask, + training=False, + ): + query_normalization = 1 / np.sqrt(self.head_dim) + + q *= ops.cast(query_normalization, dtype=q.dtype) + q_shape = ops.shape(q) + q = ops.reshape( + q, + ( + *q_shape[:-2], + self.num_key_value_heads, + self.num_query_heads // self.num_key_value_heads, + q_shape[-1], + ), + ) + b, q_len, _, _, h = ops.shape(q) + + attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k) + attention_mask = attention_mask[:, None, None, :, :] + orig_dtype = attention_logits.dtype + attention_softmax = self.softmax(attention_logits, mask=attention_mask) + attention_softmax = ops.cast(attention_softmax, orig_dtype) + + if self.dropout: + attention_softmax = self.dropout_layer( + attention_softmax, training=training + ) + + results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) + return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + + def call( + self, + x, + attention_mask=None, + cache=None, + cache_update_index=0, + training=False, + ): + seq_len = ops.shape(x)[1] + start_index = cache_update_index + positions = ops.cast( + ops.arange(seq_len, dtype="float32"), self.compute_dtype + ) + positions = positions + ops.cast(start_index, self.compute_dtype) + query = self.query_dense(x) + query = self._apply_rope(query, positions) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + key_update = self.key_dense(x) + key_update = self._apply_rope(key_update, positions) + value_update = self.value_dense(x) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + key = self.key_dense(x) + key = self._apply_rope(key, positions) + value = self.value_dense(x) + + attention_vec = self._compute_attention( + query, key, value, attention_mask, training=training + ) + + # Wipe attn vec if there are no attended tokens. + no_attended_tokens = ops.all( + ops.equal(attention_mask, 0), axis=-1, keepdims=True + )[..., None] + attention_vec = ops.where( + no_attended_tokens, ops.zeros_like(attention_vec), attention_vec + ) + + attention_output = self.output_dense(attention_vec) + + if cache is not None: + return attention_output, cache + return attention_output diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py new file mode 100644 index 0000000000..e5814940aa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -0,0 +1,267 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.gemma.gemma_decoder_block import GemmaDecoderBlock +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.rms_normalization import RMSNormalization +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaBackbone") +class GemmaBackbone(Backbone): + """Gemma core network with hyperparameters. + + This backbone implements the base Transformer network for the Gemma model. + It includes the embedding lookups and transformer layers. This backbone + will output the final hidden states for each token, not generative + predictions over the vocabulary space. For a higher-level object for text + generation, see `keras_nlp.models.GemmaCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Gemma model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value user for every layer norm + in the transformer model. + dropout: float. Dropout probability for the Transformer encoder. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example usage: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Gemma decoder. + model = keras_nlp.models.GemmaBackbone.from_preset("gemma_2b_en") + model(input_data) + + # Randomly initialized Gemma decoder with custom config. + model = keras_nlp.models.GemmaBackbone( + vocabulary_size=50257, + num_layers=12, + num_query_heads=12, + num_key_value_heads=1, + hidden_dim=768, + intermediate_dim=3072, + head_dim=64, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + **kwargs, + ): + if not config.keras_3(): + raise ValueError( + "`GemmaBackbone` requires Keras 3. Run `pip install -U keras` " + "upgrade your Keras version, or see https://keras.io/getting_started/ " + "for more info on Keras versions and installation." + ) + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=True, + embeddings_initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + seed=None, + ), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = GemmaDecoderBlock( + intermediate_dim=intermediate_dim, + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + head_dim=head_dim, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=dtype, + name=f"decoder_block_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = RMSNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_normalization", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="float32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="float32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @staticmethod + def get_layout_map(device_mesh, model_parallel_dim_name="model"): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the gemma + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Sample usage: + ``` + # Feel free to change the mesh shape to balance data and model parallel + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), axis_names=('batch', 'model'), + devices=keras.distribution.list_devices()) + layout_map = GemmaBackbone.get_layout_map( + mesh, model_parallel_dim_name="model") + + distribution = keras.distribution.ModelParallel( + mesh, layout_map, batch_dim_name='batch') + with distribution.scope(): + gemma_model = keras_nlp.models.GemmaCausalLM.from_preset() + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + of all the model weights. + """ + # The weight path and shape of the Gemma backbone is like below (for 2G) + # token_embedding/embeddings, (256128, 2048), 524550144 + # repeat block for decoder + # ... + # decoder_block_17/pre_attention_norm/scale, (2048,), 2048 + # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304 + # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048 + # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432 + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + model_dim = model_parallel_dim_name + # The sharding is partition for the hidden_dim of the model. + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( + None, + model_dim, + None, + ) + layout_map["decoder_block.*attention_output.*kernel"] = ( + None, + None, + model_dim, + ) + layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) + layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + + return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py new file mode 100644 index 0000000000..c66d318fd5 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -0,0 +1,128 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 256128, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 4, + "hidden_dim": 128, + "intermediate_dim": 256, + "head_dim": 128, + "layer_norm_epsilon": 1e-6, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 128), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaBackbone, + preset="gemma_2b_en", + input_data={ + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + }, + expected_output_shape=(1, 5, 2048), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [1.073359, 0.262374, 0.170238, 0.605402, 2.336161] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaBackbone.presets: + self.run_preset_test( + cls=GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = GemmaBackbone(**self.init_kwargs) + self.assertEqual(model.count_params(), 33407616) + self.assertEqual(len(model.layers), 6) + + def test_distribution(self): + if keras.backend.backend() != "jax": + return + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + return + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = GemmaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + with distribution.scope(): + model = GemmaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + if "attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, "model", None) + ) + if "attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), (None, None, "model") + ) + if "ffw_gating/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_gating_2/kernel" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + if "ffw_linearl" in w.path: + self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py new file mode 100644 index 0000000000..45c7c6abe0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -0,0 +1,441 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLM") +class GemmaCausalLM(GenerativeTask): + """An end-to-end Gemma model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Gemma model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.GemmaBackbone` instance. + preprocessor: A `keras_nlp.models.GemmaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + gemma_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.compile(sampler="top_k") + gemma_lm.generate("I want to say", max_length=30) + + gemma_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_nlp.models.GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.GemmaBackbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + max_sequence_length=128, + ) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def backbone_cls(cls): + return GemmaBackbone + + @classproperty + def preprocessor_cls(cls): + return GemmaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `GemmaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs in the + whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GemmaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the GemmaBackbone and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en" + ) + generations = gemma_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gemma_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gemma_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + x = token_embeddings * ops.cast( + ops.sqrt(self.backbone.hidden_dim), dtype=self.compute_dtype + ) + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py new file mode 100644 index 0000000000..20c66edff3 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -0,0 +1,173 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLMPreprocessor") +class GemmaCausalLMPreprocessor(GemmaPreprocessor): + """Gemma Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.GemmaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.GemmaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Apply tokenization to a `tf.data.Dataset`. + features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`GemmaCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Covert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Covert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + mask = ops.convert_to_numpy(padding_mask) + # Also strip any special tokens during detokenization (e.g. the start + # and end markers). In the future we could make this configurable. + mask = mask & (token_ids != self.tokenizer.start_token_id) + mask = mask & (token_ids != self.tokenizer.pad_token_id) + mask = mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..121621da85 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py @@ -0,0 +1,92 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = GemmaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[9, 5, 7, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 5, 7, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 4, 9, 5, 7, 2, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=GemmaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py new file mode 100644 index 0000000000..0e1d7a14f8 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -0,0 +1,245 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import keras +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = GemmaCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=4, + intermediate_dim=8, + head_dim=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the quick brown fox"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 11), + ) + + def test_generate(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + original_floatx = keras.config.floatx() + keras.config.set_floatx("float16") + try: + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_early_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLM.presets: + self.run_preset_test( + cls=GemmaCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = keras.ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py new file mode 100644 index 0000000000..0a91655fc4 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_decoder_block.py @@ -0,0 +1,189 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.gemma.gemma_attention import CachedGemmaAttention +from keras_nlp.models.gemma.rms_normalization import RMSNormalization + + +class GemmaDecoderBlock(keras.layers.Layer): + def __init__( + self, + hidden_dim, + intermediate_dim, + head_dim, + num_query_heads, + num_key_value_heads, + layer_norm_epsilon=1e-6, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + self.pre_attention_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_norm", + ) + + self.attention = CachedGemmaAttention( + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=self.dtype_policy, + name="attention", + ) + + if self.dropout > 0: + self.attention_dropout = keras.layers.Dropout(rate=dropout) + self.feedforward_dropout = keras.layers.Dropout(rate=dropout) + + self.pre_ffw_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_ffw_norm", + ) + + self.gating_ffw = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating", + ) + + self.gating_ffw_2 = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating_2", + ) + + self.ffw_linear = keras.layers.EinsumDense( + equation="btf,fd->btd", + output_shape=(None, self.hidden_dim), + dtype=self.dtype_policy, + name="ffw_linear", + ) + + def build(self, input_shape): + self.pre_attention_norm.build(input_shape) + self.attention.build(input_shape) + + shape = input_shape + self.pre_ffw_norm.build(shape) + self.gating_ffw.build(shape) + self.gating_ffw_2.build(shape) + + shape = self.gating_ffw.compute_output_shape(shape) + self.ffw_linear.build(shape) + self.built = True + + def compute_output_shape(self, input_shape): + # Isometric + return input_shape + + def _compute_attention_mask( + self, x, padding_mask, cache, cache_update_index + ): + decoder_mask = merge_padding_and_attention_mask( + inputs=x, padding_mask=padding_mask, attention_mask=None + ) + batch_size = ops.shape(x)[0] + input_length = output_length = ops.shape(x)[1] + if cache is not None: + input_length = ops.shape(cache)[2] + + causal_mask = compute_causal_mask( + batch_size=batch_size, + input_length=input_length, + output_length=output_length, + cache_index=cache_update_index, + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def call( + self, + x, + padding_mask=None, + cache=None, + cache_update_index=0, + ): + normalized_x = self.pre_attention_norm(x) + attention_mask = self._compute_attention_mask( + normalized_x, padding_mask, cache, cache_update_index + ) + if cache is not None: + attention, new_cache = self.attention( + normalized_x, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + ) + else: + attention = self.attention( + normalized_x, + attention_mask=attention_mask, + ) + + if self.dropout: + attention = self.attention_dropout(attention) + + attention_x = x + attention + normalized_x = self.pre_ffw_norm(attention_x) + + x1 = self.gating_ffw(normalized_x) + x2 = self.gating_ffw_2(normalized_x) + x = keras.activations.gelu(x1, approximate=True) * x2 + x = self.ffw_linear(x) + + x = x + attention_x + + if cache is not None: + return x, new_cache + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/models/gemma/gemma_lora_test.py b/keras_nlp/models/gemma/gemma_lora_test.py new file mode 100644 index 0000000000..1cbbdfa67f --- /dev/null +++ b/keras_nlp/models/gemma/gemma_lora_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaLoraTest(TestCase): + def setUp(self): + self._init_kwargs = { + "vocabulary_size": 50, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 2, + "hidden_dim": 32, + "intermediate_dim": 16, + "head_dim": 16, + "layer_norm_epsilon": 1e-6, + } + + def test_lora_fine_tuning(self): + # Set up backbone and preprocessor. + backbone = GemmaBackbone(**self._init_kwargs) + backbone.enable_lora(4) + # 4 layers, 2 weights per layer + self.assertLen(backbone.trainable_weights, 4 * 2) + self.assertLen(backbone.non_trainable_weights, 20) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + + # Test fine-tuning + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + # Test saving and reloading. + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + backbone.save_weights(temp_filepath) + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(temp_filepath) + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + def test_lora_saving_and_reloading(self): + backbone = GemmaBackbone(**self._init_kwargs) + initial_model_filepath = os.path.join( + self.get_temp_dir(), "base.weights.h5" + ) + backbone.save_weights(initial_model_filepath) + + backbone.enable_lora(4) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + lora_filepath = os.path.join(self.get_temp_dir(), "lora_model.lora.h5") + backbone.save_lora_weights(lora_filepath) + + # New backbone with same initial weights + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(initial_model_filepath) + new_backbone.enable_lora(4) + new_backbone.load_lora_weights(lora_filepath) + + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + # Test exceptions + backbone = GemmaBackbone(**self._init_kwargs) + with self.assertRaisesRegex(ValueError, "no lora-enabled layers"): + backbone.save_lora_weights(lora_filepath) + backbone.enable_lora(5) + with self.assertRaisesRegex(ValueError, "ranks must match"): + backbone.load_lora_weights(lora_filepath) + with self.assertRaisesRegex(ValueError, "filename must end in"): + backbone.save_lora_weights("bad_filepath") diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py new file mode 100644 index 0000000000..8fc3beb48c --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -0,0 +1,199 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaPreprocessor") +class GemmaPreprocessor(Preprocessor): + """Gemma preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.GemmaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `GemmaPreprocessor` expects the input to have only one segment, as Gemma is + mainly used for generation tasks. For tasks having multi-segment inputs + please combine inputs into a single string input before passing to the + preprocessor layer. + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + preprocessor = keras_nlp.models.GemmaPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Apply preprocessing to a `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=8192, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "GemmaPreprocessor requires each input to contain only " + f"one segment, but received {len(x)}. If you are using Gemma " + "for a multi-segment classification task, please combine your " + "input into a single string." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def tokenizer_cls(cls): + return GemmaTokenizer diff --git a/keras_nlp/models/gemma/gemma_preprocessor_test.py b/keras_nlp/models/gemma/gemma_preprocessor_test.py new file mode 100644 index 0000000000..f54a509979 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor_test.py @@ -0,0 +1,74 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = GemmaPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = GemmaPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 2]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaPreprocessor.presets: + self.run_preset_test( + cls=GemmaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_presets.py b/keras_nlp/models/gemma/gemma_presets.py new file mode 100644 index 0000000000..f63fef17fa --- /dev/null +++ b/keras_nlp/models/gemma/gemma_presets.py @@ -0,0 +1,66 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "gemma_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/1", + }, + "gemma_instruct_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/1", + }, + "gemma_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/1", + }, + "gemma_instruct_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py new file mode 100644 index 0000000000..6a4bb76ea0 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -0,0 +1,108 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.GemmaTokenizer") +class GemmaTokenizer(SentencePieceTokenizer): + """Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.GemmaTokenizer.from_preset("gemma_2b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "" + self.pad_token = "" + + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.end_token, self.pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = self.token_to_id(self.pad_token) + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gemma/gemma_tokenizer_test.py b/keras_nlp/models/gemma/gemma_tokenizer_test.py new file mode 100644 index 0000000000..1c617dd937 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer_test.py @@ -0,0 +1,67 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_gemma_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[4, 9, 5, 7], [4, 6, 8, 10]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + GemmaTokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaTokenizer, + preset="gemma_2b_en", + input_data=["The quick brown fox."], + expected_output=[[651, 4320, 8426, 25341, 235265]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaTokenizer.presets: + self.run_preset_test( + cls=GemmaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/rms_normalization.py b/keras_nlp/models/gemma/rms_normalization.py new file mode 100644 index 0000000000..ce9bdaf880 --- /dev/null +++ b/keras_nlp/models/gemma/rms_normalization.py @@ -0,0 +1,40 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class RMSNormalization(keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(input_shape[-1],), + initializer="zeros", + ) + self.built = True + + def call(self, x): + # Always compute normalization in float32. + x = ops.cast(x, "float32") + scale = ops.cast(self.scale, "float32") + var = ops.mean(ops.square(x), axis=-1, keepdims=True) + normed_inputs = x * ops.reciprocal(ops.sqrt(var + 1e-06)) + normed_inputs = normed_inputs * (1 + scale) + return ops.cast(normed_inputs, self.compute_dtype) diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 9a461926e4..598217d964 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -101,12 +101,7 @@ def compiled_generate_function(inputs, end_token_id, state): for v in self._sampler.variables: new_v = scope.get_current_value(v) sampler_variables.append(new_v if new_v is not None else v) - state = ( - sampler_variables, - trainable_variables, - non_trainable_variables, - ) - return outputs, state + return outputs, sampler_variables def wrapped_generate_function( inputs, @@ -115,18 +110,20 @@ def wrapped_generate_function( # Create an explicit tuple of all variable state. state = ( self._sampler.variables, - self.trainable_variables, - self.non_trainable_variables, + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], ) inputs = tree.map_structure(ops.convert_to_tensor, inputs) - outputs, state = compiled_generate_function( + outputs, sampler_variables = compiled_generate_function( inputs, end_token_id, state, ) # Only assign the sampler variables (random seeds), as other # model variables should never be updated in generation. - for ref_v, v in zip(self._sampler.variables, state[0]): + for ref_v, v in zip(self._sampler.variables, sampler_variables): ref_v.assign(v) return outputs diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index e154c88bb1..b0bd529da4 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -298,6 +298,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index bef32017ea..b1df4a6706 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -188,6 +188,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 9715bc6b75..2ca8ee07b4 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -294,6 +294,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 697af20899..ddff7a164a 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -131,8 +131,7 @@ def call( shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = ops.cast(attention_mask, "int32") - attention_mask = causal_mask & attention_mask + attention_mask = causal_mask & ops.cast(attention_mask, "bool") x = hidden_states # Intermediate result. diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 87948439a8..297ec203de 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -72,6 +72,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): batch_size, max_length = ops.shape(prompt)[0], ops.shape(prompt)[1] index = ops.cast(index, "int32") @@ -167,6 +168,7 @@ def gather_beams(x): body=body, loop_vars=(prompt, cache, index, log_probs), maximum_iterations=(max_length - index), + model=model, ) all_prompts = unflatten_beams(prompt) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 8b3d52d9a5..4259167c8c 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -73,6 +73,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): if hidden_states is None: raise ValueError( @@ -209,6 +210,7 @@ def gather_best_token(beams): body=body, loop_vars=(prompt, cache, index, logits, hidden_states), maximum_iterations=(max_length - index), + model=model, ) return prompt diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 2101c9277d..3ecf16ac28 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -92,6 +92,7 @@ def __call__( mask=None, end_token_id=None, hidden_states=None, + model=None, ): max_length = ops.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. @@ -133,6 +134,7 @@ def body(prompt, cache, index): body, loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), + model=model, ) return prompt @@ -147,32 +149,68 @@ def compute_probabilities(self, logits): probs = keras.activations.softmax(logits / self.temperature) return ops.cast(probs, logits_dtype) - def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None): + def run_loop( + self, cond, body, model=None, loop_vars=None, maximum_iterations=None + ): """Run ops.while_loops with a `StatelessScope` if necessary.""" if config.backend() == "jax": + import itertools + + if model: + model_trainable_variables = model.trainable_variables + model_non_trainable_variables = model.non_trainable_variables + else: + model_trainable_variables = [] + model_non_trainable_variables = [] - def stateless_cond(variables, *loop_vars): + def stateless_cond(state, *loop_vars): return cond(*loop_vars) - def stateless_body(variables, *loop_vars): - mapping = zip(self.variables, variables) + def stateless_body(state, *loop_vars): + ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.variables, sampler_variables), + zip(model_trainable_variables, trainable_variables), + zip(model_non_trainable_variables, non_trainable_variables), + ) with keras.StatelessScope(state_mapping=mapping) as scope: loop_vars = body(*loop_vars) - variables = [] + sampler_variables = [] for v in self.variables: new_v = scope.get_current_value(v) - variables.append(new_v if new_v is not None else v) - return variables, *loop_vars + sampler_variables.append(new_v if new_v is not None else v) + state = ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) + return state, *loop_vars variables = [ops.convert_to_tensor(v) for v in self.variables] - variables, *loop_vars = ops.while_loop( + trainable_variables = [ + ops.convert_to_tensor(v) for v in model_trainable_variables + ] + non_trainable_variables = [ + ops.convert_to_tensor(v) for v in model_non_trainable_variables + ] + state = ( + variables, + trainable_variables, + non_trainable_variables, + ) + state, *loop_vars = ops.while_loop( cond=stateless_cond, body=stateless_body, - loop_vars=(variables, *loop_vars), + loop_vars=(state, *loop_vars), maximum_iterations=maximum_iterations, ) - [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)] + for ref_v, v in zip(self.variables, state[0]): + ref_v.assign(v) else: loop_vars = ops.while_loop( cond=cond, diff --git a/keras_nlp/tests/test_data/gemma_test_vocab.spm b/keras_nlp/tests/test_data/gemma_test_vocab.spm new file mode 100644 index 0000000000000000000000000000000000000000..a049c032c296607c2c96febc6e7d2934884b85c5 GIT binary patch literal 237805 zcmZU+3s{ubdg%Yo3>OVI-G)t!B3d`cSYwSfR;_@s#;w*+<91VH4aZnx)e<*hDhM9>q)>vn1JceU9hNDwrja57G7>@B6{+{2$B;BWf zp6B8HvK*6S~Kg`>%8N{Xg!V|H+;ItDNtzYafey zEb3q0HSFKLGx3REcljSncx(dBzy5V@!LJhENsRpHdyk#lf4zVABacTn{7vo~kDRfO zMRLQQeZL<%FMrRDoZt8Ob#DHzULE)N$um#rX2yJ^)qSGXe5%!cu9bh&My9C@{V$}- z`2Rtw3>WJj73@jVNAZtM_|@M1+kW*oIXixG-ybjCS`nivujn-p_)aUmks?R_0R2n! zbPej1v0EpaQnH}Yi0+7pwn8MbRwSZYB)UN)>X^vHHjxQUA~AN6$1EaoeIl`5kpzfG zpL|+m(n*miL6M)F6-gWvdHjM%(y+)+heW1cqcuY!({77Ac|#<5RAl-ck(4o!8TUoD zFP$hk#y?_a#DYX6Q$s&K1HbL~1-_Rg(nKT!XkCHuFDPFA=!(y%I!L(N=Wm!5gETdg^YUN4t z_?9L@Dh72@ep@G16Gg1t`!xQC(f=4(u}~+~=z-5gwyHSc)RUs$K;7tf5?35?ri+v$ z>m)EqD?!RtJeTnMbz(L|$vML6yT}7vL^p8FLYrPj4}7JQ;450mBCb~2Z7Xs+?1B*$ zb`SA<5h2BsHF6X8Ui|MOb<;#20}s&M)KTDr2niA{FhL`IxUJ-|jC(>dPp-Z>M@Gum zYX2p6sn(3Ysahi*q!aW*8tR5LKE}eHe^pzL-xT5vf4fU-m>ey|cQo?#6J^?pEb8@| zMy7Gi6w;6p+;NoofXD&zWI*a`H4<2=k>KMRIebzhMfm&4i#bsvr;aAcKOBEtP8*Zt z;}?_U%&jE(NBvKvhV&{4caDDaPgj4+y+4(u)cN^=nexxar^>C-{t5DMrihFCTj(bqWLJ8Wn25KmR-`qVJXMIC#J{+W zwiy($YoaBHyPCdPd`u(fh$9#qNgs(61AeAbjg)&ea*#C7Q4d|@^)P+odGhq{&MK?V8RM;xl}AI7gZPUIqCs@92g-k|=tc9*!# zbLls!#FwR!O8i%U*nn+ z+3Lbx)qSMvK?c4f{73XR^mEF;R40#97Q;22+~FS64W0b6s!NMbE^>VgKRfQrD(o?A zc~~chz102B$lFX%_WTUC%UEinFX*$V+btsZxUYor9i|`L*Ox=X_&8(t}*5d#uWQXo!rFF%$TH`fi0uoC7u@Cdg{TDEYgKuK3}As zGPKmjG({2j@*xYE?Vm6CrF@`zGI7)GTPFVl|X+_kXYKVSQ9HHxb?K1fw7~7 zd{xoLhi{A2qK{C9O7c~{H$f`6_b1%bqwY%;c~P|k_4Sv?X_R*g%!Apm2x34#momXa z?wLX!jR%<9kcLe9P7~t~{^l|Ax0H5ao}I0(Q6IrKwe+PpnZxy$iNE!1f|!NKJejvJH?+|w!Jwcl zda&nd=(4#TCMCO z|F@|7gTz&dZl?}Z86ObFfS>sSZH@my2$V;NA)mOxsIIw)UzqoQsTCdl_bvR4*vZ?- z7~E4J@M*jRo~Aq!DZ!T`B}`iX3j*)OGtY52W7zu;2BT%+Bg z1d8Dj_ZpFE4X}pzD=f?*L)72dXjy@KGi*@lKC6*$39BEbzHW&8T#c7wX~U)}#dk@{@H?xf#n$isGY`ksEB$P&UOQ&$~?QEQV0=!;<+ z{w6rXy66|AS)5G$9bo>R8qHY9ct(FR&87aTqh*x*y|stFG!!M}$cjSJn7Et^k$;b$yl&q02QOSuo+q6|CIHom->1;TK)xp1MMGpB3Xhn zwW9V-|BhdIqv_dR$3VYlE{}z zLyf0_S4j8ODEWrCQs5F{tLJLSJog4uxSVOVn|GpM$#~3 zr^sPLikMcWP=4BNH*-T~iWn>@@)T*6v~iuVhcC>KdALok6y_$b9|P*C#EZKWy&^^K z5uaLT{gyIUaQz`}HK$Ibe9?^E`Z(HajD0-1c_>B92WE(wwoBl;eQ}DkY@Z=5r1v5B z8wuNvG~@qAq-!UBV;*fiD@D4}XGj-mzQ^_JlqZq<`ytRxdv{KkP55_Fo-N2m{8K1@ z5-fl@pzovnP&|uzWX)!zUY6irR)HfQ_3QZ?6y zqojJ2@Dd@xTZ{`kDGzn^7u-`oxI*GvtL}e-x?F}$;#UNLmx(J?WFhUT=3lkuH!su4 zv$&T)3t?Ke5HDv0<=B#WqfQ+7y+T|id*h^~L8RG^Ke8*1aLkXNbFIaNKLk{oj30)< zX!%E_bN#;|*HHa`gS1zOyoS^V2}3&uo>%kx3^9&sdC+9 zTC^09X9M%M9?}U{e4q0V&I8mMSkD=P>I-F|C^^KvM*3nE;ooA;uRO;3;Q;ep1LG`p zSG9$;!`^6lNZbKEN&m8h!gftP4{s{F&8wcoH!LK4Dnrlb!H?ND9 z8e{@)vo~6jkkiPqLZrx7(#9=?^pD?V=V7+4BL*V(;Z#p)e zO&hZoSN7k_eKs(mmt)%%E{(k9)kxLcL>VBynX3KYz=RhAEOO(IknC|EJ z8AmDiQ{4Y_g#06Z4#w}+LdGY`(Y{9`S(Krh^@3|I{bQm=B8mIoqD24ia3WG8&UB3o z5Qle6#5crVfH-;udD2o+wHt6I;g-+arQTu_{@Xv)r()g12uESEURg_|D zUq*|+NhcFX?|-QBP^rN$W9hrGQn8Y`CX0O!ec~-Bb+N7oQ!QtJ2AxEZ_h@(w-g}yJ zTiT|~6FIswPTmo%9H*TdEE;*gf^iz&JEoCi$iPhI zUA6yz4*4>YN(*up8yMgmO(mpZPa2mXxy0AcJ*AN%1?bzrS4!WQ%UM8bg5=;Xo|Pb_ zb49vIrosxo5A=s%9E_FH8}V|K#xV`YiwD1BJV!W%JO@3nigf+R z#{HayA*-Jt;Unyo4Dqc0>Yis#p&r_Lg~K82{5ESQ!fzn>O=Rho1SzTD zOoB!=td3`26E9)>E$n?tw^J8iKQ6Oba`b)7nH@$QWY0LhpE^k&!QcOO+zEk_O;BNZRQR0|1k@M7_O5mk<_WAUS$Kx31AESS8j>X0( za5V{VFO8AlHxtJD*H-eN-^iQ_*JqNix9LB$RTw?+Iq}?LO|JX}f6@%{Tq($zLhzTw zznipQV|~f>SFyFD#~52aS}uC*-7rL7Ag{q!3HJu!-o(#NUxw11v|$7KM*QGth;fOv zz8V{9PP32a{+q;EKI>6`ysOe@4W^#a1iOWPPkvc5G*UOE6FHAYJ|K(^Mj!^Jz%)pJ zXmCjs2;*PFt$}YgX5jMK<7PRXX*%SzXWK`rb$B*Q}=939D>t z1=s2bH*SB#+xY$c>+x^-{{Bts{u|Wa9r`ZuIh4Q3KWydNDs_D?*Qrz0S4u1znWFMf zc~$yJ=rJ&daDh*#k1w=xf-+xJ=`JH*a2b6BF2NP$PZ?BsuA|RW@t?*$NSmmA8tWiE zx&gmZ?kzdXI7a*~RgcPF<&AU~;isd$OzM98RG*H)U2}}`_c5ocHr+|sbnJ;RHQAgI zD?h>oU#Fk%kCdaDG^yiS5d>L-R;qBI+NmB{0?SEffkOGGKFU+zTgv${Hi+%@k}m`E zy8a06kEtWq8+dvQ6oMTfh`%U{2RF zN2vNM4kD?eVl~Im&-7b>xn8w@C2@@BpZ!HK+#aZfDqGe@cOjbi#r$+Z=dK-eTmfhF8CmpmRsu2OFDy$1Qm z_~|L9VI^h7oy&C-G6$(o#eeP!sj22#p~%yOJ&OGYKBp}?_t{67nk@3B`YHMI(55B$ zsW49B7$5hW9_3y2V}zK(E}RXpDO&y zsgIS&LQws;9;wE+CS+L%8>al3>KgfIQ4$+1r_IXe(ythQHYq=(%BLN@bUu5M56I7J z0cA>9Z>VoPGo?yc3x{CZSuodRK|U3pB9f_neiF#T-4(~;IE!dnbaC^saA4u2O_jmNuH{>UVyu5D*0rb_zn6I_U`4x z`5w=)k3FrETKwKuw)BFQ`XK)A>+d}k&-eN5Q@Zjez5#U8DDNeFrjs`2(-GW5l$qC# z^p}C{o&M-z>T{vU&~B~#PuxSqeUx&4iC#O4aRz(%%Sr6je#+J;N%Ws1RUGN)uFrMS z&E8qXG4zsFUcpb@JM>H9ApIq{ukPnugT41aB6Wkjs@KX9{0$$_?zgn^;U}D1zsz;Q zH>&VPtvde@Gx~?$YUS!ToZAE_`=>mE#LszEC(Yc`LLS~>o%y%;`A&-DA?uKDAd8?3 zl>Gz-R)~)Udo}KXH)6$$-v3Rkczzixy~y{JKlJp*%11~S51dYXz;ntEVx^DW|V2)=5vgPTnD(F66On_WGHOLD<;) zdpH9?9*4$Kjr^9qo$0t%e0Gs{Eh2SZ+W%S3-3~92K<`rKx24ips}VQ-s)4Y_M>%up z(n!Y}oU@Vc$8ap2@f%rB-+6zc$bUm0f~#w_>K!_{yY35kpGS}y`qy* zq=WYIVIOtW`vk6CwP__5S@tb^pj(u$TPu&_b`W33+q6j<;c2f)_|=i-)z2A6j5>Lj zIQ$YT4SPk7ZznmfchR4mFKVTh@Y7YCjE7gL7bpE*&96`4_rvw4(PzTXK&4rBnDntB zl~yfw|1LK74A;89qF#wp<#(wH`-+3)?1aeoijIG>X={GNli=!ZSe zVsF&v^SIZ7lj|>hFSjF)Q)dmZ3H`4?JwxA(+y}n_12nFqKO?K)HFy^q;0g@F*YL0K zA(XHdGThh53G_br2+qJigBA4MkNQn$0`*J2?%~HfQgVd0et;el`Z|4J~j=aGbD-p8!PkA;E+h7Omf(Gg;e0aAUZ_-E(?p*lYG4|-l0yqdo&{Dw~1!;n| zb*#~mwl>yy$Qq~vXDWRc*$8InVZY~I#`)1g=1$yJ=)2E4U^46VshnTncB*jT0evTP zPdW1udI|HC@t2I1FGt8J+(w><>xY?7=d(Ux9P-h(&f$L^OxRM{@d$A;_g=((34D~- zH$76^tO;C0I=PIWhdH?a4$l@}x~MT;<=l`pijQkv;vE1V9p8fOHpl7Y3gL}(#1Z6m zFms=cdE5>igmFy8{oypN{fj5H?WAi><^C6@if`?ca+B~rPJp|wKB;wmmaKI@GeaAu z&BFMHKH$7_>66+K#;ktfJoMG+THiO*wSJkTt&X^;dHc)%t-1O<&sQg8XzvpK9z1}D z;ImB^9oz9}h=Bx{0!c6p44r$aM9$}?|3ZpCFObri?~CzmpO_|`$435DDqdTnDZl!( zrt*nHnzufmqB(NaqB;EaUumkoyQ%s6m#Q?KgGth**)1oI_1NgK3hMu~NGl-Q7V_Us);qQp_kdA=Sye?liK zNM{wSf%UKfHbF1v_r1mVi4sO12`mfeZisJyGHzUpsK`0$&^P!T@-|Ox!v6=fXZHfaWdK z0saTkts2TUKT3+wOTckES~|cEHfSY{w)q-q$BwM&oF^gexI2&z?ApXVl~4sWPzS|c z+5q&>+rW9co&IIvol(ZtfqeQ`A-00_-s z#)B553+V(3whl_9tE=z%V7nqg5Dd9ZKERnW9%F6Q2*7`Klm6s{m>cW`5>GCCv=1BwupO(xdujJqa`#G zo?#515G`wnb3JT;O|TiZ!Zz3e;rF6t7cxib<@D(at>mKb1E2m$DL}gMbHPD$&raG7 ziqK2I&ly?E$Ta_00GurkSV}AN~r-kvy&iI3LdKrJvUBvIMrSD}kehend9O9V=3t$m6C$r|E z{gpu)A^HN5n!_Hp3m3d9w2d6;2JAMYK zN2=!)z1VXf^iQOH;m7eA-4C6EJmVeaTo~!xL;W)zyTFa!vz_k`df~Ga8CeRk)8rg+ zpNETZ8Lq$xT!)+BOPeTRWY{=O6y_$a+{In?RfL$>oZUljE~Ne70lHO{jWRw&*Ks`J z7^8229c(Mf6Z7OHu0?Y#1`@zs!FY5eN~WND7$^F1C!tRR|0v@TZZG4=05S!?*)Rv@ z!2(zWOCW4amgUG0>%tYtk?$gS8I5{NezbmBga#w^A|2jk!x z!mNjJ`Wr3&0^De_S;@JVaAP1T&qQ!EA_b<_{yua9vdmnlMIEc3cx81<{0eKKVXNdI$vIN}M zc&89q39dWj%S*nf%PQP8Fg|Z%n*+7fA>)5*GUZQwv_`QaZIs(iIXcL%gS^&pPdzk( z8CsyYk2QH0=kCbf+w{Lu`ri@yA9A1p`$zh4`;nc**A++ksaNMU`XBA?Lb{=65c{`8 zij_DW-~)A+YzaJ@OJnT13$EIds8_Tefh1kl4NI9kMyQ7se$aCQ4 zIs19!MR0{!Z!?E`4(Q|(?#tj?s1tuWYkzFg3$3Zx-(vb`I`)UO%_r_^=6`U&748{< z>u?intlz`PyKoPBkLg%{VdECk*-O0n#Eb3wkbc~qcXZNqpLZKhG6q0*7h{8$a{w3n z&o-SrAdZKiTY(Kg3?#r52)8~dK3>cY9p&9U+#}POZ{z5{dGueHhM##|Rzo~{hj=leSAwH7UOJA&(?;=PLymvvG?DtirmMIf zeqy53AnTwWM$+P>5gB?ZR?NsU&WcO}oJpfwp_O>rh^PH*yjVF%wQ)ddKS>(k7>ySr z&n+B;b3zYzzz?V34D_B%kUr=K4-9}8eBd8VkWSt^=vo~wCpb`cBD=HW#YOqt&;zX- zNq-CJ6VBQ|`jpR(-hp&@B7NfSK{_MjG%(3ivb_%T^49$0~U75J)&7X~z1@xmJX)`OWn z@CKxXIeQawGc+I2%2wnyuu{%;uxC<6aFE9iXdR^d!(5xsek+wZF`c;)*@4@E+`+xO zAO~_`9~^`tD1qKm#@`CY-v;cfmiF(W{lRyD{Rec?AG)rw{tB`Fy3hD~gYkD5`y6Ea z^)mN=jcpOn$m8)+iTq(-S%qF~=WJh?*Kb8i4emOqheim$5-(zK2w~X27XW$$-nR5z|=iwq;0ylmx=HAQbo+jon=GQCeBj9Jx z;IlJ#!GMeR>i04p^M2b8{m=KZPOcNK>|6TVYka?j9)`Pc4<5kC^AYk889L1V-4Z3b zm5jX*0}0TKpM|)mpj(*>)qW)jeHu8(OUHcXNU)K|6#Sf=Gt5TL0r$FSnTK2eF5buJ z2M;X5y##!VSr;th{9rZfgmm`*#L>na*^acbp0PC$_YuZgyo7@z0K^8osf<7+1Ne}VZIP7E^ta;=*%u0GZuZLB{ir`gNMIS77MWIL2ipgwut>nJ2|dF1f`dDTS8Ho`dV^i$+6aJMlBAalWmfB#X^y+tw&8J@(q zCCJ$@2j;;72+@BQAxEC)-AJT5yVXNjPoG`JdolQ}f@th-4KmbOBxh9mpS>yTaW8=l zP}WO-<9vP-`exV)+h7Omg7C~YB?sxOD3n~}J}7{L;675w{-IEc&^@aUNk5dJo4}9m z18>$L)*^@aR;5rX@vnj!sDtLrLa{(Sy0!F>v_m7h864;xV6QzSHc;b_bK#*M#~=5d zA~Et@qJ^-oS)?5-63>OAN8{*c>Rhgfak+?fLJ{kPB5A!=ByFQb(mqxsR;2BAk=PfL zM*1OfAgwA*a6%7wzz?V34D=2ilD-h>fCmPyQMMbD@z^2pLuck8zRNx&C%_5aqld(G z|B$!|(*tJmbB=h}XR{H0KA$@G5>{eKz#A9Sne zD-Z(-Fn*4Q9N))HX8$`G`=5&a&%*v;V32uvm~jm05Au9)oqER^``?27Lw6?jzYzO} z9`JF#WX>xbe+HrE|K?JyOe6jjuvBB071$+|@m$vw%bb1WkzLyGv%_q{89(8@qd+A4 z8=cI-?LxY5>tr7K0$2q7d)XiCWZh86_=NOZ4 zvR;?^rZ?sN`)^9|_SdAObdOvl?n_WMIbTd*QL-PDj%u=^C3Zo!64Zd1I=a&^#e7&CjH}aH<%22I+(G?MKa7(PNT=*YzR`Y_ z@5j+~&tn&Zw97Ei)|W)X|F+^ z2NOpD9E2h$0TWb0_&DRiC)fpg4LBF`%pK`+@k|NX27RLYEand#xC+(wgV%^KWFrR&4A^TsX1NScOZ2>DdzzIFz0YCJPvi}`pkBaoH zWd9A`b?m?M*ndN3BKzND_PA*cg9M|C{gyAmSg9q>semwuho;@(I zhdKt|Fm*i0`frf(TPQzqIh!ayx(nUiK>4YEb8@`sRx`)KJ!~=t*_=-KApzaGoxD-r zDd!hzf&N(Z7n_x3+g>A3{!XG8bE@Tem!agW~gW%k&lOm)#dw1vaZWX!- zDnXsWdtiY6@1_6yZZQ9b_=b>cHBblj&p1luv}20G9k(21>& zpYLUIzPFn5y_K9VtRsDFs|R{#aXtb4;DG_~f)D(lXN|7j1ytXmsOJx1>cm01PVi0Q zdBi6?S40neN_!$t!5KIQ=b@dvSz-MA{T0ToldKEQvQA(PWFO@Bj0j_1P2WXchGxdA zE65SBWbvL5@+O4gF5H6$@DRdA&U$PS68e-r@^kKaobPv$F^~XLz)3$#LQaDem<{eR z=6{%nuJ$DTV7wY33vm0JXm>At6b7I*lKDOP(H_$}|Iwb>p3eNfnE8D!dsOmK%y@6O z#@GaNh+`3yO~((5+l=hxOC`Rc|LLnxX5q+ z5#~463%hXVfG?kQMav6N2;@u&N#|Hneu}Z|L#%BPdGQS2e-M4GZgk`25eNlW4ng> zZf>LeO*}Ioek<{}pN*0lbl-0F|Kzb_A$dgFPDe>S{?0qRYqa~(`G7l{^&oy`aAnc< z>ll}y1-BJ^>Gb1SwEZ&rz+&RV&k07xNS}VX^q_}0=k_2+gfdL{Bk`l30yW>9L7s#2 za1k!SWiT_RUO`%hv@(Lc4$T*|auXQ_>uGEl?0wjV3maiQt|NVNt~V~ZMNx-!{= zfRnw+iTTX^sjLB(G51eq4S?Q?y|$6hcCcO}?J@F@$oXF~brZ+$(yhVXAOWU85=?^> zm<_$Gv->7e&xH4I-@roZ6~_1fqm2K|2|w=tX`AlV*k&gCe`F7BQT9PH?|1MFmbk*a zGdT~r02V3t$5CQ@ZMuZ`K4A&&kxtH|kZOOhLb;!(PfcZ?g1!dUL+P!jnOjMN^}q() zp3~%QjBg;&H-rByc{@oSkpqM?4<$=9{##)i?11LP8DbetmR;!9NS;meJU$0K7aS9L z{;lB|HPS{n_0F#W)V^>Z*9+hvn0QuK_9pWIdWmvFB~(ET)ImM8T2iFV&NFtTm1piY zq#bt$(visT5)j4=Eno!)IH3nT&>P3||EWCxpT+Zk7+_zc_BK8aYWycTf6e1J0^kHV zSM&UTCC~pedH$cy^M8I%z{xuXe&RU=XW$&T@pHj>bdQVo7vLiLCGhiX-&ZYZffE}#%9&US`fQj3%C1!Zn}_bfzWV8-3(yyVAGdFaGXm^w0J#Le<*)+GH|Wbq z%YF9W=xfl|!v<&{#6B4hY>AJ?i1G73-1_a*JE;D*iEvxNC=sk}Sd*ideH0<47uk=a z?}8l2g?&%}t#<6+i~YM){M3IU^`A`rBbEKD=N|_NR|F+sf=Z}@8t5f#-+XK=75fi9 zvMKK{?S_5%v9r!X>OYJ6&!+y7-I>&XI`zMp`scpT6Emcact)1aka}bzn4tx%-~cE1 zUWt$%q#>Vq0>W2Q#`km{{QNM!{&yFLklL z!QGq5zBP+^KAU+SIe^=n&px)0c|MPQ>uL6_LH4mo=P>(Lqzktj*>eED8_}|Wcou>2 zBiikA+9pI8{FmT2UQP|?@8#GX=Q)1L?ZfRws(M*Y7_*(+`THFyr<;#kzepbNfQir)FQoaA zz3c%DyqApbVPDV>m(VYRAKeFD7=YH5*gyGbCtvDa8QVJYRmj?dayt6d`Xf%RaNh`A z2jgq8auXSbyU>e0_hHxlpx*Txz}~%y)c<7aA3D!+{s|`->zvSy?gBUAdh)UVH`x2X z$3BeT7YgzF0{4jX0XziVpJP)H0|_t%Zc+X)-(-Zp!ZvWLxmS&W=2G6-M_MYhl7jzi zXm+tLd5-rEK8_VDGOWa#{{=Jds-|U08;b+FZ1Gx*D*&pX1bHRe% zK9_y{;&|DIy8t>!&v79^Y^jtpk>}rD%71}yV4chOzl`x8I?x^P`}^mE++PGG;M8ig!ad_)H!?S7VHmB+*Y=R{f%LN$RFOnsU!~d%uc=YTZQgi7cVu) zvZahSe2?f~NL=K@JBzrH9^zK-`n1qL>Ii3^8ZY(8Mx{dwG~eg_-#e@)(5={fJ2=pt z;Mh*t4p7EC%8Km4FZ=~Ze@y%N={d$6S>ce^nu^S{`WTfU&5RiW3Phj#_hVm{ug?{d58Tk z>t=(CaSSdJr~3wBLd3zhah`?5J(qVB@xKiIMEr4kXOR~23VvqNDdU{R1kHBdS;y}> zSS;VqtNE;JaEGB|FXh@uIg$3A%)g5%|1!cEC_mW14jtei-n-m$4<5in(7izag&0VH zUdlWE{GYV`f9+5E{`mY~l@-RH|9dGvarcl{=S0><#FGTmAO+kS%8&nSbWbAXhdJo; zz)u_b@+f~K<%b3MnKOBZ40#XxYF^D=5q&vW@NcKAY7Mpm_bM1a*GZ?GD=8-!*dwcN ze2nFMGs~X2tTRnad=KQ~JD~N1-vFCnGi-%zumg5M4vc)w83{7<40}~clYQs~a1e^1 z1WKPx=iO=64-oF;%ol?eXCRdwwT$qr1SJ&G{c=L+@VB|2T{62M_0i1Nf`6`|=W?bWBEV?NLGOWdbmBu)7(kft6`{Dh<%jj2N z1o~NLcwhj$-~&Hg=h{sO!(F%s58xq$zh(ciJ554qyd#rIJe2j8n1vH-Ul(>g%C<6wi<B0mY-gC~ zX?*i#KRsDGhInSUY?8F@oFHwQSZN>PdC$Z+d28VxNoCU_sUq$g`1?Ne$$uh;lmA$b zWd1SFo&SrJw9TXcFh;x*A!SVq#FY3)Vx0PiQd<3o^8WS=siz&vK4BjFLMQda5qdd8 zj#pqqk|FQx`4g!*utXa14?mG1?`AHSy1o@Wvt2H=*?%f+4e8QakuKK0bZNJz(Ein;Ld@LT0#m;C3U;v(m&*Jg49coP*~1X<~u%=+>2K%#CSs5&aT4`qDVJ zOye9gjjj=Tx#JIFBdE_hBe27$MaG04UFByFmmBi`fM19%AG zml=!IT*z6H`ra~pIOEa$|6PXY{zBP*hQuJ%`Co!^&t%;{s*@?`ChDo|wFpT&bSrzDc1S^=4USayV2lrT#s^!vPU=WE8vi*k4;H{8SOUwTmAR%(%_B#cM@pGD znD6Y!AHEOl9Akc8$^IIg&<(D2?628-_dxF~_UF)_%Km&A`}4W%&lB07C$m2vmn}Lil2Eq@4+E=Kr{L-WDZ!!dwVwd z4)HEO?tRdK?qF|kyTkKoXzikWUcve)iaDDyx2I6dnm18czI)2!)UWk1Dl zDw>3EREVnxdM~j48)E$zV*LjL)vW)Rw|up%|08Q{)5x5x&{>ApLL(&cHc14;SGQTn6>-yZQz-%y;p{ z2l!tE_>F+^6V|u%@0;i@q#GWs|Ji>&A4?gSe=YQVFE-G`HSpKc7bjECw6A)X(XR4q zp&jsRUA;`&@|H=M`0m0zcmNMUw}CzjG2pvRKgy&pT%b=}qpxtr<2gycTg~rn*8g_a{}$?}kNtle`~L>^|FxWJuonpNO=aja&H#81 zaD=td9O9e@3t$l#`CY{&$mOsCR>2ww^V@~tFL{4vomST4-T<3mD>UP0fo& zKrZwSvi~1u|Bv+CVE=!g{Xh87{m?ng{(mLEak!ejes+R%BVC#7|B*e!>r7>DPCNy0 z5Q@N!pKCGSAE0|miF+#hAM{G_R}i;>G>`+;2~vfhc{_V%WF0i4*CQLjl1I6+Dc4E< z*8uJo=s2Aqjv)Jg@@q%7USR(}M7U+_m%%m{yMPXCSbhI%DL%N5tf7gGUzxj#d6gd0nGau16u<0|n&w=^| z_=oR*J>;!_7JX|feT@A14-kGbb39|ndBR@DE-oT3LGumv|BOYK(XH6J`mN9_=p*1* zNZGbf#*LJ9yGE|#cN4;J7u31iJ>&>``3Fe`Ikyy*` z#Slilub7K8ALDngkOk0DNIzOhUGJpMX|wix`V@W#xmE-vV1i1hg7BC8<_oe8>S1J3 zyfh+1{2rkhsh|J-c}pT??BtzC6(9bl$N5H!`QL%A%H|?HC%WhRy6I6m_?x(ozUv*P zuAr5t)$;UZjuUgq>Z=s(N23j;%p!{7@t zE}!Q7oi#w$0oDiWnCqdNu&z|bW!3~eU}l`YOgvX$1g=AKE#KcUUf)EwdO1UP@!S9C zcfrB@-NCxS-oP9V_waiF4?(w){t4>0zF78&9%67Oz!Wfkz`XMz`z7@8?-Q<4{?n9y zkn)4|8s#6O{K`f^*~&D+roe2N1M^@3EP^G_i#_*Yxs*}N;CcOKByHzobB0eutrm(lkYl13@x09fYp z{09H6pzMAdatD|rSqCBQIruH1uB<`&BIM%U2gZ+?kL4RFK>z*vJBYgo zO2Eh7xDr_fHBbj-SNZOr_lClJM_-S-5zNp6R&an5dcXsrk2n`Xjy%o(s6w8Bb8sFm z!X>y2S6~EMImd5+4VeP7VGhiL1+WN~K$vyD z@fOejU#E}Y9(j>7Pvk0C1M5LOGu?pP1e;+isBb4M_-{l1;rs0!=)TtJ@}IxoZe7Ut z=g*v@a^N<3BV8Z( zp);BCU!(l^JCWUlb0OWhdyvjye&c|6&cS)O2tWLOg7GEBlv%`0Ieky@?3QbO+`eU` zv5>ToKlDFyE$<@G|17QXw~puH?-&To;Q1M8*{OpIk}^a^#OD;*LSf8J8_@+eU$nA81p;0E12I6?8A_K z(7%`U*-q9Rg{(c0KHPq!dbXlJ#ygj!rQT6>B749CerS!0l(s_ZFOzq~nwXz0%+IV* z#-C3xS9g}OkE>>1*T%k%HB2|xU0in~dk*j$DeKq=LjM-d=V1Wd3qSlmk(z%`kxtoD zJkNWYwKuxEfjHnCx~GDBiT6DEMevVtFO2{G!4aNI;Afu78U}dZL z5hEhTigdbl_vzC|JAL}s={_e)rmT`F>sn=%DCz3BG9_+Di4rBoeY}q$lOOxtKkgsT zdcEH7^EvO&&-ahl>-`H4(Em&76bwcQ+V%f-ppz}0MHjsrxy$l^@AGKA_Aib}GcC;) z&pLzFMcTi}eMTq%G{tDV|T=_#44;o z?Lq!6@$2XtunFbZhAOv1_LM0zULOIaTt?W6d3_@zXG6;2S>v}Y>xnnDWr#|_;`sNQBw@l`Rk4|L&Pu2h5Pyat!?y~LC$}iA%UjO?>=SRcN z67@8@{2;r0Ye{ILr#ebPnRCp*Ow7hy%*R43M*D95hy8p6hxs4SHIHv#2_L~M{s$a= z+ujOn{bQH;8psymOf&xj+7Rd8FO|-6ti&p`y3Wk7C&LwC^@qISmdiN&W z?U*j|HF>Qt|6X2-`|s4g=y&z@2FKTUehqUx*9GR^>zj<@hnj>_o=GE_bnGU_m17&K zumig=&U4>G);UK#8b(M5jr8O)?+4L8@z}rMn@ZniUwHo+@Ac96l0WI6x+&iyjt|`D z90za+)i{FKzn>tZJs;Qfi{k@Q!fCW2i#$#seX}^6Bv0cE&Y^W{ame%(hYR%VJne>A z9|)J|S5Rm-2**CKF0y)Jap)Bv{Rh{{KGeQY9BRfChg zM>)`}j5tQ79oO=b_p`+NS-}3AV@+`BOu`gQLm6hE?=x(QA3Ygj-+#{+ORO(xjL@4; zhM9hQE!LZx<$KH)o{QX%)xR&AuS?I6t#kM?=#`&yzuz*4$#;=eMs(eGePpi3^_@3X zeD4$Cp9ee<4qh$}Z%!W)4on;pzIVGg?7#7$@ZISj3j5AK6~5CvEbJZdk?_XVVPQ{Z zMELf^kA>Y+KNh}K{B+p0?CEgLb5FCMm&S9e8yAwx(W+ivNv=Y6uX+^S3-}0-M-FTJ zwhn*Pzpzz3nbD_QXYF0TZ$kAm1421jL)JcLd>Xw9JFp9Tunz}t2-P@(1X4(&x_&_T z!P%m4_@a4^E9_r9W>nZVc~ofgTNagj^~GKOi2kh+VLSe@cv#rc{E?9N+X?*BTTg{u z=ZA!EzWt%FbJSDe+slf>p7q7np|sAc-@o(LAafCm!W%n>=-(O?zW0LhaH#y32eA*l zus80Ojn-ftQxslPUL1Z){o#Mr{eByMaFIPf?0)!R=6*P0(azd_zmfNU6YBQf)jqz> zM|(RYn{V=!|H`-dD?Zwrp}FTqINI-*%GO_mWA}gQ+j@tO;hk_!nip^hSI~=VxQ;&D z!X4D@EDH5|i$VhuHASJZyC@_*o2I)(A$7Az`}?=yp5q>%|3Y;E2BQQ+(Y|qj`5yz! z{}>RmWS2h2ZZapFzhQ0i{Y9Y}$7bO{_p#t>PNo~IW|R} zEzl*8bdz)~+AlewI4(a*H8WXnFNe8WAWef_)2 zzju{??|T35djIcw|L?N@-wpZO@9NLKuYG<$977r{3*PmOzN`FuSNZp@IpMx_=~XtW zAK8!zdh6QzA+yo>>DdMMLnqSoHstq9d$)7Ec;7p`AF|@AN8Jy3@&sxp-w!9r)2Ny5 zzMgVlI3s)xP3PYYDc9K8&wZlq?FXUWeKjC4Mm&;eLJAlB=Mt`<7uRqdeYl18o$lZD zbnbQkN%x;|Z5^($-u>ff&%5&IyZq1Y_3pdIr?`LD(CQl69G717Al#A8N%`v@`2ekl z9pk_KzwG{>kv0a>2P6NKw4ZYhbWMD>FfOfn)C2oQJbXvB&Qs#Jp{Tj>ZfIEcZb+cw zYs!56#Z51{uLbUNz5Aq(bX@Oe-VLQ>&nx#I&i{MDzR|+tFdh?7@d^F^pC}5G=zU*! zH%uX?A+Ft1M&>>*AAfE@m_cuqA2a-9GwE4%bZ56dD*9aH7x`}FwQjOY-CjRR9vsI0 zf97F7RUF&XDBq;iUrq9I?rZvf{cj-_V=0znC03#1u5a+Byf;T4B)eabuC(LYBRP8W zAbC(ePs@kL`nj%)^59uxgM?%MeDueykxuWg%wHil;QjW)CgE~y!w&4i9`t=m+lj1x zhb=)qn*Vo*Ua(L0vt8(M&SHb|;7HU_@V~jQ!hPp8CwNvK*9UmCB=r6w`d=CZx-)}wp8|>@b>UQ;h zmh2Mlp7*eBkNvVom$1*ru+OKmr=;65f^Ch~%k1;><`3+*?uqpJFH*N)Flw94ALuZD zfF6DFas8j6^x;VPUtD`2DL$@6Fj8FFb(NB1(aPsAjvSBZhbYYdTVag9@FW!G1qgR- z)=#ls|NnFDf3kB-cmJ-hFfRW^=RlnQH^qOZp$s!H6SFZF^U?m!pSb_sd<1*>xp(qU z6uW=ddd#(^$(H->pKKLwBjeb`h0i=WErv1MSS>$m7Cvh78fBk=j`9J4f8y?U9(RMHN|2!wZ zImbC%!WH!58m^-c?Ve)?{-}Mey`8_R{jZ(hyhQt-OiQn2vb;G}`=8$CeQfbQw$^xm zN$q{(0B%Y14(`SD)$O@YS%1*{zX!s-_73jT{rvM84a*Z}tsz_!f@auUxvN;>Thf#$y5|VG8O7m1u*Ous_-7 zXrw36RL>^F(TnWs8|>@5Z0XzV>%Hu2wtFkuG9{s%UEi_H_lGRH(2X4Os9sY1yZNu9 z%u#pF+Skp0=R+u?C-$>_cbg+XpNZ6cw(m_g@Ljg>)+fVkaq+yBzpzemmonY8%oR5u znXCK)8~Fw17`q_67`e&(0?+Uh%;G0_(YOG%_)_u9u@b9L`>gYiVgJ(;uBD-y{ZHS3 zRL^_ockV|1jZNapu?$asKyZ#`oqtA-z|&Tu3nvN-7NY!m~SS2%T(mWJu9Ebu>&*dvymSq|KZVmj=AD`{@Pp@viGyTQ*tquVmVe~6?*w4*O2S5 z0TsWtR>9AGi}ieYh(5jKUfij%; z0=4v$lAO|G9)K=*2Z$ zM;~sXUAw!3-?fv>^2c09m;?2>UZmJ+fPIquPOK9n9nV^ zC;R~Yzame@_Fv-%!a#b}6Y~Bytp!eRlMhPBp-9hCwv!`Kim_-Fm%%uC_F3QGGrnE= z1jI2bIp1S9;#~FIzxNN5#P|J1TZf#6vKZC|m_f$6W+vHd{LkzdmM`YUu>St}$Nd#dN=x#m&3)BGF1A41J`Z67lB$sZvT zXj!aM_#}sT*0I7sFxn+L|1;!c!VE69j*)a;TG=T9v+}hA94Kv^Y6{0PmqoM z%)cLG{{2+*?{UpB{g)_nF&HHnis2ZE_EYBHpEv*htoirW-tW3={rwx(-`{WjeH?wp z{QKw3zn50peESx9#<^OZtIc?V^!CzFDxI+yhw*5A-7%Ozk7;)Hl!i(4DafCawzzI_ zUCuL2T=ntNP)5!`?c1ecCOI25^oAPul`IW&h3BJ*p4#a?>)of{>h#~&qv5Xpe>BpQ zXqxW5dDrhB|6B1@`#VZwF{=1iD&MsJ-C6Cwi}F~H_BH-!{O=t3e}VkJME)lW^S^Iu z|DiDcS0DUR=`F`ftU~Kc#t1Ca=S$D-miFt~^7IWTjQmxW54#Cq~jsF&5*{DlV=8G@f4Qe|*L|!}LkWYqR9=sQeVnn*UewrB_)esEl6m^`g)|%lk*?BJUqv^ls!fDj!f7|66bDLAPto z7<<6($)MG5as2NL>CD7z%*A{x#A5XQ>*&J#hn}zaMudBfJzq|)#44;o9RItH+<;9e z$D{GT+vths+#BCY6@3R%!`NL&j$wyA%lE?et!Mi-EMWhxVE=ArTatx&a?8{S{<8~v zun&dyt9D~t|Mq}z!Tx0fciw%^{%7}gAAFzv+r$3FF{JOaf7!elwr^Y~zwM&9%f_xt zry56)K>RIpoBd7CJ|k^ZYKK+v5$5?Na$}r>>~igC$3!1OTRaxEhn1n^3Doc-JX%-a zr0{7pt&r|&_qo=6k_G#Jy5nAV|7es}VgDm~3TOQ194_DzuAmp!a2@SC+5hf2_HATq zTvNAeTkHL!F#q0jX@1-N$9qM~Dffm}$F*g|^+8Q% z8H{|f`$sprFj!pmi+>hQ$}c7Kxc0(OvVy<2W`Xwu_T6!tH(hIM8L;d_#w2#%#@=kJ;`k71#_q;hg^nQ$w;3M?U zxR;MvlhgmkVLW>O)!H9qkNE|Y$SLUil;eK>OqfQGYjG#%J@oaZR`?$K$&Z74vx}T> zfi}f2i$j^?W?&|AzbOv0$+?)1!af4p3%O-{erUhJAJ49MwBA_U3pI0|UFiPW$fNA# zfM9J331unz}t2)iy{4{`mU z+|NovkNxVa#r1CSzO-|;YP-CKBjWb2E)Cy(t27)qH8Ol};F8w`PRsA=&toH zeMh@r`e^u8^YyT2_3-fRw}*$lXNQJ2Uiz>#DnA^)^Y+lNZNbm&+y3)VC9MRic8&^_ zyFV7TPyLuRFh_(PPd#mqKH~=cmd38_!@@u9*Y`E`L*biyM}>Mp9p&o4hnDF9j5O~e`2yYaE@NxuPFRryKi=XQTXBR zqHtuyU)h7m*q=dv73!Y-D|ypAtbXQQ_4{;a9QE0ddiK-$Xg;lM`HcRr&&Wrg49D(& zhJV+by0f1Q7o>R!SI~=VxQ^Jyj$@8qyT~UhT=x|J2ERswF$0MO{2P1)$+i3&>-jfU zmxNo6xrc}Qk63ex-v6ueA_k*fKYqtNV+4>z*9v0;R_iw>^P`M^7{xzwoPPvqw7kVX z@;3hn*|yP`i1Q_(L^?w;93#xE=`BUNUvXGYX2?}!g)*nr{KVX+iozPl%TRnS08ok!}J3{vS+PDicg*4hw@p1jCWF9AQ5`Dk4F9mr9>D$&ABn$ly-z*Bb zuM~yY2byuem&9i`ySDY#W~28aKh?W&zugNSmPOT*2ZWX4FW?%kqjtABrO)s$&=cLp z-es%>M8AX7UHz}O^}p`b|B5ZjnXPz@ot@rkEpx}jv2$?`p$GK-%j6Yw&S9&+$Yx){ zc7K5lk30tZtps=EtD)p@j6^BMqR>83M>f2q9z&!0tk6D*^*A5G(W~!`KP2Nm04?wE z9iVlQ@rV2q?d<#xbRvs5_MqD}#IXmlf1>(0ABgje#{^75E&oH!S?mAM6N8-F`KQs# zkW%M1A&EwH^!wL8o8kA|ciHTiNuQ0%U_Skijo+uw$3iT|QY=TE`nrC#bGV;B<$qY| zm{nMVb=ZJSD91Lmvne{*7KQa^PO*R4w7<9hSu^{igZ=S3`vWc7B^k8(KHGesX?^rn z(%FGs*n?K}ZwCA5kG`J+^h1bwI)_+?bs@(lieoCO{dNSk@_K?yp@!ZtQP?+^7Owc2 zc7yy7*VS(vAs-_551);G;VqB$^Ne#+liKS|&uL$fh5GKk^7=Smv9?M+rc2i11ikXB zp6{2e0YpEIGw9{7J4g2XkpG74`-;3kUWvcq8m^-cx6u1d{aR$tXN$r;@&Wq)@4j19 ze4KxsEWuC=N8if>!$`6e>F2ajK0iH-rH{u1w2I4M5Cg-B>X?+FyW#-ch_g^$UETpHs*TrOJ zg!enhToC$ltVCx&_6fSsjU3|mq`K|ThkA3`8<2SGdFA8tAxSnJe?Fw{KOa{4?;5Pb z25dq(wxJ5`^UBnJW$M2&^rutQqoQ2T;8`Q#qd%yV6{To?8UA3zg5g+?S%=>MAM`#^ekSvcgFjQQoQFP4RB zdgbTX4f?h^g|kas-yGLB-Ss`|`o@%nx;gHDf%71-Lb{u!`?~w@a2_0yMgl3M(S|JY zIDz(M@*g^pMHjk}Lmo%pcK_$yKjK;kE!HB))VqJrq7Bua_etrT#u=PL?V#tazxaH( zKugc&KBBVRQ;%YBl)UwS_D{Itv(yPoA|tsUW* zn7@{jE3pbOf36{8|7YJBnQtx_=k+W1@uu-l)($Rr%r;bE z2XN zrrDTzdRAMa6W8biksrY(a$N=c-*v`(Uwu`%`GPhnYMrCcako%&TfKVGeW9@b)yw+T z)U~N4?sK#Iq}NHKu>Rmn?`vOdbpOIB+>y>bJV5`ixdseI35Mc-zW=8DKZ5;_V_V%n zT9832+7R~+94?)aD8*Q`-gV4PYyQx)m)$?c(p$}m*FIQGF2!=J#47w={%7NM z9h5ia%N+8V|2+S1SWq?u_}K42Za;%6TS8^})SX7@^er}QUEA9=*Q zzsc|0zy4AvC%2*YzJ3T)(GwSy0cZ6U#QE=C^8RLJ8{!<}#^dHQI%XI4U?2L9`hI?F zOb`7KdSBt&c!O`F-8`J1nr|)c2ogvkjW$$#UcLV1;!yC9O?E%i^+$-$;{*!(|9z!6 zoTi__Ib6UcTtP3c;X3+oD@yi$MPsq~KfXiZUhRW>!B7lGTw6G< zF*lN4=zknzP5^x@^85!mbn6r9dQScSoblz8^>I&Udys|k|J(H|BGvI^IO$yD{BJxa zAg=we6|rw&l5qPX{rf2Nvu|XRzQi^~ZiIez9IY|_zux$NX{DQu|6j|tMO>f14K@1K zbN`^MmsTx%c?wyE>ZCq@GVVnX_asOq+2mO0dS;4Gh---UbBFcog)9F?JwHU*PmgO0 zW)_u%`Sk31*M^1k#mKXNbLe(V*9!LUTKA6zB+$6q*a9>mg{A(p94oO3Yp@O*unFy+ zQ^!vBKbbwq{-<}xd&;o?`FD@rbpJSp^jq%#xVZ&rMH{N`J{iiTvkg_)f!e2xKN#_3 z*hP=;qyc;A`;Z#o+K_ZiT*I*Ny$$m|M)Ak`{Sc~A@rE`CDnDmmn4cSK5dVub+K@#a zb+5C3cgz2KwbMJa%g<|Ux@~NLV@~2U&LGakKSy4`CA7;&9r95pvgkthDf#oJ z{E4GG*}pHU|L3Ud$(AMRf3j7$jf~@0u1KdB*Ki%JFKRz)7xdAyTcthA9=G&6$P4H8 zJI8M4BJYW-US!T8+5i9W?E2|1B?qHM+0anzzTWB|N`!|Z?)jGzZX}c123z=nw%*V$ zF2D4(Tfe|DJwJ{%oIZ_zv#zml?w3VjEZO%&2;&OB^Sh9}Uo)>Gu%G_MIAq~;x3&zK zNwV*2wE5}NP=?MIv=7lGZ+6e|Zpi#y{(ryC#B9vPd@RIb^!?0ryj&D|{>^tN-21e3 zF3FWxg*B+2$Ol4Je8MM^c^V7`#Gf_=(v6$vhE%AKjI#* z2{h7^h-1v+TD`;l=e4&VjunVw2C9U2q4Hx7#|N~Z<^R5@j#7_h(Y0IMbWk0&mmmIh zdzV9B_a>Tpe zM;^c-RL3|ZP~A!40#DPD<1j|66)J4!o6s^>wIWL5_SGt&qi;UW&D9} zqLD41B%6FADe{{CynlSab>TkTLheibLi<+z-}=2fx2tdVv;SXLws$MzGx|!jNsmrv z|MwXGgER{L-|y)EW?!};J3#)rXng#q;`~Q>wpjiW_h@`i+%xB%b3H(X-}|rhe6Gr$ z%7%gT!n&C+m`_75K@&YSOa7WCe_^P&xHso;awKXG>sKJhB90HLueS!nVdazSZE)Sm z%jQ5LbyHjIzWeyp(_x%r$72E}VG5?93>};KAQ9Kz?nGRBy9;@G-0Q4)uV>(yq~~}B z^vn{^0Byppm?4dsn2ouZkA+x_rFg&nv|M-<)}ZP)p4ls-vcT* z6Uwm-RoH=D*n@pIfVd9gA+j1rkce^r;=ai=+G1FmSu)NK%#(5bzY{U6Tt6AZU;DF% z>#xT8_*wUO-@4!8AB`_LLqCU#mq+$6Y~4Z-tEEBAp>POJ}-t$gZdSZ<#sJ8ORd6vhE`|H1wO2TQFjH#(F$ZY;)OJZi<&U;;hi z-Wy&R9VXGIASJ#@I4Qo-GkSmhJI(K9n1PBG&lve%^b50vE5D%Z{0tkh%Ua;V^O0+j z|92XHv{(K|EUO#UM@cl%Q*W6c?zqKRise{|Rak>{D6Bt!T>eKEUHHBA|Gn>K9DA0p z4=vNR4WIel_|xoa{wZvb)+Rh&pZ_xd6ym)3a`D?xg$mzd>qhy1t-dSa26atbV`>+D z4+{H+Zr6_3uN{GX;;J2gfINg+=~a_QP&3&YVynl6Z9D%Mz4~`ah2OaM3Kz4p(;o~wTz(thiEge#stUnKXM~n{Nb8QEdVOhWDaRNOb|8RKoruzNk9}S1bl!kvkUK$SW9~s_! zZlrzT^rLkT5BsMM58rJb9`?OHG<>IiXxNK4*8gSLGj)RbZxh1q<_Y0jlgEc$$-fA_ ze_tB@?-2O!tjRBropg>ZH?=E}s~Z?j3!lL`#P{F#snT$PehEE4l+S*`|M!ah6}pV& z7uSnxsCbhNOy*wpAMzINpzlNce?_CiJ^BOme@Py|V3c4ehGQhQOxIo?qPVIaZ>zyGS2RQCLOK z>W}QyAGwCU4*6l)9#geJCTfdpEeadNRZlMpo5*t1-Yp8-$STy(8_pMLcNB#k!n@G4 z*|^4e+97Y5XSd$?!pn%^lY=aeT&T1i#qcA`L_0&<7dtPu6lF)#8qdC!U^&uYNrker^z#@p*M{3UT+Tw z=Y%hy>8$s9*?ZRSkc{uXLm!-WK;o3PI+;`^G~&_vAD8^+3VLx3*U^VtxP!v@e>Q&f zJr>sgIQY;9QLq!lH80QyiRfE6hUj~2@%^{G!v-PK)7cQxd4T@^r#yt{d&qRK8R*gH z5a;$){saHP_w{2>W#>=izu)dW80whnjnPT)&Cvpf8^--8ulsszxhS>=?m;rw9qqi z*r#jRr)b~GzC|ap=t4Jg$YY?i=VKujW1r`^lw6LLsJp0t84au1*DKhp?D-_wgwz7| z?{;(OpYsfcg~IsPRsK^eu4b?C_w+;$`_~x!b@UBLz2m>)lDDOWP2$qqjsK7K|I^+n zyT=5Z?or%)_Tv^|EiBWk95*# zLl(8-Y8;!VC;Y#`|4-0QB3@J6n=#qXwfgO}xHCA13n;WlCi|Y0`C!unh{ALhL_GTZFG_41Llauw{4VeAj#xJGCfnbvj~ z?zfRB#aOiNl|SX>arA~8^4D!+{pb_WB%GSZ{veZAjcpcJt<5oooQB%7`Vz<)sCi1i z0r}dP_r5>hU|*3o+8f3?i=T_$f48p?8OQ%GCYNG4R$>(@zpCBxguMXhaV?N_WM=cj z{?!fiO(;jFwssa>cYPmcy$j^A&2QE6VHLRpwcYlCAorjq6UPfb9M8W`_yC$lXcwY! zqIM#ANL(B@P)+uJIX@|U`?v3HU{$NOuDddzPeVLmPQ1DOJ zs~^;zkN2lPZ2g>f)Gz2frG9x!{c>6Ta$Nm_{B-;7N$dUgX;wIo6S(GksZ(FnqhS_5 zG#VGF-|%Q0T3u58mr?(9tN+MG;Uqh~NqZ&5W z2m0|3sGGag%du|GJ*$5oN4N4X;Fvl#y@G#1{hL|FzkoPCAU&CFE}dRn!*vww|7VoD z^hfspE&3hAHbu@e>TZ5$|KAf=z1CO#S?05EmiSSUwy$j&!Knw-*`fJvy>qHcTI{u=bk-@KwfPTzK;{31L9Gcg-;F&_&N z$Db@Fm*U&Lz2#)9{1?ZctfXf*%irtme@$P5{8agSvi!Y3{zmm={wwhtunFa;-OYdh zj5&w&gnMkrMB7vTK9)MRR~|?1NA?gEzX$v9{_^jD@F7&A@?YFfP*%`WNTUr| z;g=C2Ppqq~pHyN~ql8x?MxV4d2IA-$#${;~n;) zvLl6&j_+Y7R1}qmIDeq`-~KF=#<))xg|TGc8^-4N7UCXhajeQXapN%|{`O1xd*`d6 zWACflIIo6m&8vJ`uPUEjRX)9{e0nt;UGZvYUj1q~w(-@FUi)fjS@bI3`m3SUac!Ab zL*0T`L;bQ>Ljw|ML=sI%VUqJ4aqm+i^x>F>GR#1oH0#lT1R8gj`~J)M-^=;m%fn2^ z%*I^I$3iT|QY=UN_Hyn2a_#?e{`YeJ_j3NPa{jOKkY7^n`!DBzFXw+R=YKEfe=p~M zFXw+R=YJP>w>+$r&MK_IIrP@LW)MPL8dk*UxhtJ+WE5=biMto9=qQH`UkjRMltH z3I4kW`)~k

tVPe}0*kx%bF^WyEgnUo!4xaY9`6{>|Yec^b9#o5LCM9BSwd!?f*3+5cAf z5}M9y+h5ks7@)0BUJ=)eYq*Zb<1?qKi-m6?u0eEUH`KV!Zr4R7h2z}hSkK13 z^C^y-hBC~+Ow7hyv=CyRT| znmy84h{afnN8^W<(_@>x6RYTJ@Q3sF-@pI0GAOKbOy4idLtmHgm)`Rs<-)(=-(O`< zl!r|*?u)O6a4%E`) z{yn?sd$11&@M!;@L-cI7wgal^M-cl2a$D7(WY=Ex-?RMRXqc*ffXDmqE^!U)(}e$| zkVYG_$m0Y~qWvB9KRWNK|ItP7MsA7vU;X~LebbMvJ4*Y;y=ByqYt-m2lwy* z|LORni@v|ZzCWbVg1E+b>u%p4s@dcH*T_>Cj1ttc&1>`l4W%bm$Q$zSaQaB3xOu`gQ!(;ywA9b1V49rIEv-)bu%D}Gvi1h*KaXkJ)GBcfhFxlET z^rgsu^MmpiTLE3#Esy&*w0oBOz4sga!b&pw&sLFZunrrr3FX*^D(pa9%X}BP2X%Ag zf8|>P67}w>!@Z%2o{$wlK7Fhq^SgfAEX7jg1AK<`;au7uRqd zeYl0X1?=NRY*4mpLcexnlI^*d4XWQdb({ULn|+I8NaL}6+swX2ySR>n>|_6lc6gV5 z`EEQ~pD#V=$#BPc?%@IYe_bA$%tju-w?NOX_g@UAmmt4@?*ZK_9*(siDz18teG18u zs9pPHC?&_D<^|V1(RHKZhuVs0ddhX9ag6&wU5)#5PYq{lJ9bE~@7MAsxd%D^oZgSfgX;61E%rJQ zcL>#}`WN;6XNy9%AN%4Z>+6V1AV0`^7WcUSa)tX4pGF(9Xcd=P&(}adfs;6m>euBR zvSz>ePxN#23jLxN$cF98ha1X=yD<+bALyy8#ufYR3UZ$x5GuYnAoS9&;X3+o3wLl2 z5Ago+GyT^pvp!~i&sUzb|0k4VgT<9#D2Ai3KJ(WGgpu@8jKw(gz4Bz}tMmP=_bm!f zz$8pT>qOt!uzq10Jv++xLocJxK)#1I@Zm0TL{nl_$ z{d86RbYC0ctojKl9Bt-5=-@w?#g<0PBKEcMf31l7>ptH9XR7)ch5dh?Rd0?^KO?UH zztnk_VJkvP;Gp-%H>)^() zL)yEr2XPOdnu&iF_R-^;=C4U=sy!f}F zTH3iSPlta}*6v#IWcX&s--ey^*QWlDunkAzwV-nC--W9Ee-|nS{g1GH_1}f>@xL5U zucZ8zMjLv*Wqrd%`q!TL^YHCKCB`m&Fl5E|{nba!4Rc?3WAg{Y-pPYQ{mzeuy6qng ziGCkaR(~`k*M2xOWRd^((P%t@Ta8JxofH23?YI`}WmP5Ce3*gO9vWN!aOXzBNfQ1SIq;ga7A^LvN+ zen!a`?*EFoUR1T~-~a5WaGluFSQ4-{QnOB zi*TIdUumy^rSw+MKBGOeoF3;Kc48%c74jXv&5I9xkX^TZf6tBz*WAxSzpcSK)b1Y@ zYIctb8|Vr10UNN1UXIj8{ondklTVEfjg!qa7nfdL8mh=0XkAenc9DCK5#OnQHH&@1 z2atP4zr8;Fu8I2bQFmDXJQ|QdBa&!B3Wxls8b^>o3Td<pEFmw=IB!1Aqd!4UtQZ{{ zP?-PsTXpAZ_rJ{juXq1A>6rBX(cv_C2Cd40bL0hNnn#DuS;_@m625}mbmvfpbWQaR zlp}R7jt=#4j0NJk8aGROzqZ-y&b3whzfrr;|E}RW`fv+(a1ZTC?GJ5=&ehIC7WSVe zbHaJ@=qc@g9IN-O)M$Tps~Zk$f1quz@k1{?t^eZb(0TmnkVO}|kwYFwU-$fQY|PUk zZ60XLsHc^$Pb*)a<{x-E)GhP=(0~LQFM5AyI_3QxFAWb|WB+yX3I?MDLopm9QHrsM zHh%TIQDGdtX5*+(yY~ZOJbeNt;qmfhVS9N+#lzxDiD{YLh_(yy==$5#HX zI39n7xWf7aKP(C}>9a8x^RWs6Uwm-RoH=Dcqg{|e{Da7FOCj-gwuPCu^Q%OuFd^ET7P*;{Legnvg8re#(rP24K?fe?RVOb)t@IE^QiyeB>gnb;6IxGSNDSYch18){*TrlSn+4!oc~_HC0s!- zuHicR(7uTMkIt9W@95g7en)P_!}i6|xA_0x=Km+t7qw4LX`i6gv2CcnY|S<4+`&CO zKDtgfWrL0o&5g?`FHnfmn)x>Xp(1B$L0U;8)xmfDVT;b%)m^{M*BwJA3ERW|9?xH z?y|NWnL`}weN-9JjAOIpe?0a-(c5C&F#E?yXFe8UFBQtnAA0^{I&+QoB8v}4X9Z!U-k1HVv}$=nl8#$Xhah6gWifbzG0i+;~IoJ z$f`Q+`>r8j7kv-*;Q$Vy8b^>o3Td<pvUkKbM4i;vS&?diP`O&p>i;)W2bkOmZlOBaS~F zNj@5XTuL8{aft6QyTm&qyR`>$WE_87CEc3M=1W_{`VM`(V<%t|s?E)B+-pAgLG!`M z)U(Fti=T!D;lxh;iC3-DCp-f)F&nMoGPj=$bLrW`?00?N^XUtbuVKHlv%8zw?_}LD z_V+XFRCaKJY@E*iCYyv)WWhdsul=*w|CeGpD)hyyB->q62Rf(nFQALwjof|x{I|_N zSi=66&N1o5y#`yPpBbb7o!&M=ziCZLSS6jxPkI(#w0GrxzDWo+#|dXkLCq-?RTB1a~<`r zt+4;XRQK=NlMCFx>yG0i4*1U@RO1K|NFj|jwBPpq$={vGq6=|vm$=4b9*_55807xZ z(xd)IE85gg)yd+Jl};Wfa1yoWt-sHXI89I7H+JK8aX3RihtxCj29lS38`!c${#6FP zwu1dNkL~>?+uC@T))D&393SWJTp#zZvP>yY=!Va{vvpdwcoyF?k0qWmT>R&R4{AB)jb!;<^ zvB}epZDB8FjDKrI8>*jU|4L^M_Td0(#l?N?4$ftZ?j_4_& zTk3!9(hM2*h-xD%zyH|~*Il0=t!iWXCXqEs`zsiuH-$b8xldYqg^wYD_{}Cq`G(0< zv3V{p7$4x68HjDpnPmE~@iXLH;P% z|6wwRxc_k6KYN4oY(hDi42Q%D+44yO0{;-k)*}WaAj^mzTUB6vk)z{QwT3 z8gU+R#mjt`Z}baCgnRV?CCC&izwEjHwSCFyZOEdc!kl071Wuw)-$3r=&xRiL{AuCd zuUU_nJckRog!EZshRI%B!*#^Iow)vYA3fXc+rTaQ9pu+b6OZdhx^81a0F zbIltkJ|B|vo)1k(%`xxZf8Ia7=brEb^#6O`Jlf4`@3{PY=p?hV%6!vh{QKtD)AKKs z>3=hS8ONlZM$1j}m(gl|d)p0d$KCccmrkYfwCYF3T+(aLK5zZ+=fhBX;=XIcaQaB3 zCc8GoH3}M0DlUDnEQ}?`p|z$gj3*}`lPn9JFO{j^%EBb!Dag6DytM$jR=ds>Wnr4Q zGR(kC^nTSk3uI529_EsLznvcDlMAsJOR*d)u?p3m+hhO1TJyxk)sAuf&zdJrPq^0x z_gV1?&mF1buHW-Y_PG9co)2~IwH^&fToori{?;Vj6bk+S6(MYrRynqz3Ole1?VCOS zex5(F=t8&k0&)|dH_q{S<*(D-d$(<5-FEqFul$8Xjc-~TGI`!N zepdd*9_QGH1IT~k|LJ##ti}-}(7r(aS|neumcLfWUz_D~GAEoTk7n4v(mb}6{fm~h z@+Vr+hB*H(C7m?d5cdIUUBLdBYZEUQi4xmMfNZ}~_@+IX~n%?WXF{@+RRG-~C| zGvqnc_+Nu`64Uwcg)gD$rg8&~j!UAh-uiKt`Hc40EcVv|_SX{SFZqY-KlXjZ9&geq%zyalfN+Oi z@n!W3+4tk;!vnI1?7z{uKW&WyWynB!+_!x&*}C}yp@bZY%zFK_WO@ZVfGkD953$U3 zU3OjPUFSvDsoWpyw{fU8zWbVUjHge)B-9$$U863TLXUf~HxwIZOfN&+lc!0268CCv zBxi`5iP@No_xEqi7tXz`PnRt8f9ua#OkaxSSc$q0{hH(S|Jg zzEu?RWY16eMabS?@|QfTzKHgW{;nK_zpkLP=i9(^JW`Uo!2 zFCng77}qc;`2W`%YalMIeD5W%A+9TVo$Nz~URW2%w|7hU4sxDN9))?RWSu-xAM*(k zXw3Nj<=4XcSV{h`Uk(ZP{O& zeo6hj@jd&1-nPuRd1;J9DXMR4OOrJd^}k^peLN;0jzLHabA8$z$>#U!fAZ0|tHOMA z$5sA{-TYN$#D|82X~G+P+hybo%*1TW#e6Kpe^&l2cFa;N$4aci8mz+xw6B-{H+#-x zR$0*{kHtBNxl{80arr;<2j$-ab+h;JX#7>&XQAQ?d<;YMHA%1HyF)@bxeZm=fnC^x zeTcrzE$I8WIm_e$afeWi+>eHY?Jf2T`;fWQ+P=NvqoIc%D&e=@FKZ8cR-U4#Q5l{H zReTKTgT-Nra2vA7qm})a$rOhZG48Il3vb%vjea`*K7wtnU%A^cT{t7I+W7u+zA>Akw7CJ?{7!H?mvCFg*&*12k5`a zw}^IY*>%h_&L3HHtu?*?IplHl_#dplC!9em+K_&yBn*~LZcD#VLLQNJYmZ~Z4;9W1 zV8df1eK_*;9J=`}x=`4E_XY3cZ|q~BFS7Di#@GLwz8}Yq#W+-aT^}tu0X;v{-%s{_ z!E-08Un&XH$TG~pOw=ype-J;Lp4jf$U@m<=QuMesLmY$Ah=t96g`(oi_VUvFq|*SP#B9tn=Fj#QF1^$a2K-@rC*G>XU84aSVAJ z`xD0iH7YBTICOYO_~(ru3gy{DZ!CL0>{ zlaw`ic1n(ZNZd(fS@acjDvMsbZ2n=tY34?-S$||a%~NHeYDAg&httCLCI3DAqjuAd zXC{Se|2u+z>X;mM4Vx0aIeAjpxpk8M)G1-l#X%wW;}Id@_wO_tFEpd`*t=|S z_}+k$a6o%J<+n82z&#P()K1FMd%oa1CF9)w6XZ#p#u=PL^9#fEVGR%I`s@9U_4`G? zOut|DYkBH=zrJri87}y}{w;lI{hsDy)qlD>_Pq=Zjn?=~T^t^o1`G{{ZyKZExav~_ z!w-)0OZ59#_~D9y;fQ&TwRcAve`b9C5`O*#?7P+MyN&$%XriZxt$ar8dv z+(LyphIh!?yV@T&2ZVd{gzIYLxqPUN6)Y2r{G8ErsrI19*?fOlb%I8I`C+px|YqZd%gSE=ss|C znEMjfw8V3Y&+SFmL~p1u=Yp*3{$Qw2elU!5&QgrUIE=>xOu`ft_E*sEX?RZiXO{NQ z9PJU9!k7FTmqep}qsp%u zU-5)}>im}8&i7BoJ;dYQ;T!0iP>%noebBf-IkH51W0`UVb?WzeafR`C^Rz!2t)uJz zRoH>J{{1d;5B8x;K8|@j=IyxW*B|-=TXt)2;L&)aqx5DRLmCI9aR}8of&}8afhn?A zosjz*{hstT^!>YMOy+R{CsF;nGUU03ZMxVua9UjLPV*OVhMpMUUT}_n2`PFL9{K+& zKkoVccvQIJw_aSsb@WBd`@Kc>ylj0!viBPU!#(l=`fv8VBdu&1OqO6MhNHDv8>2&? z6FrVeiG9$e^s&g_*6u|2i|PPW&$C~=^v3yZJSL!at^RNQ{FCU3y`w|J>!ZUI`ZS~_ z>)SY!Evw&a&${=%-;M}nj+ud(n2pLG$+zC&T>5-0M9-(54vWd&pXfhM>i>S* zIWFq|J}b?eqxsfHhoftiuMhpVj|=S^xhH{r^|>|M%$s zM-F)$HDAAZ-so@)X|&wb|BqI*p-z8({Z8k1FNwIO0httT>Tv$d=&;E-%CQYq*nwTx zgMDcCpN^z{{+iJt8=tARU3|8M^9lG$E=U+X`dsyH-@WI#8b8o8i9a}?hm?s^`KcVuA z<`qqrmS@`-*DyexyjQP2Z&;yyy;;5fQfy!QMv+1VUuRxgCvXy{aR%pb0e_g^y3Wgw z{Bb!l|F-u3S?ymOLmGeRe^1{~KS}2ba$AZ*FWEYg{onI}aE+cF#{S23dLN?CE{AS; zs0+8mRkPRckoQo#-daFp|5x3A&F{*Qfx?6Fs9&M+1^0=%i|&7n{#PW>h$NbjLW%zj z#c;&-&q%TqV=*4>JGKAi!A`_>R#&(8zg`~O%+4Z@UShyF#*rJa3FU}&Xh!@tdbVEOa+r-w-+{b3 zB!})Ab&a~Er)6N+CBFCXi^3jqi~BB&=Nx98WZ^>o|FUwHehAe#g1#X|ArZsUPmyV~ zA&a<%e~tJ&J+4XE;CfEbPa-vn{fp!?{14BuoyEm*hiAxhXx%(8Tp%wYqrah({U68u zT@mg@+&d_*t}ToMN7xrKu4&V@~ zk^A(J(EgVE&&GN@{`;oqio*Es{mOseMDrZ^pG>P8TF8uWT%)h;zWi@~{}E~Rs_PSE z3Td<*cc;O!oZNTK42n497^6Vl2jCJSL#;t3$#hatf;F6o+YK8D?N6YB%eDe!)Hu^uqpo zTiL(#`AEIa{uNG=ajm{s|D`AEH+`@EYt`pp@l!rW$7YNH=se3WfI>O?(^xLEw^r~I zu>HF?TK`Xfe*F!7{QB|}Xk<$#(S#HhOJ^yTV6Bv|s!-c({y*Dx2R$K;2JE8mL28U^ zo9Y^;yMKOwed6M}4+qFYXuYdVM;<}uxbL!F{#QOEgi|Q2`>%T+ z^#9s2h`kgkha&Bv6dNSSrCcP)7YWj_!3GI79D)Q15^RuQFE-d9!GvdJ;P@2`LSp3EkS|}5+q2Fa0yQ45ac46!$A&mkk0qDHid7i_ug~Af4uYg zJZnGC+Ru8PXRXirte-XV|7GhlXyb`>)HqMC)V?;u_diAcTEb?#rVTk1QOaNPSZr@Q z>bQM^_3xeADR1V`b=LTji^i9bMfs{_9Q*J0L*h)6hn=4p#y=Qo{d@9+I4a)OW=~dc zQcosp-jx4awb|1vKUt-ny5F{eBl*lO`{b zJt!!PO8t-eB4W7|=OV{B$yfY#4ZXODiZbgGjZu%F$GOL~?(;6a5B)YNf13P%jsMU4 zf3h?`aF;dE{5BASF%-kmqaSMo+1;qzC41Sw(d1Z^{O8M_-*x@B@tm(~rlsD)w5R(Y+{l>+StwZTxWwWv)jk` z_iSJ53)jG- zYF}jkXB!tlPafbKFZyWMN#Bjsj`#cj>-oz14L;cazt`{kQH{8NL;(lsNtF5@-qLgK3_nI+|gKVwwUiW#gx4dV1W|w#KvUUqRSjVWDo;1wUaxStn z)JMpL$UUiyODNkOQO1!)Kbh;^oB$4y5(YqCmM4n-VA?>qMr$H@~og`V%;OPnFQzwnpD zIr0K7#c*6f`to~;54L|?6P6X0x%FP6m!5m6FVT7ay~Is=#lLtbFZU%1!t!hS5?!zL zCF<_Im#81}UZP>bdx<33IODxU6WQ!IMHZgbCg|GI_4s`D6 zOXSe?rnocyqqQ$lZ0bw2EO{@{dbuxg=z3ow-QAZsOwZ7>=lc?EsIKWt3>42`48?HN z3ab&v2zv4v_a@#^^wCI-b8o_$rnuK>?3fM9;j8J% zd-oCzx9%m@($^uS9^TB(H=VtgXvD@l@;BaiF@!DO`DEDOoQmJc>%T1vo9UHbV86d& zEDe1twqqxDqoA$6PF(efW8_Np`(4KW?UVn>lyYOQ^Y)_}2T@_HY?4eNjdtz-9f)%P zbLi5hpGSdSM2mdgsth=EOnYCG{>KjOfymO^Y9Njria3hnIDy_jXul)7^`D&~d%mS^ zL2guU-lPvVj)}SCIDN+)MY8gn`hC8pFP)xUqK<=W^xSOa24AI@eiKDyN&$J#v1^|C zhOq9>`_8TP+4HZ~3~}6teigoN48&jz#c<#KaB>7jq4!_($W8-EI2)pTfQ5t`^ zi*0+JjT8RQFWgD&uep~b}YG?Xl)%CT53jy^qWtHL(iLAllW9PeCx@u-+xu( zAl|5##}4!>jWzqZeuwJ*9|>_!x&wK05Ra+b$8-5{KWEo#thxBk#X;fV=|@7XeyRGz zmN`_e90C|C!jb z?SCfz@wIr zKN7p|^bbE;^dE^`_^I}by+_oeTzeG9aRP6vBb_4an&KLO`WQ~?f4Znofo!^`kKvYn zm66sBbj~?kz$NtH3a+6S?Q8VYtk?hah<+P%xh9{`|AgWU{WP!Yr@gR{sCIftB!Y{ocCav^fhDsQm)VjVh<@Q=}jJPIhHu1WttU$}uUoa7rfl1=>N=;Nla!L^%Fg{|1A z+%1()d41Q8+fVBMKg-WYjxXQE=g;rb|G$HukCxf|ZyZ|0|Hff@23hB|-5MOKUok&b zJaKKjz2tt>>U*ydUQJJK(>8!o{o`Be&h%z9p%Dk2lXm|}GKDz)KTWnF)2&?~Bi(s8 zE~3Ep$9Z`V$Nvv`fBwTW;*j3b{tFAFe~I*KGpMNJXS?n=PT&;I;2bWXo!#%?&v&A< zzGH^{U&H>#{~=qZNWZp%Lyx#G9A^9D-Vj-|on`;)2Zc-G>A@9TLuvef&7jarPsX+Z z;{$HeZzJ^#yTNun9REMa`_R@K?crU&_n}{<_w@tc@ayu*AIri($Ad8x!!ZJ*FdC(K z8Sb&P4&@l{WrB8*efCKtGmbm+-mB-_UgJH}3lqH8VcxTMoF6kdj1|XtOvGeN#dMTo zCQ9o%cn5XIJm-s^>psub^KGp0o$V4nNd03N+lb_gY~vznLo+?~g8mQd63-IP9Lz;U zds&!IF2rKA-_!o_yyuS`y71xg*V;cG(f)B#9sHQ~kBjOCWY%#T+5KDo|6|7f|B3(q zsCE9uxeP0?605Kpg-?}*b)+hA==~yoEpIq}W4ZP(_Po^o^@?@f_8F@oyb76J>K<=u zmqBHndf4y1k0pE;>56-|#W9rC%F6-HU5l;Qj@m`~-<6L$=}Bc|19sE*BDG1`yhjvs>y@M(v$H##6E;hHa)J(n|9oWg5#p&u8cYWS$CQLeO>zknY_pTlMmY$ zmblh6MI6O(oWLoZ!8tr!|K}b*_7?kpo&9IO3oo+&WXl-#e**hYrf0DKWX5rpEUo|Z z%7@qgxggG3VKul!PhMsJyZLAIf;QEZ<7UTAWF!Chitsq^u{6KHxcsd5e@$2~GTLT3 zU-Vs!G5*f+Z4~Im3BHp_z7w=GX&*kOefWsB034=gPKxJ-{E7C2{EbfJ@Ynmd-4$ma z`cM^kQ|JmsC&`#)qc_-?a3*gul(P1S^pMNG5??KABH<`1V&*rYWwpa<%hBK z0Fc*h~i2gLD?y?^>>m?L~H=A&Q^ zSH&;+(DZaS{TJ##{Q0Ghmm$ketsqz8!FZi~M!S`IeUXfN9IWzNcl58xp0D}NzpX!= z{aovK9X4Pys<0K?u@k$|`(6Ei@<6@3&>&4Qt&OLp4b6z-pZ7X%KdNyMNu-cQ8`@v? ze$e@d_k)M?|LDd3-f!mp@qc7oEAa4Z-Y>HBwwJX3|I)Yhj=F!lc7WGEdcQqu%KiFx zPW(k2Md9}&!wK>fdcW>j{=xWAdhGi=M_#}s^dOFbj(eY6p+A_vQ28a#MSpwB`;Tov z+`G{DjP!5O|Gh&$F_QYD8}&;!oz?$+T0ikI^V`L74ZXOD+qjE9^!vW_zoh>ioyegJ zc@$7Yi@xqw|8Z!WzHaic{_+f&b=*d#pL{wD5YIpi#!zGzI7c7#aC%N3b*Db>5%f_g zIxZ~I|E*8EiySSiddbsaEIA&vL-ePTlM%--G>9kJ-#CJI)OWN$HII}o$BksuF#X@s zTrbTHNWLrlyl^zrQ%9cu(E2}9#aU@>i>mKfTZLYZ|8)JI7yb|ZKX>GL_cv2KvoHs9 zk(I8D^v$Q|-t=v4vJU`#F^Y~0j^iA^t{vKSg)Kw%8v8hqHBUYrR+6g_pT$D&X*E50 zL?8FN-a8rh7>IL=;~e9)ep`nP*o-PvJYj8rayxcnH}+ybs*ywrg>T>YNlNoyUNhGv zl-=L|?mK*KzxO_Bj0M@nxE@`+SVHvDvpVmjH#H8a?C{U_s#!Sf9n3q(_xn5p0C(Tf}D%_ zScvKul%M2MEW--a3ae4Bt)wUA@rK# z_SdDSmYcWcnzdMm4ftUH&t}K%m)~C>_zqi*uB+@i;uwJ%6%C-A;P$8FdKfl*a#kSGhmN`P2NzB>zS3c22eb+Dq<7 z96L}=9z<#XJHN#-^hw7lG_RMgm!y-7edu-9%|B32X%Jr$jXQ*+87ZV)(}o<1IEv#q zfm3MTBK_OkBN_K9?4rlLfePw3rT2eB`f&*9W72;_+K@#X(jCfi@tng2Ttaq;bDm@q z=s7mK6IbZh5ZClCAn%;6G5SfK*H0k)CTeGYv{V=0C@)rb+Z%D1{d9!Iw8lLTn^Eyu zbJ66R_^umJ__g)#T^rYt=p!@d&42wz`3YI}uX3aD^qip>_N1GL1ImqOJ~C zB#+`aPM|dZZ;$pB`h)p@XXqF3aQ)DI$|}_L=NIt{8%FX!hw+Q}oK0xvBd4&!HE*2c z?F4^Va{3&U<{^Tl{;zaDD+lZ5}`EMgBLRwiOTk&nNkH$PVJ? zo!9@(2kQ~v6wsLw>yFgt3enMIFt%;u7n-}QdX0Vz&XU*Kot$9Uh!(c_w=x8?Qu&RdAZs3_AG z^rkgG=nvMvTShN_dRkaPHf)g}ulr{2_;&GN{kxv8YTNz=-<7Re>AY1~jkQ>Z4cLq- zY(?*<}aoIPB z96FIl7m6sLm91-0zNE*<^T>>2gN4Ph2U(QH|0w4gZb{#u*e|M%)L-2QasFS*f25H^ z5l3+xCvXa-^?zn~rq?~&%bxK)&zdYaE|M)1)c>aN@0^n+4?F(g`ac_m$1#6#|I=Q3 z4OF%mbN@T-p5pG+Hc;CCPusvb$8kQy1v2_Om*TPZg1A3Z5B&X>{^KH_lb&mJ zZ}r|c{Vq!D|2WPcaIa*Wus$@rpp0Q38p);?*$OhnZuEOgzQRCMpSBhP8P}$(WiyA; zhvS3ugGV?Xh0%!p?_ z7Up0s=3^liqkTf`w>HMXc!u2jV;b_tR}@g(XZ)r1`_?7;v7MJ5qu+Xlerqy|wqe4g zWvO_UVFgwqwqeBm|5nj+d)VJy)*qy=MUh=C++lxju)pLwVbyDlVIVi7mQAf9x1zNF z-)YbFqOlB)ccMA(xz>2jWK$-#e^~$jmb$s`CP_B>Zkl{2&Ay8ixzQMnO(%UH$JDRC z?z{S~c1rP7HTqWcO|=`((Q#h-jq`|OJ-T)n+acYBUDE%i^uOr)Tj2X!;`<}-&p*{p zbi>*RNUt%6M|{;dh$ONZ`DD2~NYCxK--eW?x1l&izEV!)H)%s6bHd^tltr?_xR=@+ z)&!>?N8E$5K|5FStZ_b$PoddyN?%nY*>ucr+wA*?dc?IzQ+Z>Q(X_=rfblo&T}hmA z?Kxb)CG_A5jB^g1c(4y(@uFvhI0vr_=??q;wT94oz_VzQ<|WdOjOTM0*Tm6_o4Acl zQ`r4M+A`^H&F8ro9O?d$f5!d!t>?4vHgE;`#?^9bE|rIkH_AixBhQ8H ziBE+A!dGGyO8d_XUrkTG{9I^Q{#;l~Ux(C^=hR=G3r#OQ7aINc&o6y0>|gV_@b;0< zhJBf7;iolItXX63*K3~%dtRLse*DaYuzSm-@S|6r4ZD^<8-DoOr^C*g3E>CxJ{5M* zTL(QJTKYd8(&i-{TJX5Gktf38%oAaQ|JsZyY{ho$#LrKE-Z-s9_}Q9i`oBIO4!moe zvGCXj^K*H=W@MT5Apb|a4LKC?wzBFdc^oHj3TLp} zw|0)afJ>;`@tpP<@8J4#A=&Mjqv@)5aQZpzQ|9;2Gru2)kjCNHp7RXN?|;Soe($aQ z8S4X}bDH%5&^6Bb0K=>gFvR)*gPsc;)m1kkj_2#~e+7N@-7Pc1HG0n%tXuFo_MLtc zy?-hXx5>Mx{0TqUeD1WgReY>0^a<ax=K=n!Czbp$iM`nakjz^=GK9(GhDkr=B+C)^5-Os!FYys|ggfdAnCZ7!sGDH@A>$27>kl<4momm5@Fl@&NVP1%a<5OokN7FW+_ zp9@RL?#AcBGP3t|?VjXHtU}y>d^Oqqx_n6X{^7Z>j@*FFs6sVgXDeB=#(pT|PSnzO zlY5c8tSm;;MfEE@=>MzdU0TBaFJPzC0~We=KdNys#)A~nXhX^W z-^50u8-2VYVj z{ubN9R+s$$UDCOp{dM2;I&ViUq&PQon$)AiXp)bZ#EW-*^@IPa}!Aklntj1bo z@A!`!{BnAZf8L1=^vx)~%e>z1WWic{Pc@ zULT?7m&X4*X8k^E^~8C48()3#HmKk#1< z{K4matHK}pzcb$Ve>d@exAA|s@P9KO4+mX0z;y#L7$0o^8|t`yAOF9J|69ZV=kIs% z{qsxs|Fil3&-4FJ@_*0rfBCbe{qMw?$@BloHrG{?F%J(H*9eS4?I8Z|BOec=>B%Ad zUwUc&i!oTK*WBBy?lHmt6+YHE<1rEG%-}GY%(#cFdzk9pr#haFPGkLZ=+d|U@c!Im z-OGGXG>9{aIH#Zq%}Al#wKFjbb1)b4u@H;V?m2aAlYTPiS#;6!(pQkK`|W?ye?$6_ zKF0q)BK^ps4b^w_`HE*5R$wJ+he-dE#!}FeQ`{R?)7K)^-@T#9IgMjJ{vrFc&hHzr z85Qb$1yp{)xHM~`Zxyy3JFy#ku^)BY?d#?LBk~w>DB0O}M?5CaIlhTzWkJeu zBN^8jyelm3zgtnJFJAlK#&`MS_wKj%7alXe-0y`?mxX>iUMv>jW9~?%Ly^kBCK#s>mOvY49M>%F<7Up0s=3^liBewr-Df>HP^n4?Z z;~D{_^<&t{SiUS1mY$@YpRD+`Iep}6WXGvv@V(a3zcLNSa9i@GV#M`uq z4I@*+(h+^~ZDfYbkws)L`}W=cQF<8uh%ZM+`A{^IH9^@-v7Ja|9S6U`%a^8pb5>sq1ab6z%>Ig z7(?+^^mEA(7=`xh>c8@Mr*b5>N&T0erx)__wX&f_-fYDo=cUy@4=YpR-hf%ODQn`| z)1$>R7UMAywXb>qzKhB9xb;K;eC z_>?hlWYsr~M|$h2u#mnOOR)?quoAIvt~CDOw7xROtC4koYsq!U9g)Uc(l*4iJ1=cy z;mJhU;J3}F!d7g@PV7dnKK;F9H$P!N*<&5QYVshGNTJleqm3u#=d|O(SDy+MU$wTs zw%cA~g1fcxbpOTs|2+T9^*QGhQTY@7`G0yURKM(5>t{VG>^M#!&Yg^9-zj?X5zlXi z@{@iJDc?feJ237sAJ;&u+c7fKTkpT&z{rq9( zy}tCuG5YW6&16cPH{IiH+(jSy?erYvg#qM1q%V&OgUO*7juFUS_a7KV&rOlGXN)tT zk45oSY16Ns_n%$<U8 zBhK}Dxca;{oea)ePZNJRQy1j*t7oc)7m6{ zr`=On8@Az(3d@;92_wv^*V# z`fWHyU=&7UEGjC>Lfl_)JUyd*F*_^~CekNkDjtp>yCv=SJOi?LU7IUgQ=eyJGVI^G zYzo`egl2l`h%v^lEyqmE!W_)Sd@RJn{dWhkyV~r#kVgSUw7kjw;m|95p5@x#*K2<# zvxxie?)EIIrUT?){Uh(N91&W%ygK+%Naz8l0>A zz8Y~2>b2xLRDOZq|Md_y&^MzBTd^HG@o@j24*!G3dg({AXOO~f=j_FPRO29$NFj}O z&%6VrdG=)2`TKRNg8EgFZ}xD1q&L}r9G)ltA&WM|{)aa4Jm`No9>dqO;cK*M z&`+Vbi`_xqbLzqwVX^<=9C-n;|KSqZgHr#4_~W|56^*Vle(}M+v5ikkC+aS$|6gY# z*yF#MUpu2;xF()n+(h9UkA~aiUG$;fF6md-??C5y_5aiA$8S~F zKNRcw^5S7K<2a7dX>;DD1=;{+`)xT}PLKT#1H?NJS$fI;|Cjr13!QI@n?Dfi{UtxZ zarD^-J9j9CV+7*d-5QLdC*4~EzhpFhEK-h}9XFAU?ziW6Y-}U{{cScD%3RQJlEy{3QYF!!OL?o~t=J+E$dNO-!*Jcc3i5wi5QedaM_%wwR(y`kb75q10I z<6XX6dZ~TuRqq$8{Ksmn#X4-jW)xc74_WsN{}&BN&baSWHcjAT@-5?B`-RTgitX5m z-Pnu$s7AYf{tmu&C!ag_3jcZy|61RAfh_X1TllxF2l&58<1jMy{NEk?U!>^=#dF2? zlq6HgHaQ1rdTu%2+I_V}9mQAp*23~@_}Ym1zWO@9DAxbzwO8Z3d-(`8^akw(u@C*Y z;}dA6$9)VQ%=52%N&3apu-RlgAc9$$gajd zr~EN_4W<0QP5O69|32xF-oM#@*MD3Q&o%VoCbDa!|FSU;^xQn@pZ)Q0m)?is9&zvR zA1_G1@P0q`9$Kx#Kn_Ihq>qQenF$~ zW3>Ot?~P=W<7P4?pN(?OXpF^pOvGdqzN3wuY~LpT^Tj*yaDDAH?8|!gg>1Ph|KpIn zp1vXfpO?px?U4U7@_v)|mSO+%>c0o9B`(ek8Q1cgNss#n%pyCV5kI<+AE$g7q6|Rw zb#woJQ5LowUKQpzZ!T(u)nGn7`HHlml6|d;?MKR>W;89AR^bbs(-X|6|LV7_WivP| zcDxkJumT6_`-Pv~GZ*AYzwq-x>LJ>od-?V~%A}QkUxn4EC@TvEtfj|2zt)i(uo*qS z)y_zEe?=XK+>V{tjSu!O>~*{!)!4pcMCg6n{Ey!&n||vZ3hVh--z1qveQ?0#~% zz1W6^x4It-+r~T={%OKvVf!6pt;Bs0|2XaOuyxntq2lTjVN3rfLgnQrLe-7OLz~}n zcw^F2Ve_!3!rS#_q3HN1j^hMQ;isp|tZP*keq?=*U6(%=evI9jkA*XSJBJ@GdNl0B zfgK@Sa9o{uJ^Xx+{=1IX!#g`(4+r;rH`I2o5A~Z~3v~y+8v_-5J+x}M?xqktk>l)mOE=8FlikddXV)I&47a9r2@!AD%~?Lt9+Vukc&#CiecGG0XJo z?8`toNdwp+r?oRw?>;G=rqD>IHo%_ahJ_xxV z)i{U`mH*n={?qk;)%iYH|4WJ^g*4jG{TcObe3;Y5-C;f44mP_;=#68Ii$UCPC)%(_mD&bdf4ZWy6 zyFS!p)`y$)-R#Rb27^}=zrc2`u!x`O187Z_pzag~Z5YjljW`jE9{r_(x+m~zz9TPT$&N2RfhX1GMNBaL^{+|z2eQ!e; z?SIB%JSL*{neSP@|9fFFJz3+wF_k_YsZG+qL;5q{(--u;P%dmHW?>HIVm=mPF-q$# zX_t!il%$8Ki)?J9*#eIQa$|vF0quOa zjxEN(Y}0p2cJ1Tekp;)*8-|wGjQ_zQ=cO+j^TPMbkXf{y73TZdES@TC#dc)(_~d zC7af`{#Rw8O?VE4U#k0(M{yh{@WJPI%JCUg7>|5`yo5N1r-!Wkl6HZwSl5SsEr#PJ zZsRWc&~J}CV1B^>vil3x%_IlMuup9WLt~ioVmLVhqc9p3-|8R6lH)NElhOOchA@?! zj`WR>gmQ8w=3p+eS3SROYlqNtuksCFR)@5YVCApb%4y#HBKc1_5c`U|zpx=JbY9Qr zzNh`MEWB}EKDQ^p#v1h^`r`Q8FMK;>_iv5kAy=Tc;(K8w*`Ulx`gT^)SEIDP%Rcox z_0PDz%UWU83#{o(Zb0pe-Zxo=nitF+@?2)ow_-bXVmJ0;KkD}IE6{MA9l6JDT;^BM zOY8q!WZ&4J)+Y9ijAPjk)1&Pw?eFuVb!FBVpSNDyJDHoUeSodXqd<@I^U^P96ZJm_ zkwgmF7sY`zy|n(#^ZM^U#y)C%F1E&N*+0L>^>5x-A`d6{VJ~=I^x8+*Vb>K=LvQHs zz6TjU;P^P2>8V@n{4nY8due{$0_P$5ru4N6&r4su^U&s+6F7x4IEM?kgdVid=Kmv( zsm-^CB zP+Y?PBfm)cUo}SHr{1sl29N_$`W$uvLU@E!GkWa(|)>gAv0hdyRpsJJF$DyE|xz2@`HB)jE@S!9pB{0e_I zMu0vS^AX4Lbn+_-D$)xb7yVZOd31?C{fIF(!k1zhRv>%PJzyoh&O5G0X??2>&xKxU z|1mE%HB6tHb5>(5)?ouSqY7Km{v^MBg7pW+-1pas5NLIg#C8VecPj@4L+3@T@o1hp=7TZ%orKY0ZG0@&CSWeFf_T z?4|EV@2>}hUULVl$0>guA4C!<)ZV*aA4-ne+ zTlg4!)VQWxssF!&uaU97oNJ0WisLweQ#gZjXwU2aHom@-@1G;P`2Kma;J8S(oaO&s z+D-1k^|PTe#5|KsTsQFL5DUff-nEUemdm`YAZExnwaiJBqr z`~OF@Svj7AW_q0a*La$JKxzEj3_h~_o|K0h_sIWfrl;iLxvrUyg;E4E`N;<{bC$-UT*YQ%ZNaUStOdhSK#1CsO z{J+uUSd7O+#5GeVlT$Gr<)~dG|Ibr)(&L&24VXorgVY%Lf13QyCup1@|L=L5q8Os$N8}xwee-?N3-XW zLX~T_Vmo$XH}+ybs?mOj|BtxGUv3MVypL@r3n*sPkFH8T4ml@{!}QEW_Wxb>A8}6P zLGdJ!LK@k44&S_uo^$V=$kB_DC&leJk1oVHe$@xme_ryA5XTN2bzajH>`yf# zHS+&@d5p}-=UwOJEfna*BkX^k|1bVSTloL-;o$}H<`VgzY;zpP|MZBb;$!OW#1 zIV(K#it>HAvHA4MPmYw|M~2Qj>;oH8>K{6<3~yy0(EGCx<~naa7Gg1$Vj1FoCl%j% zDs=x^|Ci%9|F6e5)|K?`d0)#6-DZZR%vEx8W0SH1ggV-x5}_tdZ^5jNASkeb1_ zMAHH`0Cm^(|LZSrn5O?3jpOu7qj{hH=N;Dlcg0V8+W8x|LMki*1xXT|Bueu z`tQ-TNdNr{`tM)Ve^0jDF#rFK`Ts}E^~d23{r|`|>Hn`WK46IceH=ns8|7hSv{`1C z=-;0=Dzv|9e86kQ2Rtu+bnRka_OUOdar9xY`#J8t?g)xKfw%bUh2 zw_2+dS6p)qy|{_nxQjmY`)B!&jqPA-JIOc~ri;uwE|5jvcgq2DUz*H;squYw@c-2< zGoD}8^NVv~28iQB?VFAVV<-xb+eeP9sMD85j>cGw$3$$G_tLZbg~{~N{-Wl`Bg@}7-G^oQTuWNI3ycEq> z-_O6d)Kln%ihg04@D*5zRalL+Scea`A8c^E8Cy~DS>xo%^aT40kyXYWRDRhyH}tHq z4EEA<(YJdo5%$xoQOwJ`$al!g^Z4)I9Tg4=@A-}Ii%g;0xXW}5`=&CE%%OULP92HKGr*H=6Z~>RlgDdF$;?v<8+5JC^1&HBKJsob6chQG_`^{H+Sw8fxt1qjg zdjB=#0O1vnmxWq-9J@HsadMM#KzY;js(imk9dM0u0Cmr>f9y!Z7ja495txvsE4V8~eA7{oBI+iSt9_|JlF{ve((S%YHv;e%@L0 z^G<6ExNdv^-#l|q`@vQ91GH<8?--+gF+uwQSz5QBEI2Nbquj&jct7rOEIA$%F&R@) zH%$82%m#5M#oI_WiL=>%raJs*e`6e+SB{yOg_8d*d=5SCsnf}3&ZW;saS{7F?|uJ! zx%m;os-IVfA(tZh-^<7fK0;~yy=NQO-HH2)trXtueqtS@G|rHGxc-83;yC@rHJ*Qs zchlq@ksGfnYw^ZC`T0}Y13xp^{#M31h-)?4kIDZh<$p4FQU0et9N)1|{rjT!zcKs| z-*fsI{s%J1q79{SZR&$-#kCH_V|;An+5b4sExoUQ*dTl}su0KjW&3O2y4XK#bzEAv zmu=rp--+f;>J2EZqp?G~p|EN`@m_L2YM zI^6zSp9ov_d?Nhgfxin|k@i~~{z`Z}{^3*M z$2W}G@a}#jY?t@`({a}Dk)ER8kK#DGKRqJ6eNnyRlY_(lXRP6zpLM84Y*zAFpS|J&)^^PE1Qu*%X{D z|2_KQzt1o5-52@i1?10*>xO*+oKyYUAa+1qX`^|`eD?Xm7NX`f|JS+6ChPb+UW#UM zrO=p>4)<}W{}laFpA9Xqd`3E^gk{cIfy0SUYo{3(R?=5tHP&JsHefTV@beeq8d~;= zpdWZK5o)#j987#JykkCM&3gT6iO+}R3o}C9>}jF?@-$`cq>x%aDKsuImo4!b^)KxY z?m5nJNTTsc_l4#`?sI~9^5Wfz-Pnu$sK!Ag(LPf9i|5?w+#K08!?Tg+V;@iPr1pme z-p>;0BunpSif8M*2k$4nX;4UsCygA6$nLTKgS{CGUnvVYdMCZ|`(>f3OP#)3xb*<@ zckKV@+BasiJ;w%FH�#z1G^JN5yd*HD`S{ND4dQ_!OEm{?FQijjn0R+y7Dhzy1aO z#}f7Q1;hJ-}fJ=yb*YuEAa1HIw=~%9g@e2Qg{G0uU`5>)0gftE# zgDl$6D~_AEjl1YW9JkZ&0DEg)gTgnwjtB1 zUPRAHTPI5E@4cf0q7{~&)ar6-Shwpd1AffT(NP0IJ^>(vcnAK9c*JMjYc zhy7|oa|b(xv94K(RalL+SceVRjP@7V=LPKZJofoz_L)zgU(P%M7*-{ot=Nv8$lh^IEXV0N_t%M1`~S1@e{BDMg#AtMf62J#X7vt! z*z?LP)aH%3cHMrI*8iE!uAgNn z{!?H7>-PVR^6e(R$s2r>I0kT(I5TJEPrl1&dK~{VmV7X8w~H?j`yLA@qRM?$e&h4u zkb22@=S@V<$I8NFvipx84O7Y9ppHkDV=n;lmnwnbEY%eVx399t2`;Wyoq@6<_B zWZYw@(lu3wXN2v}*@@lQi~Z>RqVJaM{;f75vgc8CSTcn)+M>MWyCjP^isOiT3&k~F zPtZ&2y^J${jeZ6X*B^M&w>&1c|C(PuEmghnXvxE&wSwd4)c7)(|pVG6<4)*COU62reZp3 zM@qkQ%IV2zuATH;m`R_7RDai^X^8YYFRrZ_>zs4^7RMOOCFi3w|L-y5PUHN)U>vHj zrC5d)Scz5moB4n0{7opW)4@)xcFtO?!v@6q|7NlZThTs}{bfi0xAXtr)c%J<?f+}2T{BypCOMfB!#7rMjN^x^`6L{ zI`5g>q#nL;lDhOil$Y}HQDKEYmxW4g^;KWC=IEpRfM2Vt$7{|1J4K$sIb6UcR6MKB zL0-W%#QA@{1T&=F=BqF_z+gIsb158-*hNZ|46k6VD2)#46Oje1HDmYWjou ze{1RM@YnPIHV8{=JJ?KCA$wy$*h+3khF{m|x46H}PRF}Z;PV&Jbxb{lY=2q%1v=l< z#*QwwG>_>IHmv8eOcCpX&()hnG@%ibs8T}Vq{2x8(UK*U!M$aL2Puyri<0khgEYANs zN*+gSyEs9f!UyO7opF2)e>49t|9_tUcfmE6(1R<8`TrW(i<@{j|BuiAq4|HUd~W>T z%>TPBp1bHnzZ&U(?f(3~0rUs+ZwAr_at*?I2 z{ggR|j(h$xAdDr)qw;tBZ)@xQXXF1SI(IUrVmit(6SFV}9~%Gn?g!fU>BS@buZ;39 zul)0T()jE7A9KYK+YaWF3sHFL(Xg28wZ2~O&))a{zi52Gx7GE2Vs4DzmSF`}B97l$ zMXttLtiyxxe;eoz#{X@mZ^hq^{~HvB3$Ma<>_ln(cYfM#`h)e~OYML9e^SHvrfA~V zHDaSR_BUB)?~Mc2|F8LY*z20i0sf`A_WSA8IEc;{j1Mp#xr@wOC!jzUjjOKXm)7%7 z8(!l7uIK+Q7k6fGXwDA~DKhEW6w+uz4n-WrakLxf+_A>^1LF>IWS4OV`61#*5iKVN zYhN79|8RfgVaFNgWnZ=aKdP_VPfeN-fl+Ax&)WZ<*Zzn9?f8Qs9|@zyGZy195!u=7FDBD-Ti9Pr zrB6qZUO;}A_Caj<{NIN-zO3ADGf}IrZ5BBPHS~rf{CdoFJRi;U6dtUz9qTL$g+HiQ zE~YQVGOR>g|8Et!8f&o*8?YHw=>3Cz63J5k!#`-#r|*p6=GE;cyS;I;Kb#pTI54J7PfADp439Qc$=k~|?$8mASaX*{XSz_PC3CFPw z{}g#9>h=mN)jxh?+>GPW`UC9$C3;#}(?e#~8$Z29otl0Pz37~$d`1_4p)@{)EMDj1 z`0X4TTBQ@Ead9Kr7&av}$l5f1yN$c3USK^0vPPS5?XLcz-$DO*)z~v~AO>S7;@Z9$ z;lt^JxGO|B9a(3-34;`wrA8n#A9ze`2n%xUW!Y{r@Mt*Npv;ge^oSFYUAW zXA9K79WO;;n&*zL7rY0u&VALZGnCfpna8HT%C@7qL;goBKbE;}1y*7eR%0#Jp}pSs zhd9PGhpyMeivqof(){1E@;}n`zCUrsz2UR;wgc)Pr?u~kXEUm>6}30m-#h)oc6wa% zra?a5N#BhRjUP*BTNkz;)i{XKcgr71(&KlWBGV{*O8XpH#8FiK&bR{g@8k3nIE6U> zzUPzn+oE^>DiO|+y`M3ToxFq|T){Q;;wEDI+imhL`p~b=_t$a1{cQj}H;->JTiXSF zFyff+xbIE=dHw@gH{pNN{>Wyh-%q}r}T4=b)c;Rr9B0Wk=cR8w9oC6w12&@y*^$n{R(kFcaqCClPTz^5 zM2wH~Qb`4KGO7i|^0B|B19aPC0H|Af4o1 zaqdUOr^G`ZL=tVjrQKv*#`D=Co^9eG8$FvQ&!zdEc-(u7@YY?@hjfkf<=y)S-f!vq znCJUIT#MG_zxNdJ|N|BLK1;+W!j(sQp$|I5ZA(2pXHDSWv9!AtgM5LQi| zAWxxo`P1PHc@8zMEv<_nt_zMYp|ttBn9^pM-`!C@Nc@4d|iQBk~ zf<18h$bR+eh?D+n7)bX1@@^PR?tJOLg&&Tb77R#;2>bNgz5SZH!-0>4pYPKjwQE%P*_-CLf8Uyz zA9G*dv^VBo^n3pgV`N@8Cgjcc!r{bULOSsup7no(*2JHU9r^c=>i+l8ocL2{8uM?V z@%q1oWa3{#gF0RPSz|jdnpAM79gX~^)f9dEVje_Zm3@WUDUADnnM@zD4@ue*0GoM9qiv4_J0eTx{qyT z$D6g$r=H}uyLLDBVn3>J5J{x)q4@><#XnN~{=0x8TJ%Y@;?QgQ8gQ8Y!TAZ*{wpn> zHsnx5?Rx(AFg`Us+2r1EoPGjvPk?5}ag1amP6<1MbGU$q_JQrYb4pGvq$~^y!wC4mp9b^5!aZA>rcn}f1mUEHE37FKn%uE3`hG({r_k6|08!< z|NlMx|GU_@efmnsmf7n6IJAiU$6s>-9%U;~&|;hvOeJ=G{1NF>3PnZC_k} zq%{A&(cX3R53WD*zubSgRsA0+#P#=ds~_Nx+Fy1?{TtTKAopTF zs*&z6_lnHatN+X6NqQ=VBbE_4bV+|6rSa)|)c@9)H{tg-!47xTt6;7x$A!4 zf8iZj6a>@?{_Ji^Zf7vf4oT6 zE#m)bJ7_=>jc7tMQaB@ybGU#@=)o0SLoeDB{C{*JH&XnA#6Ly+Gh+O{fA`YL-#^6H zPru0jU%>xg;{NDu&-4GsnL{9++qjE9)IP)i7v3)^{rvp~?HL2;0};m##I*sN$he2x zU|~Zs93#;CsX<{BIT~Zp!%rJec0XZncXHz}^+!JM8``QSJFYt||DTut$I1VQYsxgC znVv#HnLpimmFmNB&F^yhOw7U@{4eMKZIb_C?{e*L@;?q;S3WBr4$Jo$GKU(4OUdR>uhX*GVa!}vvMv2VNnL-&uY^t4Gsoj(8iHO2|9 z*B3y>IsHwijh{pcoBd}Mw&F^?d_?ZVZtTUzd)E7V#QS(Mg#C`IQTU>JCR2#>57K1S zpOqon0@~;~6mb-3X*^Dzz$u(TT-znCadwWL}ks*>;euw$xaMm7o3|x*8jJ?t9|_ryMoS>?EjJX+ueni+TmMWQ)8aH z{~3$%n22nXI53%>yY5;{rB6q!GZc`&EH2^Y!m8P=ndB_gUL6qTkaJN(Z>V?QnD2NY zn)kSGl=c}oFd!@zwiL^-0xPi!tFaaj`q$Rc<2SN_+>FA%l=9=o%HJE8#P8oKY&&*h zH};~pr7Y|xWB+Y6+4HrsaF9$Qg*58+yl?*pc_01d|B>DoS=xVM3j6G?FxBw5$O{V*Y+)< ztA>wHPW64Koio{QQ<2?aTrydXjC1PP!g@3;a&IWDOSZ;6zsmlm1ckA7(U;9~JF#WJkKD#W#fR+DR^{$=0CTlc~`di9=E6%PM^z@Hs=!?F&HPs7zdHCF60Yn2u}{)< zPOe-F|1Z7m-SRnR<{=Y)kSK#|wRR6ojaWb){{ZD58FG@}|W`%z(jemoO&@;0a{w*0>-V41ed!gmq z+2L)%X+AhnAK+k)Z^ySBCx?Yc$jIj}^1bTGWzW{|KQGJNZ;|VNaV>m@jOPD+m;4@nfFB{D{`d*`Q~Vsi z#H4RDS_AMo`do!?7r&-o!@|0{!MSd%f0*(AuVmw5aEVR3!nR#wEZo}=k1HZHWpEg0p{?1ZIBeVV=+kOoHqxJuKRu62$!aa!A0F0sKDc>Ah zd;n4Z@vCebQD+a&XoF}w43TGjf6+I1u3Gz$9eSUAl1etZj2g zTwl4i1Jd+`!`eul2f_pPJBrui4d^vC5!09PCi>up^V{bw^tWMT%d-&|-}S7p{@R@I z4)IrBxj(#7t>_u&J07$3$*@o}8NIc)!Ge`q87uIm5Db&h<`zeJkl(k^xm`Uv@` zeICQ(cmg-Z$s+sVN#W@Y_UD#&OwR4e|71!yO)gZ+|DNNRag=c+FoEd5%+Wsvr_i^o z{!<@Cd2>*nit=Wh9L7kQeD+@bbI#L?0SsdtlbFSHvHV%$`I2+;>pYoRpG%%w!01W& z`>;G9K4JR=nY8~TIpti5Rz5gim#3r^%>n#Ml(x8JM|SucJ@a3$3(Aof8?S##nwzdk zS&<^=v%-Dk`n?Cj{bV#Y z650O`(2wHvcmsO0<=;fU1<~BUw~_Baj6P7WUD=ow-YNVpWb8`y!jQIQ{D|=M@whAo{L->+tWpJ0C@^PfF7_vjA)^Q-*N z*~*C-^8@;?`}R+Eng8F!{?`Z4zsCO0W&dNSko}*-{%6OJkRNuwkK*G<^UXXHoge2v z=9}P$HC^ERMb3Yi{f~KlS~usvun|Yu_+zeboXwwL`%lo5W$gb__CMOBwQViLN6E+V zIG(_hcpA^37Xui^IIi^Ng-LS#lX)R=iSIo##hkt>iroTE8`ESUn!ac_%-p-9Kc?6z&Ge^uhWP4X?Cz{ zi^_qaTKOKs$MpwP$p35nztjo;?D%`vf$%PAd<)+}%d53N$nW6?_z|Wr@}KK4&TZd0 zGV}l2r20C=H;o)sUykjle~>`N|6ly1?Nemmh2O5f*UnByW?f%>0>kS0XsrDw&hb4{{#+5 z;}#smZP+%Bv;O7Wa6A3VmkUEwuirsm*FU@c{kt<`RXxhjFBay6?u*uRsn&k6{S~+i z_h9n0I=10Jcolt4T|4Xf-bcS53k9AHQtI|-{(jW{_*Ca_!UMLw9&bRV{o&cXi9WdL z-K_I3i2e`o6nlE*dfP0%p{?NAK3$g^{>J@2m1ph8t=#aWxjA1b&$A}vFG8m>^ojMI z@VU*M@Y%&(V{(6QtnMGeXLfeOr}wUg_N}XYVxKqu_3Q8!=Xx933Z4t+%byD^xxWsr zJHHB#UiwA&Yh!Egubx!V#14gL(5)}x=gObIExd#8$sHl)e{o;xx-&d1%@5kgkWbk6Nj!~baI^fFG1e?R zeUklO!~U<+zn;q`D7-%QmZmRc@qb)W_N^!b*OZ0igmAPrU(`Q8MGi=#cNyO&r-Y}I`~y-Te*XfQC@Koi zInS?g4f_)A$NovjAE3`&cI@V%a0~q)78>2RG2|3Ezij;fh;pyG$Ty`PR0j55Jrr)U z-|dLmJ`gug=fj{YAuT;{qf+=ukzWS=qTe)@oG8ypkwnE0yuTe9Ecp>W7?uRvQg z|L}uZ;V$|;copu$bdC1+CG{)jFptz_<@=5{9!A~En6xwg0m|T;{$J{cDdB+<&$X;5 z+%K&MFf2UM@p9jx|1cC6>EAI%;~x!2ZF@a>^NYe8$TuNoQc9xgP*?#yJ@SEU3Hw8veILd74{4wMW`B5`{sf-H(|88G7{D+h-}P~F z60=blU%}U~{rx+`H^|6`?_1=RmuH3Vkl)4k@B_>>_?~OEIpmig2`?yH(nw)mzKHVN zPsINeKgTaI*(AR#vO7MMqy5YNaWB!_oI!HPJw=AAo6yEeVv;`9?Vm$;^L2O)zo(DL-!Jo> zOY-;Bn_k~vo--UbdD)iRqkKNc8?xRO@ zDkJ}&`{@thcjg})*Z+5l{fQgvh!4@@xUr7-HTQbdaj(Z4@Fu(kZ^Jtaxgx@SW0m7v6*S;fDX<2k1BI&xh$B#)A4Z zT3a_o&X@jn{lB<=)8oFw8u>qVAbiw*A4g35z-e}MopO_Hb>3Lyu<(d*W}ZEH#(vwM zFb|zqKT9%B{w}_UA7NTsat4|8W#sFd z?Mvm-LY4BT*z+%OjbyZ@(gZzOsA!O|Mcy@TqcX^lX_uy4%ou~-+k@w>P9L4Jq z`Os#@_mod>5KfO~hc}VuU#ooj{;lC{^mpK$co*J-?Y}q>-ba1_>1WN&{p^A8Fnygb z(uc{9;^R1jM5VrCvJH>oF+{$7$;PbkIQ>Td+Y|IBv9PCX*i}a4cvpB@+%xD!%j^83 zmK?@7w&nFna{ZOQA@Yj3Ty3BBzJ7CS_=@n?aF*@YtPDQK{_9!O<`)-}ey{Uf;C;Qt_V3{bNaIKN34V&cF7yAR{(*A+1Et2i*`=BJ|9tgF z_^yu@vhOi&TLKfyY~NM($`bpY-8;Q!e!B2%zxnByPqBStY+&ui1+uT1{f_|*VhHhz z?Elv5{{KIB4Zp;5_%*I!-=uE~w_v)9{U2xl$JqaD```KhpnizaW%fVDk-$Wf{hwn0 zlT(gMu&)nF=QiAqJCIp_Q2ZhKoVM;PUO~SL3yrS5-aW9%=aKP0IeGBE`=6XR81AuO z`lb92#Sc_^ZfJR>z8vwxWu9Aw=XlI>L|?i77YtNi-)~6TaSXS6KWJ0moNo!?e#bt5 zqj){ufH&cH=b!wZeSb;+hi_rbw=hm7PHKNt`8I2e|B+L|v!2>Z9JQK zdaA*5%kqpcTIxAsl0KE^Ip%wg80d7}_U+qnAN1j4-W!H0y+^#oHNFk+z&r6Sya(^Y z2au@D3lEbY#z*mSOx9a}pfN9;q0gx+XN?)2qqku}`#PO}y^TGu{(MwiPeERIjC>ru zWqIKV@=3%h_^zlQ2lx2#JZs$xanYF9i289z{TL_v>eT;nb!5zbDe+`nc$ge%;;YAw zZr-`R{$tu2zP-{Izm5@knykSQ2xr}V|kuGnJDu7$)xZk zIc2}~_r4^oKg9>&Ponw%zI$n}8}lD0ul!g)&CvcZNuR~`t-|D2BK_Uw|NCFT*XZ9s z%b)KLiBtQ+x9CZ765pYJ7vIC|MeWm8*VpCx$h2^@mSj&nCp>NY5AY-W1ij*7_$fV_ zJ3nyjK=?WRmlzQq7LJocl?TFe;(m>5*f*s-QqE?|+wu?&2;YK($oRkgnLZx+?YIMn z(5DXWPwJmY{no!ASKodq`yV4E{O^!#f*e?)6FJRvTT(vL(wjY;MDGh{EKI%rltpQ~YK2oEC~ z`&rmfP8|0g?D3ze_x&R?|AqYh|J%0w%-nj%O=31W9$&%N@C{7Y==;;o{vWQtv7-IE zX8fB>)UkgXwG-?=NuMHL<-7TobiRYAE&E+ER?43R-=qHkKf*wf{8`}o^5oB?Ho+MC zTst6Rf8?tF3Pbpb{eFsHBK^__!*gWYmmdu0eRsd6U&Fp>bq8+2LHw`g|J`Pv+i?dD z;T5BP@$zY#r3RK zgeS-+(c4}To+h6`%y9$K8BA4#Uf}@@R~=RU9t}hGi=$7t-(3EIQ%A!fh7iXvMiAxy zVaJVQ60=C-EBG3wk3XdU#XNs=`RD4)>o?y&Wxju!To4|;^iUYXI1-pZ5|fz1*%{?K zqG$IF>8+nV8oonbxw|5Kmu%C=dwxH^-4|Dc?+O0^Kf+J&QzZHy3g04sj$h(AOm;hm z_+QiKjvo!PxJKVMZ%m_mMcO(2E5C` z%?j;%?`5MR3~yD0k$m?3RqJys8QWiEFB^l6=HtCyI_HP9pXc|7Y4!0;i8{Ji|JSDe zudC{2M7}`_7*(c>HOl|T$09QG6U{a1PPD zqBgQmnbPkb23-Hm|DTjNmvH1G^Qe6u!{c}YPvU8`jP4K5kkiZVf7SiFzB$adyZFk!z{V|cSpnaH+@I+uOQ($XXZb;|LETXUlaEYBKy`~*KmX8k;M*Z?x?1IN`p&FUcbZ8R1>+QGi;VqcO8+a}a!(Y&{4 z9^91hwD_4$b)bIwKUn`?+>r5yINLH~Pv){Yo%5F%U|S9@<%H+xG4l+1@oV}u?3-1F zqeomqJLDGnLEMJhF{#c>9zPK7pwE>%xB4{w-VhE6XV%YEpQZ}g--p$|7tCEyW;f%k zzPnf0Co})A^+n+>`ruVx|k|@nhH9SR)Mw!W$g(CcFi2 z!#nU!ybJHawEAR5nzM*}P3Ord-E@)j6sUjlqW%}{&kgNQ_0a^G6rLof?AOE2f1h+d zfb<*XZ|lM_=#4OVI3aDf zIq~ZK^)H&=|3!WLZ+%f{{p_vbfcWh<{%N=++WyLY;UIbCLwAJR$lGxTw)@mc&+6Cu zqV*te&GQ}a3x~wV;tp&4Czj?eS{wty0-09R1^|?=-uuf&Iyzzxn^W zBs-3s9&9&d>`1|nyj^YGPq7u`_)o1G1->2D;?85m1c4HAcl3XyJKdK*KY=u3u zrhkB(5Kh*!BQfQ;#O}dRC7o*2;1njy?8BZsO`pqWf3wAF>2+9Oo2G?R#?9x+jQu_K z0{ffoTyMXO{oU=p*x-Xz2ScN{7XA~%!XrEEQ!;-1VECKkFAh)V<%TXi)%i!^3k_xA z$^5d=+4--{S1S#lD=rOBprg4Wd^WEleCEPK;qijkhW3i`@M(PNbTJ=T>*Iy};rth~ zLX-0}qZQlFW`zsnIv>}IWK_<#N0Ki~M->0!ywKg96*}oX%D!i`iDOIq!!7;Zvpehfh}A7cQ3E7alLZKYV8S#i8ZTl^yDzCkp;3e9rU!Y>n~1D&v2?uPnUp+6{ibynn~JV{Pn^!=vvPw7J# zsItyrtvM<7jhnkNTxXrZ{JhZM9OsmGhDtK_*{bjw9`jwN_5UZjZZ*bY%vandwlEpH zRr~kWutT5oPrTV3|K7OU{J%f5eqW*c%#ZvVZVmI;wNKCLt>G%UhhF2OS@Yp+4q}cQ zkj`N0){rZlhv6#a8k=dT^VSeYzPJJuq6mjkj3d~d&k7~v`X>&CQu4}MvqBkJjtU$@ z!gDxIp1?^|V$$=B#u%#Une{(TX&)G8Ij=9_clu8bTUV({-}j0+e2aYPP~({1Cgp!) zPB=v$ylVbmcTPA>uf>RZVR%`cO%AQ+ggSAF>KBE2vH_7@-AFbeva@H|^K)nxZbe#s zx`274u;H0iiNAo0Xh#P+(S<%^_nH2Ir2c^|b*^LL7^aVG7~k(_-(&m)`yLb3?0Zbo zr!d{fKj66QIOaNPoR^+D$v=QxDChrIWG<-l#L$m8daK#r>ftf^;7N5elJpsjXxj`U zjv>78|HCQUFQMi0>R)!(BE2mbul;IPxJ+NdGFGsPHT2b}|2x>{7>wzk)i#f>D5tc| zN3d?64QyfyJJ`il>|wfE{f}AwAajfAzcuwAnMO3%eKe{6Nt)x$>VHf$UhkWptX2Ob zT7N4m<$Fc??8`$gIa%fVDccwF=yQ$oX8pd9PcOiNZ!3)(>sJJs5n3}03EV90)Pl#8oCdcQI4a3~z3AIAxt#94KDvwJ(|ep=Y^ZKIy& zXR|}4ZB?j74Nl=Su6%G`s3q5x#dYNNpX>|uWCI$}gl4qj0@i<_{(X(Qm%V)DWe37V zaqZ|pCvJ`(6o$|x+>Ms`qOkpW>+=|!i`L)L>c68KEtI}?6(r)hZpCB^wA(_3%$4Wam^MBx0en^YTAnmbv?c)=xJvF}(ztn6+V_N{@$=bkb> zlsP7beA`;~>(7baH=^~;3dHptHa?F(_&-6G@p+754C6>(BFETwF8>2#`%`875AygQ z6u6Ee*Fnw;r#AQ>Z1F!Jd)n^|h0fD4WweKr#IYMjSW1 z?jHiX{w07Cafy;cp_D8`P2}-O*Cd2jfU!Vpsm()BY!? z9M|Jr*Gi`j^=LrvviHAQ6dLJ+#{HwQ|0a4fGXK-`&Br&qe>B&sSF4pb+CS%vL0z!V zMYN*>(fV7R-9fk7IT=_7o1WCrpX26#%NOeUz+2_7ZaGM)c=S{Okw@)#@~nTyzXCfxAh0! zogJ2>yZ+)3mdWi;XNMJX6>CvEHn9EXJ3~aY{??}O7IttIH`d?Uqu*G6tL2sQ?hnV2B>aaHH=BXJhhHTz(?67JTR5)(-f`{UL{*i&odTi#&Qh3Q&k597bQg`UwLV z#E>#7S~oYdjxL#5-+zPszoi@_C-#(Eex3)%1a?052*xE^hhGIP-s_^{h~1pWf5f|7m(rI7J_9X8+qKTK^|Kq)rx&<`Bl)wMEF& z_DS@!!=6-5+q|a$@&*o=`E&2`?X5{TLd7PZj$_~5YBLASP z_CuBaKY4gS8+j1FKmL)nPS_FqmY@`6h}J5O z^3+*OdxkT}Jb&+C-gCeC{OP0e`hOq)@cexbN#8@J{pGmz_iKOrg8d(@|MM%~x%4Y= z3@30Bm8e2B(#HO4$n6hmf84G9?%x+q3AeqUUx0N4Pt$vwygSs=qdAQOmHR>+y&ln; za>F~?nB>qg_utC?zsOD|Gyngi$JzeFNp>=u{kUTq(1<29qZJo$5z}$w`}Uo6%p5sy z%sAD=Mkg28_oMynfAQnCC&&r=M{E0M{Qt@5|9|b$=|CsCFj--roP(j8KF6M)t>X_* z@5e%=@%LKx|0(Apqr4ma|1UERMDJ~M{q`F}%=r4i5%*PQ{N391NpZs^?yJD~I~gxF z{@&&OWA=;Nj~sHHas39vYxZ06-vNCF=ilocJ2r(&Sj1&4VY<;a%r41)i{Ag5_fKZV zf8@PU>5W}<|E=!zg!f%o!HSna&~i^giB|Nd== zTk(z0c@DFlLC@)fVN=`|cCd?H@v(-3;VL~de?Yo>^sGzXzv~-D9FZ?aj<{TGfBPjN zkIY8_3Q>f^NWVWj93e}v{&uz~dF8zaLm3&>{mnS19nkVgdHNUX7wbKg+g5>NIF1uI ziAq$V8Z|hD)2Kxq>d}BkG@*}gP(KFr5e{OA9>=iqA)22wstzAR=KrEi>h}Wgxk&v^ zPUV@CV85I52e*_NY@$s6;HLHgnw_H+7m%pXPeUe4^e^cDZ>Oh?^LLW7Gx{emA6Iv` zstfk?qwg9YIU$WI_m7rf{=R3Bewks$tzi$3tHyJ}e;uu2`Gf3eQ z7I7I%*!Dkwv^>845#NY#H2-gfT*VsJ(X+EJY>=DS!VY>j<^T14VV6GGAphekeGfO= zZ~Fg+ly`kc)IS)g_-+5Q{`6>Uf4GsIfDXq*eQY=S2eQ7RP3L>aA#*W(S^bWvj-JCj zJ%zM7a{;5d>>7+0Y9C;Nov}pZGtj_orNcjWllf3!b7!K9Z-&TGU~o z-2IliUosl6xG|nl?7h_6-V^6P(0KhDJR5z3f5G=lT;%)MXnXYUzsR?6!Zf zY*UESC)xMW7}pp*8he^mMd)~F%2(0Yel*4(`4&g(Vvne^hHK=r z_cw)<<1S$lm$8IptYBtG9^cgVC!>D^Wc=KP7fw2lb46osqeqOrWnCYGOD-!T3iTb5 zldHxxkElO7qB#kkq3^smr9C3uHN=w2oh{{#wAY0`8TlxXAZn;fV6QTzkc~?e8B5S^r*}1-ZyWKDJ+N z4(z>W!n(2k0^w+Fw8BV!B!nVzU3;_Va#lD@FGlN!8p9E?_Y&VT@g?*@a$qqll+w!( ztpPGTru?_hkUFbxt;st$69%%*gh31;j$w>o!}XOrt^&t!9KSRFujEXaSw9nIF^74i zkj4T=kDUo)C(nfO>N6oxbtX)doe9a}Gs@R9VXELv=!-vM{P~eEfWfUt*wl}Nc)=rK zIR6p-w~vGq&UX@(s6yK>tWPo>LN&bx)ApGud?d^se#HF0N3NIed)kP*%6H}cX#FE$ z4C6=~e3JjRx;3M)o6)G4yF`^kcxY7~JrG zLC+>m4r64+x$N^o|Kb_p6fPnCcKya@)nz~17cL7gVHqn}McN$9D<3%%*68cgO<|qf zz$Uh^gI!$39-Sg&KUirkCa9nu(Y*siyo_GaQJRYqL zR!6VLLZN(bdy1StoD~|xHKGa4XvGCw#P(Mo3GL+iZ_b1c^2#qB@z9i04cZ&Ng-&r@ z=tj>~>4q1FXwGlHa2&mz%A~%WFh-Aje+Opd`*!)hpKVDF+kUhDSHBEm=!m+ZQXNsH z{yVAObZj(kmLy~7S?6EYxj%Fl?hjww(J$$k8ANsT+@d;aOUNU)_d{cqM6ob1+L?>s3a$MhR)=r_>Ue@S=|X?nB<(EP4` zeEQ#9{aAQ<<73v7Yz$8w{zUk~&L_f?r<#o4)rZfoekgoy{X^l2#*c)Kx{rj<;xqAD z|BQQoXm5Of_;k_3;oRZ6@Ru9^SGX+gCA6JmUK5L!BGLn|JQHH5#OX$Z@Cfx z4~0)Pe#T14;de)-41-Z!7bGg2sz3h-D+*6Pf^2xTp zVn1rD7Sel*jsKM8gd+N2tv-}fIpHw97$ep4khnO8>T|*oaV2Qke`_cu%dq}Ib}V`2 zXa9GoAdlfVw)+l*6XZ!$q6+=`xBHs)(ZrpnQ@Y*KozX`VbDn-{YO=qp?N@_SIE`A= zA^j`;4JUkqp7EUPnk7@h^W=hSOE>xUr7sk)c7T4wKJlBDbBMEQc3NW6#IZ2trdDvX9x116k~2 z-1JXpU*7cpFJ|AeV-w`WVYWLaF@@)601 zCYPnLdN6F!qkjg@?i~y}^z^;zJjbV_H0(2PpUCd)*<$}kc4(#cTgSn$Yu~Ggb-C^t z*WJpFx31cGHpXzX>&B36ar^GsFVTK5WPME@!DQ@U$R+cTj63(b>qfqC0n(dn@fGLZ zbq(~M*e#(@d=U<#7`;o@Ps9=WV4X2wl+a5tQeh2B#H()Am+Jo$xt=cuid<)b>pJYZ zFpLqDIi?&HIELdmfs?4jv^Lz#iaLfpKS$0Brz+J2Siq<>#xP#&{dkWPM?72SnRLD> z{S#Hvs74KXlsBiySS*hX;Xi-h=ToQuzE<48rm_e_dgk#Fe}a0rL_J(Ae^tm|?PX{%ep%Bbv~RR$RbE zv}5`-`@T*aft+i)KIbnboEBbK)BcOG0pj`x$b|ON1U=cMp6Sp(P$MpPU+9odC%Vv$ z$vk-!F?wcwU3PLmJ&py}o0gwbh}PE~6PLscQdl?cc!`Yu{kup;Hbu1NW7IDY+0mE9 zEnyjv|3S1aMc>5o%U-543v;E4q4t4!E zm2VAf3fr^(gZrT8khzG~o}R7tZR~18p73puQ5!j9|5mXP5v_GQlvMwx)c@p&@QSqZ z9G{N@6ru=+QH<$Y^)+sc^UoKmmyym>PiL|JSJeO71mk4FeiMzx=Mk;XGj+=NyKP6L zQ-V^Ip|^tligJ21Cob}Tub>~pNJ{@Z;uy+0;KTL(X%C+}t^Hi-ynpBY)=MA9?cer8 z`Cfn2Y>jqAxpu=5*T*j*UF7;!-G7Jvk1ppR|GE1=A&rx$L=~!0gHt$->3aDOvzWup z^cQ@o;^ZKJ_#9#Tm z=R{t@BBHrym&qk8V+E^N!#XyQkmokZ8~*)U^d0PCHnPz-&jo`A`kf}K%X+9A2-VfdK|+T!RRsh`=tDRTwbn{zZ<=OOcu)Dm_DL>z-+nmW1gN_ zpUAezmnSpsldtSLe0|-*7+L7NMd&@UFB~S1V6a>NgucIcKihvs|3ruW3A7kPNWZZt zoPV$J|DT!1s}HNh{!uwlN^U=u9m>dZRNxqn;{;Bk5>=>14Nl=SYOy^Lwcm@v`X_uJ z!qGo~>d6K)q6rD@@n*6W7jO}gZ)b9suNi%=+qb!@KKSv=L#OaUoE}pSjQLjC1AWHC z`wNX99x;B1%>4dhL3%OtBaYtKp%6=QLMrUeO!9FaBd``a+3!U+Nn`f&B^*`Cg>i`bG5ZKE4lRF&6GFW#iu&j?mX% zdnlBUr6@x=DsT)vOZ?Etn6~8!@+5l2RgzVR{;d=J3n;E#JcJSXDf-7wzcy$eAB4dy z+jDIvhYD@irXHrZd@v_eJH7^|a2n^jeDgSq8D+{eGV6qqDa>O5X^b7#20!5$6>Ebd zd4vzXxTykdaP+yx{$=OJV4iZdfQ_8PHqK%rqt^NAP>)7L{|;;-o6(AC`_B}4&PP0F z{QmzPGV8Bb@xQ9!Yei0YAIUQ1e}(ej^KUb!@ch?v!Ubt<&xdf4Y)3~F{-pPNuQfn^ zaVT^OUwQeV&_%AZ3A&@OF~Edp6{Ghfjuw3YlZ(!=X53eJZrQah3Dak=Fy2- ztz^$?PG}IH!X+%CH%I+jbRb-&52||y*zHU7Wkmn98)mo1)xAUHintf{AFT?nVI3RD ztpEL0^#XkhJJ`il>>=ad$Ic(XAcnLZ;yLPg?S~OE>l^9?DxdG^Y*I{*VDugfjo}hOKe>7Bw{O`-DUMRSt%UN|GV*DaW4TRsMe%qQ`M#{r|JtXw7@tHEm?0UPLIWA~e>Fy7<55jf*f-luM*eTDpPUkaCye*3NA zGTHWJ_Vp|I^w5{Ff>o?x9nsiBzkEJ`K@7?Japgc}{pIqg{>=*;_S?i3cCd@9*u!** z{4c-V9RHQCGV8y1myw@GX8q+#`Cr~o9FhN}nJkw7F;&2}=X03#A3ayTggIm`qB$6O zWIkFx_>xc%#raQIAz6gO*wz*(CU34MUt;X%n{0@8Yfp=h+8-s+e#Qz)$;|%`_zslQ zD{u_Q(Zlv!bRT8(6F7<9Jmbfxq(?Ss^l$nqdNoE)u}z!Us2IvR7;3~tHvB2_G$!e_ zWF3<9S#$8`P%qqoG<^X#)(PyZVgJ@S$7%L&lqMM$9wtX-UJ@D|(}ZTU;sP$B9UYi9 z$9%>dgxM8xQL8<@;dPpTU?>2s&_8>yGK=sQ@b)!r4CYGD6j zS6mMt{Hx?1dXMoZ{-&}Cu}byoN%>a2n=71$%)Gi1`B=FTNAni@7~9`A2UGhe-#&$C zdo8>6t=XZ7ei+3#f@t0E67q_9^QB}N%29ztyS{reIivmGlpT)KPv9h?HLvDyV;z#T zGBvV+`;N>1Rjw0*r(IW*>l#zOAo6dD#{VlFUxjMa;1o`y7Mb;*>g4~^o`0U_U+DQG zO<&0I{Fn6+;O7578a)3d&!3FuU#)*ko_&qE|JMJH{QT>rS&s%Zq6y8ofQx9y_K(f| z=Nr&L@9AQ<7#qs?VOA;wp5nVzuT7BL7dmavtpDTQ+TQ4Sesh1g>Hp)thmf&No3alZD9X5v9HPbI(FkC`+A9eO^!-uOxok> zy6FEv6YA5bZk=S0PpOv^#q3S#q;Ls~m^^HsGVKrgT&?=|l=cUG84FeFU!+c}f9teA z#P!&|O0J<-{5rXTSiSp_ezdOcrtlVq^W7IR;~dV@*U$dPfa@C^6TjsCOXWX|;Ou2> zO+@p9cO0Mom^LYP=~wZ;^uMi?|1gedt(f2Uzm5J$xF?;qyVdL8&k9*f?tjcTa{NHZ zq0e>rwvbEDL-g;!bg6Hw)whNM$M!TF2>G@Zpm*08Fcy!|EfMk4vTeW8JDL=&3PiVJ9Y@WF7AY)1z= zam9FK7a5IzCo26rj2=Tj;+SlV`hc|k>52g4eD`|bMZ?lmv??Y?Q@D}NQj2DynX?4YO8w@5Dh-Cf}-xrg3H z`M+K{{4M#v+`X0V3pw;$Wd8r>8;KXVSA4qUm7zW7?(lft-QhEcVTnC)-~gE*6!B$B&nV&lHx0wr7nO>yv$=x;%XD)N8|McV81a=udV( z7`|}4BAj$=C7PSmcX;a9gCX)6>)U)N^zS{SjrEXc@K6{kdWik=5c}sL_RmA1$}!cb z!6}?ZE$UE@Y5UC-SpNWX#nwNl@U8B7=DSDvKO7CCP1Zj+X8i*sFoEPr>mS%}s_dcA z)BIp)kWM3-(2U;J2Scpm!O%({qz}wI7%tE+V&uYu+9nSgJ9#h+q2=rP-L@)1yKNnq z+^h(lWEYb3*?2{mOICz#;TY0g6=C79wnMBUq&lqskZ1jYBI^$vw*Ekc^#_WrKTu-* z0dl0e;&`g3^P$$(`6Iyl~sgsByuXkgmWdk z+5U*u-|OjrFr=iFnct^exkT@cyMHXw2Y1}RZ|gFB2_vWVJ+8X{W%s}SU|1HHII=IS zkgJG%pVr8AME@|GU2$D&eDs7jk#1BjV1Chcl3U_-(6b>Qld&?sdgLB@=~@3tTMC2u z^0E7iCzajqFKQbOW$jJ=+|J_b9)Up~%XewZ9FCYy8}(b`GrA7+PQ$F#jvU-9SpSkOx_t$v<~ z>S%RM4ZQx0(Il`P=n>XW0K3 zUS6GaU@#V|6Rr( zv{NRvSElO4X$zc`P9>^PjY-#(bnP|tId$u-Ydu9jjc857bhUe|bg$YPwc>iTN9xFW zME?$KAR7^L?0~jMG_Su&xEaGm+8?{x8Hcqq3f%t;`yV&||5c%EF7|Fq)YD|EV_w+5 zazXea+L8WHR+w()A9PXwN2~siPW>Mj*bC|t>$tI3*#9-!AIi;fBrt*GCjX*k{zb_2 zAIA2D%zAitU(Xkv((Xd^53uOpUvr1O{|~BXqjS<1^r1&U5)fCugAM? zJp3JU7inW`SNwD9Dt-ME`xz(d>s3^l?BU#{@am-!Y9z%o)u1#l^AUGjUisG{HXM8o$!s3Ikr`y8Z|hD)2Kxq`fAy? z?BoFqvT28s)^$r+r;HrwXJ50a$9CD*d(m8e_VpI~x`RF4#lF5^T{$-VOe_2RVq~BD zK0UV-()0y1cc~9i?_5zEyMb&(6PnQ$jZJAUw$d-)BHGar;rS;z+3@r(bR&kz8rOh+ z`dp=RH@Rl|7#2#|@+a8z$FBE5^i;FW#m^vxOX%&=|A9sNU_0A9bue6}FJWYd{eFS{ z-pp1fm&L7M6>E6m_}aQ~i+-35^8B;Dfv`Vp(zmdKUG(`z`h7D4UFb#({fOp0#K|!vF@yCX

$|ub(aTjaJC-MeLbhvC}_zFkFiE zxz~3TtshX%M!1!)|Ff>=-8td=`ao~|zx1;G<~*aQ4X{LC#zKQ~38_WT5-U+!(p@F5 zyp=CZ6vhTNv4tJ%;wtu#^&Q^=wtxJR(0f|n3Ox_8I`g#2o;bfAvJk`T@;ah5KjIiT z=J}(@w!`S_RR3b2Up_WuU*>Zi248`+kXyiFJ>RJf70{? z%-1^)BHO>hzQ=GJC(zqy{KB@A^g(GyeE^m8Dn$R(8Ae>%(OSEi`S*49Iqm*0igQiz zR{PLL$ZE&b;1o`y7Immc1E%dWW8YcF-TePSTpJlas_$>TFauA!H{PHrGZkN){Hh)v-wWY!fD9$I!!i0agJ{SeVQ z`&SV4oA21~%C|$^6g*@qnNVqr1p6d z6NuL28bCC+zYil#zFovIgfizSM+J`II8NXsD$%!PeSc%~1Et2lky#hN%NVn<>ye9x zLVmX#!fDi^4%7X{zcFh)|2h9XpRYHrj&y@@_S!pr|Hi+UjDIg0|Bf5~o-vk< zN&1v={sg^VIt^$<%(I>>vJaYs=W?Z8eke54qxJh2nvAU@v;Kd+v2$&Ro;~xNk9!Z5 z{Qs(LS1+_iY0(GlJ9zTY_3>}z`LMJ`Fl4+xv;KdL@&7veowlEShse0_FgYUK3(~oW zc66W|y$X{Eu0AeJ;wAGx8?V^2Gv1H>87^3m`zF7ZWkObVB< zh)jRNnD>Rt!qa^CW-yC6Hud~z_O~&BG`Ucu{=cB?!Z_Rf5B&e^)1#hSlFl+BKeQFH zxAMCG-zxpa`a7BW{}aB+)C>Io=xg>_N3{0O1{tmWvq^4YQu||v-1VK#9dnKB{?u{T zhy_Gz0QEI0J8@&4V5a`}EN;vb+;#j_>>+DKy9>F;chYH$jt zk$&*jP)pXK9@CfQf6T7RpX$^3tMVt(o$@C}kIVljf9s!Dqy8683P--_Q}#<7*Y_`- zMl_)rlPA^xXr<3xkalx+xIn*%1>v-CikxrH4(;N4viR_lo#@RqHb8bGMju#It}Q7K zg!?fpJR&@#{E3s9^}VC|TmEec?%h7?Uy^soE5=@D$P_MN`_&;Vl9#cBWwiZ}f5ES` zwVy#{T%G_2=Gi9YEvz>)-xP_T-NIe%08H?Q^U0|FUo3$9&L*qc&+8 zDMbIO>D#e>8U`?kp;ERjhB1PqYb-mo$T*+eg^bYh}O;@ zMjS)VaavqLpF=HKhsjO824n-0^jUotbNVWdyN^b3X?+t5`V8j9r@F0Mf7tjt2Ji>} zKXlDaj%h|KE+8u7GyOj)zAnPkqTbN{Y4^M8Z__h0D# zbKO6B`mKK;o$G$M-nTx2-njcmr})7g*S2Z?!rjIfUZ!7c&HXRC{|)yqzRNy|Ep-4H zL*)C}PsS0g^O+f!uJb)dVWjB`HQs$>gVOu>@b~lUA2=-kV+e6XYkZDi#xW^e!Xlo@ z*Jq9|y6$dv?eqovAk$CS!KP)~ryAtZCi$Nn-7x;WW&E2=?D4&$C)f0oId;mv(fSpi zTD&7%mi7{s(c7*6Q2Yu#v;L^-Sf#IFL?8I@5!Z;whv4aAYdCrqPi@{3x(ci}c>H$b zEw_ZvSKl5wD{l`^)ZZ39*ZjxfvsJf+j;-6mm0e#@9_f(|lMvqujQMt?YBkf@uCiuKhFZpL=sc z9{tArg?xInCdDiY=!LkM|FumcU-b^>%Mo9M!zjiPl%N!SY>fU5;{)q#4GhsU>-8=h zAHXR4dko`9U_u?2L^Owg>XP+;$;|oyjoJsutiM;R?#)*R=c)fs~!csE{^NnD5qC zC@z{mFh(ZPyJ|d|Od(e4zIOF1?wOM!yok(yUmEoXDJSA&-&OTLGV^oP_d~_%#S(2~ zGWz%CWydUG87p{U`)^fvx|*=J#=^*`4{wjWaxedgzw^B!v!3wFf0I3T(z%fo?i4;^(t_SD(-R+pK@gJ}j=gz#6GVY>NtI zO#VM)$FqKsJzVnl*&|thpWT=LPucyQ|CBwD^{ec`qJPMq$a*$AdHTO)r@#7}>=uk> z{VaQ|YCC)UbmGT(27 z@0Y9=uE8lpw!!w7^?xZ>PYbuX$Mc`q%dVyODq|xb$~yXBop+0RdIK_k)z!++O80h( z|Ifv1+5H{YvIn}ZWe<`=W7o3dscYH8{nxTbQ016LG@%)-xPXgj$FyfRQ}TTFEIC*4 zeD*w<5>Ar~OV4MIN@EP;&DXLMm_YKvwd~2-=d-7Zp3m-(#`>TACcBfo^0D7!cahzQ zp&xOKA&D7e>R;d0Ic(p$mz}7(mYouR3F+5qpOKSI&e3=+`!ap5E~Uo`x5n&yKG=pFM<4 zafzL4*<0ieCOywx@+y+_+2hKElfGZ!tX0=2ydXSZB@dH*uC2e!^RMvyPk8>vJb!Yy z-t$L}V{(y)d=#J%ML3M<6Y{&XW^0^>oR?;*;d&b{>T4cV_Kz(of5^m|@`p?cM?T_H zP5i%;d)dX(IfAGUwuJ1xyq6tA>HkmK-N$!b*Ngvu?Xyi7FqMP^1re1Lm6UWT=u}cl zN=m-TQ7TG~kx_AsawO^GGuVIu0|pEjFkrx7V;kGpHM@$FoaDqRGBQd|u1QJCcT$p4 z?)1CA=f$`4z4(1^x8LuN$L;yLdcWT9_v`(-e7-K;%;nmbXzlF;=63VrRqWB)+e$Jz z!}A}{zlrw$r3+U=?!hGHzWskgjKRn1hmTResAE3RFDPe|g_(k>n1<@Zj+te@Wn8r-{Y0uHDemF?B%v2|kZS3u2 z$9d~%)GM{A%8+}_LDY86CFdd919(swX;FSM*cTw$``0m3KPN*!hdGm57IKh_Jaid9 zpi5{Dg3Ve1Cu0N#-t`!a?7Z~i^q`2U=?Nt=3r ztetKCJ=6RM5biI)jQpgYb*V^n<_Q806{lQS?9(x>zlS}Dim*@}g z@~k7-M`H}eVvjbp1ml>?7kIaM#>vcSh|Ub&f8H8{v-TS3uaxF{f9ylw{=a7LU*i2M zy+1j@HR+gy$(Vwvn1<`bUND z&VNK_BFzxbOw7V;G>p-Q8gFd@b2RsHP?)*Q^UyJuE`RpBb04h>tTo`a%eyWhGqFeb zEHXNSxNrRdzi1AmU7pP4zJIj$^35yKy=R(tKdcR?(f23!AsYKvkX7fDbL6=u9|b5x z5sFcQGDKtlrbEis31y6o#{R9`qOpJbS!Gi>i*x|d+56ERy=Z@WqyuO~qyvagu?G)} zl+8#N(EgIMx5)-niofr_jr3{J86c4^pjEtWQQXXtE?}2>s7AQ@=vmN)NCyy&{SPsh zA69qZF!K>qF-JOp=#0@w2hhpw7>*-Np6)S!cY^sOy3kE;aw;;fObDmRGdPR$=tX?4 z_Blg)lu3VJ9pjQW=mVZg4GF)|PeBS&F&INJ93wFrV=xxuFdk``fOJek&r$ln*Ny3& z12mbvdzdpw$f=ly>8L$rF9?~PeNmW6&cgfq|K>19XHWL+|JyMr%w_KwEpJXx2aK0j z$%x#GF)bF40u8yOE{NVq4 z)4V@EINx9%{Ra-Y<}i++6aT1RJI3C0Ox!q#7WZxKGX6(9I#4~>`{Mv=QFqpJp#hDk z^=v1^)3^S&$$Av#$S=~7bTPLa7dK8ZpGHTTxY2gly>N!xt|H@0@*MWmJNJ+5MTxYI zbSKdohJ=^oN9>>X?t4YLlV}aY2jhQz-SRwTzfifu{#mbsYCm{N#c02zz{?Wg-)44@w;7uZ@Ao}<1RB{?hrmNEycrPruUmNRPqBF&+ z)GZZcCIA21|G%b-X9i|s7NUOv#tV%D`^rVM_DKDY5BC3@RJQc3J~;o;`_;60|19r{ z5B|TX-xi(!I9r@^Fc&?)P~NqN^O!RtyYYYf>-7Jx`+o&{_XzqZG8zwMMRsQ}M&}Rb zFy|r<`FQ{Q;R5FO&mS&iE<)5MwIB10L!}?0f4)xUUW^iyVUM+XCD#8{GWVVTf7Tj4 z=6Y0Vi=s93mBLgYK27~ML;W|4E<&H7ISM!L-Tn7%%zgX+9#Vgw(Ld1kwjw&Wr~R1m zf2VVOggJ~O=)^G`N6$S;;RM-rGaWzKy*|l2h`tO?;|$Is8vC9j&!ZO!Cp~|!^XZVx zTyEZBpT0&4b1JIL1MHuv{pVJ}uhcV?d6s>M_6hu#|L+!ASMQpk7><$n$MQd#y(!l_ zqi_CSAELEc+l=;U+J97^(*B>({`YGC^?~~IBZsyBhqV8~mHHNCzRA86<8)LEH!jD1 z<|<@A>U%vy_fV?8qYWDCp5rhcY0>?KX~sdL=TQ}zGJM&TxiBOQ|v{kL~A`TqYkrZ7*%H1z$SLw(R%r0nG=gTi)@(H`IF{AZv< zxt>YxI_^wbayIreSDk#fPfxSff*H4Itue z%2JB_JDfgOeEDd9#Cpgf_VaJ`|MhQ;49C@_h1`oGzt$0-Bo_IsWnyMC4s zPLh4=FWwvxqW}MOvF}p1ogz=;49;SYJXLbSc~Z>fhvcpE&PQhMMc?^<-nkML!W3|e z&Lzom#(f+K|Dl~g3Q|#;GRpZh&cGizDwGd(=6$NO@6GS;M-{3EkFvgOR5(C?T6<)a z`h8TWKQv1I9!38*Dl{D)6`HvpJT5KDtPg0mJ^<~^9mTG(M}M&U48?GaM8h$0l-PsL z+!9@T&iVVyW6|M0?Gwb2E-vB5aT|{`Oh7ui9vu}Xk==D8!env^rlKc|4AaQzn1Pwt zm3Bp#Mb5??%*7t_>Lp{WOJFWH&%Q5%p19N)kiDwJ+JI(r=WXWDg^N#>cGKk_)E==% zR+xI}+Q6L2ykv)dQj$ENO=zj~&Fg)8Y2QvpXEBr(D*xD*s{AMWCL?_tvTy%sn!G8F z9ONPo`6xgkiqJGh`5!0#3F1fVSn;F%r2Kze{-3J+3v)o2TCy%(-Xt4FEC23QERGVC zA^L|@Wn?c@u9<6?>(PvT%6hr7TX9a9-gn0aRqSp24&gAO{_hd86UT5IefmQAD_S3O zg8d}Au(U2Z|2wHq|96`G49=qGRqNo%^XNtQJJ!Je_TBuSOr{_egE17fv&`+2BQY9d z@c#avvCPq!{~*ROk4HzBxAkoA!&7!VKn_=ouuRlCv=fb0fDWudx2j{=cXVo}?~QXU^mP{{Ei~=JJ{D zjlTc?+(~EOC2#f0W2iuBjxk`K@h@3ksQuHHMd#w~Z!=D;u|7{63y_H{v?E<$p976V^#*E)mCf?9@TSHIC zt!I<-|MU8V?Ec!;VBEISng;HBa`hjw`h`yB@&bJW9AiF?D!POHXZ0U)^dE47TkYu! z!b!3V4a}#=)2L^DU;lT8{Ve_${h#r88T$4qrc3`Z(tn)vpCbJy>D##GJbIDvvbq^5 zNX20E>HnrE|1;!I?M~|g`4jE(ao_*{4*$RC|AZSVj^P-IUDlS3CQHaY7{k1J^dA$) zlKWES-(>lBr2I?nXRjhRnp4{_+uYjZ!C`&Thr_yg9}a7$4|TrcYl&-S{4sI+tP8@n zwB%4U`LeKe&ShcC^x@{{oM$oPim++2vuz#?VVwB8AGj-wCwJAv!p`IBrTTthN3QRi zaC@juxl7-2VmOfJKV|hF33UnA*jqR;wBI}^lpXnaC_VJ?P;Q^vzOLKoFK$sseJoU+ zzboudxHA+Fr@M^eoqM7Gt3~6&0b#fE75C;040{R(`0m!_-{(JQDX%3iM)vqWCguzs z5|&K3G%W4BG-U5c4e735KIY=EBK4B6vgneKmpUS>8a^tl))ta5@&hjWL8byF@6 z>oYG81xyDaz#t6*0=sUk` zntFSNzFfZk|Jc-!$2}hfD8!zb_7TiT4Moi5Q_Kx#8gDa~p{mf?Z~4xEBP*t*{@4Eh zIBThesV5tb*cZUu*Z=RY4rXsTOh0glegN(0K$BE>z{wAD|KyIK}NW&fqN0;XHbg(51ZMAJ-*FD3>SVRMq>=dVjP-s)N^_7?f)s#|7CAu zZznt2-s}Hs^A4QSKhf^hozee2ssF2Q(8xa;KaLkq8YUne4LR~pfiX36wB{~4`+hR> z6!iW7ZY&)fqCG(EN0Y)XWo#pW*#8y+7H( zUhREr&Ut^-9`^p;z218_)O&y9fv%_Y!HuhtlmoQk5cbJ`Q64O9Rt{?9yC{LtWXD6Pq7>Ds_n4$kM&Dv?^mSS!ARPzAL>F6la|KQd}w$?lMhudWI zypbHHkW(=Y(=h|_bp4TNd^s|Nne1K9+e=7xJC9*Dxl0^#$UgnwJ7%rqaxUV!)Y0?`@t z(SI;X=joqfAIkB+{Qsun&QR#OA{3(pWvE0An&y~a!@(5#d9;pHHzcbUQtAK5|MLGE z;XXK@kGWC3P%n;Vv|*RA#UZj}yzvJPGao@GqJINK{|Kni?}*Ore*b*lo?-gGL#45@ z9_iwb33~!d-o8`4X>9>>7f#_cy1&&g43+=7zHaX=`~OY8J zEXH9x(l7x%_9~>4S()B@UI>$zqctv5$f=RPcBB0vXD_Tt2-DcRogp}#oPn8G@+;}{ zrakjT>EXEdn#^q$W@FFsbo#&aFo(JPyuI0H)5Bcmd8it$Up4ryP$^6WGPueTPq8l@x$JppXYN2N+E6e&l%NciXqs;S`%L@a$(Fg|%@A*!oDQ<5GAlS5HbDenkdKGm6;>?5&XT+u&h zDyA#P=`n>;*+o=8w*Me8L~DLV3%8Wcc*&hf+GRRV_O7?>eI>^u4HJ-#Ntlc&h|d3+ zO1^*o&ot)gn1TPV=l?|K56*P`EX>9n%tf^KZyx#o`~07K?hf7h65ajns_Icz3=>=Inr>JdVnnA_Wt=l#mpa^|MPs3^C^;yi;Vxz z>UT``4NzKQ{HZPLJO8Fxo5Ea){o0@^6o{h)WvE0A>d}lgG|e;rkG}u=oizSGZ2V8B z-cEM(8vl!{MjLkk|Kt2y?e8@0FQW5r4vFV5j-V3_W3|89xMR#M^R&Nn^*@+TpuX zDF2=3-;0D(>fx8{^CFkT>F=(zzneJ~gE17tF%qNE{bW)YLq=!fjU~rnJkl@$>6nDc zn1Y_Vq%f77hEikx57u9l(a9*AmDtZ*g=qc$bm3-TCT3wa=3p+G<|_Zj2?vcKThPkf zma6U_ul%F>gz}FA!qv`J{_RP`OrH9Hke%+O)Kgn&$+w>{XevPwq1^UC4`ok4uR2M{N zg5?U6hkO*E5Jl*g*2QF(das0x&IT$YD^Y`bG@}iNPM`z_!Fh^^P_s>vIsiWeZ(kMm#kTxe>--Y&P z=nU$txgRko^qv2&9_u^*|FpC@FOAU7+<|IoSc3zI&PuEs{O;W5`{)0k7T+11#W|cu zFA{$5-GAW>C^8!Vr;y#dosBSPXh>xqjG>6;K8KSdF&bm={{O$mGMCqTH;iK*kN3|l zt}OAc_+b6}6yuB;#uamnE3m)NI0L6VOPcE@ARUu18B;J7O~xS2y~Y`2i*ZWd`hVl) zcJ!_PKW$us1H#mvH7+@BTyoMlg>3BPcE}n2QU9%8{e@}bjP?f98{@PaleLT!H?qH) z7^d^<7|wmL`R_6AHP#*hZoAxT$ui~C85Pmm`I+2ip`^sveUA6SZ1y?WKizv73subZ zp2qmKxz1eUJ`WjKfQBsh=bp*jQsDl1)?G8_pyRyzdyh8vY(*}&JmjMQ-Gfp?A=x#I z{)FtQ7!-=h5|p75HK<23N)wFh@WKBTm@Ba#(O!r6Non(5{%aHd5VG%7zMboMgt=*? z^p`#d$4LLv%K!POU(cVp1JzTc|1{~3TGXX^H}N)*jjr33n;JUBv&7tZ^#6}z%zONk zq=fr%=JEpipFC#-GM_}%dHNswL@IOWf6&FPwn*8(U%9#`DKtz^45zrAM&J1%^Nb6u z131He7VYDW3(z`E8ANryHA02f1{JG=PrT3kY-jHv z&x!LqdXex4-{_g)A(>1;hP+!kTKQJ)qBF2Q&|jW4en1scg&T~a7>=F?oi{~}#u#*G zhcK4x>wl1a>-W?Ji}ba}v+wfk6Ue^x_xAUtGw+$EKCHL?k-1#|XI}>W3G)vhYooR?+H(|* z?F-q9P>d2ZPwE$Pb$yJw8gt*3=D^8a z`S$*=G(T~j@~i(>CeA&=l<=!$E(Qrwn8$B{=R9Wafm>~%`;vzc^(zjO zN6`2GAY=XLT*6NFW9a)Y;bi?3?rn4c$GM%rNp#^9PU8%EN|VD`vg_95aE|QuPsj6Q zFA`og4#cjJ{=YZw-9AK=J}KPx9QXYC9I4FZXFdOEV?5^U>$K^Ir0Yq~-{twa|8Msn z_!j%cTSd0NPWR{9zWt9^%A3riF$QDNlqn5-!-J{50of{z+frON!TAHI)*q_LkoLmW zl65nDgUQm~_iZFg^OcW6d9+yhK*ce6^Q1g_{N4Gfaqcl5Y3R7oepoUclQ0=g%6s!n z<NGqy3}rtIJTnGnH?$k$s9drlM9kokrFxJK4+B^~?>* z$Kj~XxNd~}rH#$dw$`|Zwz-3BZ?^t#PAtq2b|z+FHuj8E{)RhqkhyRDa=JCW%=1vC zE{V#1B^j+>F4YD`XOQn}^ZsN-rT14Z_A4J%Q8~GBM96U60%Rf!xmfyj>jOrN3{A>Z zbDsVe*`jQ_n*J9)IREji{ulcGzkbRZ{$Hs}w0l|fmvQ8YJ0At;dCs{G zWEb5)5!rpaJp$GT6i06IQVCgxzV&|6Ci|Wd`goq7TMg>bG21(!?YL*d`}={K`L$t} ze~9Ohhj182u*aNuNy*62$y`3v{5OscBNGp&}tkUqXGCbwN1loY14j49)y@iSsb|JbsKsab%TzEY`t1 zRDRlAU_Nd7)v?xm^<7Q$z14TxlqPj*GrxoEd&NFtu8;-_fVZHqW|H-^|i2VAfJ&eKH2-l{f>Fj`TF}YJZT|(xL9T+}LR=G#( z&D#3s)B{ml!tJPi98HG@goF3#AEtP2ZvTEki2oLU!d2{lCC88h`Weq+7|!D+RB@k3 z){~!(qJDeQ^l=gDWd&mgb!Y+oT; zQm=~rlKJS6kHsSTmpZNhBFj*KC(;^qs&&MEfd0Bs zO!iH=9ZXh79hVBTPaH@=bm zfU@`fk#(gCAOY#Zd`a4*iJSeQ1 zr=FfYC_KZyV)B)-$2@;-+D);qk>9{V?%yTngQ_cjfpp1o-}o#9<=;M{MIk8+;GTW4 z@;}6RW3Ti#9*%`&H~Kfr?&PpMJQXhFzjDn5;X~wQ_%NE(CHc=i9oF9e{jlzV1>q`g zN3CJo@J32#Qjx9sK}xuu+gkRcZ>z)WxYId)GzvFDxzjJ07@6t4oAf{KV*UiC;8B#W z>>tWrc{)7KT>eZ_=nLoidF)5E+Zkld`lmyaK1dTi%F%c9TRi)M$X}l%lZ@_h@KyiF z%uWie;hE6(=EBhac4p{!=Q|3`#8_79#Mq*QiLu4W6Jy!KC&qF{PK+%XJTbO(=)~Bv zu@hst<0i(IkDeG?F=k?H<;KKtbO`+(w_58v>a6cL8jLSetp7mt|L`W2WmAUx(owXIy(pXz zwq4!f^Ce@vF<5nxnQ}{kmJZj9(S&d`6jDubd~1jdi}R%zRZix|OCw+Jxg5 zhGp1T;~$s4M}72c-mN*3>}}*BY@c&S>?GM0xku^noOhNMhuM!rZl|pGC6B#lZaVE* zarEddu~U)T80J@Ri8YN5;WYc1$W59?-!@9?#;K{sa;c$S;~D98qW?vo6{cpmwJ~Gq z;La)E*6Fm(Fz25($Xsx0IEZF6A>KA1Wa$em8k!Uq9~q$kMSpwJx~Ss=!qQIm2jzc# zi{;MSUXj{AtQ_1wEMwnCmJipi=B9Z6=zM1X0`5%- z)d?w~Ce!)K8QN*r*Ud?>c2YYne531|$mXswv4h-N(0Xi4tSvnyv@b}pk2NJMzBxG* zKh-b%*}I(7w(e#wd7xjoBtgFrE7UnxlGk9_{Yk;FGkzmwBn7uEDRUm(ASRd~U5uOr#DS3tKS zejgsjH_+ebJ*T&W#PId0+HL=a?TTWab+e8*F=oi+w_Al7l zJ233(ENSP~X*LMwiNt^QMAJ9J?v9h?%LV}Bk;@iQ#?)wAIj0WY51t>&&O8(L9eO5|kIuB_C)2!HX4pS8GgPH!hU$z=`7cxc%d~&<*-$t0*-$_A z+0ek;xZqi3_}S1r{aJO+v!R7-<=2+>tUkxHp<|r?nrq7|JjbueDza|wLVLfg&(fBq zTrC}nABnY%e=1zW{K9zSgo*Oa_4>Q`DSnGT;LrFQ5^vD&#W%1BYq1jt@UQqCevd!m ziW@^bK7C=xn!PYAa^J-n3ypslhMb~>VM*@7urzC7SeCOenmOP2$6Ouz@)`5}2mSkDrEqyg-w&&5=sQM! zKdj0Br&#{j`C)C={IITbepr8EekeFSUpbf`HXfT1+f+P16wdxaZ1dUqVawDn#I_!r zABtvuKWxkTr`Y!6^NoKpLvh|0Vms3wj_pdBA9jzPA4nyH>nqKD{jX7bekN>X-+@w8V`udW5YihVV?rl?=YuQdKG=@-_yW<6Oj-G4A<8vl}; z=0@Y+1moX?u!Y>pUc`MH|Lx=sW9#C2I<~gxEQ*A%o4w??@$bomu$L@7V~*mS_PN*k zSh9k>lKX!CRb;jBHTCqXuCFEQ*z2d$ug-jTe!FjeJcE8VlkSu3JO6-xJO2)H`^TgQ z{*1TqH~by_ZVB;$7==4@nU zW|$w%G(SqVvbS+>=ifoL&PfVw8A;YOC4~;MX>w9%o|+U6k}V68tQSqPUNk8*kd4Ar zk4y?RsY&5La+39=Nugpo{pHN0u%E0Frj#sW-$#~X7O;z5=b?ht0D39ZmM{;$#@|vu?rYf)G3ifR7IsBK9 zOWEU-m3P-KA{VnSlm`}L1y*AnI`BjM7{5g~{*HdP%U8GwpTlE#98X{iwxbF!;J4^O zzX{r7T!Kq+Ij+SGxCNiX=kYL}z<2R9R%0DDVmo$WFUnDc4!nrp;175WJ@_;JhQA|G zU*RHLicz>4H{d4Rf;$jbUu4zO|EWI~liBP!$E~|RY27`!?2NJAIr_ie1o=Nf{x`46 zeHH)Jl4CS*RLbjvzsyx_lo;oeU|TNi{|5oYqTX#83P-OH*;U~GiyXP>5N%l zl;tdiMXy{C7B5Q**}u3vv?rv==nDJ3-@0c zxmiEMzk7tUz9M<0Gch9hfVBW*&nw2(k<3mECnEVPYic8@W8Fn|KW{(FOV;h}NDAGr zCWd>X{8DNSKon+oQg|_V4so51!s_4mh`;B~=uAB3`>YG#eir9&9zBmFh2BVhNqq5l zNsxBkn}mNdDI_zeAQgi#0zHokC+)6eZVTFf+{SzjzB@L=FC3#?M;%^9_il0BN8j_2 z5Ra#PCbr~$?a>)&C0*K|k#>W%StI3%FX`jl!#vhE|CTo6(O6hv4rbc~e#3?5Z<`~|I(Fd~Z8Fwf&rQDB{gi9I z&fKg`WG*x&t;&vtn>=R}XW#RAru*0n^R$g?-h1A&+PrGv$@T~7Y`I5!I0|#B=JVl*?_hY4eVw|OY@AZE|ibB+A_vhd8?5`;?;NAxVhbl)i6=$W~H-TU^u!9DrC z=T5yKe4E>%CtY75+M+@w;x(#hzPS?cd^8!tWvOZT&4 zvxZV0DI9AK>r2Yyn{@AyyY#4SQeLjJhU_VAwD)SfQ#kKd|4yPZZqMfZo;|!*{-syd zhKY?_kHf zac=AN9%tVTvmyE>(qmf`F1qhe{BK4a>(L$E?@BzV+_FCxxxE|_?-1p`nlK_%%g2Rd?QwcwMA$afeP^40Q}1<4^QXxV+JiUbt2;;Z<%f>WQJ!a1 zSojbAQ5lNz%3^L&y7%RuG|w!asNGr}h5fcPQHO4m{?WIH(*N(nw6(+CuX<#-_A+fg z^9|%p7|47Jc^mFPPg5-1P2L-s(QBTk10b7c(T9#Ux6gc# zUz;$k!gffLcIHc5`zdjKp1c&7V-y}o&&5wA){O-*bhyB}GQ*g|&Thxidb_`l}~ z?LV^RrJQJvIw2`69hn@KjiLWcqyMK%TS2e4lAbMZc5+z7yjoscQ}(}U|EHVto|&ls zo2dWm+hio_|H=dN5<|5Avk2R;eQILZQEzQQoAV~gUDmVhKIE)w>|x&9YkewN*7@!k zR^=zGSvg|uN}4q+>Ga=ZwK!^~TR%6``Z=g=}SSOOxMS zAFadtu{8K8Ucig^4StI^q6A3TgV?vI5pl3zjonpk)&au@D#@|*Y~^H=Z~zKJLBU3?EuV<8q} zDOO-L*5P2@@vxEHg6()hT;Ihm=0(cOUa}nhxDCWbc;j|`n3sMYyM(#>p?`~AN?wjp zxEkg3N&h@ff0O(XYO%|>u7&K`{c>y<**#=wtZU7mV_hqYV%=-rBFotuV%_2MiQVy? zv8&zlj5vDgK9ktfS{v)RW<_j1dw*TOo@+M7^85D-7cwh)@ynnZh<^k(;$u(^#qYv> z_%t5GKjGFybMWYjKjmIm#JY!kBi42CC9$p@7sXWl)-7Hb>$&mn*e9bf58e^$x%G}% z_nlY8bj9L5*Nuwx4Ej*)zcJt7`EJ4^%wNVW%(r1a^CB$6DyU-OoAC_3in;ib=eR@I zyP-Oa`~2~HnLmYA=I>(@{uyiWFW8M=;7@oR@oCPh!y+t3_B?AcGR?nbTo9Jdy&x=` z?msyCC(A31XNu_`kcU-JMaDnsIsTPg>))R1{O5DM_PfBO@P=OdugOj9g~$7a&Etm9 z!}beX*^9Kv+w!#aV~2#oF?1Bzg3TzxR$S{HwvqXVE(mMRUJ%wDzaXqVX+K1|K0Y?` z+kjmCnPpgk<;cTIti~#2AyQB;CKn+G*;tAtXf9Dk(Sn0`e_eV9b31BKjanQ)^lywh zG@=1z`iQ0KpmK7beq-faIuiD(k#r;|X5OJ*+r@vUd5)5N{Y-9q$vuVEsz|%vZIxN*w@tw z^Nb0@0Cx{pL*#{Mjni*pgbUrV2A?^25fbp7J)#I8p!OYDB>rbPe$O6(f)r^K#%UP$b| z=a2YdVzV%((DT%5i9KKXQ=)0y_-m2F} zXx_tmdMxc77M7i~#=Y>eu>ADpVa2$SVdde=L*CiT!>Y`S!|JZf!W!=Rlbv@y^`fwj zTyNgCfZV{ok=&H&|9s?T_ATPsI^&{Hgl$tU3fr-x`JzyaokuPTyRf@{U?^#GHUqi$ zxbykXmPfLheev;Q9rbIkvu$zqRY92{={KU&XP z3yF4g;IGo67eO(-5?x91JNpZ`ko{JC24BZQY`{tR{4cmj*bm9$_w_e7h`X8Zz9z(< zLK%L7j*nQEgP-8%_$AV=Q(s~$cHl5NaTyXOe06hOZSTNLCnNA z@CGVlzp#}dQ`>jPsE79)E+{U4TKDV7bU|HE>uz{+{_ ze^`apSTmUZ4{P<8)(xfqOLgY0I(vgUdn36?9sa@kr;H2e|EzyXGj2;aZX7UF7bw`ajr%y(m4VjywO}J}%><3S**5?C;e7IjnAMd-sgd1Nt@58KZUjInf!T z4eZewqfMFmKjgtl>ex){qi0($=oz9jR6C}Ld#F4nf0WCg`%o#bR-mfPK8G6TUSyc3 zn>8@(Iz2G#oGNd24Gg=_%bOFTH46PhvyO68eyX#r%&U`a&F0x%cbl|rCu>fmhU#Of z*1xBQ17tn_I_{0^4dh+IeHOFOgnz|G?8PsUbt+W_kgEM~MhS9|jivVSFBvyX*-8zY z#JPpsjG`&2VJo?PUaE0JYRI>jV-5dx()Jgfo#@ zK7S044+1Nk-?BPZaKco1K} zSMd~buoAmbjwT$#Z*T_vJkt=2!dOhiB+S4M)d{~(5=XK*SzK~~``|&ya`9|zLAP** z@A09+UWYsISxqm*ZR|+LboLT_kQZE6!G1DN%%N9j{D8H71!`P ztqpq7oYC*KInR^-O*?hHa6PYF5>m+D$yz@bb_vD{Z`wcpj(AhqH)5PHgYbLd{)GSF zwiQE#%O$6ie;2-t+<1*|CCn-A*JC@rh+(e(EB9BKFU3y0hRbjiCFsEw_$kWpR|MCs z5N{sy3^Eb-aJvr=U;-}3e~SApVQyx=2!CWh&Yq7)U2_Y!58=Oa{}s3Oc#L~B`8Bfp zo)6LgTter9U!i|09SFAyeDy=d`l-gbxRv=42;cJyJdZ0c39pO8?|I^b#Ww`MV*UlL zVeZH88>FQX@df0q>s}<#Z!Yr_XFh3copwl9D1^`2{+Wg~s42#KZ_MD;V zVadq!u#{Zp|LM7D>Bhh5VFkI8J&*e;{;SC~bN%~0BR#CmOb_eG_3Q=QH}Ky`ZW6vQ zEj?^@{T6a7dlC0-{I`=k&ZLLpbLnAcZ+h58?q)CHzK8!_vQ+r8H2cV1UrtuAS90Ia zzlyA$>;LW<{_mdY|L$ZRdp-At!Ret9P2U(XmE80AuOe5ougP#mfwKhGlIz&lb1xWvci50} zci4zc6YdU$*vz~ITlp1X8}oMTIDdC2#!lv4*v+p5dzkm4bi&=C4EvbNQNgbg`o;=hgD&c35w z``@PhCwFyf|Bu@{b@GA`?V;MsUV6^{=w9tVS%tdyQ)DSO#$={;9^le>NElF7cu zRNsT#!Csu~d->M8$krn1St31M-$6F9H|I$2Jn2o=o%B7pH}G#HtJ!M~`(B;C7rAoQ z#E^$oSdC?qCWc%r#|mW6niz7h1WOT5m>9BPy=GX9tut>AMW^JaGx8I;iM??0?fTod z>l55=J<9E&VD{}{!`$1$Mskg7^05}{@I!UbZ(uo7{LfgRej1~`e1^=%!^}I#pWp|0 zPTlsBw&`hY(N91H9lr^m#AEnA^6{efjaogPBJ5R2!~OUo{smcBkH5k)xcD$!i;rV6 zK8uI(DCVF5+akMX{L+PVMEH#4XiA|YMAZOs3^Xp}_Bs515S&bI?#oCPD8nbPuTzMkRKF4duhONd3Mc9Vz*lP~G8<&df z24lj-+SH50*@Dlz_6dFuNAAXlJ?^vFeKrd7TmHNFy}-R#yS$Wr9s9N78Rpr3?0!A$ zLp;NWguj;m4dS~^Tp!o({|P3FD=zG6{=eb>cm7*kQ;QeH^LJru#s5>>Vcc42VE#0D zi?}~8%nJTr6s~M2-GKMK*}ba6$WW8y|D<79s2wys)V+FHs1L7%hGC~dzUs2Z2LS@W|TJG?+)+x1g;kM zXV4g>gZFP_{Zlkx9OA(!5D`P^v#Q_TJ+|3ttl=G9|jVa+^q zeOQZiSU=bt05)JFHjSq{$7XE7*1_hMunpU>RtBi3w?`Rr@Rb!8V?(_yW~Icqj%8&8niPz9-DHP-CVM>UcS^j9@)SE|#kZ$EFZC3{n@wLIiO_L4MfMz9xqP>dbeg`KE8M<0eN>_@bgs1(s!qUhgPf0NfX z$upzmqy6&7W_fU+Jn~(6>NE1ukNE!?HwtqDw+F};d1;_Lv!uKA|C z_z7|S1=q-Db*@<-<(IpTEUh_zS;U*Z<9RZ^*Alg?|-Kaevx1Kf?FiZ;Na1 zc8^_nP1vX1>k@IT#$Pi46*WV(Z%Z2@#aP{uBG74k7 zoOSlPu*m-1#mHVD|6>XBQY@Qc{Ey|wulK(@y>Hp(o08R?^2~90<)l1A*0I;0lb3qsDYA*ZnfpQhEo3WuTRr`c>pR*m z3ZIiE-^A0>bql^LJ^zeE-{x*ilwPZ)-B-B%2D_!%Tk4H3bDu5^Gq774zaZUzBc7$c z%eBJ&KKicYWxmVDr2kzQCCtP4qHvA4O}v-8=V#pKZSnq4_|J;BP@H!Q*W){WiF*Y; z>G~h@|4;tE7iW$zKXs3*xIcl!C~k577>|pypKC9{A6)Yd;lIfLkK~Kw)52`@oImBa zv52l51z3-DSd08K6RmBfBS$ocT|{op>KC@y7qFRsAvR&@XlsjxTVpiXTBG`F!s29e z-N=dtiCEo5SI<9>|4MR2g0)FxF8eaFv3G1}DDcg4d^^-KA3$_wU$maPBj2}D-rLOS zwnpWi{a~4IN|zm-eYIB{d&E^j?iTkhawmH+xnq_y1U63)ZzdtM>-Q9k9)>%qKvr|<<17)hZ6hGch49_ zm#Ix3J1P`U8Wnc1?;>|j85N40bH5eai>>!6Gbf)jDr~|QY(@dM_53y_8&i?_W%}^u z#@98E(uW^m>|jk==j~y2RtT%+n2(xi&Du0;)h1iBHq~5qr8#Y7WvMc?M4gwFYO)54 z4_ga@rJlJQ9rz7CB|m=-k6{IVgNx+J?_rdDc^kfoCs2znVZ8d?9SKzVb7TnVQ+!_`{C`~Csn1%6=FU2S>n|rMOS8>?Ojg!rssDsq zb@-03+O=!aeE)RoKgo5SzQ6BKaMIi`xsiR-Ip4q6`El-bAuh*txEZ(O6ZjPTnmztc z_#z&|xA7D*u?Wkt8im+~Jy`79l#@S3I2t8*G<^0Zoch;@LA?1W#;IytkC*{Q|9(PcQnVpviIFN{#ESJ9RHf43+(@~HbB^Q zmGoaV(yLy2iDzRo{Z|hCR~G#jwlHrMw#YTx3h2KmN_NckJag#37SMlXdfr){cedx{ zzL$IHbni2b{tM;I6;tTHrqX{+rvI8m|CLVvHG%%?i1IJL)hX}w@?HbkNH(2T{?EKS zR%($CTf3D1Q~FfgqVvJBj1l9;fs3(-IY--&o#OvD#?baPH1}W6zDkMq}VYacmOT7Gvhk#-K%`je(7uw{zddz4(MN z@JVCeW6`<)#=dhtAKN|m^Rd0!;OH#zvR-50^UmgB-^U)EO%(k@qe@)SS)w&#KOd_e z_xadEzS%vPhA&{DeEbJx;I@nDRB#=x#yETo!|||i*;_9RuM2k;$=t5Mm+=^$Kn~WU z5WC?_y7)1iSoUD-9C__vV?*4Ed-31!FusCCD1-fM@#FY4ev9EBat&_BgZOutN{IWV zc>FF*!$0FmEWi#>;m6xx>Lz{~_YR>?!&J<`Yom?kG&8n!xOv26^N1szVK>;kV$GniF>jE)Z1!}N4AKu7WRK6F@JrYI363I+ ze*_=Hr|~&_37Ob|gZMZ67%$*WTbro+fVq{~=}Wyr;HtiVd-VU_T!J^LCm z-}P(BbtCCg$O85a97dCb#h4N*1wiBe(a`RglH(JIP)AcatUTd&s@6kKUn- zeIHrQzk;k}-%nO$N&~Wn{Qz0ZzmBYDZy+1JZ|Scml}zy@r@CKO^bwqPrY zunpU>a8*N2oaC1MoEBSdQ2521bfw0`9{7cvzXZM0~?=Gxnka z&*3lPy2v%Fah0$q-Fvd|Y5eBk2&QuX2KN#);Zo)DYTSf%>Iu_>@fYzD61nxq6T&@- zHY~E1VKL%o28Jwh$&9416xrjF%mXEb<=WI0mB#8Od2{C_nfFOD?~@c(PE87VHFO2X z=?rGu2bf7`pl#l;z`DC>1H<}5bPCLy=i2|Qo!(>(Md4Yx1orLI2U-U*Fci(UrlQ8$ z3hdCn7Z+RibJE(1PHQYmthGopR=20RoW1niKku;J{p1Ehm>I;QL)$PR>2|^SSpUeb;~Owb$#v)?RzB*WQEO{yyLb^#9>S zu7N48A^Ug3_tH*6s}zUG4*CsS&3lxqG_r7k29c6sC02$cF9L)jRHPU|v?HvC;j$4r){5QL~4^(kqKsLF!Pn1aw zPY#c*@4Z3g9?BKUu~$Ip2lTr+UyLk)q804RM%wRbr5U%;*D>#H=+`4% z96OLs&SxUCIDYru3p$4Wjo^WKI1ks-UNYcucoJBy$jpNVXob7sZFrb-`)Qw@Fa|Hd zYw%yNf_Od#x$pwSAcHV}16$x9-~c=V$6*U`xsh4W1tGW%w+G-~zy(DRgco5N{vG}R z@59$<+fFEewcIEE7XAu;$MN4GZw3eCLmeE1r@;?Bz!FX7e(ot>hYzp|`5WW{T!VcR z48sr5Z3OzEYTX6aJT5TTbAkTe1(mJ8pi-z;ak)fj(SQELxUVsZ4faKN-N|3iv;z>a?`TvjSALp9!(}HYG(hhi!!F!Bvq`8^@X3XzzV*HP6=eQ%v_#e9R8UI619piuK z#oh<~A29y~1K0;)XvIf7|3A_&a|k04VVoa@7{npbxxB|XMI7cgxl@DR<{%9I~?;F!OS1wb{B59LMd*Gocja*hd9SiiDkaa@$caKxP2TMU`_e*%{=Fz zmEqeY&)hIF#Bmr#H!}YU3rs)%3v&eyxD)pvd=74f&2TmR9oXN0<38m$%6(x!_X+|9{Z0g7^DX_+CHn|MtuM4r>~77r%c%pQa4TIe$I*tIV&ks?IO3shG8uTKY(J zrv_Djc2Eu28_O=fe}HV+%lJRa_#eC*2bcH1lr8Vk59}GxcH-+u^ZchT(S_{(ko-|b zxljsxekQZOisye5`XLKAHs4nMA6);#T!vd1yw$)kn?#VpOs(dvx@2Om-u-Piu3La#{4IE9z*FY z{h+fvucvulALDrqRTI3&PJf|+enRbgJkOySzb&VD?>)UzP5Ae2;XU|`E7f`p@5Mu4 z=kj-t+II2Y{0_c*w8ng!n(rX-eD8$bm3&_ix1Loi6@-E3>5m-c`-IQZAAw=w8zH`- z$5tv#e6e+W_h>!eL4qjpqzF4oxCC(}A^Q*f-smQt*OZxs`~W_LvP#zC`JiUYPkE>A zQNCIA`HyM#hRgb&_sI1C77u zTZeqd#vEt2TDjM})K|DEHa)r_NtEWHbkR!f)TW`3KgoOhyXc+W4&JA>K0B=qhV z7+?L=(t6JS75Ix@(wmoeVGiuM0Pgo&`uKg50h4aVxkSB*}6yZA2M{m74P)YbTp z-NEramafH~+#^*JrjK|03fHYx!TQy#$E{{PZZ&(__@3S>zNg1`^@_lmv04tk8JaU+&cDGZ=vsBHL5%)%oG!s zKB1E28Wm?HRK|~GmxdDTmuRxe%bKjJU_zBpb83`1;8E30jH(tILM>K3<=51iqFlJH zc*mjnDErG-jH;D2RbN#|URP57tfboNl4?f=Iz#HrPpSj=Zrr-C2j5Go$DLH)`6T-p zk_v4cRew`bgL|jZGpVpAsi9;NJ(C*V8&dSpH17b06rbRGug5}4GS*2j&KTXuScrUN zewboDw%Kw#mXdQKrR=hlTumwE_*2Str<4~gv+~m^6sK1e zKgn-5oYe5_YZ`&b`PUSM*axpE4hif@NJU@MC}i$^U0L9G=yf?E8@ubstIB!)Rplb{ z_Pwh71Fx#!$m@)QUuXZr>+FAco&685GY)>0{g1D*|M69oPrs@PsGNOORZxw+25LWe zl|3`B^6vYqjDufQkd5;Lx(R+LH|hTk`xSqqwhfG7uXvNatZ&lZ-=w|2sY2`p@Dca7{w;hD z`|Pik11`u0=j^ZLIQ45~a38q|ZiTzx05lT5;qaT>58l*wIDQ&x4!p@;&^H;oy{XE* zZ!&g!Q{{)=)K58=4?WOZ##p|NF?|+e{d}pbiLpL&0UfxNb3X9wD(1wwHrB3^@3C6T zyM}M(K`X?QyfciSPzx)9GL30lBpaz(3%H*eDGkx%Bo)5>6zlN$WtXJjje5?De zPw}0Hzf#$?eDm{u#sXhweU|-gP2avujgMZgmb@#~{N!5IJ+n@=Kj3$8nIx*$U&>mZ zg3SC&iv(fW$hk(^>jhpJdMZEo6*7(+Io-?_9$3guRi!fxo6p z^96ViegNJm<883Nv*c$D6E|*mnQd7s^m$}vR)qP_m>gFOv!67^{AWzAjl;};#+17& zro2@#=09WX*@$s%#CZQFCin9(=09R8c`T+G6w_ham*jH%^?m|EAy6Vvd-utrYA z6d|4HiDAW%@nbP1W@Ad?o|=v^|4ANJ)mY8+>zX`q)@B>5hLt0#569WFU1Qa{;;PwL zV^yc)e77^D%4A#>?zqa4Wzo1w-;b+gYh1-u4VL@R2=B%nSK$XED!?uO*a&;);*9a)8X#Yxrnvg6;_Ab__iS9j4@bB+#MKR5(#$p6BCg-BWL6*0m{XOP=NEY?E_chxpt{rYqT|oOk+^dqtj=*W|MDJ#8w$ zUtv|7iZvj+0v%gg8_NcZR~%4MLuNPT$}tWIut;*@90p+HLs}C)uFDV z*;e-pZR$DNrr?e?^%76tiZ=D{Z)4uC&>C3P#`;&DHAH&h_w%e_{EbX^D6%fwiju$B z`aZ?ec~)Yd%St}m#ynt~M(=1-=GuHKYel~0KswL*dH%J_wcfAWDTmy(e&uEO=^J&b z0Lip+CZCGQbmLyU%CC}@e%8K>yIXZ^2T_u!{`_JnG7_?heQ)9>w6J+fh< zll?CzbT8+dJ6%=_ZmoO$>?iljOFX^|zuJ)gbp=-7A-~$2@~w`Ye)P|`x=61(+R6N% z%LvF6L;;F0)sOnTe)h7a~nG4YG5Aa=;e5<}Qz_X#iYCIoM(}4o3 zc{ZSyLtSb;nqzsk2IPG{D4(mqYTKG?`PT;6D-%%rmVi1!1y-l0!0JMFKOayJ^&U(H z)a&n3UnoHP4=D6r7w^^uH26+HLmvbbelMV5(iuqy6!8QU4F;J14=5fEu>ZA9$*KTr zPOoYdSrctnb+}zs?sn=99o7|E<#RpsY1&nq&b3;$wySw(xA^=|W@C^#oNHJ8)A%kImftWrrVY6YNsFBE{7}E$~xN3fc~UPj&!qEpj~dR(W3X;RrqY7RX}?A zT%&pKv}**059C_m&TiIj=);iK0B)f*-Reh@PMxP-M@ip!G0W!L?ojL7x z{TO?Ic>VbODRw^ZlIcB0UXXdvm*1}5$4)Xo%Cna1p}VtPUH*3Z+vtZ|W?8GU@=Gko zbSv+_6j|B%t#a)sVjS4Yy{m`$?@TM--OBt|w+gouSw-$F%N=S}G5$(itt#DLWR>kH zvdWPaO|7at;iYfi%KTRk^R}(3eYRC~`+HQ6Y}nVz`@gMf!o7KKEAwBitpBvibF5X~ z*;e^ZwX*-yEC2hg3atB;+L0Y6TGg4+qb_9k${ywdf2Ckmk9rSfT79_nuWDufqg4ZY zTG{{Hsv%_fp%QC&T`T>cA}eyVRndteD@Jt0rYu#23GWVEAc^;qgw|Z2tBdEfS9`1i0xmN^Lj4WC2QRyKM^UxlZAN8o> zh)0!cJ$(Pzqv};2_H21no8?m-vi|%j)>b`gT;WmET#)Y*c^C%=)w=4GJpQ1(4`o@t zIS+fqJn~=RQGj^bH+tCr?NR47#a0*TbO+s5&ov(OcU!&dJnHKVs()ve6*}x;exr;1 z5N>PeaHbXB@8SJtk47Hz@ceUI(K)vjL&kS`lz7&sgT*_IN_!L>t+9++iH$_87mfBt?Rt#Uu=0v&sW#0wnGlf zzssw@;ZF8H6X+`M}n$*w#ceKP-Hb6F0vXA1=R%2YkSoKrO$JJxx?IFf{Y!4s@&PD zDyUx9OI`IU?}dbwzcXqT?7u>VaFAzoCHJ-BtwEJsL;L2wm38C_+57vBUZLy{5|(Ra z)XMoFX63GmT8Hspi~r}4jYx*%nRhe)Scqi0GBbtzC3BOnApZjSG5X{+$iGH1y_flE zWX`u)GtXr1_#QU6;Wh6kCIbi>;wUeG0?y+8*BD?BV+jJ&M9P{2pf< zcOB*MMffs&4IFz;%NgutfvQJY$ehkzbX;tE{BExBwzu(8a zPo`BG>{H2yeJVcIC--z8^Fn=mbF)tct`aN%DDC~!8RbrMKcMgF`p~Crq;nnji}%kc z3zd#0c7Zo zE7c!l4sSMU^&Vqi6dUe(klk#=>ssH#Gvi8itl_s2@~Ee~vMm3~E7@CeC3CHP^1kD= zJRdr(R%GU?h?TW#tsKmmIzLQW+3$DAbw$d`*|Ju-@6ktLQ8XV4m`5!v9o%A#HtA8navZAky2OBHBZ%Us`DbtKo)|6R-X`I!ek*sUH0;=y3b z>h-6rzPh#QcSo(zJ8LzNUaP?m*V6x^|39}@!yhmYoRPF5)Jqg%a}g`fCalCP^S&#S zR%(5RMn6nhnY%7k7X4EPd2vGa#!KWnpRjV)Mfv@UQ7f?lqgF*UVO8#rT2-A1t2&=~@ZegWe_g68TdVp*m#Be7wnoA= z9UirsIc~{Xt5(Xw(}ez2mooo#soL@{k^iACzJFroI!CRJmmHkr5b+k5{+z4Sdo>Til#@c*r`huXPlFG_EIGw zHGPRjS0t>=6)}GQBVjqRE|(Lsui;&exl!sr$o${sv>Wz)gf3UX#+X%jc9i;0Snd}_ zt>T?As{~nE#`;e9GL>g_sv>%sD)+PRBX}A0&pwbl5?1ZAtbf#9#+=aQeE;`y_I+GV z|Nn9|6K_kXlX(fg|I(DOyj2Oy*Lk_x&c-c28{Goo%hk?8cSn9eon@CZ_sP2(dq%CE zI-W1k8;V(dEQt5B))Lw=Y7K0q|8p#14ZU-@!aEDB;fF5Q2t-bp{iNuBK4!(!2`j;d zx+J97pEA0NdRURPTHZ@pjVt0-6W2rie)egy0AIVMM|CX5SF>lL=0wV>^jxMYD9?yn z6}ysF=_>T5y%xW4nM$DO2)~O)U%enqKlwn?%6pb|Bd*h&Lvbtj;mef0r$;WZzjIZ` zb+^BrdB@9W`#la9^gLu4XC(=^}F+odqbP`6OR7{>Y;GeRr(d~`Hx+tV)tg|Jg!m=e%-j2 zG4 zccOScc-C*$$C2$Bo7LvIN}t5um46lcbXfcKq-1Gs$!fBomL z;yNJxeVaA%{ALYp-OM`QX2ssWihlo9`Wk+sYsu4|t@;V}#DT5KL1wPmqO7V-^ba=4 z>Dk2fzeO&{!JZr4#Q1*;{ogIB#&6+@EsUeK(Es0}V%H|#|J7q8Y|lZGd} zLEiig+JwDr)z$hm@+N4n+n~QkMxz__C1ia6)%qUtDfk+E1N!$|&0O6E*5|L*Fvml# z4g7A}1`Ve-Xynue<&|AczvOBac&=6<6oDIxw`^d1eYG-!S1W7B202#oPVI^f`WNDG z;~no+rapG19O!dRDbqr##dLX@%ii2l=Zrr4_(Ln%5}`oUCo}H4Js#%iib9^ zKXn7YPqRVQk8NNrZUYORyIB9;CCA>~azge)yXAr$?75J4bhjL&nFBw8VkiJF{DEt- zbp3ALf7z{$yRJ|Qw-@0hcoSC8&-IEmyhr)Y*LQi(^GPYdoOk;V=!@`fxc{f0(29TM zT7otBT?<#hDp(KgwEJ$@f_*D|0)Kmv``{tCl;eZQ$KZDQntk{^ioFWT;IBE~g#24% zC(;i=NWu#+i2uyPw<_!Ct#Wv7rTz2$4{!yil*9KLa{ctt`JO{QeWQY^Dds;|_hwz% zjV%6PiuTXC_rvsm9%EnHLDs>KusGQ3zJhgd#`=xOrmQsM zf7Ze2qqMG~|3et>W1p9gzEj%{`akcSQsCh!wI7&L$JtZrB<`*~^ndn@sppQB3Lc$e z|3B;CRhRJnGoJr2zz2f|qnBu?^Ad$&_#vME`z~Rx(7F{$uD*9txjF|MV|;=M4R!1WJ$m zlCd-Tqf5mh_WmE`9RR4tUIVoc@eTmgZ{ZyP_70f6|4r}ly$ap~Xo1!hyaNDU@If2+ zApq^r0iDnVfBOCImF)drL*Mt{&(%*_AsAT2`|mKcj=u72TEo15FoK`R`)T%k@%}r+ z;s5>qdy9X!nxUy`w;G}0z{{%7+O4|$-Hi8l(}(#_RR?#|r`fH_xkdKBpI3R9YkuEu zmEJM06I@5La2DoZglp<5uD3Yy60XUQ!D{#|&%Nt#U);f%<>2p?wQ4sP{BC`obEEj( z&$-R`i9fbmF^C@7&HUbOjU1fUu;*Rw>AN-b>~0Mn+O2_uyA?wA`~QPA{N3u!e^fIJ^O`z^~xf@Gkr(tY!@Qmv9w)o%`!8$S)H9 zM&xJU3vd_w1Kf%GqsXtry>Ohgo<#ltehSaPN!%)t_u%*ENGpT;3-UL(pMlq*i~d9} z{5Q`1JLlq*^KGBy9>DQ+@M(CA@aORRMeH}hw;(g0eW2i|V-EmiV|PIgYi7BScXnL) zZuXVBSUYQCpDEJKaWS%F#iUAEM=OJJsDMhSf@-LN+V@{(EIF%s;Ee-^R5T&sh+mdfNWe(`$RT#+=47Ue4f1< zd^_$Az9IK8-;N9M4Y@Gij${8r4Ki=-8rDbFFg{wtJ3?zzl)*P2pm_SYN_MQ_M$G#a z-~cCNbIt`hkPDIXz057|l>&jk0EW87a>PO6PCeh5rczL3fF`83zX6I=^lhUef% z@XzowD1;wyz5@AMuALU-QI37c|KPX}8HZoOzu_(nL z{uy&W{n)QWehR|aV{kL}JK^`lbvN-5Mrj{OJN zH&1@dzWI*Be6!>T`|8P0H?oK0Anv`K??d+EKeU7WpM)Po4sjelLHmFah|n&g5PScu z;@l$=w2$N|+Sggy7qXf8kQQi!Ca8x7sD(PHu45hpDvz;;16kh0bKb>sA6ZsKKj19= zfHeIAr27>8gcs;1gx^%&3Hl9?Geam==Wr0{G5IW`G5o@A=b%# z03y#$DT*9+qX+GH2pM*9p9^vy;JAO^n+o}9=jhl6J+$%Q3)~0LsT~#1__-IpiJx!7Z*l)K=2`E?&y_dP2H;chX~s@}i`)!X!}V~0aNCi0!rgE) zd=(yn+u%X?CVUINf!{~rTkuul{uypgD1;y5R)Rc)ypMAa!c*9vgKuK*T(UJaK%2x#z!t$8q}pFsuGKo{qI$mdA6 z8`;P4kFcM^9->Zuj~v1+1^~&C291XOCH!An>lv|;x{fUxb*fxQ#DIM)sKdo~>R9;N+5KlTs|aBdKW zu!muIC;Pu4f;|c`&cz{tJqf9@F^yJ@(dS*y`^f8gA9+3dFn-b*w|>3ySFTsVn)NE&%6iMg>*dMlLe?5KQ^$I{cw1Nk`-~+Q+KMNe}J99#Io#cWX z?75H^mGYtBju%wumWp`c&)vkgHjpJ8m%h)w+7H=-i>%aHU|23}Pt2W5* zl!W;mkR{yOe8 z*nf?jg9T{d_|N%f!YANz_!PX0J6Y!&Zv1`*#DHw@8_@iI2Bbh6wBh9}a6&e?AO~_G z&+u}lZ5ls$#~hdiPzV(U`)q|BLxp{|64y%GwbFL2v|ZbA?SMgTFvrJ8WtKQ*iwv`6 zm~Mk@E5=rA+e)yN*tSw^rMAuZEt~DfEZ2r}jyhx1Yoj*uf2%fWv#!z>eOg!ZLs_56 zxCU-8hHcwsY}<`-qcL_EgAZ(fCgWyf++vKK7{=Q+6UulsK8;7?&v-MwjHf;DMYs)a zhdba)a3|aad*RD)H{1jJ;4AP|xEJ=r*Wf<59}d6+@F0909)fScLHG_l0*}ID@K5kK zd>0PE_uvV55)Q*t@H9LF&%*cNIrsq_fgi$;;Ky(jegbpZnYQV$O{ZI|^XPYL)g~kqJwssk_ z-^8|A-cGWq#!yGi%A{U)FP=IAkIRkkr}+y1TfqUD_i+tO^x>fd}O)W7wc9R6F>3ZFu8?Mz2*Xf40v$D7AhDAqI8Opd;8P_S} z2L6qpB#haZhiQAjhdG~NpUALJWY{ORDPtQaw%I4P*(bKyC$=kNJ14f=C$`&8H!9;s zoNlz8ZnT|tD1!y#s5Xq&D4wwtug zR(ada+GZ=d?G|mb)!VjH+iaz_kuhS!B$S=L$)26cc5S~_l=t=<_&2AvXYg-ZwrM+O zZ`Af3qMx|^X36>)ZwW6wRwQn-c zc5UCRo39nseY2e@yC`lpC+&K_MYmilD&8)bTa2?^_dB)oT2c9SG3_+YcK!do{{HWG z>2vIq*ri?Of0uUaR^7T$x8ABRXpiaGWL*W~vW z`-4iR9#zNW7|r~7o@ z?YeKD?!)p!-LL!axJUOB=l%P1Ko97Fz5Lr?*t$Qc2XE7Z-_V2K(bx6$z5MUhL;8lk zsejPJ`bT|B-_}7LyzcOII(Ypf@SG05l+t(forB-kBYNcer{RbmIryX=;qXyCa+v>Z zuw9Sp(J$)JgL?EyJ$jh`ZGh`f^eBb-m>zpTkA44f{geL5{2ziRWd!D7J+8-z^1J#j zx_$Sc4(X8jhwtgoQ#$lR9s04prziBp9zB76PvGr|2lT{)dctId`g;QNF#p?NyPo)g zo}eV3)RW)SlZW->Q+o30$Mxio^(5tgSop%z;XOM1fDS*X!{5>2$94G7lRC^lr_Fie z;b{(^(_u>PqPa~^>8U+>>Pvd+0X_Agp8B4iI;^Lj)6;tT0X?H~N6&mo&v5*p zo_R#iJf~;%eSQDnqXs*!i*lQu({ta~b3f1zbVNVY4-fHgb4WkZkM-k2{2$U${Y3w) zpXzzN=+%opz3A7A0lnCx7yEC6FA1H-GBuW^F^9&S8q3z0OJg}2%hgz(#_}~*ps_-Y z6=}?^v0{ytXsk?Qheu{MqQHP)`N z4vlqctV?6v8tc(mP-DFs>(f}j#zGn!(Ac2HhBOw|*s#W;8jEQxuCav1k{TP;c&5g) zH15#2Q{&khcWFFF~SUaRps zjn`|uLF0`YZ_;?P##=Prs&S9Ty&CsvyiMbNjR!Q|uJI0ycWS&#4N)xOR)@Y(u6Lp%X*F=LR8a2_RiDpf-XrfgU9!+>P;nPH$Cj6S{ z&_t&ux-`+Pi5^V^HPNeyK27v%BBY4{O$=&cNE2c4hQh>%CL)@MiZ>A^;+jZ^n-tCJ z7sdKTseaL_UwHJ3ZvCQ9zv$PbTa(3_EYW1CCd)KguE`2bR%)_JlhvB6(PXVA>oi%f z$p%d}YO-0AEt+iAq(_rpP5Lw$&}6$NJ2cs;$u3QHYqCd^K~45*vQLvFF{H_`CL@}R zYBHwDxF!>tOlmTv$x%&ZYAQ=p4ox{Vm8~h4rgAiutEoIq%QG*zyt3QbjNs!CJUnyS%Mt)}WURj;WAO*Lz(MN?i)1vJ&JsSZtbYN|_9-J0sr zR8UjBn(EV3zotT(8qn0BrbaYfs_8OKmutFG(^ZQSSOIdo!p_iO`DPJ!Y=%qrvRHT>O zdZ|<|mFcB&y;P-_n)FhOUW)3a#8YCB!w81~5=Mn)ax{~tnS9L@Xr@Rr#hNM6OsQtd zG*hk_JXC6?N;B1(snJZWX6iIkubBqTG-{?vGtHW5(M+pmJeu)orcE<`%>*>lu9*&T z)0^qiOt)rwG!xWJuV(r*)32G3W(G7fsF@+ngf%m)nGwxIG!xZKOfzxKB>3M0UxEkV zL3kXVgv0QhW|EpoX=YT%Gj%*m#~nKE)bVT`cjUfckyLG%+ z#~XFLNv~w;l`OsD)GN7qrAV*1^-8f`snaV_y^_?40-Y$*iBg>?(}@b5sL_c=ooLdD zX7NVOi4mR5(n*I-mg;1gPL}Irg-%xMWQ|VN>tusYHtJ-PPB!agi%xoV(x;PcI_cNR zfKImSWS35M>tv5k26eJmC;N1=UnfI4IiQn+Iyt11VV#WVWK^g8uQ_yT=#<|W-NpzS zV_2tB7_XP||JsN#BF0EyyxyqS%lJ13Wv`9u^>Y5rK?VOtR`PFT)oW2>V5>Gptug9e z8!`sAdSf*3Z%Y$TBmd^0Nv}5>*=h`rF}!-+XJp&5f!~0!wHu>@ewKfb#M3^T%^!Z+8Q5bC>5UHlO^O{RMGO;Hhe;S? z*-p3)lYWOuzr&>8VWQ|TNp~18oqD5-e-jIa@nr;Ojo^Hji5g?sLvYyxDRhy7?YVc^ zvz=a-$!C{Iugg9;>#~C(Oe6>saMopqK$z%Fmby%qx|UsW>*n8t>o&1_d|x+L~C6 zFg}*;*wCWaB;9L5*~MTNf?WiCdZV9z=1)xSLvIYu=FMhfu^rA7nJmeuG|0b+1H;6E zFwqY#+wGhTnshLh-R;N+O`L-!YZ!LuWu-V;c0}1h+rP;$7s!zDgkc|=xQC23%fk3D zNer1thiup78p6JOL~KJQwjmQd%)f~NL)b>SX0!MwdW@hkdToO=X0t9DCf+Pk zz`*9<-}rRQW>dNdd}1s+nY@`J+oK8MFg7RuZ?qeOWSqvQ(q_H|_{ zM&`WBq+}X_F^o?mhy$^F+D_2c8)Mm@x%TmZc%yIGV&Ym3Wm>ecFMBb~ZCPnFU$&E1 z+m@r*9ODF2* zRABN{V4}tNqc9;&o@l+sAi@GW1`|60DM>;`#e2WAam}VQJF3%}I_=cyY@N>0>0F)8 z)9E6ecI$MpPM7I)txkJ%x=p78I^C($L7nc?>5xtj>U2z}<2s$vnM|F@(iw-&xWvCAq@SROqd?s<+w_i0Wlgvn(2yg=blKmqjPux|gke%OZr! zklqS--U<*$z`HE^Y|&n*x7v&JR!8+)9h~f_Sr(1UqI+2kEQ^ROdOUCS;IqeTi(sDK z3KB<d+{bZV7LSptoY(Z^i1C zg-372-EYNkOb{lS`Bo~7aJ-coK@=^EvSm@dEE<+Y=&dAyk_ePXq{vFD)DE61x4onq zmxX6p1eQhjvKaVNn1SV>gMV@uG!Ao_0nIrZ#jH1Efdib74KB!mT*!lbD1bsJ0yh*x z36w$^ltU#{K|Pqbo1ht5!2@3KK|6FpH}pUddZ7;nU|4f5^5r6Z7wNfbpcd+YxLm~L zA}$wkx!S-F=;I<@R|gQcs|$$VMf@(}cM(6Yr_8zfAq1r38iXMTYc2=9bI>~neRI$! z2YqtTCx`s!kX{b)v7J5;5+i5g|_z(uq)q5z>lKhY`|?P=^uHj8KOW>M(-N5z>#4=LmX7 zsK*FtMo2S4*+Lfy)M5vPp`HN5|5$Yt;0n|x^ zGK^3s5y~+_IYy|H2z3&nJR_6~ZD%e*c}6JD2;~{!I*5!w1fmcFbdDr|GK!>NRC7`E zkD`AR{iEm~MgJ)JN6|lu{!#RgqJI?qqv#(+|0w!L(Laj*QS^_Ze-!O&=pRG>82ZQ1KZgD> z^pBx`4E?#G&Bf3&hMqC>jG<=?J!9w@L(dp`#?Ui{o-y={p=S&|W9S(}&lq~f&@+af z+z#hr=ov%L7<$IgGlrfq^o*fr3_WA$8AHz)ddAQ*hMsZsjH729J>%#ZN6$EV#?dp5 zo^kYyqh}mFGmf5d^o*ltoO+H^&vEKG zPCdt|=Q#BoNB=nb$I(BI{&Dn=qkkOzNP>V zCaBj0^_oQAB>E=NH;KMU^i85~5E-MFNuCh^h=^&68)0s zmqb6F#&bN1=aT4@M4u$OB+OH<@MMcz~7Jw@JAanOA1|5=#oO06uP9)C50|2bV*TXDfCHEXDM__QD-UiN^wo5&@F{-DRfJr zTMFG$=$1ma6uPC*Ero6=bW5RI3f)rZmO{4_x~0%9g>EV8E`@$6>Mn(jDRiWRGM7Tn z6!n%u*A%*rqT?t!j-ulz`i-LBDEf_}+bDXCqRS|{jH1gZ`5z_!qvU^-{Ew3VQSv%U zUPsC6DES;EpQGe+lzfho&r$L@N;b0a_XT0di8d$8;YSsZ|6B68*(8Z zihyu=gv)D#0Ce+zQT%VV<%8b0^N2Lh2-}vQrML5mGM^~(i88+xybuI@=J!EAgkT8B ze0~I?db=PK2wy<>0vF@~aTN3b;R^^~K=?x9Ec5{J7W$wI$YUY?3h`HoTGjZeDb(8m z;tAyF?e3t(Pt+{&2UVz4$>p1T-WCuT?RW!R53PEqyQWy)owct|D+;2f=er0%=vnac-#IN+cBq8!Cy65^PM(8kSyDX#yO=0zLYEEg6P>j)XY zY{_vtHVSg3IJiS{2XYN|iYD6^HFQ$-&V0Lc zjS9|!bIxK!Art{6WFHrwyIADT5?o5`qBL=r8f=BA-);_Q!^;=&|H|&g3)*<6Z8<9p z%bqK-R-WTB=SnpBF%fWV>TB6XDLZMWb`$vHb!IY5lknN88bzw?GsbJR;T$=oNFD9V zvDDy3Gj`URcpam5C=+`vr$Z2ZC)DO+Z83&UJ&7+*sEyukaw+f1i zJ*DRwxN9_UzL|#7O6(pq_HgXW1vGDK1mfwyAFq?2>!u_aU!Uu#)43pJ(Mx6ab4i70 zC`0H!%w@;;g9K$aG2zN}mGl~C0L9>88x1H@O$j|$hPG~bY``2bgdB$1ujGg8g-$1$? zE5ng_h7|L3rRVAL&eOe}r#m}OH*!7|hf$^JGNd!RmF5AO<}sP34WwzAX)d*Nb{F7> zX`Qso1#YN>Ht+*}n7TeXS4-zqK{eC>Y2*+-2megXq?v+A z=MW}0ABvz6nxPdqmq%QAPRIqq=QRQ8<&odKxYGGKKsxyYKsfitUlal2FG7EJ7Gwkd+{EK91IodTKR5Yx zlOK0Kke-`(`OqJRRDTa5q~v#u15E2%DkF*s!6Mw zG^+7mL%G%vUk!28l4cz`)e&zU<<&qLG*E^Oq}fQlHFDfYo*U7zxlrkrV!&T3d2A)$ zt(3io_&k({hkSX7$4lLMNy}FYg!PfOuNR2JhkqY+)kgYlq}@ijw2|jF!gEVVw~?nd zuC2EAg&=CQQM`Us@uRw*Z2BpBKVJMq-%j-HB-KH7I|$N=duKrDE~!(QjiL)QKe%wN>z;TF6 zC`9#z2os_R1}F+9p3_4Vb(lEAq!&izFz3SwrH4^@IHvRn`5f^8af}dtg#3&UeuQcq zA)nk8(%b>k%si!;SxQF-7xIXT$)z+CNogjL(o7zunY>9eF_UIOCe1`lnhBIN6DVmW zNzzP`q?r&&Gm(*IvLYQPS|%gXOgf~QSV%KrkY+HRW)PlcfSqQ5on|1NW)PiD5+8Sn zGy~o=1KBhK*faytG=tAHgUIwKYBE4fk5as&Q7tftSjcSDLRMG{jx4Cu0==IF`T+|( z0~eTxUT|fDkws7pt_sBVEf!a`0T3}}JYzrbYcLhi5@X#ER$gry7?@*1EQ zhJbVVg-`)aS}3T3kQQj=3xy?sdm(u#Oh8HtMdYa{sRcLjxbf>IeRo_7#ZDkU#hqFx zA)h5KATOopRO;128TyrRu51K|r@S6WuL8f7O{M6^JQTd1yqIv}p< zexR&sh`WZgXmJZQKJaUyws>KHY}At7T29pl&3${J4lQVh3-xGTUk2o+p5ulbAhrf$ zC#6P8hb^ZIjrgaHE;Ld?jg)#L+BH$KO+?#7&YCFECX#6;2hI3v83and6?+J;Sz}`UuU7YJic9VDy#nOWwJ!B$? z{~)T;f)|2RXK(~4hF%ivr5bz5L@$}>C1Wh#EcCg7Fnz>Bt6S*D-an)Tp8X4rlgQU$I`@#@mhEQ?{|3l1EACW1Euw8%`uB28hDnT17W6c!6P zrU@(-5f2ZhMR!b##o0jkV$v=y1V0RDv4s5d@L4RyolAeQ3_oS$i_3qpETqM9@>|Zi z@>(rcpkoE+Dgs)pBupi4l`YT)M5UkbZ#I`9!k9X0PTP~4Pmh{3&=YUsKqAqZNeW9sl{ge zwBWym__^R0J(RD9vh)@~nHGKc^I`Xqx3(^zY*_F#*Z(3%?Ct`Ryc~F2Z+FrrqSTo3P!~ z9qneZ7e9UY@1sonNWY)>c;YOE$bX1@4p3eLxD6r)$=4udF^HbDgGKtDi$jDRBF{tU zJVgD6Da$Zr5XRpy@eC7ognAtzAFMkq(l=dXt!9zF<|1n^i_HBlGQYdXJnbTLr;ADQ z#hmFPbEAvQg)TD3xyU17F-5&Fr?1rx!hXtK?igJ=W~fCH=(6G;%6dni3z-=Jp9ovmhxJ(#Ejb#Gj2jNx+{R=V$v(dUkUl6gSJEmZKmlDh z2b2JL@-+i~eUycdaBZaLC!C*r`6**&3YP+uPk_9(lec!lcMx9}adlDVUF4H%ajBbg z-T3Jy@7=`JP1qjd=^;-+%AcOoQZHrHKcb}&akAXL6siLJ4p6=WZLIG3Uh66Z(bUv9ET}K?#)Xhl=lektV1o)ompb7XNmQUCDt&O zSdUm@9b$>@$`W(xOU$J&F%Q1P+WZo8-{&)Z5Yc(=zvr1PJfAfH37vP8Kp9l(JQFPE znFTn{xc0m=sq?IXpU-a4|IN{TB}R6pN4w?FRmHNI@ZNiyNC1MDx|^!;;ENcEGwaj?Z8I_|Ch3{`SY6 zZhLnA`ds+-*XN@D&gK98>vLuEug}eozdpDA%QN-dd;H^_1JA=>pS%7!|9kjkD?fMr zv3C7&_Wbwk`E%d%$KCVq-Sf}y`D5+>?K$)K*Yvf_j&y1uTTFw`aJpW@voow^FHzKJ^A-vpMGHf^!xkg zx&MqG)IZPtan2t-{@gD7^M3mN^uL?WD}P$3K4Z*X3#ZuTQ^sf4=zZ|K^YP?yu{|zyA94Ui#Df z(@#I%etLiT`7eJg?zi`s`PaQ)XfeEe(7KmHo` zYkv`Po-dxSp1(ccJl{P(JU>0ZJik2?&(yQvS@bM30%sJn4VZ|D^v( z|C4^d*NrFrK4Fa~{XS(4pR&f2em?__C;d0zgr2k33Ph;arzuz;* zlYW1jm+_?kNxxsG#*_Xh{eC|k{uVgnN&l1nC;iX*pY=cMf7bu3|5^XD{%8Gu1|84( zpY=cMf7b7l-+0#ltp8d6vwpu%k7xbQ`k(bb>wnh&tp8d6v;JrO&-(qkIiB@D>wnho z*U<5-|5^XD{%8Hq`k(bb>-X#Hc-HUN*zv6YS^u;CXZ_FmpY=cMf7bu3|5^XD{%8Hq z`k(bb>-V{1JnMhf?`MkPXY1h;?C|UM@ay(?(f^|VMgNO_zm5;Tj*l1pW>v$iYWP`v z_*s1T{%gGGf6@P<-_Iw*ulM6c|BHSVMV$s{d90tNvI0ulis0`+0A?>VMV$s^9me!{?##s{d90tNvI0ulis0zv_S0 z|Ek|-rtzx(O~21h!_SSww;1D1|C|0d{crl;^uOshryJ&U!<=rs>3`Gzrr)<7<4ymY z{x|(^`rq{XPISEKf7Ab_|4qN2OUIl3H~nw=-}Jxf_bzk1>Gw11@N>d=)BmR5?-k=s zzt45U=eptBmGP$kO}}qhhBL>*&lKZL|C|0d{crl;^uOu%v&Qf_aJ=b%({G10d^0oN z^uOtU)BmpjUH`lOcm41BeQq4@`rq}x>wnk(uK!*CyZ(3m@A}{Mzw3Y3|E~XC|GWNo z{l2#z@A}{Mzw3Y3|E~XC|GWNo{qOqyj6UA=zw3Y3|E~XC|GWNo{qOqU^}p+X*Z;2n zUH`lOcm41B-}U>pY52Kn_`P+!>wnk(uK!*CyMEt9jd%Ti?i%mO3F zq5nhwhyD-!J~xaH{U7>2^!s^jeCYqs|Dpdw|A+n${U7>2^nd98(Ep+TL;r{V5B)xa z58s3hGt6OzIm|G}hyD-!z8xFB9UC9|KlJ;SZ1|RJeCYqs|DoUK_wk|sL;r{V5B+|2 z9Da5jzD*k+`akshoHIW3f9U_v|DoUK{^941@uB}izi-}#Z{Eg-e&5><-`fwLjfU^; zhfUtF$s0C#!zORo@)sHzfIn-$s0C#!zORo`u*n-XdBgWN!zORoAI{ZvIZ1RRp z-mu9VHhIHWslz63_&RmiJ48J4_m$Az1gtU8@}Tjey$$2dc#(4*y;^iz2P0wu+J3}H;XAQmt2cbrJbX4Az8@RjHw{0R4?mX=TfJecH~f4)d<{K(4Ly8*8@77G zR&Utq4O_k8tLb5@H+-))Z1sk%-mujhwtB->Z`kUMf9to^8{TCNTfJecH*EFBzxCVd z4O_iot2cbbJ#6)ct=_QJ8@77GR&Utq4O_k8>+fN!H*EEWt=_QJ8$LUXf9wCZ{(tND zIb!%6G5)RJhHrRjHvXgEhHv=VeE57ZZ1{!^->~5u|Iu&5H+-EwZ1{!^-|+ck`1~X>@zTxxmu<08%eZ!`2*z^sXzG2fhZ2E>x->~T$HhsgUZ}?tk z*z^sbp@&W1u<08<9}SznVbeE!=QC{jhE3nF=^HkE!)K>q(>HAT#y|T1(f^Nr+rHs5 z)v)awwtd64Z~UYGAN^h{4Bs6MpSOnB3&X~5*!Ycq^xOE2fArh=jeqpp`i+0|+xv}w z^xOOmFDi#u62oWwVe>a^{)Wxpu=yJ{f5YZ)cvUfM{)X2T!{%?;{0-m7j(_yq{EdI~ z`|fIZl`(Ap#y|S)|HeQ1Z2*Vww1)4r#y|RP0mnc3|Iz=CelJOfUEr__9Cm@jE^zqn zYj|-w>;i{f;IIoEz6%>(lMJs(hR>D5=gMIhIP3z4UEr__9RH`^E^zqHY}f@3uTzFy z;IIoEzC#G*%2N$moMUEr__9Cm@jE^ycd4!gi%7dX6*8Fqog zE^ycd4zFg07rMhPaQHrC*aZ%|z+o3S>;i}HM~3f5hHc=m4IH+C!!~eu0X*yjhwt*n z|LONKc>K}-N59SB@VaN%430ngeb+tg28Z3?_@m!;aQxA4KRCP+8h`Zv(QivQYzc=g z;rOHfkN!XU?Fxt2M#HXf*cA@D!tqD{AN_yy|IzNXvJwH4@J-_+;ic%^jvwaJvW|P&&+e@nS1U% z51vQQe|w%h&z=|0tLM%0?)mWikLQ0q|ML9X^B>PYp8xac|Dykk{xAB!=y$~?U-bJ< zH(&IB(f>ui@7Q8$7i+$J(Qnb0FZ!+e@?5moNH#Z=5gsE&uXG{}=rhfcc`|3NT;v`))5^^n1sWFZ%uMc=AQR@0;WM=2!^E zcYm=GjFn)#nvA7jyo<>f{nmmpcE}g~z8{SDGw~gDd`F$H`n|@?SN&i0f7Sn0|5yEA z^?%jxd+gX1#&_BIs{gD0uljw@n6LVM*O;&Rzv}m$W4`LQNy%6JU-f$z6kEi6)&Euh zSN(R0`Kte`{;&GK>i??W+%R^9`Kte`{;&EC24gT7=7R zF2->&j*IHitmAAzm&iAb5-H0!d1oliEvfHD&O>f)BjDsTYX`z;(O=tR^hF}TjiU6<|^I| z=9_-zD&O>f)BjCBf0b|gzv(y335OL9D;!ohtk^r{oBnV5d8~ZX|4sil{cKjg>Hntx zoBnV5zv=&`|C@gAVPfbLL!W%t|6Tug{a)S2d&+#*&vnIX{P?as-}Qgj@0EUhZyvAp z^IiXU{onO}*Z*C=u~5G2|E`}03lA0^EZ_Bi*Z*Drcm3b>f7fps5Kb(dSU9m58pY5k zoLG#FVm1(Cqxhaa21nt{AwTq+GQ>bVMo%$N z4~rJ#rx-uQ_$fd1d%qCF^%$e(3+9|A+n``hV!> z+G0Qz1FEoX`Jvx@B8*!Yx0q3caf|mDF|f)H{YF;#q5p^eANqgj|Dpef{vY~(=>Mtz zr~aS%f9n6K|EGTMLc+;~lZ&BNe(L|J-}{jm-p9M^cux}Vuj5@we(LAw@>Bm${Xg|H zb@8q*TwQ+Z|Eb^DD~w%CKEm0>=qo?<|J3iDV1DZVssE?`pZdKc%uoH^6XvJ>pZdKo z%uoG4_5ak*>cxB{-XZ3vej~D&lH{lUpZb65|Ed3{e(xCbQ~yu>KlPiN#5>2>4aVp! z-aY1*{$KhH&|-=bpdi2W|I+WhWPa)YrT>?H@9qN{1T@Gm{lE18(*H}p@mjpY%rE`F z^m~_?U-}K&@=L$>n)#*wm;PV+f9Z!6@=HIk5Ie@$1jP6)b^);q$S?iB^qa-xm;PV+ zf9e0F|CfI6tYSM5+kyPj@BLNW295DtYzgv9|1bT&^xGBWw|>$gzxDstPdwzeej~fs z9OSot`XRsdo8RQOe&f6R*8f|-O+tR_|E=FNC%^Uo*8f}oZ~edZ+bra_e!GSI)^EO( z-}-;+|E>SGetU-e*8f}oZ~edZ|JMIo|8M=j_5arYTR(vj?}OugaDMA2Fyg&%e(NVO zLSn=_yBG(?yW)6PoZtF?>-Qcn{`U0wt^c=v`-zYmndqPBCpJQCWTJnfpWKLnVJ7;$ zZ_Y&jME^wpME^v;chH&WpXi_Hx5vmt|3v>p|3v>pzmZ}l`n|8tME^wpME^wpM85%J zCi*A(C;BJ)C;BJ)C;BJ)?L0z|gdmB%M<)8c`_4qa_unyo%tZf0|3v>p|3v>p|3v>p zzjx%B=%47H>YwVL>YwVL>YwVL>NldyRKJN;ruwJ)r~0S*r~0S*r~0S*r~0S*O|UZ6 zKh^JjX{P$8`ltG*`ltG*`i(L()la~LfC&K;0>+g}kACm(Gu1!UKh;mhgp3Ip6EY@K z{Zsu@{Zsu@{Zsu@{f3^I>YwVL>Nom~(PyUmZDTUkZz7hd{ssMZ!dcM2px?XQ7>C9< zG$v$O(7&L6LH~mO1^o;97xXXaU(mmxe?k9({ssL;rCHE#Sega>3;GRAv!H)Lzg-&VqiM zhb-tfYR!WF1^vdYSsv|Dt}gy%^hOQU9X;Mg5EV7xgddH^R-L{zd)9xLMS1_moBbi~1M!``ekv4l3q< zS=4W=n??NwyIIn|q~CBirh&1Q%98#i{Weos(r-7FCH=NjS<=6xe@Xw6{w4iO`j_-C z>0i?C=5m(wFX>;>zodUj|C0VC{Y(0n^e^c*Ma+`^CH+hKm-O3bWl8^%{w4jUidoXX zq<=}j{YsYfFX>;>zog%uD@*#9^e^c*Le7%@CH+hKm-H{`U(#>;k|q61`j_=D>tEJy zv>cPjm_%k-|FV7~=Ge)^PA1FxZDq2oe_8*s{$>5k`j_=D>tEJyYoBHP%len~+udhb z|FV8Nn=I?MJ&RFvmh~^|U)FCR9aGF0NoQHVJx-SOFY8~{Z#bQ0{mc57_1o%XS--tb zmh~G`XIcNU{$>5k`j_?FX=GXdvi@cL%len~FY8~?zoLId|BC(<{VV$I;R;8rs()4gs{U2|_JCQ{ZwHlC{j2&{^{?t*)xWBLRsX7f zv)QcbU)8^=e^vjg{#E^}`d9U@>bEbR;8rs()3#32#>Quj*gbzp8&#|Ehk|-mL0h)4!&F zP5+wyHT`yvS<}C!e@(v)WY+Yr={MHRn*KHYYx>vpujyaYzovgp|C;_a{cHMdC$pyC zj5urh*YvOHH|Ecp{x$t;`VITDrhiTUn*KHYYx>vpujyaYzovgp|C;_a{cHNy^snh( z)4!&FP5+wyHT`S)*YqtENuuHVc#>-yLAuj^mezpmf@ zGwb@-^{?w+*T1fRUH`iNb^Yu5*Y&UKU)R5`e_j8&{&oH9`fW(Fu76$sy8d-yLA zuj^mezpmfTH0%1;^{?w+*T1fRUH`g%o4l;+U)OJ&mv#O2d0E%Lt{;%dy8d-yLA zuj{wp%eww`{p6Cm)W4~JQ~##^P5qntH}!Ao-_*aUe^dXa{!RTbO3dT4see6Cm)Ne1GP5qntH}!Ao-_*aU-{d}<`Zx7&>fhAAsee=d zrv6R+oBB8P+vR0b|EB&;{hRtX^>6Cm)W50UZZBKEF_C z-EF`7rGHERmi{gMTl%;3+Zbj`|CWB#AY1yk^n*Ux(!Zr2_=z22w)CS2+0wtI ze@p+C{w@7m`nUCO>j#6ft$$lTwvcW8+xoZlZ|mRIzpZ~;|F-^Z{oDGt^>6Fn*1xTP zTmQEHZT;K&xAkx9-`2mae_Q{y{%!r+`nUCO>$eFHB9U$V+xoZlZ|mRIzpZ~;|F-^Z z{oDGt^>6Fn*1xTPTmQEHZT;K&xAkx9-`2mae_Q{y{%!r+`nUCO>)+A8qkl*Lj{Y6} zJNj*Uv!j1U|Bn70{X6=1^zZ23(Z8eLjy*g2?SI2`WJmvw{vG{0`gip2=-<)5qu(wx zcJJBIZ~GnwBs=4)xWFX20y#{clGb;-_^gXe^)=8l3o3_tl8CX z&zfERyZY^%1CnJ||E~UB{k!^i_3!H6)xWEMSO2blo9OK7-_^gXe^>vm{$2gM`gis3 z>fhCGI~_nR_S4zbZ$q73{k!^i_3!H6)xWEMSO2blh$g%GS%U29x4F%({$2gM`t5JS z7-U!fuKqp!d;0hE@9E#uzo*|uH+%YRwzH>yPye3&J^g$7_w?`S-_sB5WKaK|{yqJB z`uFtj>EF}8r+-iXo_??=d;0hE@9E#uzo&ms|DOIm{d@ZN^zZ54({D4LJ^g$7_w?`S zXC$(xe^39O{yqJB`k|rh>9;M8ZE)+SEuYX_vzW#mv`}+6w@9W>!zpsB^|Gxfx{rmd&_3!K7 z*T1iSU;n;-yXV+FXJ7xm{(b$-N9>@puYX_vzW#mv`}$ds?CXc4vaf$%|Gs`cB>Vbp zr?amgn##WZef|6T_x11V-`BsdAE*jXk^}t*`VaIU=s(bZpdYphY?TB32l@~6ALzH= z&w>5}{RjFF^dIOy(0`yGq70yw1N{g35A?%YInaNgpI-^Tk^}t!R}S>sdFMd?f&K&i zw#s1ya-jb}|AGDk{RjFF^dIO4G;^TeK0F8d5A+}CKhS@m|3LqN{sa97`VaIU=s(bZ zp#M<+q5ebthx!loAL>8Uf2iN)Jcs%Z^&jd#)PJZSHp`*@L;W`CIn)oIOa(fsQ*y^q5ebthx+Z;bEyAN|Dk?c_8jUz)PJb|Q2(L+L;W`HIn>Wei{zLtC z+&R>LsQ*y^q5ebthx!loAL>8U&sXG7KM?2df8x=9sQ*y^k$zS^NBWQSAL)nY!eit} z|B?P9{YUy4qa5i!(to7?NdJ-kBmGDEkMtkuKhl4sAH2(v{v-WI`j7M<=|9qcr2k0& zk^Uq7NBWQSAL&2Rf25yh%8~vf{YUzb^dIRz(to7?NdJ-kBmGDEkMtkuKhl4s|49Fl z{v-WtNsja%=|9qcr2k0&vHoNI$NG=;AL~EXf2{vl|FQmK{m1$_gBp#|itp8X)vk>Mb%vO%|AL~EXf2{vl|FQmK z{m1%`^&jg$)_<)3SpTv9WBteaxrac?Io5xy|5*RA{$u^e`j7P=>p#|itp8a5iGEHZ zC;CtHpXg^Ma-#o4|B3z+{U`c4ik#?&pmU=CME{BY6a6RpPxPPYKhb}pAGQp4krVwV z`cL$q=s(eaqMz}}iT)G)C;CtHpXfi)f1>|H|A~HXBPaS#^q=TI(SM@feTAIp2eJcX z=S2UB{uBKt`cL$q=s(eaqM!Z9iT)G)C;CtHGe$Ypf2#jf|Ec~{{ipg*^`GiL)qkr0 zR6mcDQ~hjRPW7MaKh@8V=2ZWw{!{&@`cL(r>Oa*Fisw}Sss2;_3`$P*vrReGf2yBx z%BlWS{ipg*^`GiL)qkr0RR5{|Q~jsOa+gs{d5~ zss2;_r}|IzpXujbbEf}H|CxRUHfQ?J^q=WJ(|@M_Og~GQGyRNg&h($@KhuAv|4jdx z{xkh&`p@*A=|9tdrvFU;nf^2VXZp|dpXoo-f2RLT|C#rWLbN%P~&-I_{Ki7Y*|6KpM{&W53`Z=|n>p$0juAeu`x&Cwg z=la>Boa^VrWLbN%P~nWZ=`kaPVG4CGw@x&Cwg z=lUHS2p^Yo{fu0=&oE9o*MF}6T>pjs3;hlg@{!9Ir`Y-kKvT>mym-;XDyVQ_N z{S0j`^{rT$C(m-;XDU+TZq zf2sdc|E2y*{g?V(hsdS=OZ}JnFZEyQztn%J|5E>@{!9IgWVqg3>1TU$rTH(ztYdoA%u{ zrTH(ztVrD z|4RRr{ww`g`mgjm&YUa#+;Xn;U+cft&-dk8|F!;W{nz@h^y zuk~N+zt+!~=34)?{%if$`mgn0>%Z22t^Zp8wf<}U*ZQyZU+cftf32TU&9(k({nz@h z^yuk~N+=Mr%Z22t^Zp8wf<}U9AmEa zGqJhR@3uj1^xx>e(SM`=M*ofe8~r!>Z}i{jztMlA|3?3f{u})_`rSszjs6?`H~Me% z-{`;5f203KKOdeO{Wtn=^xx>e(SM`=M*ofe8~r!>T{p>%{u})_`fv2#=)ckLeo}7q z-{`;5f203K|Be0|{Wtn=^xx>e(SM`=M!);pxzT^4|3*Kno?HF5`fv5$>c7>0tN&L2 zt^Qm6xB74O-|D~Bf2;pi|E>O8{kQsW_225h)qku1R{yR3Tm85CZ}s2mztw-M|5pF4 z{#*UG`fv5$>c7?RqDyY|-|D~B&z$B~|E>O8{kQsW_225h)qku1RzGi_Tm85CZ}s2m zztzv+=T`r%{#*UG`fv5m^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq z^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq z^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0Fq^w0F)>A%x|r~gj>o&G!h zclz)2-|4^8f2aRW|DFCj{dfBB^xx^f(|@P`PXC?$JNA%x|r~gj>o&G!hclz)2-|4^8f2aRW|DFCj{dfBB^xx^f(|@P`PXC?$JNA%y@x#v#*TtDxgx&FC+_e?U^Ki5Ck@1juV`rZG|T)!LOnd^54 zJaheX{d4^cbmsc!`se!R`se!R`rS6mT>o7ET>o7ET>o6Z8xV0jC3F3A{d4^+b>{l# z`du^1T>o7ET>o7ET>o4@Z=JdRxqj|CZj)uMf3AP7pK;DyKg*xFemBd*X2;#~xR8{& z{(JrR`tSAM>%Z53um4{Ez5aXs_xkVkJHa0(_;auSUjM!Rd;RzNS?b*DcLOH(`tSAM z>vuyS_xgG8-0Q#Bf3N>u|GoZu{rCFajmf?Kd;RzNU2V#}{(JrR`tSAM>%Z53um4^@ zU!HsY_xkVk-|N5Ef3N>uzxy+}*Y5^R?)Bg6zt?}S|6c#Se)kRXp#MStgZ>BoE=A=* z|AYPq{SW%xK*)ps2mKHFAN0H3kq7+``XBT^==ZlL$brLH~pP2mKHFAN0F% zmk0gs;^aa9gZ>Bo5BeYUKj?qZ|DgXt|AYPq{SW#d^grl#PbUxhAM`)yf6(u~P9F3> z=zq}vp#MStgMN2+;_gl!^grl-(Ep&{?U6j{f7I_XRvz^~>UWPMkNO|=Kk9$f?`BUP z^*`!=)c>geQU9aCY|D*m#{g3(|^*`!=)c>ge zQU9a5C0$jKm5PH zr>;jo{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+` zhyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=> z{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci> z5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q% z{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@% zAO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk z{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$j zKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8 z{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5 zfB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG z`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A z|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW z@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K z|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<# z;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e z|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe z!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0` z|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+` zhyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=> z{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci> z5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q% z{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@% zAO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk z{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$j zKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8 z{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5 zfB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG z`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A z|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW z@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K z|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<# z;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e z|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe z!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0` z|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+` zhyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=> z{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci> z5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q% z{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@% zAO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk z{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$j zzx;ps|MLIk|I7cE|1bYv{=fWx`Tz3&<^Rk7m;W#SU;e-RfBFCN|Kp|3p9kU;e)n z{S*BY{S*BY{S*BY{S*BY{S*BY{S*BY{S*BY{S*BY{S*ECfBFCN|KR;5qsGt8Y|KCOZi~1M!FY4$2%m0`EFaKZuzx;ps z|MLIk|I7cE|1bYv{=fWx`Tz3&<^Rk7m;W#SU;e-RfBFCN|KW|6l&U{D1lX^8e-k%m0`EFaKZuzx;ps|MLIk z|I7cE|1bYv{=fWx`Tz3&<^Rk7m;W#SU;e-RfBFCN|2@}#uAl!e|6l&U{D1lX^8e-k z%m0`EFaKZuzx;ps|MLIk|I7cE|1bYv{=fWx`Tz3&<^Rk7m;W#SU;e-RfBFCN|KI^`g0jM(obq1i$0Mr?P zIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$ zs51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UI zfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g z0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l) z8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}t zodKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW z)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9 zK%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$ z0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@ z3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS z&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG z>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4 zpw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H= z0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{ z2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(E zX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(o zbq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$i zP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb z0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51ba>7VJJ>7VI$ z2B6LWbf$l%f2Mz?f2Mz?f2Mz?f2Mz?f2Mz?f2Mz?f2Mz?f2Mz?-x+{915jrG>I^`g z0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l) z8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}t zodKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW z)ER&}15jrG>I^`g0jM(obq1i$0Q64(o&G!hclwA%x|r~gj>o&G!hclz)2-|4^8f2V)0f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7 zf3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7f3AP7 zf3DvdfI0)vx&FESx&FC+X8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW z)ER&}1JHZ@_xkVk-|N5Ef3N>u|GoZu{rCEv0jM(oz1M%Q|6c#SerEvc3_zU$s51a{ z2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs?#q{rCFs_228i*MG17UjM!Rd;RzN z@Acp7zt?}S|6c$7|5wGm?8cQ9U>N1uLG++GX8;Jn0E83YlF;VTj6jG7LjNc2M_?If zXa?D-vaGT{%dR@LFH8TM{x|(^`rq`w>3`GzrvFX2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A` zU>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2! z7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_ zh5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~ zVE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g z7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh= z24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ z0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`z`uVu zZ9N*mXaJ)Dj0P|oz-R!Y0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaJ)Dj0P|oz-R!Y z0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaJ)Dj0P|oz-R!Y0gMJP8o+1(qXCQtFdD#U z0HXnn1~3}HXaJ)Dj0P|oz-R!Y0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaJ)Dj0P|o zz-R!Y0qk4-G=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfP zU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR z7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|n zMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y z(EvsR7!6=FfYAU(1K9iTCkA&<}`Y-*L{!9O*|I&Zyzw}@FFa4MPOaG<+(tqi{^k4cf z{g?hr|E2%Zf9b#UU-~com;OutrT@}@>A&<}`Y-*L{!9O@|JHx&zxChxZ~eFaTmP;9 z)_?22_22q${kQ&G|E>Slf9t>X-}-O;xBgrIt^d}4>%aBi`fvTW{#!o{AR0h4fM@{G z0HOgz1BeC?4ImmoxBgo{4ImmoG=OLT-TH6+xBgrIt^d}4>%aBi`fvTW{#*a8|JHx& zzxC4qq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz z1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c z1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaN08KMf!nKs1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<8vO`e^{s0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC? z4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz z1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c z1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC? z4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOh$pIRt1{Q7^pBG5A2gCqVvfi zi}T5L+w Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + hf_id, keras_preset = SIZE_MAP[size.lower()] + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + keras_preset + ) + keras_nlp_model.load_weights(weights_file) + + print(f"\n-> Loading HuggingFace Gemma `{size.upper()}` model...") + hf_model = transformers.GemmaForCausalLM(CONFIG_MAPPING[size.lower()]) + + print("\n✅ Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to HuggingFace Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + hf_vocab_size = hf_model.model.embed_tokens.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if hf_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - hf_vocab_size + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + + # Pre-attention norm + update_state_dict( + hf_model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + query_target_shape = hf_model.model.layers[ + i + ].self_attn.q_proj.weight.shape + query_tensor = decoder_block.attention.query_dense.weights[0].value + query_tensor = query_tensor.transpose(1, 2).reshape(query_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.q_proj, "weight", query_tensor + ) + + key_target_shape = hf_model.model.layers[ + i + ].self_attn.k_proj.weight.shape + key_tensor = decoder_block.attention.key_dense.weights[0].value + key_tensor = key_tensor.transpose(1, 2).reshape(key_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.k_proj, "weight", key_tensor + ) + + value_target_shape = hf_model.model.layers[ + i + ].self_attn.v_proj.weight.shape + value_tensor = decoder_block.attention.value_dense.weights[0].value + value_tensor = value_tensor.transpose(1, 2).reshape(value_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.v_proj, "weight", value_tensor + ) + + out_target_shape = hf_model.model.layers[ + i + ].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + hf_model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + hf_model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + hf_model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + hf_model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\n✅ Weights converted successfully.") + print(f"\n-> Saving HuggingFace model to `{output_dir}`...") + + # Save model to HF Transformers format + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir) + + print(f"\n✅ Saving complete. Model saved at `{output_dir}`.") + + # Tokenizer + + if not vocab_path: + tokenizer_preset = preset or SIZE_MAP[size.lower()] + print( + "\n-> Loading KerasNLP Gemma tokenizer with " + f"preset `{tokenizer_preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + tokenizer_preset + ) + # Save tokenizer state + keras_nlp_tokenizer.save_assets(output_dir) + vocab_path = os.path.join(output_dir, "vocabulary.spm") + print("\n✅ Tokenizer loading complete.") + + hf_tokenizer = transformers.GemmaTokenizer(vocab_path) + + print(f"\n-> Saving HuggingFace Gemma tokenizer to `{output_dir}`...") + # Save tokenizer to HF Transformers format + hf_tokenizer.save_pretrained(output_dir) + + print(f"\n✅ Saving complete. Tokenizer saved at `{output_dir}`.") + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_dir, + FLAGS.vocab_path, + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("size") + app.run(main) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py new file mode 100644 index 0000000000..005eac272d --- /dev/null +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -0,0 +1,322 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os + +import gemma +import torch +import torch_xla.core.xla_model as xm +from absl import app +from absl import flags +from gemma import model_xla as gemma_model + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a Keras model to PyTorch format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`) and the model size (`2b` or `7b`) to `--weights_file` +and `--size`, respectively. + +Optionally, you can specify the output path for the converted model at +`--output_file`. (This defaults to `gemma.ckpt`) +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --output_file fine_tuned_imdb.ckpt +``` + +For converting a Keras model to PyTorch format from a preset, +simply pass the Keras preset name to `--preset`. +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --preset gemma_2b_en \ + --output_file path/to/keras_torch_model.ckpt +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": gemma.config.get_config_for_2b(), + "gemma_instruct_2b_en": gemma.config.get_config_for_2b(), + "gemma_7b_en": gemma.config.get_config_for_7b(), + "gemma_instruct_7b_en": gemma.config.get_config_for_7b(), +} + +SIZE_MAP = { + "2b": (gemma.config.get_config_for_2b(), "gemma_2b_en"), + "7b": (gemma.config.get_config_for_7b(), "gemma_7b_en"), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_file", + "gemma.ckpt", + "An output file for the converted PyTorch checkpoint. Default: `gemma.ckpt`", +) +flags.DEFINE_string( + "vocab_dir", + "gemma_tokenizer", + "A directory in which the vocabulary for the tokenizer will be stored.", +) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def _reconcile_attention_dims(qkv, target_shape): + return torch.cat(qkv).reshape(tuple(target_shape)) + + +def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir): + device = xm.xla_device() + + if preset is not None: + print( + f"\n-> Loading PyTorch Gemma model config for preset `{preset}`..." + ) + model = gemma_model.GemmaForCausalLM( + PRESET_MAP[preset], world_size=1, rank=0, device=device + ) + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + print(f"\n-> Loading PyTorch Gemma model config for `{size}` model...") + config, size_preset = SIZE_MAP[size.lower()] + model = gemma_model.GemmaForCausalLM( + config, world_size=1, rank=0, device=device + ) + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + size_preset + ) + keras_nlp_model.load_weights(weights_file) + + print("\n✅ Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to PyTorch Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + torch_vocab_size = model.embedder.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if torch_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - torch_vocab_size + update_state_dict( + model.embedder, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + model.embedder, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + # Pre-attention norm + update_state_dict( + model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + qkv = ( + decoder_block.attention.query_dense.weights[0].value.transpose( + 1, 2 + ), + decoder_block.attention.key_dense.weights[0].value.transpose(1, 2), + decoder_block.attention.value_dense.weights[0].value.transpose( + 1, 2 + ), + ) + qkv_target_shape = model.model.layers[i].self_attn.qkv_proj.weight.shape + combined_tensor = _reconcile_attention_dims(qkv, qkv_target_shape) + + update_state_dict( + model.model.layers[i].self_attn.qkv_proj, "weight", combined_tensor + ) + + out_target_shape = model.model.layers[i].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\n✅ Weights converted successfully.") + print(f"\n-> Saving PyTorch model checkpoint to `{output_file}`...") + + # Save model checkpoint + torch.save({"model_state_dict": model.state_dict()}, output_file) + + print( + f"\n✅ Saving complete. Model checkpoint available at `{output_file}`." + ) + + if preset is not None: + # Tokenizer + print( + f"\n-> Loading KerasNLP Gemma tokenizer with preset `{preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + preset + ) + print("\n✅ Model loading complete.") + print(f"\n-> Saving tokenizer state to directory `{vocab_dir}`...") + + # Save tokenizer state + os.makedirs(vocab_dir, exist_ok=True) + keras_nlp_tokenizer.save_assets(vocab_dir) + + print( + "\n✅ Saving complete. Tokenizer state " + f"available at `{vocab_dir}/vocabulary.spm`." + ) + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) + + +def main(_): + flag_error_handler() + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_file, + FLAGS.vocab_dir, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py new file mode 100644 index 0000000000..9fa50cbd2b --- /dev/null +++ b/tools/gemma/run_gemma_xla.py @@ -0,0 +1,287 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import random +import sys +from typing import List + +import gemma.xla_model_parallel as xla_model_parallel +import numpy as np +import torch +import torch.multiprocessing +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from absl import app +from absl import flags +from gemma.config import GemmaConfig +from gemma.config import get_config_for_2b +from gemma.config import get_config_for_7b +from gemma.model_xla import GemmaForCausalLM +from gemma.tokenizer import Tokenizer + +PAD_TOKEN_ID = -1 + +FILE_PATH = "gemma.ckpt" +TOKENIZER_DIR = "gemma_tokenizer" + +PRESET_MAP = { + "gemma_2b_en": get_config_for_2b(), + "gemma_instruct_2b_en": get_config_for_2b(), + "gemma_7b_en": get_config_for_7b(), + "gemma_instruct_7b_en": get_config_for_7b(), +} + +SIZE_MAP = { + "2b": get_config_for_2b(), + "7b": get_config_for_7b(), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `preset` is not passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "checkpoint_file", + "gemma.ckpt", + "A PyTorch checkpoint file containing the converted weights.", +) +flags.DEFINE_string( + "vocab_file", + "gemma_tokenizer/vocabulary.spm", + "The file containing the vocabulary for the tokenizer.", +) +flags.DEFINE_string( + "prompt", + "The capital of France is", + "A test prompt for verifying functionality of the PyTorch Gemma model.", +) + +# This is a modified version of `run_xla.py` script in the Hex-LLM Gemma repo +# to ensure proper functionality after porting checkpoints from Keras. + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def generate( + i: int, + model_config: GemmaConfig, + checkpoint_file: str, + vocab_file: str, + prompts: List[str], + output_lens: List[int], + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], +): + # Set seed from config + seed = model_config.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + device = xm.xla_device() + xm.set_rng_state(seed, device) + + rank = xla_model_parallel.get_model_parallel_rank() + world_size = xla_model_parallel.get_model_parallel_world_size() + if rank > 0: + sys.stdout = open(os.devnull, "w") + + # Load model with ported weights and place on device + with _set_default_tensor_type(model_config.get_dtype()): + model = GemmaForCausalLM(model_config, world_size, rank, device) + model.load_weights(checkpoint_file) + model = model.to(device).eval() + + # Create tokenizer with saved Keras tokenizer state + tokenizer = Tokenizer(vocab_file) + + prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + + batch_size = len(prompts) + assert batch_size == len(temperatures) + assert batch_size == len(top_ps) + assert batch_size == len(top_ks) + max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)]) + assert max_seq_len <= model_config.max_position_embeddings + if model_config.num_key_value_heads < world_size: + assert world_size % model_config.num_key_value_heads == 0 + n_local_heads = 1 + else: + assert model_config.num_key_value_heads % world_size == 0 + n_local_heads = model_config.num_key_value_heads // world_size + + # build KV caches + kv_caches = [] + for _ in range(model_config.num_hidden_layers): + k_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + v_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full( + (batch_size, max_seq_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + input_token_ids_tensor = torch.full( + (batch_size, min_prompt_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, : len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len] + ) + token_ids_tensor = token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != PAD_TOKEN_ID + input_token_ids_tensor = input_token_ids_tensor.to(device) + input_positions_tensor = torch.arange( + 0, min_prompt_len, dtype=torch.int64 + ).to(device) + mask_tensor = torch.full( + (1, 1, max_seq_len, max_seq_len), -2.3819763e38 + ).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) + temperatures_tensor = torch.FloatTensor(temperatures).to(device) + top_ps_tensor = torch.FloatTensor(top_ps).to(device) + top_ks_tensor = torch.LongTensor(top_ks).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + xm.mark_step() + + # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = model( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index + ).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select(1, output_index).squeeze( + dim=1 + ) + output_token_ids = torch.where( + curr_prompt_mask, curr_token_ids, next_token_ids + ).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) + output_index = output_index + 1 + xm.mark_step() + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[ + len(prompt_tokens[i]) : len(prompt_tokens[i]) + output_lens[i] + ] + if tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(tokenizer.decode(trimmed_output)) + + for prompt, result in zip(prompts, results): + print("======================================") + print(f"PROMPT: {prompt}") + print(f"RESULT: {result}") + print("======================================") + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.size: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a model size (`2b` or `7b`) to `--size`." + ) + if FLAGS.size and FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + if FLAGS.preset: + model_config = PRESET_MAP[FLAGS.preset] + else: + model_config = SIZE_MAP[FLAGS.size.lower()] + prompts = [ + FLAGS.prompt, + ] + n = len(prompts) + output_lengths = [10] * n + temperatures = [0.95] * n + top_ps = [1.0] * n + top_ks = [100] * n + xmp.spawn( + generate, + args=( + model_config, + FLAGS.checkpoint_file, + FLAGS.vocab_file, + prompts, + output_lengths, + temperatures, + top_ps, + top_ks, + ), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/sentencepiece_testing/create_gemma_test_proto.py b/tools/sentencepiece_testing/create_gemma_test_proto.py new file mode 100644 index 0000000000..c3ce418a4b --- /dev/null +++ b/tools/sentencepiece_testing/create_gemma_test_proto.py @@ -0,0 +1,36 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "gemma_test_vocab.spm", + vocab_size=11, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + + +if __name__ == "__main__": + main() From f75d8cb128f22b4840a16b68155f408d83a8ee8d Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 21 Feb 2024 17:08:30 -0800 Subject: [PATCH 10/70] Update to the newest version of Gemma on Kaggle (#1454) Includes some small cleanups for the Kaggle assets. --- keras_nlp/models/gemma/gemma_presets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_presets.py b/keras_nlp/models/gemma/gemma_presets.py index f63fef17fa..72360f72dc 100644 --- a/keras_nlp/models/gemma/gemma_presets.py +++ b/keras_nlp/models/gemma/gemma_presets.py @@ -25,7 +25,7 @@ "path": "gemma", "model_card": "https://www.kaggle.com/models/google/gemma", }, - "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/1", + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/2", }, "gemma_instruct_2b_en": { "metadata": { @@ -37,7 +37,7 @@ "path": "gemma", "model_card": "https://www.kaggle.com/models/google/gemma", }, - "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/1", + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/2", }, "gemma_7b_en": { "metadata": { @@ -49,7 +49,7 @@ "path": "gemma", "model_card": "https://www.kaggle.com/models/google/gemma", }, - "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/1", + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/2", }, "gemma_instruct_7b_en": { "metadata": { @@ -61,6 +61,6 @@ "path": "gemma", "model_card": "https://www.kaggle.com/models/google/gemma", }, - "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/1", + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/2", }, } From cd5e33c255cafd04b77f666d35cc1e0bfdb60903 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Thu, 22 Feb 2024 15:30:02 -0800 Subject: [PATCH 11/70] Add dtype arg to Gemma HF conversion script (#1452) --- tools/gemma/export_gemma_to_hf.py | 37 ++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py index 31e3f3c69b..6f1fdf24d2 100644 --- a/tools/gemma/export_gemma_to_hf.py +++ b/tools/gemma/export_gemma_to_hf.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import contextlib import os import torch @@ -116,6 +116,19 @@ "A path containing the vocabulary (must be a `.spm` file or equivalent). " "If not passed, the vocabulary of the preset will be used.", ) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): @@ -310,17 +323,25 @@ def flag_error_handler(): "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " "for your model to the `--size` flag." ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) def main(_): flag_error_handler() - convert_checkpoints( - FLAGS.preset, - FLAGS.weights_file, - FLAGS.size, - FLAGS.output_dir, - FLAGS.vocab_path, - ) + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_dir, + FLAGS.vocab_path, + ) if __name__ == "__main__": From e2624a1427c0f0c5790072d67f738bf6903a36db Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 22 Feb 2024 17:50:57 -0800 Subject: [PATCH 12/70] Fix gemma testing import (#1462) --- keras_nlp/models/gemma/gemma_causal_lm_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py index 0e1d7a14f8..517c5f4e3a 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -15,9 +15,9 @@ import os from unittest.mock import patch -import keras import pytest +from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM From 4a0adf2b479b95a15beb3ec5ae36c8ce5413e74b Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:30:04 -0800 Subject: [PATCH 13/70] Add docstring for PyTorch conversion script install instructions (#1471) * Add docstring for conversion script install instructions * Add docstring to verification script * Change wording --- tools/gemma/export_gemma_to_torch_xla.py | 22 +++++++++ tools/gemma/run_gemma_xla.py | 59 ++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 3 deletions(-) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py index 005eac272d..08d4b3ac98 100644 --- a/tools/gemma/export_gemma_to_torch_xla.py +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Prior to running this conversion script, please install the PyTorch +implementation of Gemma and `torch_xla`: + +`pip install git+https://github.com/google/gemma_pytorch.git` +`pip install torch_xla` + +Please also ensure that your installed versions of `torch_xla` and `torch` are +compatible. +""" + import contextlib import os @@ -50,6 +61,17 @@ --preset gemma_2b_en \ --output_file path/to/keras_torch_model.ckpt ``` + +Following this usage, you can run the verification script to confirm +functionality of the converted checkpoint: + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --size 2b \ + --checkpoint_file fine_tuned_imdb.ckpt \ + --vocab_file gemma_tokenizer/vocabulary.spm \ + --prompt "Inception is about" +``` """ diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py index 9fa50cbd2b..f212154c99 100644 --- a/tools/gemma/run_gemma_xla.py +++ b/tools/gemma/run_gemma_xla.py @@ -11,6 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This is a modified version of `run_xla.py` script in the PyTorch Gemma repo +to ensure proper functionality after porting checkpoints from Keras. Please +run `export_gemma_to_torch_xla.py` prior to running this verification script. + +As with the conversion script, ensure that `torch_xla` and the PyTorch +implementation of Gemma are properly installed: + +`pip install git+https://github.com/google/gemma_pytorch.git` +`pip install torch_xla` + +Note that this verification script can take several minutes to run. +""" + import contextlib import os import random @@ -31,6 +46,47 @@ from gemma.model_xla import GemmaForCausalLM from gemma.tokenizer import Tokenizer +""" +Sample usage: + +Run the verification script supplying your model size, converted checkpoint file, +vocabulary file, and test prompt. + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --size 2b \ + --checkpoint_file fine_tuned_imdb.ckpt \ + --vocab_file gemma_tokenizer/vocabulary.spm \ + --prompt "Three Billboards" +``` + +After a delay (a couple minutes if running on CPU), this should produce: +``` +====================================== +PROMPT: Three Billboards +RESULT: Outside Ebbing, Missouri is a film in the tradition of Hollywood westerns +====================================== +``` + +If running from a preset, instead provide your converted checkpoint file and +the associated preset name: + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --preset gemma_2b_en \ + --checkpoint_file gemma_2b.ckpt \ + --prompt "California is the largest" +``` + +After a delay (a couple minutes if running on CPU), this should produce: +``` +====================================== +PROMPT: California is the largest +RESULT: producer of strawberries in the world, and is a +====================================== +``` +""" + PAD_TOKEN_ID = -1 FILE_PATH = "gemma.ckpt" @@ -74,9 +130,6 @@ "A test prompt for verifying functionality of the PyTorch Gemma model.", ) -# This is a modified version of `run_xla.py` script in the Hex-LLM Gemma repo -# to ensure proper functionality after porting checkpoints from Keras. - @contextlib.contextmanager def _set_default_tensor_type(dtype: torch.dtype): From 6c642c80ae7da0f7dae59d5437865a7c0088aebc Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:40:06 -0800 Subject: [PATCH 14/70] Add an annotation to tests that need kaggle auth (#1470) We can skip these by default, for users who have not yet set them up. We will need to set them up for CI, see https://github.com/keras-team/keras-nlp/pull/1459 --- keras_nlp/conftest.py | 18 ++++++++++++++++++ keras_nlp/models/gemma/gemma_backbone_test.py | 2 ++ .../gemma/gemma_causal_lm_preprocessor_test.py | 1 + keras_nlp/models/gemma/gemma_causal_lm_test.py | 2 ++ .../models/gemma/gemma_preprocessor_test.py | 1 + keras_nlp/models/gemma/gemma_tokenizer_test.py | 2 ++ 6 files changed, 26 insertions(+) diff --git a/keras_nlp/conftest.py b/keras_nlp/conftest.py index b876a7a0a8..d66cad5d42 100644 --- a/keras_nlp/conftest.py +++ b/keras_nlp/conftest.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import tensorflow as tf @@ -83,6 +85,10 @@ def pytest_configure(config): "markers", "keras_3_only: mark test as a keras 3 only test", ) + config.addinivalue_line( + "markers", + "kaggle_key_required: mark test needing a kaggle key", + ) def pytest_collection_modifyitems(config, items): @@ -107,6 +113,16 @@ def pytest_collection_modifyitems(config, items): not backend_config.keras_3(), reason="tests only run on with multi-backend keras", ) + found_kaggle_key = all( + [ + os.environ.get("KAGGLE_USERNAME", None), + os.environ.get("KAGGLE_KEY", None), + ] + ) + kaggle_key_required = pytest.mark.skipif( + not found_kaggle_key, + reason="tests only run with a kaggle api key", + ) for item in items: if "large" in item.keywords: item.add_marker(skip_large) @@ -116,6 +132,8 @@ def pytest_collection_modifyitems(config, items): item.add_marker(tf_only) if "keras_3_only" in item.keywords: item.add_marker(keras_3_only) + if "kaggle_key_required" in item.keywords: + item.add_marker(kaggle_key_required) # Disable traceback filtering for quicker debugging of tests failures. diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index c66d318fd5..add2ac1995 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -53,6 +53,7 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.kaggle_key_required @pytest.mark.large def test_smallest_preset(self): self.run_preset_test( @@ -69,6 +70,7 @@ def test_smallest_preset(self): ), ) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): for preset in GemmaBackbone.presets: diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py index 121621da85..c3305afe91 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py @@ -82,6 +82,7 @@ def test_generate_postprocess(self): x = preprocessor.generate_postprocess(input_data) self.assertAllEqual(x, "the quick brown fox") + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): for preset in GemmaCausalLMPreprocessor.presets: diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py index 517c5f4e3a..5ed1ce015c 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -142,6 +142,7 @@ def test_generate_compilation(self): causal_lm.compile(sampler="greedy") self.assertIsNone(causal_lm.generate_function) + @pytest.mark.kaggle_key_required @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( @@ -150,6 +151,7 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): for preset in GemmaCausalLM.presets: diff --git a/keras_nlp/models/gemma/gemma_preprocessor_test.py b/keras_nlp/models/gemma/gemma_preprocessor_test.py index f54a509979..0cb427af03 100644 --- a/keras_nlp/models/gemma/gemma_preprocessor_test.py +++ b/keras_nlp/models/gemma/gemma_preprocessor_test.py @@ -64,6 +64,7 @@ def test_sequence_length_override(self): x = preprocessor(input_data, sequence_length=4) self.assertAllEqual(x["token_ids"], [1, 4, 9, 2]) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): for preset in GemmaPreprocessor.presets: diff --git a/keras_nlp/models/gemma/gemma_tokenizer_test.py b/keras_nlp/models/gemma/gemma_tokenizer_test.py index 1c617dd937..65569c8174 100644 --- a/keras_nlp/models/gemma/gemma_tokenizer_test.py +++ b/keras_nlp/models/gemma/gemma_tokenizer_test.py @@ -48,6 +48,7 @@ def test_errors_missing_special_tokens(self): ) ) + @pytest.mark.kaggle_key_required @pytest.mark.large def test_smallest_preset(self): self.run_preset_test( @@ -57,6 +58,7 @@ def test_smallest_preset(self): expected_output=[[651, 4320, 8426, 25341, 235265]], ) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self): for preset in GemmaTokenizer.presets: From 4ba3ca7a06c671428d7acb07bbd78642edaedda0 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 26 Feb 2024 19:59:59 -0700 Subject: [PATCH 15/70] Fix Mistral memory consumption with JAX and default dtype bug (#1460) --- keras_nlp/models/mistral/mistral_causal_lm.py | 1 + keras_nlp/models/mistral/mistral_presets.py | 4 +- .../convert_mistral_checkpoints.py | 190 +++++++----------- 3 files changed, 72 insertions(+), 123 deletions(-) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index 3296bb9495..e9bd4e5616 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -190,6 +190,7 @@ def next(prompt, cache, index): mask=padding_mask, end_token_id=end_token_id, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. diff --git a/keras_nlp/models/mistral/mistral_presets.py b/keras_nlp/models/mistral/mistral_presets.py index 82a2ec44f6..7fb4b4e0a6 100644 --- a/keras_nlp/models/mistral/mistral_presets.py +++ b/keras_nlp/models/mistral/mistral_presets.py @@ -23,7 +23,7 @@ "path": "mistral", "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", }, - "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3", + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6", }, "mistral_instruct_7b_en": { "metadata": { @@ -33,6 +33,6 @@ "path": "mistral", "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", }, - "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3", + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6", }, } diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py index 8e10089efd..7e13b9dd7a 100644 --- a/tools/checkpoint_conversion/convert_mistral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime import gc -import json import os -import pathlib +import shutil +import tempfile import traceback -import keras import numpy as np import requests from absl import app @@ -27,10 +25,10 @@ from transformers import AutoTokenizer from transformers import MistralForCausalLM -import keras_nlp from keras_nlp.models import MistralBackbone from keras_nlp.models import MistralCausalLMPreprocessor from keras_nlp.models import MistralTokenizer +from keras_nlp.utils.preset_utils import save_to_preset PRESET_MAP = { "mistral_7b_en": "mistralai/Mistral-7B-v0.1", @@ -227,124 +225,74 @@ def main(_): preset = FLAGS.preset hf_preset = PRESET_MAP[preset] - # === Create the save directories === - model_dir = pathlib.Path(__file__).parent / f"{preset}" - tokenizer_dir = model_dir / "assets" / "tokenizer" - if not model_dir.exists(): - os.makedirs(model_dir) - if not tokenizer_dir.exists(): - os.makedirs(tokenizer_dir) + # === Create the temporary save directories === + temp_dir = tempfile.mkdtemp() - # === Load the Huggingface model === - hf_model = MistralForCausalLM.from_pretrained(hf_preset) - hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) - hf_model.eval() - print("\n-> Huggingface model and tokenizer loaded") - - # === Load the KerasNLP model === - keras_nlp_config = dict( - vocabulary_size=hf_model.config.vocab_size, - hidden_dim=hf_model.config.hidden_size, - num_layers=hf_model.config.num_hidden_layers, - num_query_heads=hf_model.config.num_attention_heads, - num_key_value_heads=hf_model.config.num_key_value_heads, - intermediate_dim=hf_model.config.intermediate_size, - sliding_window=hf_model.config.sliding_window, - layer_norm_epsilon=hf_model.config.rms_norm_eps, - rope_max_wavelength=hf_model.config.rope_theta, - dtype="float32", - ) - keras_nlp_model = MistralBackbone(**keras_nlp_config) - - # === Download the tokenizer from Huggingface model card === - spm_path = ( - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" - ) - response = requests.get(spm_path) - if not response.ok: - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") - tokenizer_path = tokenizer_dir / "vocabulary.spm" - with open(tokenizer_path, "wb") as tokenizer_file: - tokenizer_file.write(response.content) - keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute())) - print("\n-> Keras 3 model and tokenizer loaded.") - - # === Port the weights === - convert_checkpoints(keras_nlp_model, hf_model) - print("\n-> Weight transfer done.") - - # === Check that the models and tokenizers outputs match === - test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) - test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) - print("\n-> Tests passed!") - - # === Save the model weights in float32 format === - keras_nlp_model.save_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - print("\n-> Saved the model weights in float16") - - del keras_nlp_model, hf_model - gc.collect() - - keras_nlp_config["dtype"] = "float16" - - # === Save the weights again in float16 === - keras_nlp_model = MistralBackbone(**keras_nlp_config) - keras_nlp_model.load_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - keras_nlp_model.save_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - print("-> Saved the model weights in float16") - - # === Save the model config === - keras_nlp_config["dtype"] = "bfloat16" - model_config = { - "module": "keras_nlp.src.models.mistral.mistral_backbone", - "class_name": "MistralBackbone", - "config": {**keras_nlp_config}, - "registered_name": "keras_nlp>MistralBackbone", - "assets": [], - "weights": "model.weights.h5", - } - model_config_json = json.dumps(model_config) - with open(model_dir / "config.json", "w") as model_config_file: - model_config_file.write(model_config_json) - print("\n-> Saved model config") - - # === Save the tokenizer config === - tokenizer_config = { - "module": "keras_nlp.src.models.mistral.Mistral_tokenizer", - "class_name": "MistralTokenizer", - "config": { - "name": "mistral_tokenizer", - "trainable": True, - "dtype": "int32", - "proto": None, - "sequence_length": None, - }, - "registered_name": "keras_nlp>MistralTokenizer", - "assets": ["assets/tokenizer/vocabulary.spm"], - "weights": None, - } - tokenizer_config_json = json.dumps(tokenizer_config) - with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file: - tokenizer_config_file.write(tokenizer_config_json) - print("\n-> Saved tokenizer config") + try: + # === Load the Huggingface model === + hf_model = MistralForCausalLM.from_pretrained(hf_preset) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + sliding_window=hf_model.config.sliding_window, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = MistralBackbone(**backbone_kwargs) - # === Save metadata === - metadata_config = { - "keras_version": keras.__version__, - "keras_nlp_version": keras_nlp.__version__, - "parameter_count": keras_nlp_model.count_params(), - "date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"), - } - metadata_config_json = json.dumps(metadata_config) - with open(model_dir / "metadata.json", "w") as metadata_config_file: - metadata_config_file.write(metadata_config_json) - print("\n-> Saved metadata") + # === Download the tokenizer from Huggingface model card === + spm_path = ( + f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" + ) + response = requests.get(spm_path) + if not response.ok: + raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") + tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") + with open(tokenizer_path, "wb") as tokenizer_file: + tokenizer_file.write(response.content) + keras_nlp_tokenizer = MistralTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") + + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + # === Save the model weights in float32 format === + keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + + del keras_nlp_model, hf_model + gc.collect() + + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_nlp_model = MistralBackbone(**backbone_kwargs) + keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") + + # === Save the tokenizer === + save_to_preset( + keras_nlp_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) if __name__ == "__main__": From 5d22424b40a31158973399eba9c5510f3d06830e Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 27 Feb 2024 14:19:14 -0800 Subject: [PATCH 16/70] Bump the master version to 0.9 (#1473) 0.8 is out! We can consider our master branch an 0.9 preview. --- keras_nlp/version_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/version_utils.py b/keras_nlp/version_utils.py index 4d6a8186d4..5735e45aaa 100644 --- a/keras_nlp/version_utils.py +++ b/keras_nlp/version_utils.py @@ -15,7 +15,7 @@ from keras_nlp.api_export import keras_nlp_export # Unique source of truth for the version number. -__version__ = "0.8.0" +__version__ = "0.9.0" @keras_nlp_export("keras_nlp.version") From 3db86d147e6cb0fc7293626c8fc77bc9edab77ed Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Wed, 28 Feb 2024 11:39:51 -0800 Subject: [PATCH 17/70] Pin to TF 2.16 RC0 (#1478) --- requirements-jax-cuda.txt | 4 ++-- requirements-tensorflow-cuda.txt | 4 ++-- requirements-torch-cuda.txt | 4 ++-- requirements.txt | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 903a603352..10d07dffce 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release +tensorflow-text==2.16.0rc0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index be95915996..7cc2e705e6 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. -tf-nightly[and-cuda]==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow[and-cuda]==2.16.0rc0 # Pin to rc until TF 2.16 release +tensorflow-text==2.16.0rc0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 7ea2981478..1bbe6a2e76 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release +tensorflow-text==2.16.0rc0 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index b226229d15..e7cc934b17 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release +tensorflow-text==2.16.0rc0 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu From 414b4f4c2025db65bbc3d9aa1363c9a4094a9dde Mon Sep 17 00:00:00 2001 From: Chris Sauer Date: Wed, 28 Feb 2024 11:40:39 -0800 Subject: [PATCH 18/70] Fix gemma rms_normalization's use of epsilon (#1472) Hi wonderful Keras folks, I was browsing the new Gemma source and noticed that the RMSNorm code didn't use the epsilon parameter it takes in. This fixes that. While we're here, I'm curious what drove the 1+scale multiplier (instead of just initializing scale to 1). Would love to learn if you're down to share. Thanks, Chris (ex-Googler) --- keras_nlp/models/gemma/rms_normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/gemma/rms_normalization.py b/keras_nlp/models/gemma/rms_normalization.py index ce9bdaf880..c3e4296020 100644 --- a/keras_nlp/models/gemma/rms_normalization.py +++ b/keras_nlp/models/gemma/rms_normalization.py @@ -35,6 +35,6 @@ def call(self, x): x = ops.cast(x, "float32") scale = ops.cast(self.scale, "float32") var = ops.mean(ops.square(x), axis=-1, keepdims=True) - normed_inputs = x * ops.reciprocal(ops.sqrt(var + 1e-06)) + normed_inputs = x * ops.reciprocal(ops.sqrt(var + self.epsilon)) normed_inputs = normed_inputs * (1 + scale) return ops.cast(normed_inputs, self.compute_dtype) From 8590c22dddd06b8dd02483410d751f599698a250 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Thu, 29 Feb 2024 17:17:34 -0800 Subject: [PATCH 19/70] Add `FalconBackbone` (#1475) * Add Falcon backbone. * Add docstring. * Add dtype. * Add checkpoint conversion script. * Fix tests. * Random fixes. * Add cache. * Cast cumsum to int32. * Make sublayers public. * Address backbone comments. * Update attention computation to use einsum. * Falcon only works with Keras3. * Fix tests. * Remove falcon_causal_lm file. * Remove commented/unused codes. --- keras_nlp/models/falcon/__init__.py | 13 + keras_nlp/models/falcon/falcon_attention.py | 156 +++++++++++ keras_nlp/models/falcon/falcon_backbone.py | 160 +++++++++++ .../models/falcon/falcon_backbone_test.py | 49 ++++ .../falcon/falcon_transformer_decoder.py | 254 ++++++++++++++++++ .../convert_falcon_checkpoints.py | 238 ++++++++++++++++ 6 files changed, 870 insertions(+) create mode 100644 keras_nlp/models/falcon/__init__.py create mode 100644 keras_nlp/models/falcon/falcon_attention.py create mode 100644 keras_nlp/models/falcon/falcon_backbone.py create mode 100644 keras_nlp/models/falcon/falcon_backbone_test.py create mode 100644 keras_nlp/models/falcon/falcon_transformer_decoder.py create mode 100644 tools/checkpoint_conversion/convert_falcon_checkpoints.py diff --git a/keras_nlp/models/falcon/__init__.py b/keras_nlp/models/falcon/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/models/falcon/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/falcon/falcon_attention.py b/keras_nlp/models/falcon/falcon_attention.py new file mode 100644 index 0000000000..0358ade54b --- /dev/null +++ b/keras_nlp/models/falcon/falcon_attention.py @@ -0,0 +1,156 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class FalconAttention(keras.layers.Layer): + def __init__( + self, + num_heads, + attention_dropout_rate, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.attention_dropout_rate = attention_dropout_rate + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # m = model dim + # n = num attention heads + # h = head dim + # k = key/value length + + batch_size, seq_length, hidden_dim = inputs_shape + + self.head_dim = hidden_dim // self.num_heads + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.query_dense = keras.layers.EinsumDense( + equation="bqm,mnh->bqnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="query_dense", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="key_dense", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="value_dense", + ) + self.value_dense.build(inputs_shape) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.output_dense = keras.layers.Dense( + hidden_dim, + dtype=self.dtype_policy, + name="output_dense", + ) + self.output_dense.build(inputs_shape) + + self.softmax = keras.layers.Softmax(dtype="float32", name="softmax") + + self.built = True + + def call( + self, + inputs, + alibi, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + batch_size, seq_length, hidden_dim = ops.shape(inputs) + + query = self.query_dense(inputs) + key = self.key_dense(inputs) + value = self.value_dense(inputs) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key) + attention_scores = ops.add(attention_scores, alibi) + attention_scores = ( + attention_scores * self.inv_norm_factor + ) # [batch_size, num_heads, query_length, kv_length] + attention_scores = self.softmax( + attention_scores, ops.expand_dims(attention_mask, 1) + ) + attention_scores = self.attention_dropout(attention_scores) + attention_output = ops.einsum( + "bnqk,bknh->bqnh", attention_scores, value + ) + attention_output = ops.reshape( + attention_output, + [batch_size, seq_length, self.num_heads * self.head_dim], + ) # [batch_size, query_length, hidden_dim] + + attention_output = self.output_dense(attention_output) + + if cache is not None: + return attention_output, cache + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "attention_dropout_rate": self.attention_dropout_rate, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py new file mode 100644 index 0000000000..4951189fe0 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -0,0 +1,160 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.falcon.falcon_transformer_decoder import ( + FalconTransformerDecoder, +) + + +@keras_nlp_export("keras_nlp.models.FalconBackbone") +class FalconBackbone(Backbone): + """The Falcon core architecure. + + This network implements a Transformer-based decoder-only network, + [Falcon](https://arxiv.org/abs/2306.01116). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_attention_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The dimensionality of the embeddings and hidden states. + intermediate_dim: int. The output dimension of the first Dense layer in + the MLP network of each transformer. + layer_norm_epsilon: float. Epsilon for the layer normalization layers in + the transformer decoder. + attention_dropout_rate: float. Dropout probability for the attention. + feedforward_dropout_rate: flaot. Dropout probability for the feedforward. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Falcon decoder. + # TODO: Update the preset. + model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset") + model(input_data) + + # Randomly initialized Falcon decoder with a custom config. + model = keras_nlp.models.FalconBackbone( + vocabulary_size=10, + num_layers=2, + num_attention_heads=2, + hidden_dim=32, + intermediate_dim=32*4, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype="float32", + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_attention_heads, + hidden_dim, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + dtype=dtype, + name="token_embedding", + ) + + self.transformer_layers = [] + for i in range(num_layers): + layer = FalconTransformerDecoder( + num_attention_heads=num_attention_heads, + intermediate_dim=intermediate_dim, + attention_dropout_rate=attention_dropout_rate, + feedforward_dropout_rate=feedforward_dropout_rate, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + self.final_layernorm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_layernorm", + ) + + # === Functional Model === + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed Tokens. + x = self.token_embedding(token_ids) + + # Apply successive transformer decoder blocks. + for transformer_layer in self.transformer_layers: + x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask) + sequence_output = self.final_layernorm(x) + + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone_test.py b/keras_nlp/models/falcon/falcon_backbone_test.py new file mode 100644 index 0000000000..140ce7e7bf --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone_test.py @@ -0,0 +1,49 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.tests.test_case import TestCase + + +class FalconBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_attention_heads": 8, + "hidden_dim": 16, + "intermediate_dim": 32, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_transformer_decoder.py b/keras_nlp/models/falcon/falcon_transformer_decoder.py new file mode 100644 index 0000000000..3b29cedd7e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_transformer_decoder.py @@ -0,0 +1,254 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.falcon.falcon_attention import FalconAttention + + +class FalconTransformerDecoder(keras.layers.Layer): + def __init__( + self, + num_attention_heads, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_attention_heads = num_attention_heads + self.intermediate_dim = intermediate_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + + def build(self, decoder_sequence_shape): + self.hidden_dim = decoder_sequence_shape[-1] + self.input_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="input_layernorm", + ) + self.input_layernorm.build(decoder_sequence_shape) + + # Attention layers. + self.key_dim = self.hidden_dim // self.num_attention_heads + self.attention_layer = FalconAttention( + num_heads=self.num_attention_heads, + attention_dropout_rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention", + ) + self.attention_layer.build( + decoder_sequence_shape, + ) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.post_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(decoder_sequence_shape) + + # Feedforward layers. + # TODO: use_bias should be an argument to the transformer to support + # other sizes of models, e.g. 7B, that don't use bias. + self.dense_h_to_4h = keras.layers.Dense( + self.intermediate_dim, + activation=keras.activations.gelu, + use_bias=True, + dtype=self.dtype_policy, + name="dense_h_to_4h", + ) + self.dense_h_to_4h.build(decoder_sequence_shape) + + self.dense_4h_to_h = keras.layers.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype_policy, + name="dense_4h_to_h", + ) + self.dense_4h_to_h.build( + ( + decoder_sequence_shape[0], + decoder_sequence_shape[1], + self.intermediate_dim, + ) + ) + + self.feedforward_dropout = keras.layers.Dropout( + rate=self.feedforward_dropout_rate, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + + self.built = True + + def call( + self, + inputs, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + training=None, + ): + attention_mask = self._compute_attention_mask( + decoder_sequence=inputs, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + + residual = inputs + + x = self.input_layernorm(inputs) + + alibi = self._build_alibi_tensor( + self.num_attention_heads, decoder_padding_mask + ) + + # Attention block. + attention_output = self.attention_layer( + inputs=x, + alibi=alibi, + attention_mask=attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + + if attention_cache is None: + x = attention_output + else: + x, attention_cache = attention_output + + x = self.attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self.post_attention_layernorm(x) + + x = self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + + x = self.feedforward_dropout(x, training=training) + + x = x + residual + + if attention_cache is not None: + return x, attention_cache + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_attention_heads": self.num_attention_heads, + "intermediate_dim": self.intermediate_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + } + ) + return config + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def _compute_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + attention_cache=None, + attention_cache_update_index=None, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def _build_alibi_tensor(self, num_heads, attention_mask): + batch_size, seq_length = attention_mask.shape + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), + dtype=self.compute_dtype, + ) # num_heads + arange_tensor = ( + ( + ops.cast(ops.cumsum(attention_mask, axis=-1) - 1, dtype="int32") + * attention_mask + ) + )[:, None, :] + alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype) + alibi = ops.expand_dims( + alibi, 0 + ) # [None, batch_size, num_heads, seq_length] + return ops.transpose(alibi, [1, 2, 0, 3]) + + def _get_slopes(self, num_heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + ) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py new file mode 100644 index 0000000000..90a06503dc --- /dev/null +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -0,0 +1,238 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +import keras +import numpy as np +import tensorflow as tf +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone + +keras.config.disable_traceback_filtering() + + +def convert_checkpoints(hf_model): + hf_config = hf_model.config.to_dict() + cfg = {} + cfg["vocabulary_size"] = hf_config["vocab_size"] + cfg["num_layers"] = hf_config["num_hidden_layers"] + cfg["num_attention_heads"] = hf_config["num_attention_heads"] + cfg["hidden_dim"] = hf_config["hidden_size"] + cfg["intermediate_dim"] = 4 * cfg["hidden_dim"] + cfg["feedforward_dropout_rate"] = hf_config["hidden_dropout"] + cfg["attention_dropout_rate"] = hf_config["attention_dropout"] + + keras_model = FalconBackbone(**cfg) + + hf_wts = hf_model.state_dict() + + # transformer.word_embeddings.weight + keras_model.get_layer("token_embedding").embeddings.assign( + hf_wts["transformer.word_embeddings.weight"] + ) + + for i in range(keras_model.num_layers): + # split key query value + fused_qkv = ( + hf_wts[f"transformer.h.{i}.self_attention.query_key_value.weight"] + .numpy() + .T + ) + seq_length, _ = fused_qkv.shape + head_dim = cfg["hidden_dim"] // cfg["num_attention_heads"] + fused_qkv = fused_qkv.reshape( + seq_length, cfg["num_attention_heads"], 3, head_dim + ) + query, key, value = ( + fused_qkv[..., 0, :], + fused_qkv[..., 1, :], + fused_qkv[..., 2, :], + ) + + fused_bias = hf_wts[ + f"transformer.h.{i}.self_attention.query_key_value.bias" + ].numpy() + fused_bias = fused_bias.reshape(cfg["num_attention_heads"], 3, head_dim) + query_bias, key_bias, value_bias = ( + fused_bias[..., 0, :], + fused_bias[..., 1, :], + fused_bias[..., 2, :], + ) + + # TODO: check if bias is true before assigning bias. + # transformer.h.0.self_attention.query_key_value.weight + # transformer.h.0.self_attention.query_key_value.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.query_dense.kernel.assign(query) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.query_dense.bias.assign(query_bias) + + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.key_dense.kernel.assign(key) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.key_dense.bias.assign(key_bias) + + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.value_dense.kernel.assign(value) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.value_dense.bias.assign(value_bias) + + # transformer.h.0.self_attention.dense.weight + # transformer.h.0.self_attention.dense.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.output_dense.kernel.assign( + hf_wts[f"transformer.h.{i}.self_attention.dense.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.output_dense.bias.assign( + hf_wts[f"transformer.h.{i}.self_attention.dense.bias"].numpy() + ) + + # transformer.h.0.mlp.dense_h_to_4h.weight + # transformer.h.0.mlp.dense_h_to_4h.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_h_to_4h.kernel.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_h_to_4h.bias.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"].numpy() + ) + + # transformer.h.0.mlp.dense_4h_to_h.weight + # transformer.h.0.mlp.dense_4h_to_h.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_4h_to_h.kernel.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_4h_to_h.bias.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"].numpy() + ) + + # transformer.h.0.input_layernorm.weight + # transformer.h.0.input_layernorm.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).input_layernorm.gamma.assign( + hf_wts[f"transformer.h.{i}.input_layernorm.weight"] + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).input_layernorm.beta.assign( + hf_wts[f"transformer.h.{i}.input_layernorm.bias"] + ) + + # transformer.h.0.post_attention_layernorm.weight + # transformer.h.0.post_attention_layernorm.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).post_attention_layernorm.gamma.assign( + hf_wts[f"transformer.h.{i}.post_attention_layernorm.weight"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).post_attention_layernorm.beta.assign( + hf_wts[f"transformer.h.{i}.post_attention_layernorm.bias"].numpy() + ) + + # transformer.ln_f.weight + # transformer.ln_f.bias + keras_model.get_layer("final_layernorm").gamma.assign( + hf_wts["transformer.ln_f.weight"].numpy() + ) + keras_model.get_layer("final_layernorm").beta.assign( + hf_wts["transformer.ln_f.bias"].numpy() + ) + + # TODO: Assign lm_head weights for CausalLM. + # # lm_head.weight + # keras_model.get_layer("lm_head").kernel.assign( + # hf_wts["lm_head.weight"].T.numpy() + # ) + + # Save the model. + print("Save KerasNLP model weights.") + temp_dir = tempfile.mkdtemp() + keras_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + + return keras_model + + +def check_output(keras_model, hf_model, hf_model_name): + sample_text = ["I am so happy today!"] + hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + hf_tokenizer.pad_token = hf_tokenizer.eos_token + hf_sample_input = hf_tokenizer( + sample_text, padding="max_length", return_tensors="pt" + ) + sample_input = { + "token_ids": tf.constant(hf_sample_input["input_ids"].numpy()), + "padding_mask": tf.constant(hf_sample_input["attention_mask"].numpy()), + } + print("token_ids: ", sample_input["token_ids"][0, :7]) + print("padding_mask", sample_input["padding_mask"][0, :7]) + + keras_output = keras_model.predict(sample_input) + + activation = {} + + def get_activation(name): + def hook(hf_model, input, output): + activation[name] = output[0].detach() + + return hook + + hf_model.transformer.register_forward_hook( + get_activation("transformer.ln_f") + ) + hf_model(**hf_sample_input) + hf_output = activation["transformer.ln_f"] + print("Keras shape: ", keras_output.shape) + print("HF shape: ", hf_output.shape) + + print("KerasNLP output:", keras_output[0, 1, :5]) + print("HF output:", hf_output[0, 1, :5]) + print( + "Difference:", + np.mean( + abs(keras_output[:, :6, :] - hf_output.detach().numpy()[:, :6, :]) + ), + ) + + +def main(): + hf_model_name = "tiiuae/falcon-rw-1b" + hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name) + keras_model = convert_checkpoints(hf_model) + check_output(keras_model, hf_model, hf_model_name) + + +if __name__ == "__main__": + main() From c739f81111af639440a59eec20d85ae653367baf Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:25:20 -0800 Subject: [PATCH 20/70] CI - Add kaggle creds to pull model (#1459) * CI - Add kaggle creds to pull model * add kaggle env variables * Kaggle env: * Kaggle env: * Kaggle env: * Kaggle env: * Update Build script for Kokoro * Add Kaggle env var * set gemma preset to extra_large * Change Gemma small preset to bfloat16 * Change Gemma small preset to xlarge --- .kokoro/github/ubuntu/gpu/build.sh | 14 ++++++++++++++ .kokoro/github/ubuntu/gpu/jax/continuous.cfg | 18 ++++++++++++++++++ .kokoro/github/ubuntu/gpu/jax/presubmit.cfg | 18 ++++++++++++++++++ .../github/ubuntu/gpu/keras2/continuous.cfg | 18 ++++++++++++++++++ .kokoro/github/ubuntu/gpu/keras2/presubmit.cfg | 18 ++++++++++++++++++ .../ubuntu/gpu/tensorflow/continuous.cfg | 18 ++++++++++++++++++ .../github/ubuntu/gpu/tensorflow/presubmit.cfg | 18 ++++++++++++++++++ .kokoro/github/ubuntu/gpu/torch/continuous.cfg | 18 ++++++++++++++++++ .kokoro/github/ubuntu/gpu/torch/presubmit.cfg | 18 ++++++++++++++++++ keras_nlp/models/gemma/gemma_backbone_test.py | 3 ++- 10 files changed, 160 insertions(+), 1 deletion(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 2017b77c82..87cd206495 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -1,4 +1,18 @@ set -e + +export KAGGLE_KEY="$(cat ${KOKORO_KEYSTORE_DIR}/73361_keras_kaggle_secret_key)" +export KAGGLE_USERNAME="$(cat ${KOKORO_KEYSTORE_DIR}/73361_keras_kaggle_username)" + +if [[ -z "${KAGGLE_KEY}" ]]; then + echo "KAGGLE_KEY is NOT set" + exit 1 +fi + +if [[ -z "${KAGGLE_USERNAME}" ]]; then + echo "KAGGLE_USERNAME is NOT set" + exit 1 +fi + set -x cd "${KOKORO_ROOT}/" diff --git a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg index 1b9ffb605a..63351a9f40 100644 --- a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "jax" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg index 1b9ffb605a..63351a9f40 100644 --- a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "jax" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg index 7e971ac96d..03fa92222a 100644 --- a/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "1" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg index 7e971ac96d..03fa92222a 100644 --- a/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "1" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg index b85ee6f4eb..c707593fcb 100644 --- a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "tensorflow" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg index b85ee6f4eb..c707593fcb 100644 --- a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "tensorflow" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg index 5d25106b3f..345159e802 100644 --- a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "torch" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg index 5d25106b3f..345159e802 100644 --- a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "torch" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index add2ac1995..855d49658b 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -54,8 +54,9 @@ def test_saved_model(self): ) @pytest.mark.kaggle_key_required - @pytest.mark.large + @pytest.mark.extra_large def test_smallest_preset(self): + # TODO: Fails with OOM on current GPU CI self.run_preset_test( cls=GemmaBackbone, preset="gemma_2b_en", From 134f8b788b2a6a1bb30c31d777cf257225754a42 Mon Sep 17 00:00:00 2001 From: TheCrazyT Date: Tue, 5 Mar 2024 00:20:06 +0100 Subject: [PATCH 21/70] Update reversible_embedding.py (#1484) --- keras_nlp/layers/modeling/reversible_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index baa5fb7027..d115217687 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -73,7 +73,7 @@ class ReversibleEmbedding(keras.layers.Embedding): # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`. hidden_states = embedding(token_ids) # Project hidden states to shape `(batch_size, seq_length, vocab_size)`. - logits = embedding(hidden_state, reverse=True) + logits = embedding(hidden_states, reverse=True) ``` References: From c1b6b549af0fad3813aad42e1c8ea5629356f086 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:33:50 -0800 Subject: [PATCH 22/70] doc fix for constrastive sampler (#1488) Fixes https://github.com/keras-team/keras-nlp/issues/1481 --- keras_nlp/samplers/contrastive_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 4259167c8c..36d10690d7 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -34,7 +34,6 @@ class ContrastiveSampler(Sampler): alpha: float, the weight of minus max similarity in joint score computation. The larger the value of `alpha`, the score relies more on the similarity than the token probability. - seed: int. The random seed. Defaults to `None`. Call arguments: {{call_args}} From f3eda3cd2df1c7f160e6c4c0ea1c3f99907b2dbd Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:34:11 -0800 Subject: [PATCH 23/70] Remove broken link to masking and padding guide (#1487) Fixes https://github.com/keras-team/keras-nlp/issues/1446 --- keras_nlp/layers/modeling/transformer_decoder.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 15c245768c..d06a1948f5 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -34,12 +34,9 @@ class TransformerDecoder(keras.layers.Layer): paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users can instantiate multiple instances of this class to stack up a decoder. - By default, this layer will apply a causal mask to the decoder attention layer. - This layer will correctly compute an attention mask from an implicit - Keras padding mask (for example, by passing `mask_zero=True` to a - `keras.layers.Embedding` layer). See the Masking and Padding - [guide](https://keras.io/guides/understanding_masking_and_padding/) - for more details. + By default, this layer will apply a causal mask to the decoder attention + layer. You can also pass padding or attention masks directly to the layer + during call, e.g. with `decoder_padding_mask` or `decoder_attention_mask`. This layer can be called with either one or two inputs. The number of inputs must be consistent across all calls. The options are as follows: From 7f692ca960dbf0f63a5f924a4c145d4a943fd58f Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Tue, 5 Mar 2024 15:29:15 -0800 Subject: [PATCH 24/70] Fix a typo. (#1489) --- keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py | 4 ++-- keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py | 4 ++-- keras_nlp/models/generative_task.py | 2 +- keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py | 4 ++-- .../models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py | 4 ++-- keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py | 4 ++-- keras_nlp/models/opt/opt_causal_lm_preprocessor.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py index 01f3c88d30..60491ed931 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py index 20c66edff3..04a067be82 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -124,7 +124,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -152,7 +152,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 598217d964..99c447e8ef 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -137,7 +137,7 @@ def _normalize_generate_inputs( ): """Normalize user input to the generate function. - This function coverts all inputs to tensors, adds a batch dimension if + This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). """ diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 97d0b42d97..3278b18a4f 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py index 92ff9bbb03..71b3d8ce04 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py @@ -99,7 +99,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -127,7 +127,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py index 893036cd58..624c37c9a1 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py index 1895854e41..0a9ab86b00 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py @@ -132,7 +132,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -160,7 +160,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence From 8851624ec32c89e5e3c679e0f9be7f3c58a875b4 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Tue, 5 Mar 2024 19:14:57 -0800 Subject: [PATCH 25/70] Fix dtype accessors of tasks/backbones (#1486) * Fix dtype accessors of tasks/backbones * Address comments, minor fixes --- keras_nlp/models/albert/albert_backbone.py | 1 + keras_nlp/models/backbone.py | 11 ++++++++++- keras_nlp/models/bart/bart_backbone.py | 1 + keras_nlp/models/bert/bert_backbone.py | 1 + keras_nlp/models/bloom/bloom_backbone.py | 1 + keras_nlp/models/deberta_v3/deberta_v3_backbone.py | 1 + keras_nlp/models/distil_bert/distil_bert_backbone.py | 1 + keras_nlp/models/electra/electra_backbone.py | 1 + keras_nlp/models/f_net/f_net_backbone.py | 1 + keras_nlp/models/falcon/falcon_backbone.py | 1 + keras_nlp/models/gemma/gemma_backbone.py | 1 + keras_nlp/models/gpt2/gpt2_backbone.py | 1 + keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py | 1 + keras_nlp/models/llama/llama_backbone.py | 1 + keras_nlp/models/mistral/mistral_backbone.py | 1 + keras_nlp/models/opt/opt_backbone.py | 1 + keras_nlp/models/roberta/roberta_backbone.py | 1 + keras_nlp/models/t5/t5_backbone.py | 1 + keras_nlp/models/task.py | 10 ++++++++-- keras_nlp/models/whisper/whisper_backbone.py | 1 + keras_nlp/models/xlnet/xlnet_backbone.py | 1 + keras_nlp/tests/test_case.py | 8 ++++---- 22 files changed, 41 insertions(+), 7 deletions(-) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 1e342e791c..09053ff893 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -230,6 +230,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 9c8cdaa60e..867616da69 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -28,6 +28,15 @@ def __init__(self, *args, dtype=None, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True + if dtype is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + if isinstance(dtype, keras.DTypePolicy): + self.dtype_policy = dtype + else: + self.dtype_policy = keras.DTypePolicy(dtype) + else: + self._set_dtype_policy(dtype) def __dir__(self): if config.keras_3(): @@ -67,7 +76,7 @@ def token_embedding(self): This layer embeds integer token ids to the hidden dim of the model. """ - return self._token_embedding + return getattr(self, "_token_embedding", None) @token_embedding.setter def token_embedding(self, value): diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 803d5a2a9f..f100133d25 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -232,6 +232,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 2248260da7..320dc1c2ee 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -196,6 +196,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 5737dcc889..5c6f81ca5b 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -149,6 +149,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index e7bd8ca20a..9063b11df5 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -178,6 +178,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 1ae0840ea8..73634b4216 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -159,6 +159,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 13be2d8eb8..f4f2a23b69 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -202,6 +202,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 309f312a17..ab056c84c7 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -206,6 +206,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py index 4951189fe0..5a3a0fccda 100644 --- a/keras_nlp/models/falcon/falcon_backbone.py +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -130,6 +130,7 @@ def __init__( "padding_mask": padding_mask, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index e5814940aa..c829aa948f 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -157,6 +157,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index d93b2199b0..b7d2b10acf 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -170,6 +170,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py index 1955ed5801..415fa56af2 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py @@ -137,6 +137,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index cc628ad7a5..733d9ef434 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -127,6 +127,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 3e2cfae148..52de945760 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -166,6 +166,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 0b98a6c64e..16fe4a0218 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -146,6 +146,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 1ab61eeeb7..09fe753762 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -156,6 +156,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index cf747c503c..862c4766f4 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -224,6 +224,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 783cc0b41b..0656d2194e 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -36,6 +36,12 @@ def __init__(self, *args, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True + if self.backbone is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + self.dtype_policy = self._backbone.dtype_policy + else: + self._set_dtype_policy(self._backbone.dtype_policy) def __dir__(self): if config.keras_3(): @@ -128,7 +134,7 @@ def __setattr__(self, name, value): @property def backbone(self): """A `keras.Model` instance providing the backbone sub-model.""" - return self._backbone + return getattr(self, "_backbone", None) @backbone.setter def backbone(self, value): @@ -137,7 +143,7 @@ def backbone(self, value): @property def preprocessor(self): """A `keras.layers.Layer` instance used to preprocess inputs.""" - return self._preprocessor + return getattr(self, "_preprocessor", None) @preprocessor.setter def preprocessor(self, value): diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index c66a61d4e5..a2b685544e 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -274,6 +274,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 0d660bead9..45be1f74e7 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -184,6 +184,7 @@ def __init__( "segment_ids": segment_id_input, }, outputs=output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 0541ae6451..6b88757c64 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -329,10 +329,10 @@ def run_precision_test(self, cls, init_kwargs, input_data): for weight in layer.weights: if is_float_dtype(weight.dtype): self.assertDTypeEqual(weight, policy.variable_dtype) - for sublayer in layer._flatten_layers(include_self=False): - if isinstance( - sublayer, (keras.layers.Softmax, keras.layers.InputLayer) - ): + for sublayer in layer._flatten_layers(): + if isinstance(sublayer, keras.layers.Softmax): + continue + if isinstance(sublayer, keras.layers.InputLayer): continue self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) From f92d4f896480acfd1c71997fba45f580595e0c2c Mon Sep 17 00:00:00 2001 From: Shivam Mishra <124146945+shmishra99@users.noreply.github.com> Date: Thu, 7 Mar 2024 00:58:30 +0530 Subject: [PATCH 26/70] Auto-labels 'gemma' on 'gemma' issues/PRs. (#1490) * Auto-labels 'gemma' on 'gemma' issues/PRs. * make labeling logic generic. * Update labeler.yaml with formatting fixes * Update labeler.js with formatting fixes * Update script reference. --------- Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com> --- .github/workflows/labeler.yaml | 42 ++++++++++++++++++++++++ .github/workflows/scripts/labeler.js | 49 ++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 .github/workflows/labeler.yaml create mode 100644 .github/workflows/scripts/labeler.js diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml new file mode 100644 index 0000000000..6832d9acb0 --- /dev/null +++ b/.github/workflows/labeler.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This workflow automatically identifies issues and pull requests (PRs) +# related to Gemma. It searches for the keyword "Gemma" (case-insensitive) +# in both the title and description of the issue/PR. If a match is found, +# the workflow adds the label 'Gemma' to the issue/PR. + +name: 'Labeler' +on: + issues: + types: [edited, opened] + pull_request_target: + types: [opened, edited] + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/labeler.js') + script({github, context}) diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js new file mode 100644 index 0000000000..7240113cc3 --- /dev/null +++ b/.github/workflows/scripts/labeler.js @@ -0,0 +1,49 @@ +/* +Copyright 2024 Google LLC. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + + +/** + * Invoked from labeler.yaml file to add + * label 'Gemma' to the issue and PR for which have gemma keyword present. + * @param {!Object.} github contains pre defined functions. + * context Information about the workflow run. + */ + +module.exports = async ({ github, context }) => { + const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title + const issue_discription = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body + const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number + const keyword_label = { + gemma:'Gemma' + } + const labelsToAdd = [] + console.log(issue_title,issue_discription,issue_number) + + for(const [keyword, label] of Object.entries(keyword_label)){ + if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_discription.toLowerCase().indexOf(keyword) !=-1 ){ + console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) + labelsToAdd.push(label) + } + } + if(labelsToAdd.length > 0){ + console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`) + github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: labelsToAdd + }) + } +}; From 3cacebde615ca411af7cd95b7f32a70504321ae0 Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Wed, 6 Mar 2024 22:51:18 +0200 Subject: [PATCH 27/70] Add BloomCausalLM (#1467) * Initial commit for BloomCausalLm * Avoid adding a start token into token ids * Revert "Avoid adding a start token into token ids" This reverts commit 57ce4c203b40bdfc8c8361987df25349029cecfc. * Tie embeddding weights * Export BloomCausalLM * Add bloomz to the preset map * Fix some presets names * Revert "Fix some presets names" This reverts commit 1b5949ffae58fef20b2a931638a2813230b4f96a. * Add doc Example * format the code * Alibi bias small fixes * Add tests * Maek float16 dtype test to keras3 only * Update model version * Edit exampples * Remove max_sequnce_length argument * Try fix keras2 error * Update hf_model download fun * Make 1b models easier to copy * Optimize conversion script * Save checkpoints in float16 * Add test for mixed_float16 * Try to reproduce the keras2 Error * Revert "Try to reproduce the keras2 Error" This reverts commit 7cc8671097277c3c2503c0fe6d98facabe87ecf6. * Revert "Make 1b models easier to copy" This reverts commit 4d701d3331dd1c8db079b2e1f1fd2a974a72f244. * Show How to couple dtype_policy between backbone, causalLM, and backbone layers in dtype arg is based to bacckbone * Revert "Show How to couple dtype_policy between backbone, causalLM, and backbone layers in dtype arg is based to bacckbone" This reverts commit 77c576bf2002d4c3889bcc83444b870edcad3e7e. * Add validate_only flag to conversion script * Change preset version * set cache dtype to self.compute_dtype * Minor fix --- keras_nlp/layers/modeling/alibi_bias.py | 4 +- keras_nlp/layers/modeling/alibi_bias_test.py | 4 - keras_nlp/models/__init__.py | 5 + keras_nlp/models/bloom/bloom_backbone.py | 7 - keras_nlp/models/bloom/bloom_backbone_test.py | 1 - keras_nlp/models/bloom/bloom_causal_lm.py | 318 ++++++++++++++++++ .../bloom/bloom_causal_lm_preprocessor.py | 4 +- .../models/bloom/bloom_causal_lm_test.py | 188 +++++++++++ keras_nlp/models/bloom/bloom_decoder.py | 34 +- keras_nlp/models/bloom/bloom_preprocessor.py | 4 +- keras_nlp/models/bloom/bloom_presets.py | 2 +- keras_nlp/models/bloom/bloom_tokenizer.py | 18 +- .../convert_bloom_checkpoints.py | 222 ++++++++---- 13 files changed, 693 insertions(+), 118 deletions(-) create mode 100644 keras_nlp/models/bloom/bloom_causal_lm.py create mode 100644 keras_nlp/models/bloom/bloom_causal_lm_test.py diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 8a66ad05af..fdc956ae15 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -94,7 +94,9 @@ def _get_alibi_bias(self, num_heads, key_length): ) slopes = ops.expand_dims(slopes, 1) - seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) + seq_range = ops.expand_dims( + ops.arange(1 - key_length, 1, dtype="int32"), 0 + ) seq_range = ops.cast(seq_range, dtype=self.compute_dtype) alibi_bias = ops.multiply(slopes, seq_range) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 120c7622b1..69a48f29e1 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -99,7 +99,6 @@ def test_correct_output(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias() output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -127,7 +126,6 @@ def test_correct_output_num_heads_not_power_of_two(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias() output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -162,7 +160,6 @@ def test_correct_output_alibi_bias_max(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -187,7 +184,6 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( input_tensor = ops.zeros(input_shape) layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index cdd50670f3..692b51c4da 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -36,6 +36,11 @@ from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor from keras_nlp.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 5c6f81ca5b..9b7c65a399 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -53,8 +53,6 @@ class BloomBackbone(Backbone): dropout: float. Dropout probability for the Transformer decoder. layer_norm_epsilon: float. Epsilon for the layer normalization layers in the transformer decoder. - max_sequence_length: int. The maximum sequence length that this decoder - can consume. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at @@ -80,7 +78,6 @@ class BloomBackbone(Backbone): intermediate_dim=32*4, dropout=0.0, layer_norm_epsilon=1e-5, - max_sequence_length=128, ) model(input_data) ``` @@ -96,7 +93,6 @@ def __init__( intermediate_dim, dropout=0.0, layer_norm_epsilon=1e-5, - max_sequence_length=2048, dtype=None, **kwargs, ): @@ -105,7 +101,6 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=_bloom_kernel_initializer(stddev=0.02), - tie_weights=False, dtype=dtype, name="token_embedding", ) @@ -161,7 +156,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.dropout = dropout self.layer_norm_epsilon = layer_norm_epsilon - self.max_sequence_length = max_sequence_length def get_config(self): config = super().get_config() @@ -174,7 +168,6 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, "layer_norm_epsilon": self.layer_norm_epsilon, - "max_sequence_length": self.max_sequence_length, } ) return config diff --git a/keras_nlp/models/bloom/bloom_backbone_test.py b/keras_nlp/models/bloom/bloom_backbone_test.py index 83732e4945..47ff7ec4cc 100644 --- a/keras_nlp/models/bloom/bloom_backbone_test.py +++ b/keras_nlp/models/bloom/bloom_backbone_test.py @@ -27,7 +27,6 @@ def setUp(self): "num_heads": 4, "hidden_dim": 8, "intermediate_dim": 32, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py new file mode 100644 index 0000000000..31eae30c6b --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -0,0 +1,318 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.BloomCausalLM") +class BloomCausalLM(GenerativeTask): + """An end-to-end BLOOM model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a BLOOM model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.BloomBackbone` instance. + preprocessor: A `keras_nlp.models.BloomCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + bloom_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.compile(sampler="top_k") + bloom_lm.generate("I want to say", max_length=30) + + bloom_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + bloom_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[1, 46, 15762, 632, 3, 3, 3, 3, 3]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0]] * 2), + } + + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset( + "bloom_560m_multi", + preprocessor=None, + ) + bloom_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset( + "bloom_560m_multi", + preprocessor=None, + ) + bloom_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + features = [ + " airplane at airport", + " airplane airport", + ] + vocab = ["", "", "", ""] + vocab += ["!", "air", "Ġair", "plane", "Ġat", "port"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + merges += ["Ġai r", "Ġa i", "pla ne"] + tokenizer = keras_nlp.models.BloomTokenizer(vocabulary=vocab, merges=merges) + preprocessor = keras_nlp.models.BloomCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.BloomBackbone( + vocabulary_size=tokenizer.vocabulary_size(), + num_layers=4, + num_heads=4, + hidden_dim=32, + intermediate_dim=128, + ) + bloom_lm = keras_nlp.models.BloomCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + bloom_lm.fit(x=features, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def backbone_cls(cls): + return BloomBackbone + + @classproperty + def preprocessor_cls(cls): + return BloomCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `BloomCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = self.backbone.embeddings_layer_norm(x) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_heads + head_dim = self.backbone.hidden_dim // num_heads + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + end_token_id=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + end_token_id: The id of the end token to stop on. If all + sequences have produced a new `end_token_id`, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + end_token_id=end_token_id, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if end_token_id is not None: + # Build a mask of `end_token_id` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + ops.equal(token_ids, end_token_id), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py index 60491ed931..b56e1a3ef0 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -175,8 +175,8 @@ def generate_postprocess( # end markers). In the future we could make this configurable. padding_mask = ( padding_mask - & (token_ids != self.tokenizer.eos_token_id) - & (token_ids != self.tokenizer.bos_token_id) + & (token_ids != self.tokenizer.start_token_id) + & (token_ids != self.tokenizer.end_token_id) ) token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/bloom/bloom_causal_lm_test.py b/keras_nlp/models/bloom/bloom_causal_lm_test.py new file mode 100644 index 0000000000..70af6a2302 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm_test.py @@ -0,0 +1,188 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.tests.test_case import TestCase + + +class BloomCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["", "", "", ""] + self.vocab += ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = BloomTokenizer( + vocabulary=self.vocab, merges=self.merges + ) + self.preprocessor = BloomCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = BloomBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_heads=2, + hidden_dim=4, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ( + [ + " airplane at airport", + " airplane airport", + ], + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + vocabulary_size = self.tokenizer.vocabulary_size() + self.run_task_test( + cls=BloomCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, vocabulary_size), + ) + + def test_generate(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + backbone = BloomBackbone.from_config( + {**self.backbone.get_config(), "dtype": "bfloat16"} + ) + causal_lm = BloomCausalLM( + backbone=backbone, preprocessor=self.preprocessor + ) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_mixed_float16(self): + backbone = BloomBackbone.from_config( + {**self.backbone.get_config(), "dtype": "mixed_float16"} + ) + causal_lm = BloomCausalLM( + backbone=backbone, preprocessor=self.preprocessor + ) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_early_stopping(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["airplane at", "airplane"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate("airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BloomCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomCausalLM.presets: + self.run_preset_test( + cls=BloomCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index a0c62a2541..c6478e7270 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -110,8 +110,8 @@ def call( decoder_sequence, decoder_padding_mask=None, decoder_attention_mask=None, - attention_cache=None, - attention_cache_update_index=None, + cache=None, + cache_update_index=None, use_causal_mask=True, ): self_attention_mask = self._compute_attention_mask( @@ -119,8 +119,8 @@ def call( decoder_padding_mask=decoder_padding_mask, decoder_attention_mask=decoder_attention_mask, use_causal_mask=use_causal_mask, - attention_cache=attention_cache, - attention_cache_update_index=attention_cache_update_index, + cache=cache, + cache_update_index=cache_update_index, ) residual = decoder_sequence @@ -129,14 +129,14 @@ def call( attention_output = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, - cache=attention_cache, - cache_update_index=attention_cache_update_index, + cache=cache, + cache_update_index=cache_update_index, ) - if attention_cache is None: + if cache is None: x = attention_output else: - x, attention_cache = attention_output + x, cache = attention_output x = x + residual residual = x @@ -147,8 +147,8 @@ def call( x = self._dropout_layer(x) x = x + residual - if attention_cache is not None: - return x, attention_cache + if cache is not None: + return x, cache else: return x @@ -158,8 +158,8 @@ def _compute_attention_mask( decoder_padding_mask, decoder_attention_mask, use_causal_mask, - attention_cache, - attention_cache_update_index, + cache, + cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -167,18 +167,14 @@ def _compute_attention_mask( if use_causal_mask: batch_size = ops.shape(decoder_sequence)[0] input_length = output_length = ops.shape(decoder_sequence)[1] - if attention_cache is not None: - input_length = ops.shape(attention_cache)[2] + if cache is not None: + input_length = ops.shape(cache)[2] causal_mask = compute_causal_mask( batch_size, input_length, output_length, - ( - 0 - if attention_cache_update_index is None - else attention_cache_update_index - ), + (0 if cache_update_index is None else cache_update_index), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index 734c9f4bf8..8eb693cb50 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -126,8 +126,8 @@ def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer # assets have loaded when restoring a saved model. self.packer = StartEndPacker( - start_value=self.tokenizer.bos_token_id, - end_value=self.tokenizer.eos_token_id, + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, pad_value=self.tokenizer.pad_token_id, sequence_length=self.sequence_length, return_padding_mask=True, diff --git a/keras_nlp/models/bloom/bloom_presets.py b/keras_nlp/models/bloom/bloom_presets.py index d3e9c780c0..7d24c04aa5 100644 --- a/keras_nlp/models/bloom/bloom_presets.py +++ b/keras_nlp/models/bloom/bloom_presets.py @@ -25,6 +25,6 @@ "path": "bloom", "model_card": "https://huggingface.co/bigscience/bloom", }, - "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/1", + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/3", }, } diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index cc3fcc2fc3..0d7f74b163 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -74,16 +74,16 @@ def __init__( merges=None, **kwargs, ): - self.bos_token = "" - self.eos_token = "" + self.start_token = "" + self.end_token = "" self.pad_token = "" super().__init__( vocabulary=vocabulary, merges=merges, unsplittable_tokens=[ - self.bos_token, - self.eos_token, + self.start_token, + self.end_token, self.pad_token, ], **kwargs, @@ -94,7 +94,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges): if vocabulary is not None: # Check for necessary special tokens. - for token in [self.bos_token, self.eos_token, self.pad_token]: + for token in [self.start_token, self.end_token, self.pad_token]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " @@ -102,12 +102,12 @@ def set_vocabulary_and_merges(self, vocabulary, merges): "your `vocabulary` or use a pretrained `vocabulary` name." ) - self.bos_token_id = self.token_to_id(self.bos_token) - self.eos_token_id = self.token_to_id(self.eos_token) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) self.pad_token_id = self.token_to_id(self.pad_token) else: - self.bos_token_id = None - self.eos_token_id = None + self.start_token_id = None + self.end_token_id = None self.pad_token_id = None @classproperty diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py index 38acd099cf..d8a36b3912 100644 --- a/tools/checkpoint_conversion/convert_bloom_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -15,19 +15,16 @@ import json import os -os.environ["KERAS_BACKEND"] = "torch" -os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import huggingface_hub +import numpy as np +import transformers +from absl import app +from absl import flags -import huggingface_hub # noqa: E402 -import numpy as np # noqa: E402 -import torch # noqa: E402 -import transformers # noqa: E402 -from absl import app # noqa: E402 -from absl import flags # noqa: E402 - -import keras_nlp # noqa: E402 -from keras_nlp.models import BloomBackbone # noqa: E402 -from keras_nlp.models import BloomTokenizer # noqa: E402 +import keras_nlp +from keras_nlp.models import BloomBackbone +from keras_nlp.models import BloomPreprocessor +from keras_nlp.models import BloomTokenizer FLAGS = flags.FLAGS @@ -37,23 +34,47 @@ "bloom_1.7b_multi": "bigscience/bloom-1b7", "bloom_3b_multi": "bigscience/bloom-3b", "bloom_7b_multi": "bigscience/bloom-7b1", - "bloom_176b_multi": "bigscience/bloom", + "bloom_multi": "bigscience/bloom", + # Multitask finetuned on xP3 (Crosslingual Public Pool of Prompts) https://huggingface.co/datasets/bigscience/xP3 + # xP3 is a mixture of 13 training tasks in 46 languages with English prompts + "bloomz_560m_multi": "bigscience/bloomz-560m", + "bloomz_1.1b_multi": "bigscience/bloomz-1b1", + "bloomz_1.7b_multi": "bigscience/bloomz-1b7", + "bloomz_3b_multi": "bigscience/bloomz-3b", + "bloomz_7b_multi": "bigscience/bloomz-7b1", + "bloomz_multi": "bigscience/bloomz", + # Multitask finetuned on xP3mt + # (Crosslingual Public Pool of Prompts machine-translated) https://huggingface.co/datasets/bigscience/xP3 + # xP3mt is Mixture of 13 training tasks in 46 languages with prompts in 20 + # languages (machine-translated from English) + "bloomz_7b_mt": "bigscience/bloomz-7b1-mt", + "bloomz_mt": "bigscience/bloomz-mt", + # Multitask finetuned on P3 (Public Pool of Prompts) https://huggingface.co/datasets/Muennighoff/P3 + # xP3 is a mixture of 8 training tasks with English-only prompts + "bloomz_7b_p3": "bigscience/bloomz-7b1-p3", + "bloomz_p3": "bigscience/bloomz-p3", } EXTRACT_DIR = "./model" flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f'Must be one of {", ".join(PRESET_MAP.keys())}' ) flags.mark_flag_as_required("preset") +flags.DEFINE_boolean( + "validate_only", + False, + "To validate the output of a preset that has been already uploaded. " + "No weights conversion will happen.", +) def download_hf_model(hf_model_name): hf_model_dir = huggingface_hub.snapshot_download( repo_id=hf_model_name, allow_patterns=["*.json", "*.bin"], - ignore_patterns=["onnx/*"], + ignore_patterns=["*/*"], local_dir=EXTRACT_DIR, ) @@ -99,51 +120,61 @@ def convert_weights(keras_model, hf_model): # assign huggingface weights to the keras model. # Embedding layer. keras_model.get_layer("token_embedding").embeddings.assign( - hf_wts["word_embeddings.weight"] + hf_wts["word_embeddings.weight"].detach().numpy() ) # LayerNorm. keras_model.get_layer("token_embedding_layernorm").gamma.assign( - hf_wts["word_embeddings_layernorm.weight"] + hf_wts["word_embeddings_layernorm.weight"].detach().numpy() ) keras_model.get_layer("token_embedding_layernorm").beta.assign( - hf_wts["word_embeddings_layernorm.bias"] + hf_wts["word_embeddings_layernorm.bias"].detach().numpy() ) - keras_model.get_layer("final_layernorm").gamma.assign(hf_wts["ln_f.weight"]) - keras_model.get_layer("final_layernorm").beta.assign(hf_wts["ln_f.bias"]) + keras_model.get_layer("final_layernorm").gamma.assign( + hf_wts["ln_f.weight"].detach().numpy() + ) + keras_model.get_layer("final_layernorm").beta.assign( + hf_wts["ln_f.bias"].detach().numpy() + ) # Decoder layers. for i in range(num_layers): decoder_layer = keras_model.get_layer(f"transformer_layer_{i}") # LayrNorm. decoder_layer._pre_attention_layernorm.gamma.assign( - hf_wts[f"h.{i}.input_layernorm.weight"] + hf_wts[f"h.{i}.input_layernorm.weight"].detach().numpy() ) decoder_layer._pre_attention_layernorm.beta.assign( - hf_wts[f"h.{i}.input_layernorm.bias"] + hf_wts[f"h.{i}.input_layernorm.bias"].detach().numpy() ) decoder_layer._post_attention_layernorm.gamma.assign( - hf_wts[f"h.{i}.post_attention_layernorm.weight"] + hf_wts[f"h.{i}.post_attention_layernorm.weight"].detach().numpy() ) decoder_layer._post_attention_layernorm.beta.assign( - hf_wts[f"h.{i}.post_attention_layernorm.bias"] + hf_wts[f"h.{i}.post_attention_layernorm.bias"].detach().numpy() ) # Attention layer. attention_layer = decoder_layer._self_attention_layer - fused_qkv_kernal = hf_wts[ - f"h.{i}.self_attention.query_key_value.weight" - ].T - fused_qkv_kernal = fused_qkv_kernal.view( + fused_qkv_kernal = ( + hf_wts[f"h.{i}.self_attention.query_key_value.weight"] + .T.detach() + .numpy() + ) + fused_qkv_kernal = fused_qkv_kernal.reshape( hidden_dim, num_heads, 3, head_dim ) query_kernal = fused_qkv_kernal[..., 0, :] key_kernal = fused_qkv_kernal[..., 1, :] value_kernl = fused_qkv_kernal[..., 2, :] - fused_qkv_bais = hf_wts[f"h.{i}.self_attention.query_key_value.bias"] - fused_qkv_bais = fused_qkv_bais.view(num_heads, 3, head_dim) + fused_qkv_bais = ( + hf_wts[f"h.{i}.self_attention.query_key_value.bias"] + .detach() + .numpy() + ) + fused_qkv_bais = fused_qkv_bais.reshape(num_heads, 3, head_dim) query_bais = fused_qkv_bais[:, 0, :] key_bais = fused_qkv_bais[:, 1, :] value_bais = fused_qkv_bais[:, 2, :] @@ -156,24 +187,24 @@ def convert_weights(keras_model, hf_model): attention_layer._value_dense.bias.assign(value_bais) attention_layer._output_dense.kernel.assign( - hf_wts[f"h.{i}.self_attention.dense.weight"].T + hf_wts[f"h.{i}.self_attention.dense.weight"].T.detach().numpy() ) attention_layer._output_dense.bias.assign( - hf_wts[f"h.{i}.self_attention.dense.bias"] + hf_wts[f"h.{i}.self_attention.dense.bias"].detach().numpy() ) # mlp. decoder_layer._mlp_intermediate_dense.kernel.assign( - hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T + hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T.detach().numpy() ) decoder_layer._mlp_intermediate_dense.bias.assign( - hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"] + hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"].detach().numpy() ) decoder_layer._mlp_output_dense.kernel.assign( - hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T + hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T.detach().numpy() ) decoder_layer._mlp_output_dense.bias.assign( - hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"] + hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"].detach().numpy() ) @@ -185,20 +216,21 @@ def validate_output( ): input_str = ["the quick brown fox ran, galloped and jumped."] - # KerasNLP - token_ids = torch.tensor(keras_tokenizer(input_str)) - padding_mask = token_ids != 3 - keras_model_input = { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - keras_model_outputs = keras_model.predict(keras_model_input) - + # HuggingFace hf_model_input = hf_tokenizer(input_str, return_tensors="pt") - hf_model_outputs = hf_model(**hf_model_input).last_hidden_state hf_model_outputs = hf_model_outputs.detach().numpy() + # KerasNLP + preprocessor = BloomPreprocessor( + tokenizer=keras_tokenizer, + sequence_length=hf_model_outputs.shape[1], + add_end_token=False, + add_start_token=False, + ) + keras_model_input = preprocessor(input_str) + keras_model_outputs = keras_model.predict(keras_model_input) + # Comparing the outputs. print("🔶 KerasNLP output:", keras_model_outputs[0, 0, :10]) print("🔶 HF output:", hf_model_outputs[0, 0, :10]) @@ -207,41 +239,87 @@ def validate_output( def main(_): preset = FLAGS.preset - assert ( preset in PRESET_MAP.keys() - ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + ), f'Invalid preset {preset}. Must be one of {", ".join(PRESET_MAP.keys())}' + + validate_only = FLAGS.validate_only + + if not validate_only: + print(f"✅ Coverting {preset}") - print(f"✅ Coverting {preset}") + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("✅ Huggingface model downloaded from hub") - hf_model_name = PRESET_MAP[preset] - hf_model_dir = download_hf_model(hf_model_name) - print("✅ Huggingface model downloaded from hub") + hf_model = transformers.BloomModel.from_pretrained( + hf_model_dir, + ) + hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained( + hf_model_dir + ) + print("✅ Huggingface model loaded") - hf_model = transformers.BloomModel.from_pretrained(hf_model_dir) - hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained(hf_model_dir) - print("✅ Huggingface model loaded") + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("✅ Keras model loaded") - keras_model = convert_model(hf_model) - keras_tokenizer = convert_tokenizer(hf_model_dir) - print("✅ Keras model loaded") + convert_weights(keras_model, hf_model) + print("✅ Weights converted") - convert_weights(keras_model, hf_model) - print("✅ Weights converted") + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("✅ Numerics validated") - validate_output( - hf_model, - keras_model, - hf_tokenizer, - keras_tokenizer, - ) - print("✅ Numerics validated") + # Delete huggingface model + del hf_model + del hf_tokenizer - keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) - keras_nlp.src.utils.preset_utils.save_to_preset( - keras_tokenizer, preset, config_filename="tokenizer.json" - ) - print("✅ Preset saved") + # Save float32 keras preset + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + + # Delete float32 Keras model + del keras_model + + # Load The model in float16 percision + preset_path = os.path.join(os.getcwd(), preset) + keras_model = BloomBackbone.from_preset(preset_path, dtype="float16") + + # Save float16 keras model + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + + print("✅ Preset saved") + else: + print(f"✅ Validating {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("✅ Huggingface model downloaded from hub") + + hf_model = transformers.BloomModel.from_pretrained( + hf_model_dir, + ) + hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained( + hf_model_dir + ) + + keras_model = BloomBackbone.from_preset(preset) + keras_tokenizer = BloomTokenizer.from_preset(preset) + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("✅ Numerics validated") if __name__ == "__main__": From 536e1ba25d85aed07daa28c869dd7b4209c14ebd Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:37:34 -0800 Subject: [PATCH 28/70] Remove the bert jupyter conversion notebooks (#1492) This is not how we do conversion anymore, and it really clutters the library. We can refer to these from our git history if we need them --- .../bert_base_cased.ipynb | 1073 -------------- .../bert_base_multi_cased.ipynb | 1084 -------------- .../bert_base_uncased.ipynb | 1073 -------------- .../checkpoint_conversion/bert_base_zh.ipynb | 1084 -------------- .../bert_large_cased_en.ipynb | 1280 ----------------- .../bert_large_uncased_en.ipynb | 1280 ----------------- .../bert_medium_uncased_en.ipynb | 971 ------------- .../bert_small_uncased_en.ipynb | 895 ------------ .../bert_tiny_uncased_en.ipynb | 858 ----------- 9 files changed, 9598 deletions(-) delete mode 100644 tools/checkpoint_conversion/bert_base_cased.ipynb delete mode 100644 tools/checkpoint_conversion/bert_base_multi_cased.ipynb delete mode 100644 tools/checkpoint_conversion/bert_base_uncased.ipynb delete mode 100644 tools/checkpoint_conversion/bert_base_zh.ipynb delete mode 100644 tools/checkpoint_conversion/bert_large_cased_en.ipynb delete mode 100644 tools/checkpoint_conversion/bert_large_uncased_en.ipynb delete mode 100644 tools/checkpoint_conversion/bert_medium_uncased_en.ipynb delete mode 100644 tools/checkpoint_conversion/bert_small_uncased_en.ipynb delete mode 100644 tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb diff --git a/tools/checkpoint_conversion/bert_base_cased.ipynb b/tools/checkpoint_conversion/bert_base_cased.ipynb deleted file mode 100644 index a7fc9db3ef..0000000000 --- a/tools/checkpoint_conversion/bert_base_cased.ipynb +++ /dev/null @@ -1,1073 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "564b86a2-a7ed-4e22-f1fa-4246368d30a7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.8 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 49.8 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 45.7 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 49.9 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 59.9 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 57.4 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 48.3 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 67.8 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.1 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 11.3 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 51.7 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 73.0 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 60.3 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 71.3 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/jbischof/keras-nlp.git@bert_ckpt tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "source": [ - "TOKEN_TYPE = \"cased\"\n", - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_NAME = MODEL_TYPE + \"_\" + TOKEN_TYPE\n", - "VOCAB_SIZE = 28996" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "JdXFWsMVEf-x", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "4d87dbf9-dda5-4afe-fc07-533c84ad768a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\n", - "401886519/401886519 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{TOKEN_TYPE}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "LHYiSsvYtfEU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8c63f095-4a49-4de0-b61a-8b2fe4705ff8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = extract_dir + \"vocab.txt\"\n", - "checkpoint_path = extract_dir + \"bert_model.ckpt\"\n", - "config_path = extract_dir + \"bert_config.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "5f3b159b-381e-48e8-a77b-138b251a458c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [28996, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "220a2d92-27b0-41d7-a081-d39aea11d89a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 22268928 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 108,310,272\n", - "Trainable params: 108,310,272\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\"encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "3627782d-92d2-4299-8a5e-55a74402d01e" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Swp2Y0OvYIId", - "outputId": "f413f078-4d4a-42ee-a048-bcf77e655f4b" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 27 - } - ], - "source": [ - "keras_nlp_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "b46aab26-c8c2-467f-e8f8-6c3fa992d5b2" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 28 - } - ], - "source": [ - "mg_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "4447e234-bf17-46f2-f09f-86493a61e1cb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "42e126a0-92af-4779-f652-47b3396d38a9" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "f506140a-217b-4f45-b331-793e47fc5bae" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "210441" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "07debeae-6f54-4201-8d15-8cfa992fb121" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "f30ac6fac11115322e3d0c61e87e98b2 bert_base_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "b96360bc-47d6-4611-e727-f861d6038726" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_cased/model.h5\n", - "433474808/433474808 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "f1aa7c12-6681-4ae2-c689-f4b7ddfd1701" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 23 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, 0:10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OKfvONXmzX_L", - "outputId": "ece23e28-78b2-415a-faba-be5a3825812f" - }, - "execution_count": 26, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 26 - } - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "checkpoint convert model garden -> keras-nlp Bert cased", - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_multi_cased.ipynb b/tools/checkpoint_conversion/bert_base_multi_cased.ipynb deleted file mode 100644 index c87bbdbda0..0000000000 --- a/tools/checkpoint_conversion/bert_base_multi_cased.ipynb +++ /dev/null @@ -1,1084 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "52ac4b95-b8c9-4c2d-b4ca-eb57505d2cae" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.3 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 42.7 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 49.4 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 69.8 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 48.8 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 57.7 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 72.3 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.2 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 10.6 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 63.0 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 64.7 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 63.4 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 100.5 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 73.7 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-base-chinese tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_SUFFIX = \"multi_cased\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 119547" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JdXFWsMVEf-x", - "outputId": "043339cd-3d95-4083-c915-f9632effc319" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\n", - "660198400/660198400 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{MODEL_SUFFIX}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LHYiSsvYtfEU", - "outputId": "c185b640-d358-4d25-c249-7ac5ec4234df" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "d1f58eaa-a601-4bf1-bd1f-0b5ad1a28a28" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [119547, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "937d8db5-4567-4fdc-e4c8-3a14c1f50b36" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 91812096 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 177,853,440\n", - "Trainable params: 177,853,440\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"The झटपट brown लोमड़ी.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "9cdd88a8-efc8-43bb-c469-9496578f46d6" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"input_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9gF51d-CbwUh", - "outputId": "39dc40e7-7e9d-4a70-a55e-2ecc8d3e91ce" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "48a0565d-30bb-4db0-c658-a4a4af7dc599" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "12d8cb41-85f9-44a5-e990-9e497ac45fff" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "32cb37c4-344e-4b34-dab9-907f4f6b7216" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "3289713c-5ba2-4c9a-d3fa-fefa52a6223a" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "764415" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "d394313b-ccf0-4db5-b270-aee6f27d1f06" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "b0631cec0a1f2513c6cfd75ba29c33aa bert_base_multi_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "d66aca01-c1aa-404a-d2a1-090644f842bd" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_multi_cased/model.h5\n", - "711647480/711647480 [==============================] - 8s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_SUFFIX)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "330f724f-29a4-4e90-d1b8-669501575d80" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "57dd61f6-5415-47a6-af9b-01d082f6b745" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "UqSREee-sLmP" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_uncased.ipynb b/tools/checkpoint_conversion/bert_base_uncased.ipynb deleted file mode 100644 index 282998b76e..0000000000 --- a/tools/checkpoint_conversion/bert_base_uncased.ipynb +++ /dev/null @@ -1,1073 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "6fa43e61-f3ca-449b-a6ea-971bae8f5020" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.0 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 46.1 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 57.2 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 66.9 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 56.9 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 26.7 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 9.2 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 70.6 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 59.9 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 52.8 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 75.9 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.0 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 68.0 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 76.8 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/jbischof/keras-nlp.git@bert_ckpt tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "source": [ - "TOKEN_TYPE = \"uncased\"\n", - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_NAME = MODEL_TYPE + \"_\" + TOKEN_TYPE\n", - "VOCAB_SIZE = 30522" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "JdXFWsMVEf-x", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "fe21cd4d-7f08-4d49-ef07-1660e3e76fa4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\n", - "405351189/405351189 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{TOKEN_TYPE}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "LHYiSsvYtfEU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "77580043-09a7-4950-99bb-473d71614365" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = extract_dir + \"vocab.txt\"\n", - "checkpoint_path = extract_dir + \"bert_model.ckpt\"\n", - "config_path = extract_dir + \"bert_config.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "67ca9d60-c88e-4ad1-c93a-a06eeaf2e631" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [30522, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "next_sentence..pooler_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "next_sentence..pooler_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "e130445e-1676-4a21-b332-b373814aef14" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 23440896 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 109,482,240\n", - "Trainable params: 109,482,240\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"next_sentence..pooler_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\"next_sentence..pooler_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "72dfdb08-dbbb-4f8a-8a0a-2d881659edeb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Swp2Y0OvYIId", - "outputId": "8a981dbb-eb5b-432c-ee6d-52419e962cfd" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "2cb14435-68bc-48e3-b462-80842686d0e9" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "c7fc2d2d-919f-443e-85dd-d7139d01b173" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "0b8e4fbc-c356-4f66-d01c-d43fc1a326fb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "76d51be8-d503-4807-c78e-0b94a64ecb6c" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "9312bab4-4889-42b4-a7d6-b7e59c45dfcc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "94d5ec8bea7816f0c29904c959284343 bert_base_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "c92995ea-0e01-4289-f45b-3a52900ac7e7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_uncased/model.h5\n", - "438162680/438162680 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "fff1c4e5-a41c-42d4-df2d-303d4bc8a0ec" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 23 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, 0:10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "6b0bdfbe-27df-483c-a761-2ef34d8e5417" - }, - "execution_count": 24, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 24 - } - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "checkpoint convert model garden -> keras-nlp Bert uncased", - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_zh.ipynb b/tools/checkpoint_conversion/bert_base_zh.ipynb deleted file mode 100644 index 7dbf1dac8a..0000000000 --- a/tools/checkpoint_conversion/bert_base_zh.ipynb +++ /dev/null @@ -1,1084 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "0e8b9004-fe3f-4fb7-b905-cebb361887ab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.2 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 47.5 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 43.2 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 51.1 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 72.3 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 54.9 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 71.1 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.2 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 60.0 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 68.5 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 67.5 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 75.7 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 71.3 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 12.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-base-chinese tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_SUFFIX = \"chinese\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 21128" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JdXFWsMVEf-x", - "outputId": "e47ef71a-11c2-4c53-f398-ca3de4da77f9" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\n", - "379546341/379546341 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{MODEL_SUFFIX}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LHYiSsvYtfEU", - "outputId": "47518515-60ce-49f3-e6ac-39273b1d4616" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "59d95000-83a8-48f8-870a-51c2dbb5db6c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [21128, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "6b7cd86d-33af-4864-a3c7-63acf53fecc2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 16226304 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 102,267,648\n", - "Trainable params: 102,267,648\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"敏捷的棕色狐狸\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "5c8144f8-4a07-4a6d-8c06-6bca7e8625cf" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"input_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9gF51d-CbwUh", - "outputId": "dc0fdef7-1f6c-4e7d-ac84-71066108846d" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "893835bc-6438-45e5-c326-0ba631e1a4d6" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "e8094ce2-2cf5-4c6a-8dc3-474f20b9b58e" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "3282584c-38a9-486c-a3bf-1589dc17083f" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "8da86cef-d78f-4979-db6b-5fe8c9428c19" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "75770" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "438c793e-bbc8-4c46-808c-a6bb69ef68e6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "79afa421e386076e62ab42dad555ab0c bert_base_chinese.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "5db37007-d7eb-4431-b182-8304425d8f71" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_zh/model.h5\n", - "409304312/409304312 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=\"zh\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "ba0b05d4-f1f2-41d0-ab93-3fa5fe7ea2e5" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "0f3528b5-262f-4736-e377-bd450ee586de" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "ae9jwK4xqZQf" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_large_cased_en.ipynb b/tools/checkpoint_conversion/bert_large_cased_en.ipynb deleted file mode 100644 index 0c164a5b5b..0000000000 --- a/tools/checkpoint_conversion/bert_large_cased_en.ipynb +++ /dev/null @@ -1,1280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "fcd8f38b-e213-456f-b49f-397b21076120" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.8 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 46.2 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 58.6 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 59.1 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 64.3 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 74.3 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 62.7 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.0 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 60.3 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 11.6 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 62.9 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 68.5 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 71.4 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 69.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-large-vars tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_large\"\n", - "MODEL_SUFFIX = \"cased\"\n", - "MODEL_SPEC_STR = \"L-24_H-1024_A-16\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 28996\n", - "NUM_LAYERS = 24\n", - "NUM_ATTN_HEADS = 16\n", - "EMBEDDING_SIZE = 1024" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2018_10_18/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "534dc0a1-97bb-412a-bc20-e02af3e88afc" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip\n", - "1242178883/1242178883 [==============================] - 17s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "774aec39-bd51-46c5-ccb4-3abc99acee5b" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_large_cased\n", - " creating: cased_L-24_H-1024_A-16/\n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.meta \n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: cased_L-24_H-1024_A-16/vocab.txt \n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.index \n", - " inflating: cased_L-24_H-1024_A-16/bert_config.json \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "86b04c44-0640-46fc-8e61-1706e3778dfc" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [1024]\n", - "bert/embeddings/LayerNorm/gamma [1024]\n", - "bert/embeddings/position_embeddings [512, 1024]\n", - "bert/embeddings/token_type_embeddings [2, 1024]\n", - "bert/embeddings/word_embeddings [28996, 1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/key/bias [1024]\n", - "bert/encoder/layer_0/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/query/bias [1024]\n", - "bert/encoder/layer_0/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/value/bias [1024]\n", - "bert/encoder/layer_0/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_0/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/output/dense/bias [1024]\n", - "bert/encoder/layer_0/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/key/bias [1024]\n", - "bert/encoder/layer_1/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/query/bias [1024]\n", - "bert/encoder/layer_1/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/value/bias [1024]\n", - "bert/encoder/layer_1/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_1/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/output/dense/bias [1024]\n", - "bert/encoder/layer_1/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_10/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/key/bias [1024]\n", - "bert/encoder/layer_10/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/query/bias [1024]\n", - "bert/encoder/layer_10/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/value/bias [1024]\n", - "bert/encoder/layer_10/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_10/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_10/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_10/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/output/dense/bias [1024]\n", - "bert/encoder/layer_10/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_11/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/key/bias [1024]\n", - "bert/encoder/layer_11/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/query/bias [1024]\n", - "bert/encoder/layer_11/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/value/bias [1024]\n", - "bert/encoder/layer_11/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_11/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_11/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_11/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/output/dense/bias [1024]\n", - "bert/encoder/layer_11/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_12/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/key/bias [1024]\n", - "bert/encoder/layer_12/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/query/bias [1024]\n", - "bert/encoder/layer_12/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/value/bias [1024]\n", - "bert/encoder/layer_12/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_12/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_12/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_12/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/output/dense/bias [1024]\n", - "bert/encoder/layer_12/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_13/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/key/bias [1024]\n", - "bert/encoder/layer_13/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/query/bias [1024]\n", - "bert/encoder/layer_13/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/value/bias [1024]\n", - "bert/encoder/layer_13/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_13/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_13/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_13/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/output/dense/bias [1024]\n", - "bert/encoder/layer_13/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_14/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/key/bias [1024]\n", - "bert/encoder/layer_14/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/query/bias [1024]\n", - "bert/encoder/layer_14/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/value/bias [1024]\n", - "bert/encoder/layer_14/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_14/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_14/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_14/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/output/dense/bias [1024]\n", - "bert/encoder/layer_14/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_15/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/key/bias [1024]\n", - "bert/encoder/layer_15/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/query/bias [1024]\n", - "bert/encoder/layer_15/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/value/bias [1024]\n", - "bert/encoder/layer_15/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_15/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_15/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_15/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/output/dense/bias [1024]\n", - "bert/encoder/layer_15/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_16/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/key/bias [1024]\n", - "bert/encoder/layer_16/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/query/bias [1024]\n", - "bert/encoder/layer_16/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/value/bias [1024]\n", - "bert/encoder/layer_16/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_16/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_16/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_16/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/output/dense/bias [1024]\n", - "bert/encoder/layer_16/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_17/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/key/bias [1024]\n", - "bert/encoder/layer_17/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/query/bias [1024]\n", - "bert/encoder/layer_17/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/value/bias [1024]\n", - "bert/encoder/layer_17/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_17/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_17/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_17/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/output/dense/bias [1024]\n", - "bert/encoder/layer_17/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_18/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/key/bias [1024]\n", - "bert/encoder/layer_18/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/query/bias [1024]\n", - "bert/encoder/layer_18/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/value/bias [1024]\n", - "bert/encoder/layer_18/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_18/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_18/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_18/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/output/dense/bias [1024]\n", - "bert/encoder/layer_18/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_19/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/key/bias [1024]\n", - "bert/encoder/layer_19/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/query/bias [1024]\n", - "bert/encoder/layer_19/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/value/bias [1024]\n", - "bert/encoder/layer_19/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_19/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_19/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_19/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/output/dense/bias [1024]\n", - "bert/encoder/layer_19/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/key/bias [1024]\n", - "bert/encoder/layer_2/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/query/bias [1024]\n", - "bert/encoder/layer_2/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/value/bias [1024]\n", - "bert/encoder/layer_2/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_2/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/output/dense/bias [1024]\n", - "bert/encoder/layer_2/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_20/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/key/bias [1024]\n", - "bert/encoder/layer_20/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/query/bias [1024]\n", - "bert/encoder/layer_20/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/value/bias [1024]\n", - "bert/encoder/layer_20/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_20/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_20/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_20/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/output/dense/bias [1024]\n", - "bert/encoder/layer_20/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_21/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/key/bias [1024]\n", - "bert/encoder/layer_21/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/query/bias [1024]\n", - "bert/encoder/layer_21/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/value/bias [1024]\n", - "bert/encoder/layer_21/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_21/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_21/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_21/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/output/dense/bias [1024]\n", - "bert/encoder/layer_21/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_22/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/key/bias [1024]\n", - "bert/encoder/layer_22/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/query/bias [1024]\n", - "bert/encoder/layer_22/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/value/bias [1024]\n", - "bert/encoder/layer_22/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_22/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_22/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_22/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/output/dense/bias [1024]\n", - "bert/encoder/layer_22/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_23/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/key/bias [1024]\n", - "bert/encoder/layer_23/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/query/bias [1024]\n", - "bert/encoder/layer_23/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/value/bias [1024]\n", - "bert/encoder/layer_23/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_23/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_23/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_23/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/output/dense/bias [1024]\n", - "bert/encoder/layer_23/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/key/bias [1024]\n", - "bert/encoder/layer_3/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/query/bias [1024]\n", - "bert/encoder/layer_3/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/value/bias [1024]\n", - "bert/encoder/layer_3/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_3/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/output/dense/bias [1024]\n", - "bert/encoder/layer_3/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/key/bias [1024]\n", - "bert/encoder/layer_4/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/query/bias [1024]\n", - "bert/encoder/layer_4/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/value/bias [1024]\n", - "bert/encoder/layer_4/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_4/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/output/dense/bias [1024]\n", - "bert/encoder/layer_4/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/key/bias [1024]\n", - "bert/encoder/layer_5/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/query/bias [1024]\n", - "bert/encoder/layer_5/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/value/bias [1024]\n", - "bert/encoder/layer_5/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_5/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/output/dense/bias [1024]\n", - "bert/encoder/layer_5/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/key/bias [1024]\n", - "bert/encoder/layer_6/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/query/bias [1024]\n", - "bert/encoder/layer_6/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/value/bias [1024]\n", - "bert/encoder/layer_6/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_6/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/output/dense/bias [1024]\n", - "bert/encoder/layer_6/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/key/bias [1024]\n", - "bert/encoder/layer_7/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/query/bias [1024]\n", - "bert/encoder/layer_7/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/value/bias [1024]\n", - "bert/encoder/layer_7/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_7/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/output/dense/bias [1024]\n", - "bert/encoder/layer_7/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_8/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/key/bias [1024]\n", - "bert/encoder/layer_8/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/query/bias [1024]\n", - "bert/encoder/layer_8/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/value/bias [1024]\n", - "bert/encoder/layer_8/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_8/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_8/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_8/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/output/dense/bias [1024]\n", - "bert/encoder/layer_8/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_9/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/key/bias [1024]\n", - "bert/encoder/layer_9/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/query/bias [1024]\n", - "bert/encoder/layer_9/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/value/bias [1024]\n", - "bert/encoder/layer_9/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_9/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_9/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_9/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/output/dense/bias [1024]\n", - "bert/encoder/layer_9/output/dense/kernel [4096, 1024]\n", - "bert/pooler/dense/bias [1024]\n", - "bert/pooler/dense/kernel [1024, 1024]\n", - "cls/predictions/output_bias [28996]\n", - "cls/predictions/transform/LayerNorm/beta [1024]\n", - "cls/predictions/transform/LayerNorm/gamma [1024]\n", - "cls/predictions/transform/dense/bias [1024]\n", - "cls/predictions/transform/dense/kernel [1024, 1024]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 1024]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertLarge model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "07380744-32ea-4200-effe-ef697bcc8dca", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 1024) 29691904 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 1024) 524288 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 1024) 2048 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 1024) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 1024) 2048 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 1024) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 1024) 12596224 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 1024) 12596224 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 1024) 12596224 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 1024) 12596224 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 1024) 12596224 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 1024) 12596224 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 1024) 12596224 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 1024) 12596224 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 1024) 12596224 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 1024) 12596224 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 1024) 12596224 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 1024) 12596224 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_12 (Transfor (None, None, 1024) 12596224 ['transformer_layer_11[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_13 (Transfor (None, None, 1024) 12596224 ['transformer_layer_12[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_14 (Transfor (None, None, 1024) 12596224 ['transformer_layer_13[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_15 (Transfor (None, None, 1024) 12596224 ['transformer_layer_14[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_16 (Transfor (None, None, 1024) 12596224 ['transformer_layer_15[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_17 (Transfor (None, None, 1024) 12596224 ['transformer_layer_16[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_18 (Transfor (None, None, 1024) 12596224 ['transformer_layer_17[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_19 (Transfor (None, None, 1024) 12596224 ['transformer_layer_18[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_20 (Transfor (None, None, 1024) 12596224 ['transformer_layer_19[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_21 (Transfor (None, None, 1024) 12596224 ['transformer_layer_20[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_22 (Transfor (None, None, 1024) 12596224 ['transformer_layer_21[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_23 (Transfor (None, None, 1024) 12596224 ['transformer_layer_22[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 1024) 0 ['transformer_layer_23[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 1024) 1049600 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 333,579,264\n", - "Trainable params: 333,579,264\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Large from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded to TensorFlow Hub. See run_classifier_with_tfhub.py for an example of how to use the TF Hub module, or run an example in the browser on Colab.\"\n", - "\n", - "### TF Hub statement:\n", - "\"The weights of this model are those released by the original BERT authors.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/bert_en_cased_{MODEL_SPEC_STR}/4\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"The quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "orig_pooled_output, orig_sequence_output = embedding_model(\n", - " tf.constant([\"The quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], orig_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "598bad82-1f9a-4869-e9e4-9ac38157cd89", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a35bbd8a-7f87-4c6a-fd63-7ba227e4cdf0", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - orig_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - orig_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertLarge checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "144ded0b-512a-4f3a-9552-17ec709dd11c", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "eb9adff4-7aad-44a3-f6aa-49f83199e00f", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "210441" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "d16ef506-b1d0-450e-cae8-240c53896ede", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "8b8ab82290bbf4f8db87d4f100648890 bert_large_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertLarge(weights=\"cased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": 21, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_large_uncased_en.ipynb b/tools/checkpoint_conversion/bert_large_uncased_en.ipynb deleted file mode 100644 index a80069e199..0000000000 --- a/tools/checkpoint_conversion/bert_large_uncased_en.ipynb +++ /dev/null @@ -1,1280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "c5e29c2b-f3ba-46eb-c604-74545a36204f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.6 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 64.7 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 63.0 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 48.9 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 72.6 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 61.3 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 63.9 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 60.8 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 69.3 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 1.7 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 56.2 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 70.4 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 71.4 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 10.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-large-vars tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_large\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-24_H-1024_A-16\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 24\n", - "NUM_ATTN_HEADS = 16\n", - "EMBEDDING_SIZE = 1024" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "41332d42-8e39-408f-fcc0-803cef5ccfb7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip\n", - "1247797031/1247797031 [==============================] - 23s 0us/step\n" - ] - } - ], - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2018_10_18/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "3496afcb-a342-449e-b5b4-975238cec430" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_large_uncased\n", - " creating: uncased_L-24_H-1024_A-16/\n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.meta \n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-24_H-1024_A-16/vocab.txt \n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.index \n", - " inflating: uncased_L-24_H-1024_A-16/bert_config.json \n" - ] - } - ], - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "OGij7IQU4rJL" - }, - "outputs": [], - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "ba1605fe-4503-49dc-ebc5-d89148527b5d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [1024]\n", - "bert/embeddings/LayerNorm/gamma [1024]\n", - "bert/embeddings/position_embeddings [512, 1024]\n", - "bert/embeddings/token_type_embeddings [2, 1024]\n", - "bert/embeddings/word_embeddings [30522, 1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/key/bias [1024]\n", - "bert/encoder/layer_0/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/query/bias [1024]\n", - "bert/encoder/layer_0/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/value/bias [1024]\n", - "bert/encoder/layer_0/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_0/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/output/dense/bias [1024]\n", - "bert/encoder/layer_0/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/key/bias [1024]\n", - "bert/encoder/layer_1/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/query/bias [1024]\n", - "bert/encoder/layer_1/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/value/bias [1024]\n", - "bert/encoder/layer_1/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_1/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/output/dense/bias [1024]\n", - "bert/encoder/layer_1/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_10/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/key/bias [1024]\n", - "bert/encoder/layer_10/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/query/bias [1024]\n", - "bert/encoder/layer_10/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/value/bias [1024]\n", - "bert/encoder/layer_10/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_10/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_10/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_10/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/output/dense/bias [1024]\n", - "bert/encoder/layer_10/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_11/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/key/bias [1024]\n", - "bert/encoder/layer_11/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/query/bias [1024]\n", - "bert/encoder/layer_11/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/value/bias [1024]\n", - "bert/encoder/layer_11/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_11/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_11/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_11/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/output/dense/bias [1024]\n", - "bert/encoder/layer_11/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_12/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/key/bias [1024]\n", - "bert/encoder/layer_12/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/query/bias [1024]\n", - "bert/encoder/layer_12/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/value/bias [1024]\n", - "bert/encoder/layer_12/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_12/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_12/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_12/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/output/dense/bias [1024]\n", - "bert/encoder/layer_12/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_13/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/key/bias [1024]\n", - "bert/encoder/layer_13/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/query/bias [1024]\n", - "bert/encoder/layer_13/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/value/bias [1024]\n", - "bert/encoder/layer_13/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_13/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_13/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_13/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/output/dense/bias [1024]\n", - "bert/encoder/layer_13/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_14/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/key/bias [1024]\n", - "bert/encoder/layer_14/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/query/bias [1024]\n", - "bert/encoder/layer_14/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/value/bias [1024]\n", - "bert/encoder/layer_14/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_14/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_14/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_14/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/output/dense/bias [1024]\n", - "bert/encoder/layer_14/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_15/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/key/bias [1024]\n", - "bert/encoder/layer_15/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/query/bias [1024]\n", - "bert/encoder/layer_15/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/value/bias [1024]\n", - "bert/encoder/layer_15/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_15/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_15/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_15/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/output/dense/bias [1024]\n", - "bert/encoder/layer_15/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_16/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/key/bias [1024]\n", - "bert/encoder/layer_16/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/query/bias [1024]\n", - "bert/encoder/layer_16/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/value/bias [1024]\n", - "bert/encoder/layer_16/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_16/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_16/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_16/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/output/dense/bias [1024]\n", - "bert/encoder/layer_16/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_17/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/key/bias [1024]\n", - "bert/encoder/layer_17/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/query/bias [1024]\n", - "bert/encoder/layer_17/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/value/bias [1024]\n", - "bert/encoder/layer_17/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_17/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_17/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_17/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/output/dense/bias [1024]\n", - "bert/encoder/layer_17/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_18/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/key/bias [1024]\n", - "bert/encoder/layer_18/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/query/bias [1024]\n", - "bert/encoder/layer_18/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/value/bias [1024]\n", - "bert/encoder/layer_18/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_18/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_18/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_18/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/output/dense/bias [1024]\n", - "bert/encoder/layer_18/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_19/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/key/bias [1024]\n", - "bert/encoder/layer_19/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/query/bias [1024]\n", - "bert/encoder/layer_19/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/value/bias [1024]\n", - "bert/encoder/layer_19/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_19/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_19/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_19/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/output/dense/bias [1024]\n", - "bert/encoder/layer_19/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/key/bias [1024]\n", - "bert/encoder/layer_2/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/query/bias [1024]\n", - "bert/encoder/layer_2/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/value/bias [1024]\n", - "bert/encoder/layer_2/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_2/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/output/dense/bias [1024]\n", - "bert/encoder/layer_2/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_20/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/key/bias [1024]\n", - "bert/encoder/layer_20/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/query/bias [1024]\n", - "bert/encoder/layer_20/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/value/bias [1024]\n", - "bert/encoder/layer_20/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_20/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_20/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_20/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/output/dense/bias [1024]\n", - "bert/encoder/layer_20/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_21/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/key/bias [1024]\n", - "bert/encoder/layer_21/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/query/bias [1024]\n", - "bert/encoder/layer_21/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/value/bias [1024]\n", - "bert/encoder/layer_21/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_21/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_21/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_21/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/output/dense/bias [1024]\n", - "bert/encoder/layer_21/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_22/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/key/bias [1024]\n", - "bert/encoder/layer_22/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/query/bias [1024]\n", - "bert/encoder/layer_22/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/value/bias [1024]\n", - "bert/encoder/layer_22/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_22/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_22/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_22/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/output/dense/bias [1024]\n", - "bert/encoder/layer_22/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_23/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/key/bias [1024]\n", - "bert/encoder/layer_23/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/query/bias [1024]\n", - "bert/encoder/layer_23/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/value/bias [1024]\n", - "bert/encoder/layer_23/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_23/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_23/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_23/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/output/dense/bias [1024]\n", - "bert/encoder/layer_23/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/key/bias [1024]\n", - "bert/encoder/layer_3/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/query/bias [1024]\n", - "bert/encoder/layer_3/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/value/bias [1024]\n", - "bert/encoder/layer_3/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_3/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/output/dense/bias [1024]\n", - "bert/encoder/layer_3/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/key/bias [1024]\n", - "bert/encoder/layer_4/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/query/bias [1024]\n", - "bert/encoder/layer_4/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/value/bias [1024]\n", - "bert/encoder/layer_4/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_4/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/output/dense/bias [1024]\n", - "bert/encoder/layer_4/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/key/bias [1024]\n", - "bert/encoder/layer_5/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/query/bias [1024]\n", - "bert/encoder/layer_5/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/value/bias [1024]\n", - "bert/encoder/layer_5/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_5/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/output/dense/bias [1024]\n", - "bert/encoder/layer_5/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/key/bias [1024]\n", - "bert/encoder/layer_6/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/query/bias [1024]\n", - "bert/encoder/layer_6/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/value/bias [1024]\n", - "bert/encoder/layer_6/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_6/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/output/dense/bias [1024]\n", - "bert/encoder/layer_6/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/key/bias [1024]\n", - "bert/encoder/layer_7/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/query/bias [1024]\n", - "bert/encoder/layer_7/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/value/bias [1024]\n", - "bert/encoder/layer_7/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_7/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/output/dense/bias [1024]\n", - "bert/encoder/layer_7/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_8/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/key/bias [1024]\n", - "bert/encoder/layer_8/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/query/bias [1024]\n", - "bert/encoder/layer_8/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/value/bias [1024]\n", - "bert/encoder/layer_8/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_8/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_8/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_8/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/output/dense/bias [1024]\n", - "bert/encoder/layer_8/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_9/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/key/bias [1024]\n", - "bert/encoder/layer_9/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/query/bias [1024]\n", - "bert/encoder/layer_9/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/value/bias [1024]\n", - "bert/encoder/layer_9/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_9/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_9/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_9/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/output/dense/bias [1024]\n", - "bert/encoder/layer_9/output/dense/kernel [4096, 1024]\n", - "bert/pooler/dense/bias [1024]\n", - "bert/pooler/dense/kernel [1024, 1024]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [1024]\n", - "cls/predictions/transform/LayerNorm/gamma [1024]\n", - "cls/predictions/transform/dense/bias [1024]\n", - "cls/predictions/transform/dense/kernel [1024, 1024]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 1024]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertLarge model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "g1kp1M9b6hdU", - "outputId": "8eb94045-3b29-400e-92e7-329cbc1250d7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 1024) 31254528 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 1024) 524288 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 1024) 2048 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 1024) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 1024) 2048 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 1024) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 1024) 12596224 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 1024) 12596224 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 1024) 12596224 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 1024) 12596224 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 1024) 12596224 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 1024) 12596224 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 1024) 12596224 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 1024) 12596224 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 1024) 12596224 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 1024) 12596224 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 1024) 12596224 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 1024) 12596224 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_12 (Transfor (None, None, 1024) 12596224 ['transformer_layer_11[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_13 (Transfor (None, None, 1024) 12596224 ['transformer_layer_12[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_14 (Transfor (None, None, 1024) 12596224 ['transformer_layer_13[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_15 (Transfor (None, None, 1024) 12596224 ['transformer_layer_14[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_16 (Transfor (None, None, 1024) 12596224 ['transformer_layer_15[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_17 (Transfor (None, None, 1024) 12596224 ['transformer_layer_16[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_18 (Transfor (None, None, 1024) 12596224 ['transformer_layer_17[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_19 (Transfor (None, None, 1024) 12596224 ['transformer_layer_18[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_20 (Transfor (None, None, 1024) 12596224 ['transformer_layer_19[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_21 (Transfor (None, None, 1024) 12596224 ['transformer_layer_20[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_22 (Transfor (None, None, 1024) 12596224 ['transformer_layer_21[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_23 (Transfor (None, None, 1024) 12596224 ['transformer_layer_22[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 1024) 0 ['transformer_layer_23[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 1024) 1049600 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 335,141,888\n", - "Trainable params: 335,141,888\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ByCEoIyn-_Ld" - }, - "source": [ - "## Load Bert Large from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded to TensorFlow Hub. See run_classifier_with_tfhub.py for an example of how to use the TF Hub module, or run an example in the browser on Colab.\"\n", - "\n", - "### TF Hub statement:\n", - "\"The weights of this model are those released by the original BERT authors.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "outputs": [], - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/bert_en_uncased_{MODEL_SPEC_STR}/4\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "w4OhasjS9Ozn" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "orig_pooled_output, orig_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HzUii8Tp9qth", - "outputId": "5c7585f3-831e-4da2-abdf-07a968f86856" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], orig_pooled_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "II0akvof9qth", - "outputId": "0f4d085a-7835-41da-80de-4282fe953f6b" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - orig_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - orig_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertLarge checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OD0B0UxN-QdB", - "outputId": "89ffaadc-32ee-4a16-bfc0-6f61f163bc73" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q0K9JAY5-QdD", - "outputId": "1a4a04e2-ffda-4f68-d396-1490faa1c5e2" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-jVECpzp-QdD", - "outputId": "9b05e1b1-adcb-4073-950b-1ab0ab579a92" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cc5cacc9565ef400ee4376105f40ddae bert_large_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "b3061cdc-a831-44ae-a6a5-a999b1bf2544" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_uncased/model.h5\n", - "1341009960/1341009960 [==============================] - 46s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertLarge(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ycvzjqdZYjNo", - "outputId": "42366183-9d87-4409-a853-de745c38babb" - }, - "execution_count": 20, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "7ff6fcf4-64fb-4f58-c8b3-1bbc05376e85" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KcwejBTMXsIc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb b/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb deleted file mode 100644 index 01e5169a63..0000000000 --- a/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb +++ /dev/null @@ -1,971 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "86338ead-40b4-4149-f30f-8f2b28d3426d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.1 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 52.2 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 49.3 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 55.9 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 68.5 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 54.1 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 69.5 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 54.6 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 8.9 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 66.4 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.1 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 56.3 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 70.2 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 58.8 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_medium\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-8_H-512_A-8\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 8\n", - "NUM_ATTN_HEADS = 8\n", - "EMBEDDING_SIZE = 512" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "9acfd0df-e230-4698-9178-08c7e1854fa4" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-8_H-512_A-8.zip\n", - "154608092/154608092 [==============================] - 1s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "4d6a2bda-7482-42a1-aff5-8a28a6796085" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_medium_uncased\n", - " inflating: uncased_L-8_H-512_A-8/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-8_H-512_A-8/bert_config.json \n", - " inflating: uncased_L-8_H-512_A-8/vocab.txt \n", - " inflating: uncased_L-8_H-512_A-8/bert_model.ckpt.index \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "c4cc8f57-f252-49fe-a9b6-de96c2162813" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [512]\n", - "bert/embeddings/LayerNorm/gamma [512]\n", - "bert/embeddings/position_embeddings [512, 512]\n", - "bert/embeddings/token_type_embeddings [2, 512]\n", - "bert/embeddings/word_embeddings [30522, 512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/attention/output/dense/bias [512]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/key/bias [512]\n", - "bert/encoder/layer_0/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/query/bias [512]\n", - "bert/encoder/layer_0/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/value/bias [512]\n", - "bert/encoder/layer_0/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_0/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/output/dense/bias [512]\n", - "bert/encoder/layer_0/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/attention/output/dense/bias [512]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/key/bias [512]\n", - "bert/encoder/layer_1/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/query/bias [512]\n", - "bert/encoder/layer_1/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/value/bias [512]\n", - "bert/encoder/layer_1/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_1/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/output/dense/bias [512]\n", - "bert/encoder/layer_1/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/attention/output/dense/bias [512]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/key/bias [512]\n", - "bert/encoder/layer_2/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/query/bias [512]\n", - "bert/encoder/layer_2/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/value/bias [512]\n", - "bert/encoder/layer_2/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_2/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/output/dense/bias [512]\n", - "bert/encoder/layer_2/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/attention/output/dense/bias [512]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/key/bias [512]\n", - "bert/encoder/layer_3/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/query/bias [512]\n", - "bert/encoder/layer_3/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/value/bias [512]\n", - "bert/encoder/layer_3/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_3/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/output/dense/bias [512]\n", - "bert/encoder/layer_3/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_4/attention/output/dense/bias [512]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/key/bias [512]\n", - "bert/encoder/layer_4/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/query/bias [512]\n", - "bert/encoder/layer_4/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/value/bias [512]\n", - "bert/encoder/layer_4/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_4/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_4/output/dense/bias [512]\n", - "bert/encoder/layer_4/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_5/attention/output/dense/bias [512]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/key/bias [512]\n", - "bert/encoder/layer_5/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/query/bias [512]\n", - "bert/encoder/layer_5/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/value/bias [512]\n", - "bert/encoder/layer_5/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_5/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_5/output/dense/bias [512]\n", - "bert/encoder/layer_5/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_6/attention/output/dense/bias [512]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/key/bias [512]\n", - "bert/encoder/layer_6/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/query/bias [512]\n", - "bert/encoder/layer_6/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/value/bias [512]\n", - "bert/encoder/layer_6/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_6/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_6/output/dense/bias [512]\n", - "bert/encoder/layer_6/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_7/attention/output/dense/bias [512]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/key/bias [512]\n", - "bert/encoder/layer_7/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/query/bias [512]\n", - "bert/encoder/layer_7/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/value/bias [512]\n", - "bert/encoder/layer_7/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_7/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_7/output/dense/bias [512]\n", - "bert/encoder/layer_7/output/dense/kernel [2048, 512]\n", - "bert/pooler/dense/bias [512]\n", - "bert/pooler/dense/kernel [512, 512]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [512]\n", - "cls/predictions/transform/LayerNorm/gamma [512]\n", - "cls/predictions/transform/dense/bias [512]\n", - "cls/predictions/transform/dense/kernel [512, 512]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 512]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertMedium model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "c4c615a2-2997-4682-a8fa-97da40289e23", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 512) 15627264 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 512) 262144 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 512) 1024 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 512) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 512) 1024 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 512) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 512) 3152384 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 512) 3152384 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 512) 3152384 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 512) 3152384 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 512) 3152384 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 512) 3152384 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 512) 3152384 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 512) 3152384 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 512) 0 ['transformer_layer_7[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 512) 262656 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 41,373,184\n", - "Trainable params: 41,373,184\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertMedium(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Medium from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded to TensorFlow Hub. See run_classifier_with_tfhub.py for an example of how to use the TF Hub module, or run an example in the browser on Colab.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "06d8485d-84f0-47ea-eab9-7d229da581ce", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1fbc2ea7-577f-491d-97d1-2becf294eb94", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertMedium checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertMedium(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0c3ae2e3-f23e-46a3-a7bc-2b9e29e72a21", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "68d4f6b4-dd02-423c-ca45-c18ede23a3dc", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0e9970fb-cb39-4255-c1bb-65d891a8b845", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bb990e1184ec6b6185450c73833cd661 bert_medium_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertMedium(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_small_uncased_en.ipynb b/tools/checkpoint_conversion/bert_small_uncased_en.ipynb deleted file mode 100644 index ded288b2d3..0000000000 --- a/tools/checkpoint_conversion/bert_small_uncased_en.ipynb +++ /dev/null @@ -1,895 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "33a180e3-462f-4f6a-b641-0104f18f96de" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.8 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 48.2 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 48.8 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 52.6 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 64.8 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 56.7 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 68.5 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 38.1 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 74.8 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 10.1 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.2 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 50.7 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 56.0 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 69.5 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_small\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-4_H-512_A-8\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 4\n", - "NUM_ATTN_HEADS = 8\n", - "EMBEDDING_SIZE = 512" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "37545bbe-70bc-4824-f96e-3eda6b99e709" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-4_H-512_A-8.zip\n", - "107814641/107814641 [==============================] - 1s 0us/step\n" - ] - } - ], - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "8ca6d690-8f73-45e7-a0e9-16b5901a4297" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Archive: bert_small_uncased\n", - " inflating: uncased_L-4_H-512_A-8/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-4_H-512_A-8/bert_config.json \n", - " inflating: uncased_L-4_H-512_A-8/vocab.txt \n", - " inflating: uncased_L-4_H-512_A-8/bert_model.ckpt.index \n" - ] - } - ], - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "OGij7IQU4rJL" - }, - "outputs": [], - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "405b2156-e025-4848-8aee-2589c4354311" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bert/embeddings/LayerNorm/beta [512]\n", - "bert/embeddings/LayerNorm/gamma [512]\n", - "bert/embeddings/position_embeddings [512, 512]\n", - "bert/embeddings/token_type_embeddings [2, 512]\n", - "bert/embeddings/word_embeddings [30522, 512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/attention/output/dense/bias [512]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/key/bias [512]\n", - "bert/encoder/layer_0/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/query/bias [512]\n", - "bert/encoder/layer_0/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/value/bias [512]\n", - "bert/encoder/layer_0/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_0/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/output/dense/bias [512]\n", - "bert/encoder/layer_0/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/attention/output/dense/bias [512]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/key/bias [512]\n", - "bert/encoder/layer_1/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/query/bias [512]\n", - "bert/encoder/layer_1/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/value/bias [512]\n", - "bert/encoder/layer_1/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_1/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/output/dense/bias [512]\n", - "bert/encoder/layer_1/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/attention/output/dense/bias [512]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/key/bias [512]\n", - "bert/encoder/layer_2/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/query/bias [512]\n", - "bert/encoder/layer_2/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/value/bias [512]\n", - "bert/encoder/layer_2/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_2/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/output/dense/bias [512]\n", - "bert/encoder/layer_2/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/attention/output/dense/bias [512]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/key/bias [512]\n", - "bert/encoder/layer_3/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/query/bias [512]\n", - "bert/encoder/layer_3/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/value/bias [512]\n", - "bert/encoder/layer_3/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_3/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/output/dense/bias [512]\n", - "bert/encoder/layer_3/output/dense/kernel [2048, 512]\n", - "bert/pooler/dense/bias [512]\n", - "bert/pooler/dense/kernel [512, 512]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [512]\n", - "cls/predictions/transform/LayerNorm/gamma [512]\n", - "cls/predictions/transform/dense/bias [512]\n", - "cls/predictions/transform/dense/kernel [512, 512]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 512]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertSmall model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "g1kp1M9b6hdU", - "outputId": "24c6056a-ef5a-426c-b172-f5aeec699fad" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 512) 15627264 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 512) 262144 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 512) 1024 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 512) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 512) 1024 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 512) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 512) 3152384 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 512) 3152384 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 512) 3152384 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 512) 3152384 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 512) 0 ['transformer_layer_3[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 512) 262656 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 28,763,648\n", - "Trainable params: 28,763,648\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertSmall(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ByCEoIyn-_Ld" - }, - "source": [ - "## Load Bert Small from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded to TensorFlow Hub. See run_classifier_with_tfhub.py for an example of how to use the TF Hub module, or run an example in the browser on Colab.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "outputs": [], - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "iAubWsWj9qtg" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HzUii8Tp9qth", - "outputId": "97bd5f3c-8440-4f86-f87d-2940370463e7" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(, )" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "II0akvof9qth", - "outputId": "e7b49897-abb3-45da-9ad8-45a62b22a37d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertSmall checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertSmall(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OD0B0UxN-QdB", - "outputId": "6d4feea6-bc13-4fc2-a9b0-33e929ff9d0c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q0K9JAY5-QdD", - "outputId": "697099b1-de1e-4393-a272-12144d29d155" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "228209" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-jVECpzp-QdD", - "outputId": "7e75dcb5-b985-4b44-801c-e8d63d34f035" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "08632c9479b034f342ba2c2b7afba5f7 bert_small_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertSmall(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb b/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb deleted file mode 100644 index d1fb1119cd..0000000000 --- a/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb +++ /dev/null @@ -1,858 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "5107d4a7-7205-448d-8989-9d81d01b7195" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |████████████████████████████████| 511.7 MB 6.7 kB/s \n", - "\u001b[K |████████████████████████████████| 2.1 MB 47.0 MB/s \n", - "\u001b[K |████████████████████████████████| 4.6 MB 49.1 MB/s \n", - "\u001b[K |████████████████████████████████| 5.8 MB 58.4 MB/s \n", - "\u001b[K |████████████████████████████████| 438 kB 69.4 MB/s \n", - "\u001b[K |████████████████████████████████| 1.6 MB 58.6 MB/s \n", - "\u001b[K |████████████████████████████████| 99 kB 11.7 MB/s \n", - "\u001b[K |████████████████████████████████| 636 kB 69.5 MB/s \n", - "\u001b[K |████████████████████████████████| 1.3 MB 56.7 MB/s \n", - "\u001b[K |████████████████████████████████| 352 kB 74.5 MB/s \n", - "\u001b[K |████████████████████████████████| 43 kB 2.3 MB/s \n", - "\u001b[K |████████████████████████████████| 116 kB 68.3 MB/s \n", - "\u001b[K |████████████████████████████████| 1.1 MB 61.3 MB/s \n", - "\u001b[K |████████████████████████████████| 238 kB 67.1 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_tiny\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-2_H-128_A-2\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 2\n", - "NUM_ATTN_HEADS = 2\n", - "EMBEDDING_SIZE = 128" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "8e952e27-282d-440e-ba72-b526fe586b75" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip\n", - "16529104/16529104 [==============================] - 0s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "15e3842f-e312-4e79-829c-7fc136e29f60" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_tiny_uncased\n", - " inflating: uncased_L-2_H-128_A-2/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-2_H-128_A-2/bert_config.json \n", - " inflating: uncased_L-2_H-128_A-2/vocab.txt \n", - " inflating: uncased_L-2_H-128_A-2/bert_model.ckpt.index \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "01e19274-afa7-45ae-f4f8-2c46538ed97c" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [128]\n", - "bert/embeddings/LayerNorm/gamma [128]\n", - "bert/embeddings/position_embeddings [512, 128]\n", - "bert/embeddings/token_type_embeddings [2, 128]\n", - "bert/embeddings/word_embeddings [30522, 128]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_0/attention/output/dense/bias [128]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/key/bias [128]\n", - "bert/encoder/layer_0/attention/self/key/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/query/bias [128]\n", - "bert/encoder/layer_0/attention/self/query/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/value/bias [128]\n", - "bert/encoder/layer_0/attention/self/value/kernel [128, 128]\n", - "bert/encoder/layer_0/intermediate/dense/bias [512]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [128, 512]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_0/output/dense/bias [128]\n", - "bert/encoder/layer_0/output/dense/kernel [512, 128]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_1/attention/output/dense/bias [128]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/key/bias [128]\n", - "bert/encoder/layer_1/attention/self/key/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/query/bias [128]\n", - "bert/encoder/layer_1/attention/self/query/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/value/bias [128]\n", - "bert/encoder/layer_1/attention/self/value/kernel [128, 128]\n", - "bert/encoder/layer_1/intermediate/dense/bias [512]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [128, 512]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_1/output/dense/bias [128]\n", - "bert/encoder/layer_1/output/dense/kernel [512, 128]\n", - "bert/pooler/dense/bias [128]\n", - "bert/pooler/dense/kernel [128, 128]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [128]\n", - "cls/predictions/transform/LayerNorm/gamma [128]\n", - "cls/predictions/transform/dense/bias [128]\n", - "cls/predictions/transform/dense/kernel [128, 128]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 128]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertTiny model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "89fd51e3-8fa8-4045-de21-ec90a4d515dd", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 128) 3906816 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 128) 65536 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 128) 256 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 128) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 128) 256 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 128) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 128) 198272 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 128) 198272 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 128) 0 ['transformer_layer_1[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 128) 16512 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 4,385,920\n", - "Trainable params: 4,385,920\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertTiny(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Tiny from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded to TensorFlow Hub. See run_classifier_with_tfhub.py for an example of how to use the TF Hub module, or run an example in the browser on Colab.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "aea2b522-a267-4f9e-ffcc-e5160d4ad04d", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a669a92c-cb29-4673-e1cb-5f8ad7ab2c23", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertTiny checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertTiny(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a83056b5-9673-4e88-b81b-a92209c2305f", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0dc1fcf7-abf6-472b-acb3-5cafc8cf85cf", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "07cae6fa-240b-473b-9b10-3734b9da0593", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "c2b29fcbf8f814a0812e4ab89ef5c068 bert_tiny_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertTiny(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file From 7e1362f2d9abc665c2bac7a20432b03a5a214475 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Thu, 7 Mar 2024 16:18:37 -0800 Subject: [PATCH 29/70] Add `FalconTokenizer` (#1485) * Add FalconTokenizer. * Update checkpoint conversion script. * Address reviews. --- keras_nlp/models/__init__.py | 2 + keras_nlp/models/falcon/falcon_presets.py | 30 ++ keras_nlp/models/falcon/falcon_tokenizer.py | 117 +++++++ .../models/falcon/falcon_tokenizer_test.py | 62 ++++ .../convert_falcon_checkpoints.py | 303 +++++++++++------- 5 files changed, 397 insertions(+), 117 deletions(-) create mode 100644 keras_nlp/models/falcon/falcon_presets.py create mode 100644 keras_nlp/models/falcon/falcon_tokenizer.py create mode 100644 keras_nlp/models/falcon/falcon_tokenizer_test.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 692b51c4da..1abfc0dc84 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -80,6 +80,8 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/falcon/falcon_presets.py b/keras_nlp/models/falcon/falcon_presets.py new file mode 100644 index 0000000000..b0bb6aa54e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_presets.py @@ -0,0 +1,30 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon model preset configurations.""" + +backbone_presets = { + "falcon_refinedweb_1b_en": { + "metadata": { + "description": ( + "24-layer Falcon model (Falcon with 1B parameters), trained on " + "350B tokens of RefinedWeb dataset." + ), + "params": 1311625216, + "official_name": "Falcon", + "path": "falcon", + "model_card": "https://huggingface.co/tiiuae/falcon-rw-1b", + }, + "kaggle_handle": "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1", + }, +} diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py new file mode 100644 index 0000000000..3201d27a63 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -0,0 +1,117 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.falcon.falcon_presets import backbone_presets +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.FalconTokenizer") +class FalconTokenizer(BytePairTokenizer): + """Falcon tokenizer based on BytePairTokenizer. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by Falcon + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a Falcon preset. + + This tokenizer does not provide truncation or padding of inputs. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.FalconTokenizer.from_preset("falcon_refinedweb_1b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6} + merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + merges += ["Ġ f", "o x", "Ġf ox"] + tokenizer = keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + tokenizer("a quick fox.") + ``` + """ + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Falcon uses the same start as end token, i.e., "<|endoftext|>". + self.end_token = self.start_token = "<|endoftext|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.end_token_id = self.token_to_id(self.end_token) + self.start_token_id = self.end_token_id + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/models/falcon/falcon_tokenizer_test.py b/keras_nlp/models/falcon/falcon_tokenizer_test.py new file mode 100644 index 0000000000..735bcac4b6 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + " airplane at airport<|endoftext|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=FalconTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + FalconTokenizer(vocabulary=["a", "b", "c"], merges=[]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=FalconTokenizer, + preset="falcon_refinedweb_1b_en", + input_data=["The quick brown fox."], + expected_output=[[464, 2068, 7586, 21831, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconTokenizer.presets: + self.run_preset_test( + cls=FalconTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py index 90a06503dc..fdbdffd670 100644 --- a/tools/checkpoint_conversion/convert_falcon_checkpoints.py +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -11,51 +11,110 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Falcon weight conversion script. + +To run, install the CPU only development environment and huggingface libraries: +``` +pip install -r requirements.txt +pip install transformers huggingface-cli +``` + +Login to Huggingface: +``` +huggingface-cli login +``` + +Finally run this script to convert, validate and upload weights. +``` +python tools/checkpoint_conversion/convert_falcon_checkpoints.py \ + --preset falcon_refinedweb_1b_en +``` +""" + +import json import os -import tempfile -import keras -import numpy as np -import tensorflow as tf -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" -from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +import absl # noqa: E402 +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +import transformers # noqa: E402 -keras.config.disable_traceback_filtering() +import keras_nlp # noqa: E402 +PRESET_MAP = { + "falcon_refinedweb_1b_en": "tiiuae/falcon-rw-1b", +} -def convert_checkpoints(hf_model): +EXTRACT_DIR = "./model" + +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string( + "preset", + "falcon_refinedweb_1b_en", + f'Must be one of {",".join(PRESET_MAP.keys())}.', +) + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onnx/*"], + local_dir=EXTRACT_DIR, + ) + + return hf_model_dir + + +def convert_model(hf_model): hf_config = hf_model.config.to_dict() - cfg = {} - cfg["vocabulary_size"] = hf_config["vocab_size"] - cfg["num_layers"] = hf_config["num_hidden_layers"] - cfg["num_attention_heads"] = hf_config["num_attention_heads"] - cfg["hidden_dim"] = hf_config["hidden_size"] - cfg["intermediate_dim"] = 4 * cfg["hidden_dim"] - cfg["feedforward_dropout_rate"] = hf_config["hidden_dropout"] - cfg["attention_dropout_rate"] = hf_config["attention_dropout"] + kwargs = {} + kwargs["vocabulary_size"] = hf_config["vocab_size"] + kwargs["num_layers"] = hf_config["num_hidden_layers"] + kwargs["num_attention_heads"] = hf_config["num_attention_heads"] + kwargs["hidden_dim"] = hf_config["hidden_size"] + kwargs["intermediate_dim"] = 4 * kwargs["hidden_dim"] + kwargs["feedforward_dropout_rate"] = hf_config["hidden_dropout"] + kwargs["attention_dropout_rate"] = hf_config["attention_dropout"] - keras_model = FalconBackbone(**cfg) + return keras_nlp.models.FalconBackbone(**kwargs) + +def convert_tokenizer(hf_model_dir): + tokenizer_file_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_file_path) as tokenizer_file: + hf_tokenizer = json.load(tokenizer_file) + + vocab = hf_tokenizer["model"]["vocab"] + merges = hf_tokenizer["model"]["merges"] + return keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + + +def convert_weights(keras_model, hf_model): + hf_model.eval() hf_wts = hf_model.state_dict() - # transformer.word_embeddings.weight + # token_embedding. keras_model.get_layer("token_embedding").embeddings.assign( - hf_wts["transformer.word_embeddings.weight"] + hf_wts["word_embeddings.weight"] ) - for i in range(keras_model.num_layers): - # split key query value + for ilayer in range(keras_model.num_layers): + # Split key query value. fused_qkv = ( - hf_wts[f"transformer.h.{i}.self_attention.query_key_value.weight"] + hf_wts[f"h.{ilayer}.self_attention.query_key_value.weight"] .numpy() .T ) seq_length, _ = fused_qkv.shape - head_dim = cfg["hidden_dim"] // cfg["num_attention_heads"] + head_dim = keras_model.hidden_dim // keras_model.num_attention_heads fused_qkv = fused_qkv.reshape( - seq_length, cfg["num_attention_heads"], 3, head_dim + seq_length, keras_model.num_attention_heads, 3, head_dim ) query, key, value = ( fused_qkv[..., 0, :], @@ -64,9 +123,11 @@ def convert_checkpoints(hf_model): ) fused_bias = hf_wts[ - f"transformer.h.{i}.self_attention.query_key_value.bias" + f"h.{ilayer}.self_attention.query_key_value.bias" ].numpy() - fused_bias = fused_bias.reshape(cfg["num_attention_heads"], 3, head_dim) + fused_bias = fused_bias.reshape( + keras_model.num_attention_heads, 3, head_dim + ) query_bias, key_bias, value_bias = ( fused_bias[..., 0, :], fused_bias[..., 1, :], @@ -74,132 +135,118 @@ def convert_checkpoints(hf_model): ) # TODO: check if bias is true before assigning bias. - # transformer.h.0.self_attention.query_key_value.weight - # transformer.h.0.self_attention.query_key_value.bias + # Attention/query. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.query_dense.kernel.assign(query) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.query_dense.bias.assign(query_bias) + # Attention/key. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.key_dense.kernel.assign(key) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.key_dense.bias.assign(key_bias) + # Attention/value. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.value_dense.kernel.assign(value) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.value_dense.bias.assign(value_bias) - # transformer.h.0.self_attention.dense.weight - # transformer.h.0.self_attention.dense.bias + # Attention/dense. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.output_dense.kernel.assign( - hf_wts[f"transformer.h.{i}.self_attention.dense.weight"].T.numpy() + hf_wts[f"h.{ilayer}.self_attention.dense.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).attention_layer.output_dense.bias.assign( - hf_wts[f"transformer.h.{i}.self_attention.dense.bias"].numpy() + hf_wts[f"h.{ilayer}.self_attention.dense.bias"].numpy() ) - # transformer.h.0.mlp.dense_h_to_4h.weight - # transformer.h.0.mlp.dense_h_to_4h.bias + # MLP/dense_h_to_4h. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_h_to_4h.kernel.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"].T.numpy() + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_h_to_4h.bias.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"].numpy() + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.bias"].numpy() ) - # transformer.h.0.mlp.dense_4h_to_h.weight - # transformer.h.0.mlp.dense_4h_to_h.bias + # MLP/dense_4h_to_h. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_4h_to_h.kernel.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"].T.numpy() + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.weight"].T.numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).dense_4h_to_h.bias.assign( - hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"].numpy() + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.bias"].numpy() ) - # transformer.h.0.input_layernorm.weight - # transformer.h.0.input_layernorm.bias + # input_layernorm. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).input_layernorm.gamma.assign( - hf_wts[f"transformer.h.{i}.input_layernorm.weight"] + hf_wts[f"h.{ilayer}.input_layernorm.weight"] ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).input_layernorm.beta.assign( - hf_wts[f"transformer.h.{i}.input_layernorm.bias"] + hf_wts[f"h.{ilayer}.input_layernorm.bias"] ) - # transformer.h.0.post_attention_layernorm.weight - # transformer.h.0.post_attention_layernorm.bias + # post_attention_layernorm. keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).post_attention_layernorm.gamma.assign( - hf_wts[f"transformer.h.{i}.post_attention_layernorm.weight"].numpy() + hf_wts[f"h.{ilayer}.post_attention_layernorm.weight"].numpy() ) keras_model.get_layer( - f"transformer_layer_{i}" + f"transformer_layer_{ilayer}" ).post_attention_layernorm.beta.assign( - hf_wts[f"transformer.h.{i}.post_attention_layernorm.bias"].numpy() + hf_wts[f"h.{ilayer}.post_attention_layernorm.bias"].numpy() ) - # transformer.ln_f.weight - # transformer.ln_f.bias + # final_layernorm. keras_model.get_layer("final_layernorm").gamma.assign( - hf_wts["transformer.ln_f.weight"].numpy() + hf_wts["ln_f.weight"].numpy() ) keras_model.get_layer("final_layernorm").beta.assign( - hf_wts["transformer.ln_f.bias"].numpy() + hf_wts["ln_f.bias"].numpy() ) - # TODO: Assign lm_head weights for CausalLM. - # # lm_head.weight - # keras_model.get_layer("lm_head").kernel.assign( - # hf_wts["lm_head.weight"].T.numpy() - # ) - - # Save the model. - print("Save KerasNLP model weights.") - temp_dir = tempfile.mkdtemp() - keras_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) - - return keras_model - -def check_output(keras_model, hf_model, hf_model_name): - sample_text = ["I am so happy today!"] - hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name) - hf_tokenizer.pad_token = hf_tokenizer.eos_token - hf_sample_input = hf_tokenizer( - sample_text, padding="max_length", return_tensors="pt" - ) - sample_input = { - "token_ids": tf.constant(hf_sample_input["input_ids"].numpy()), - "padding_mask": tf.constant(hf_sample_input["attention_mask"].numpy()), +def validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, +): + input_str = ["the quick brown fox ran, galloped and jumped."] + + # KerasNLP model. + token_ids = torch.tensor(keras_tokenizer(input_str)) + padding_mask = token_ids != 3 + keras_model_input = { + "token_ids": token_ids, + "padding_mask": padding_mask, } - print("token_ids: ", sample_input["token_ids"][0, :7]) - print("padding_mask", sample_input["padding_mask"][0, :7]) + keras_model_outputs = keras_model.predict(keras_model_input) - keras_output = keras_model.predict(sample_input) + # HuggingFace model. + hf_model_input = hf_tokenizer(input_str, return_tensors="pt") activation = {} @@ -209,30 +256,52 @@ def hook(hf_model, input, output): return hook - hf_model.transformer.register_forward_hook( - get_activation("transformer.ln_f") - ) - hf_model(**hf_sample_input) - hf_output = activation["transformer.ln_f"] - print("Keras shape: ", keras_output.shape) - print("HF shape: ", hf_output.shape) - - print("KerasNLP output:", keras_output[0, 1, :5]) - print("HF output:", hf_output[0, 1, :5]) - print( - "Difference:", - np.mean( - abs(keras_output[:, :6, :] - hf_output.detach().numpy()[:, :6, :]) - ), - ) + hf_model.register_forward_hook(get_activation("ln_f")) + hf_model(**hf_model_input) + hf_model_outputs = activation["ln_f"].detach().numpy() + + # Comparing the outputs. + print("🔶 KerasNLP tokens ids:", keras_model_input["token_ids"]) + print("🔶 HF tokens ids:", hf_model_input["input_ids"]) + print("🔶 KerasNLP output:", keras_model_outputs[0, 1, :10]) + print("🔶 HF output:", hf_model_outputs[0, 1, :10]) + print("🔶 Difference:", np.mean(keras_model_outputs - hf_model_outputs)) + + +def main(_): + preset = FLAGS.preset + print(f"✅ Coverting {preset}") + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("✅ Huggingface model downloaded from hub") -def main(): - hf_model_name = "tiiuae/falcon-rw-1b" - hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name) - keras_model = convert_checkpoints(hf_model) - check_output(keras_model, hf_model, hf_model_name) + hf_model = transformers.FalconModel.from_pretrained(hf_model_dir) + # Falcon uses GPT2 tokenizer. + hf_tokenizer = transformers.GPT2TokenizerFast.from_pretrained(hf_model_dir) + print("✅ Huggingface model loaded") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("✅ Keras model loaded") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted") + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("✅ Numerics validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print("✅ Preset saved") if __name__ == "__main__": - main() + absl.app.run(main) From 184822437d21b358fa2e7193bc755ca2f7211ea7 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Thu, 7 Mar 2024 17:21:08 -0800 Subject: [PATCH 30/70] Add Falcon Preprocessor. (#1498) --- .../falcon/falcon_causal_lm_preprocessor.py | 178 ++++++++++++++++ .../falcon_causal_lm_preprocessor_test.py | 94 +++++++++ .../models/falcon/falcon_preprocessor.py | 195 ++++++++++++++++++ .../models/falcon/falcon_preprocessor_test.py | 80 +++++++ 4 files changed, 547 insertions(+) create mode 100644 keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py create mode 100644 keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py create mode 100644 keras_nlp/models/falcon/falcon_preprocessor.py create mode 100644 keras_nlp/models/falcon/falcon_preprocessor_test.py diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py new file mode 100644 index 0000000000..61afb9b5a7 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py @@ -0,0 +1,178 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.FalconCausalLMPreprocessor") +class FalconCausalLMPreprocessor(FalconPreprocessor): + """Falcon Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.FalconCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.FalconCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.FalconTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset( + "falcon_refinedweb_1b_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`FalconCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..5e812259e2 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.falcon.falcon_causal_lm_preprocessor import ( + FalconCausalLMPreprocessor, +) +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=FalconCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = FalconCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [6, 1, 3, 4, 2, 5, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconCausalLMPreprocessor.presets: + self.run_preset_test( + cls=FalconCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_preprocessor.py b/keras_nlp/models/falcon/falcon_preprocessor.py new file mode 100644 index 0000000000..b37d641467 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_preprocessor.py @@ -0,0 +1,195 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.falcon.falcon_presets import backbone_presets +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.FalconPreprocessor") +class FalconPreprocessor(Preprocessor): + """Falcon preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.FalconBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `FalconPreprocessor` forces the input to have only one segment, as Falcon is + mainly used for generation tasks. For tasks having multi-segment inputs + like "glue/mnli", please use a model designed for classification purposes + such as BERT or RoBERTa. + + Args: + tokenizer: A `keras_nlp.models.FalconTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.FalconPreprocessor.from_preset("falcon_rw_1b") + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + features = ["a quick fox.", "a fox quick."] + vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6} + merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] + merges += ["Ġ f", "o x", "Ġf ox"] + tokenizer = keras_nlp.models.FalconTokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.FalconPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.FalconPreprocessor.from_preset("falcon_rw_1b") + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=2048, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Falcon requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Falcon " + "for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + + @classproperty + def tokenizer_cls(cls): + return FalconTokenizer diff --git a/keras_nlp/models/falcon/falcon_preprocessor_test.py b/keras_nlp/models/falcon/falcon_preprocessor_test.py new file mode 100644 index 0000000000..7676062287 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_preprocessor_test.py @@ -0,0 +1,80 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=FalconPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = FalconPreprocessor( + tokenizer=FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "airplane at airport" + preprocessor = FalconPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 6]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconPreprocessor.presets: + self.run_preset_test( + cls=FalconPreprocessor, + preset=preset, + input_data=self.input_data, + ) From 49c243b520225ef6b249d527ba47c693605cde51 Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Fri, 8 Mar 2024 04:44:12 +0200 Subject: [PATCH 31/70] Rename 176B presets (#1496) --- tools/checkpoint_conversion/convert_bloom_checkpoints.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py index d8a36b3912..a9e833d1b0 100644 --- a/tools/checkpoint_conversion/convert_bloom_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -34,7 +34,7 @@ "bloom_1.7b_multi": "bigscience/bloom-1b7", "bloom_3b_multi": "bigscience/bloom-3b", "bloom_7b_multi": "bigscience/bloom-7b1", - "bloom_multi": "bigscience/bloom", + "bloom_176b_multi": "bigscience/bloom", # Multitask finetuned on xP3 (Crosslingual Public Pool of Prompts) https://huggingface.co/datasets/bigscience/xP3 # xP3 is a mixture of 13 training tasks in 46 languages with English prompts "bloomz_560m_multi": "bigscience/bloomz-560m", @@ -42,17 +42,17 @@ "bloomz_1.7b_multi": "bigscience/bloomz-1b7", "bloomz_3b_multi": "bigscience/bloomz-3b", "bloomz_7b_multi": "bigscience/bloomz-7b1", - "bloomz_multi": "bigscience/bloomz", + "bloomz_176b_multi": "bigscience/bloomz", # Multitask finetuned on xP3mt # (Crosslingual Public Pool of Prompts machine-translated) https://huggingface.co/datasets/bigscience/xP3 # xP3mt is Mixture of 13 training tasks in 46 languages with prompts in 20 # languages (machine-translated from English) "bloomz_7b_mt": "bigscience/bloomz-7b1-mt", - "bloomz_mt": "bigscience/bloomz-mt", + "bloomz_176b_mt": "bigscience/bloomz-mt", # Multitask finetuned on P3 (Public Pool of Prompts) https://huggingface.co/datasets/Muennighoff/P3 # xP3 is a mixture of 8 training tasks with English-only prompts "bloomz_7b_p3": "bigscience/bloomz-7b1-p3", - "bloomz_p3": "bigscience/bloomz-p3", + "bloomz_176b_p3": "bigscience/bloomz-p3", } EXTRACT_DIR = "./model" From 865034da7cf3f5b9a0fdde7d457e1def982beb2f Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:16:19 +0200 Subject: [PATCH 32/70] Add bloom presets (#1501) --- keras_nlp/models/bloom/bloom_presets.py | 99 ++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 4 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_presets.py b/keras_nlp/models/bloom/bloom_presets.py index 7d24c04aa5..134de5173d 100644 --- a/keras_nlp/models/bloom/bloom_presets.py +++ b/keras_nlp/models/bloom/bloom_presets.py @@ -17,14 +17,105 @@ "bloom_560m_multi": { "metadata": { "description": ( - "24-layer Bloom model. trained on 45 natural languages and " - "12 programming languages." + "24-layer Bloom model with hidden dimension of 1024. " + "trained on 45 natural languages and 12 programming languages." ), - "params": 816115712, + "params": 559214592, "official_name": "BLOOM", "path": "bloom", - "model_card": "https://huggingface.co/bigscience/bloom", + "model_card": "https://huggingface.co/bigscience/bloom-560m", }, "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/3", }, + "bloom_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1065314304, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/1", + }, + "bloom_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1722408960, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/1", + }, + "bloom_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 3002557440, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/1", + }, + "bloomz_560m_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1024. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 559214592, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-560m", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/1", + }, + "bloomz_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1065314304, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/1", + }, + "bloomz_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1722408960, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/1", + }, + "bloomz_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 3002557440, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/1", + }, } From 786aa946d41754be4f16a6c739f86c03b3f8d016 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 11 Mar 2024 11:26:31 -0700 Subject: [PATCH 33/70] Create workflow for auto assignment of issues and for stale issues (#1495) * Create auto-assignment.js * Create auto-assignment.yml * Create stale-issue-pr.yml * Minor changes to auto_labler --- .github/workflows/auto-assignment.yml | 21 ++++++++ .github/workflows/scripts/auto-assignment.js | 43 +++++++++++++++++ .github/workflows/scripts/labeler.js | 10 ++-- .github/workflows/stale-issue-pr.yml | 50 ++++++++++++++++++++ 4 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/auto-assignment.yml create mode 100644 .github/workflows/scripts/auto-assignment.js create mode 100644 .github/workflows/stale-issue-pr.yml diff --git a/.github/workflows/auto-assignment.yml b/.github/workflows/auto-assignment.yml new file mode 100644 index 0000000000..de72da8ba2 --- /dev/null +++ b/.github/workflows/auto-assignment.yml @@ -0,0 +1,21 @@ +name: auto-assignment +on: + issues: + types: + - opened + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/auto-assignment.js') + script({github, context}) diff --git a/.github/workflows/scripts/auto-assignment.js b/.github/workflows/scripts/auto-assignment.js new file mode 100644 index 0000000000..176b305f39 --- /dev/null +++ b/.github/workflows/scripts/auto-assignment.js @@ -0,0 +1,43 @@ +/** Automatically assign issues and PRs to users in the `assigneesList` + * on a rotating basis. + + @param {!object} + GitHub objects can call GitHub APIs using their built-in library functions. + The context object contains issue and PR details. +*/ + +module.exports = async ({ github, context }) => { + let issueNumber; + let assigneesList; + // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number. + if (context.payload.issue) { + //assignee List for issues. + assigneesList = ["SuryanarayanaY", "sachinprasadhs"]; + issueNumber = context.payload.issue.number; + } else { + //assignee List for PRs. + assigneesList = [mattdangerw]; + issueNumber = context.payload.number; + } + console.log("assignee list", assigneesList); + console.log("entered auto assignment for this issue: ", issueNumber); + if (!assigneesList.length) { + console.log("No assignees found for this repo."); + return; + } + let noOfAssignees = assigneesList.length; + let selection = issueNumber % noOfAssignees; + let assigneeForIssue = assigneesList[selection]; + + console.log( + "issue Number = ", + issueNumber + " , assigning to: ", + assigneeForIssue + ); + return github.rest.issues.addAssignees({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + assignees: [assigneeForIssue], + }); +}; diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js index 7240113cc3..aa4178645e 100644 --- a/.github/workflows/scripts/labeler.js +++ b/.github/workflows/scripts/labeler.js @@ -23,16 +23,20 @@ You may obtain a copy of the License at module.exports = async ({ github, context }) => { const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title - const issue_discription = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body + let issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number const keyword_label = { gemma:'Gemma' } const labelsToAdd = [] - console.log(issue_title,issue_discription,issue_number) + console.log(issue_title,issue_description,issue_number) + if (issue_description==null) + { + issue_description = '' + } for(const [keyword, label] of Object.entries(keyword_label)){ - if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_discription.toLowerCase().indexOf(keyword) !=-1 ){ + if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){ console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) labelsToAdd.push(label) } diff --git a/.github/workflows/stale-issue-pr.yml b/.github/workflows/stale-issue-pr.yml new file mode 100644 index 0000000000..034fb4c266 --- /dev/null +++ b/.github/workflows/stale-issue-pr.yml @@ -0,0 +1,50 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Awaiting response issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 14 + days-before-issue-close: 14 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: completed + only-labels: "stat:awaiting response from contributor" + stale-issue-message: > + This issue is stale because it has been open for 14 days with no activity. + It will be closed if no further activity occurs. Thank you. + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:awaiting response from contributor" + close-issue-message: > + This issue was closed because it has been inactive for 28 days. + Please reopen if you'd like to work on this further. + days-before-pr-stale: 14 + days-before-pr-close: 14 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you." + close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further." + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Contribution issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 180 + days-before-issue-close: 365 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: not_planned + any-of-labels: "stat:contributions welcome,good first issue" + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:contributions welcome,good first issue" + stale-issue-message: > + This issue is stale because it has been open for 180 days with no activity. + It will be closed if no further activity occurs. Thank you. + close-issue-message: > + This issue was closed because it has been inactive for more than 1 year. + repo-token: ${{ secrets.GITHUB_TOKEN }} From 8698f846758dd7fda9905c60a4ba4f2fa13b0b8c Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:33:05 -0500 Subject: [PATCH 34/70] Update requirements to TF 2.16 GA (#1503) --- requirements-jax-cuda.txt | 4 ++-- requirements-tensorflow-cuda.txt | 4 ++-- requirements-torch-cuda.txt | 4 ++-- requirements.txt | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 10d07dffce..2d53a76c87 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 7cc2e705e6..14f1441924 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. -tensorflow[and-cuda]==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 1bbe6a2e76..89362bb846 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index e7cc934b17..f1e0b31956 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu From 29a87cbfd1bb770a566ca3c88ff322c1677ddd3e Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:26:09 +0000 Subject: [PATCH 35/70] Expose Task and Backbone (#1506) These are already exposed on KerasCV, and I think it is time to also expose these in KerasNLP. This will give us a class to document common model functionality to all backbones such as `enable_lora` and `token_embedding` on keras.io. It can also open up a path for writing a custom architecture outside the library itself. --- keras_nlp/models/__init__.py | 2 ++ keras_nlp/models/backbone.py | 3 ++- keras_nlp/models/task.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 1abfc0dc84..033a9dc874 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -20,6 +20,7 @@ ) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.models.backbone import Backbone from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM @@ -130,6 +131,7 @@ from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.models.task import Task from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 867616da69..bfdc8207ad 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class @@ -20,7 +21,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Backbone") class Backbone(keras.Model): def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 0656d2194e..9957f6546f 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -16,6 +16,7 @@ from rich import markup from rich import table as rich_table +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.keras_utils import print_msg @@ -26,7 +27,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Task") class Task(PipelineModel): """Base class for Task models.""" From 7e3dfc8fabe3082ba31593d086951e0ae5ac5a84 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:29:18 +0000 Subject: [PATCH 36/70] Clean up and add our gemma conversion script (#1493) * Clean up and add our gemma conversion script From flax -> keras. Useful to have as reference. * Fix comments * Convert to bfloat16 weights * Review comment --- .../convert_gemma_checkpoints.py | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_gemma_checkpoints.py diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py new file mode 100644 index 0000000000..ed81e023d4 --- /dev/null +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -0,0 +1,224 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert Gemma flax checkpoints to the Keras format. + +Setup: +pip install -r requirements.txt +pip install git+https://github.com/google-deepmind/gemma.git +python pip_build.py --install + +Usage: +cd tools/checkpoint_conversion +python convert_gemma_checkpoints.py --preset gemma_2b_en +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" +# No GPU for conversion, makes memory management easier. +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import kagglehub # noqa: E402 +import keras # noqa: E402 +import numpy as np # noqa: E402 +import sentencepiece # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 +from gemma import params as params_lib # noqa: E402 +from gemma import sampler as sampler_lib # noqa: E402 +from gemma import transformer as transformer_lib # noqa: E402 + +import keras_nlp # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "gemma_2b_en": "google/gemma/flax/2b", + "gemma_7b_en": "google/gemma/flax/7b", + "gemma_instruct_2b_en": "google/gemma/flax/2b-it", + "gemma_instruct_7b_en": "google/gemma/flax/7b-it", +} + + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) + + +def download_flax_model(handle): + return kagglehub.model_download(handle) + + +def convert_model(flax_config, vocab_size): + return keras_nlp.models.GemmaBackbone( + vocabulary_size=vocab_size, + num_layers=flax_config.num_layers, + num_query_heads=flax_config.num_heads, + num_key_value_heads=flax_config.num_kv_heads, + hidden_dim=flax_config.embed_dim, + intermediate_dim=flax_config.hidden_dim * 2, + head_dim=flax_config.head_dim, + ) + + +def convert_tokenizer(proto_path): + return keras_nlp.models.GemmaTokenizer(proto=proto_path) + + +def convert_weights(keras_model, flax_config, flax_params): + # Chomp the embedding weights. Upstream pads for TPU efficiency, but this + # leads to weird gotchas (you need to disregard part of your output logits). + embeddings = flax_params["transformer"]["embedder"]["input_embedding"] + embeddings = np.asarray(embeddings[: keras_model.vocabulary_size, :]) + keras_model.get_layer("token_embedding").set_weights([embeddings]) + keras_model.get_layer("final_normalization").set_weights( + [np.asarray(flax_params["transformer"]["final_norm"]["scale"])] + ) + for i in range(flax_config.num_layers): + flax_layer_name = f"layer_{i}" + keras_block = keras_model.get_layer(f"decoder_block_{i}") + + flax_block = flax_params["transformer"][flax_layer_name] + keras_block.pre_attention_norm.set_weights( + [flax_block["pre_attention_norm"]["scale"]] + ) + keras_block.pre_ffw_norm.set_weights( + [flax_block["pre_ffw_norm"]["scale"]] + ) + + keras_block.gating_ffw.set_weights( + [flax_block["mlp"]["gating_einsum"][0]] + ) + keras_block.gating_ffw_2.set_weights( + [flax_block["mlp"]["gating_einsum"][1]] + ) + keras_block.ffw_linear.set_weights([flax_block["mlp"]["linear"]]) + + attn_block = flax_block["attn"] + if flax_config.num_heads != flax_config.num_kv_heads: + # MQA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["q_einsum"]["w"][:, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][1, :, :, :]) + ) + else: + # MHA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][1, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][2, :, :, :]) + ) + keras_block.attention.output_dense.kernel.assign( + flax_block["attn"]["attn_vec_einsum"]["w"] + ) + + +def validate_output( + keras_model, + keras_tokenizer, + flax_params, + flax_tokenizer, +): + input_str = "What is Keras?" + length = 32 + + # KerasNLP + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor(keras_tokenizer) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=keras_model, + preprocessor=preprocessor, + ) + keras_output = gemma_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + + # Flax + transformer_config = transformer_lib.TransformerConfig.from_params( + flax_params, + cache_size=length, + ) + transformer = transformer_lib.Transformer(transformer_config) + sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=flax_tokenizer, + params=flax_params["transformer"], + ) + flax_output = sampler( + input_strings=[input_str], + total_generation_steps=length - 5, # Length of "What is Keras?" + ) + flax_output = input_str + flax_output.text[0] + + # Comparing the outputs. + print("🔶 KerasNLP output:", keras_output) + print("🔶 Flax output:", flax_output) + + +def main(_): + preset = FLAGS.preset + + assert ( + preset in PRESET_MAP.keys() + ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + + print(f"🏃 Coverting {preset}") + + # Currently all flax weights are bfloat16 (and have much faster download + # times for it). We follow suit with Keras weights. + keras.config.set_floatx("bfloat16") + + handle = PRESET_MAP[preset] + flax_dir = download_flax_model(handle) + proto_path = flax_dir + "/tokenizer.model" + print("✅ Flax model downloaded from kaggle") + + variant = handle.split("/")[-1] + flax_tokenier = sentencepiece.SentencePieceProcessor() + flax_tokenier.Load(proto_path) + flax_params = params_lib.load_and_format_params(flax_dir + "/" + variant) + flax_config = transformer_lib.TransformerConfig.from_params(flax_params) + print("✅ Flax model loaded") + + keras_tokenizer = convert_tokenizer(proto_path) + vocab_size = keras_tokenizer.vocabulary_size() + keras_model = convert_model(flax_config, vocab_size) + print("✅ Keras model loaded") + + convert_weights(keras_model, flax_config, flax_params) + print("✅ Weights converted") + + validate_output(keras_model, keras_tokenizer, flax_params, flax_tokenier) + print("✅ Output validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print(f"🏁 Preset saved to ./{preset}") + + +if __name__ == "__main__": + app.run(main) From 8c941139b350e50065713fb176a97f3ef941ea5b Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:32:19 -0500 Subject: [PATCH 37/70] Don't auto-update JAX GPU (#1507) * Don't auto-update JAX GPU * Ignore jax GPU updates --- .github/dependabot.yml | 3 +++ requirements-jax-cuda.txt | 2 +- requirements-tensorflow-cuda.txt | 2 +- requirements-torch-cuda.txt | 2 +- requirements.txt | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0df37b1230..eb7a6ac0c5 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -21,3 +21,6 @@ updates: python: patterns: - "*" + ignore: + # ignore all updates for JAX GPU due to cuda version issue + - dependency-name: "jax[cuda12_pip]" diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 2d53a76c87..2ded131217 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 14f1441924..5426beb5a3 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 89362bb846..43dc4c5ef5 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index f1e0b31956..8578a4199b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu From 81dd7b5e98a58109dc7d47cc70420abcb5db1714 Mon Sep 17 00:00:00 2001 From: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:20:35 -0400 Subject: [PATCH 38/70] Keep rope at float32 precision (#1497) * Keep rope at float32 precision * Carry out all of RoPE in float32 * Formatting * Cleanup * Do not cast x --- keras_nlp/models/gemma/gemma_attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index 80c2ac6a63..e01c1f8ce4 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -94,13 +94,14 @@ def _apply_rope(self, x, positions): # TODO: refactor to use RotaryEmbedding layer? max_wavelength = 10000 x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.cast( - ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype + freq_exponents = (2.0 / x_shape[-1]) * ops.arange( + x_shape[-1] // 2, dtype="float32" ) timescale = max_wavelength**freq_exponents radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] - sin, cos = ops.sin(radians), ops.cos(radians) + sin = ops.cast(ops.sin(radians), self.compute_dtype) + cos = ops.cast(ops.cos(radians), self.compute_dtype) x1, x2 = ops.split(x, 2, axis=-1) # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA # compilation on jax. We should be able to remove this once the @@ -156,10 +157,9 @@ def call( ): seq_len = ops.shape(x)[1] start_index = cache_update_index - positions = ops.cast( - ops.arange(seq_len, dtype="float32"), self.compute_dtype - ) - positions = positions + ops.cast(start_index, self.compute_dtype) + positions = ops.arange(seq_len, dtype="float32") + + positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) query = self._apply_rope(query, positions) From 0b0305a2ec4e71fee9d20d74129d268992cdefe9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:11:56 -0700 Subject: [PATCH 39/70] Bump the python group with 2 updates (#1509) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.1.2 to 2.2.1+cu121 Updates `torchvision` from 0.16.2 to 0.17.1+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-torch-cuda.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 43dc4c5ef5..050dd85b1c 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -4,8 +4,8 @@ tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 -torch==2.1.2 -torchvision==0.16.2 +torch==2.2.1+cu121 +torchvision==0.17.1+cu121 # Jax cpu-only version. jax[cpu] From f29aff8230c1d4f7780b0e1ffda8bb3ebda16948 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 13 Mar 2024 14:06:01 -0700 Subject: [PATCH 40/70] Fixes for the LLaMA backbone + add dropout (#1499) * Firxes for the LLaMA backbone + add dropout * Address review comments CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer * Remove self._hidden_dim and self._head_dim --- keras_nlp/models/llama/llama_attention.py | 120 +++++++++--------- keras_nlp/models/llama/llama_backbone.py | 108 ++++++++++------ keras_nlp/models/llama/llama_backbone_test.py | 1 - keras_nlp/models/llama/llama_decoder.py | 75 +++++++---- 4 files changed, 182 insertions(+), 122 deletions(-) diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py index 529e73b009..33ffcef209 100644 --- a/keras_nlp/models/llama/llama_attention.py +++ b/keras_nlp/models/llama/llama_attention.py @@ -18,34 +18,33 @@ class LlamaAttention(keras.layers.Layer): - """Grouped query attention for Llama models""" + """A cached grounded query attention layer with sliding window.""" def __init__( self, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads + self.dropout = dropout self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength - self.kernel_initializer = keras.initializers.get(kernel_initializer) - self.max_sequence_length = max_sequence_length + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) self.rope_scaling_factor = rope_scaling_factor - self.rope_max_wavelength = rope_max_wavelength def build(self, inputs_shape): - self.hidden_dim = inputs_shape[-1] - self.attn_head_size = self.hidden_dim // self.num_query_heads - # Einsum variables: # b = batch size # q = query length @@ -54,18 +53,27 @@ def build(self, inputs_shape): # u = num query heads # v = num key/value heads # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self.num_query_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) + self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="key", ) @@ -73,8 +81,12 @@ def build(self, inputs_shape): self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="value", ) @@ -86,21 +98,28 @@ def build(self, inputs_shape): name="attention_softmax", ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + self._output_dense = keras.layers.EinsumDense( - equation="bqm,mh->bqh", - output_shape=(None, self.hidden_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build(inputs_shape) + self._output_dense.build((None, None, self.num_query_heads, head_dim)) - self._rotary_embedding_layer = RotaryEmbedding( + self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) - self._rotary_embedding_layer.build(inputs_shape) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" self.built = True @@ -110,6 +129,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + training=None, ): query = self._query_dense(hidden_states) @@ -136,75 +156,61 @@ def call( key = self._key_dense(hidden_states) value = self._value_dense(hidden_states) - query = self._rotary_embedding_layer(query) - key = self._rotary_embedding_layer(key) + query = self.rotary_embedding_layer(query) + key = self.rotary_embedding_layer(key) - key = ops.tile(key, [1, 1, self.num_key_value_groups, 1]) - value = ops.tile(value, [1, 1, self.num_key_value_groups, 1]) + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - attention_output, attention_scores = self._compute_attention( + attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output_shape = ops.shape(attention_output) - - attention_output = ops.reshape( - attention_output, - [ - attention_output_shape[0], - attention_output_shape[1], - self.hidden_dim, - ], + attention_output = self._dropout_layer( + attention_output, training=training ) attention_output = self._output_dense(attention_output) if cache is not None: - return (attention_output, cache) + return attention_output, cache return attention_output def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - mask_expansion_axis = -3 - for _ in range( - len(attention_scores.shape) - len(attention_mask.shape) - ): - attention_mask = ops.expand_dims( - attention_mask, axis=mask_expansion_axis - ) - return self._softmax(attention_scores, attention_mask) + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum("aecd,abcd->acbe", key, query) - - norm_factor = ops.sqrt( - ops.convert_to_tensor(self.attn_head_size, self.compute_dtype) - ) + attention_scores = ops.einsum(self._dot_product_equation, query, key) - attention_scores /= norm_factor + attention_scores = attention_scores / self._norm_factor attention_scores = self._masked_softmax( attention_scores, attention_mask ) attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( - "acbe,aecd->abcd", attention_scores, value + self._combine_equation, attention_scores, value ) - return attention_output, attention_scores + return attention_output def get_config(self): config = super().get_config() config.update( { "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_factor": self.rope_scaling_factor, - "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index 733d9ef434..b5383d528a 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.llama.llama_decoder import LlamaDecoder + +# from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm +# from keras_nlp.utils.python_utils import classproperty + def _llama_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) @@ -27,41 +34,64 @@ def _llama_kernel_initializer(stddev=0.02): @keras_nlp_export("keras_nlp.models.LlamaBackbone") class LlamaBackbone(Backbone): """ - LLaMA core network with hyperparameters. + The Llama Transformer core architecture with hyperparameters. This network implements a Transformer-based decoder network, - LLaMA, as described in ["LLaMA: Open Foundation and Fine-Tuned Language Models"](https://arxiv.org/abs/2302.13971). + Llama, as described in + ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. The default constructor gives a fully customizable, randomly initialized - LLaMA model with any number of layers, heads, and embedding - dimensions. This backbone also supports LLaMA2 checkpoints. + Llama model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. Args: - vocabulary_size: int. The size of the token vocabulary. - num_layers: int. The number of transformer layers. - num_query_heads: int. The number of attention heads for each transformer. - The hidden size must be divisible by the number of attention heads. - hidden_dim: int. The size of the transformer encoding and pooler layers. - intermediate_dim: int. The output dimension of the first Dense layer in - a two-layer feedforward network for each transformer. - num_key_value_heads: int. This is the number of key_value heads that - should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, - the model will use Multi Head Attention (MHA), if num_key_value_heads=1 - the model will use Multi Query Attention (MQA) - rope_scaling_factor: float. The scaling factor for calculation of rotary - embedding - rope_max_wavelength: int. The maximum angular wavelength of the - sine/cosine curves, for rotary embeddings. - layer_norm_epsilon: float. a value added to the denominator for - numerical stability. - max_sequence_length: int. The maximum sequence length that this encoder - can consume. If `None`, `max_sequence_length` uses the value from - sequence length. This determines the variable shape for positional - embeddings. + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Llama decoder. + model = keras_nlp.models.LlamaBackbone.from_preset("llama7b_base_en") + model(input_data) + + # Randomly initialized Llama decoder with custom config. + model = keras_nlp.models.LlamaBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` """ def __init__( @@ -72,10 +102,10 @@ def __init__( hidden_dim, intermediate_dim, num_key_value_heads, - rope_scaling_factor=1.0, rope_max_wavelength=10000, - layer_norm_epsilon=1e-5, - max_sequence_length=4096, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, dtype=None, **kwargs, ): @@ -83,31 +113,31 @@ def __init__( self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, - embeddings_initializer=_llama_kernel_initializer(stddev=0.01), tie_weights=False, + embeddings_initializer=_llama_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) self.transformer_layers = [] for i in range(num_layers): - layer = LlamaDecoder( + layer = LlamaTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, - rope_scaling_factor=rope_scaling_factor, - max_sequence_length=max_sequence_length, rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, kernel_initializer=_llama_kernel_initializer(stddev=0.02), + dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) self.layer_norm = LlamaLayerNorm( - dtype=dtype, epsilon=layer_norm_epsilon, - name="layer_norm", + dtype=dtype, + name="sequence_output_layernorm", ) # === Functional Model === @@ -140,8 +170,8 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout def get_config(self): config = super().get_config() @@ -155,8 +185,12 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config + + # @classproperty + # def presets(cls): + # return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index efff972c6b..56d8c44bd3 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -28,7 +28,6 @@ def setUp(self): "num_key_value_heads": 2, "hidden_dim": 8, "intermediate_dim": 8, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 3b9d6906b8..1ef247c575 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -24,20 +24,20 @@ from keras_nlp.utils.keras_utils import clone_initializer -class LlamaDecoder(keras.layers.Layer): - """Llama decoder block.""" +class LlamaTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Llama backbone.""" def __init__( self, intermediate_dim, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) @@ -48,37 +48,50 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length + self.dropout = dropout + self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.supports_masking = True + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape self.hidden_dim = decoder_sequence_shape[-1] - # Self attention layers. + # Self attention layer. self._self_attention_layer = LlamaAttention( num_query_heads=self.num_query_heads, num_key_value_heads=self.num_key_value_heads, rope_max_wavelength=self.rope_max_wavelength, - max_sequence_length=self.max_sequence_length, rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, dtype=self.dtype_policy, + name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="self_attention_layernorm", ) self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) # Feedforward layers. self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -86,23 +99,30 @@ def build(self, decoder_sequence_shape): self.intermediate_dim, activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_gate_dense", ) self._feedforward_gate_dense.build(decoder_sequence_shape) self._feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_output_dense", ) - intermediate_shape = list(decoder_sequence_shape) - intermediate_shape[-1] = self.intermediate_dim - self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) self._feedforward_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="feedforward_layernorm", ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -115,6 +135,7 @@ def call( decoder_attention_mask=None, self_attention_cache=None, self_attention_cache_update_index=None, + training=None, ): self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, @@ -125,10 +146,9 @@ def call( ) residual = decoder_sequence - x = self._self_attention_layernorm( - decoder_sequence, - ) + x = self._self_attention_layernorm(decoder_sequence) + # Self attention block. x = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, @@ -139,6 +159,8 @@ def call( if self_attention_cache is not None: x, self_attention_cache = x + x = self._self_attention_dropout(x, training=training) + x = x + residual residual = x @@ -152,7 +174,7 @@ def call( decoder_output = x + residual if self_attention_cache is not None: - return (decoder_output, self_attention_cache) + return decoder_output, self_attention_cache return decoder_output def _compute_self_attention_mask( @@ -160,8 +182,8 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, - self_attention_cache=None, - self_attention_cache_update_index=None, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -174,16 +196,16 @@ def _compute_self_attention_mask( if self_attention_cache is not None: input_length = ops.shape(self_attention_cache)[2] + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + causal_mask = compute_causal_mask( - batch_size, - input_length, - output_length, - ( - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index - ), + batch_size, input_length, output_length, cache_update_index ) + return ( ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None @@ -198,17 +220,16 @@ def get_config(self): config.update( { "intermediate_dim": self.intermediate_dim, - "hidden_dim": self.hidden_dim, "num_query_heads": self.num_query_heads, "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), + "dropout": self.dropout, } ) return config From 34d2099935eea6ecddb432288fcb130e5c8bae19 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 13 Mar 2024 15:54:46 -0700 Subject: [PATCH 41/70] Add `LlamaPreprocessor` and `LlamaCausalLMPreprocessor` (#1511) * Add a preprocessor for the Llama backbone * Add causal lm preprocessor for the Llama backbone --- .../llama/llama_causal_lm_preprocessor.py | 185 +++++++++++++++++ .../llama_causal_lm_preprocessor_test.py | 90 +++++++++ keras_nlp/models/llama/llama_preprocessor.py | 191 ++++++++++++++++++ .../models/llama/llama_preprocessor_test.py | 57 ++++++ 4 files changed, 523 insertions(+) create mode 100644 keras_nlp/models/llama/llama_causal_lm_preprocessor.py create mode 100644 keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py create mode 100644 keras_nlp/models/llama/llama_preprocessor.py create mode 100644 keras_nlp/models/llama/llama_preprocessor_test.py diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a221185582 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") +class LlamaCausalLMPreprocessor(LlamaPreprocessor): + """Llama Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`LlamaCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..aa4d155c8c --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = LlamaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=LlamaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py new file mode 100644 index 0000000000..580557f50d --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -0,0 +1,191 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.LlamaPreprocessor") +class LlamaPreprocessor(Preprocessor): + """A Llama preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. + with the appropriate tokens. + 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` + that can be passed directly to `keras_nlp.models.LlamaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the from_preset(). + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=False, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Llama requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Llama" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + + @classproperty + def tokenizer_cls(cls): + return LlamaTokenizer diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py new file mode 100644 index 0000000000..6807886812 --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["the quick brown fox"], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = LlamaPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) From 0ef44ff35e4cddc4a7915df09ce84e91d7172a36 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Thu, 14 Mar 2024 11:40:21 -0700 Subject: [PATCH 42/70] Always run the rotary embedding layer in float32 (#1508) * Always run the rotary embedding layer in float32 * Fix the int32 issue with TensorFlow * Only run sin/cos embedding compute step in float32 * Avoid start_index from downcasting automatically * Use stack instrad of concatenate --- keras_nlp/layers/modeling/rotary_embedding.py | 47 ++++++++++++------- keras_nlp/models/gemma/gemma_attention.py | 43 +++++++---------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 45f77ce494..b494d559bd 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -85,30 +85,42 @@ def __init__( self.built = True def call(self, inputs, start_index=0): + inputs = ops.moveaxis( + inputs, (self.feature_axis, self.sequence_axis), (-1, 1) + ) cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) - return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + return ops.moveaxis( + output, (-1, 1), (self.feature_axis, self.sequence_axis) + ) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - x1, x2 = ops.split(tensor, 2, axis=self.feature_axis) - half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) + x1, x2 = ops.split(tensor, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + half_rot_tensor = ops.stack((-x2, x1), axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, inputs, start_index=0): - def get_axis(axis): - return axis if axis > 0 else len(inputs.shape) + axis + start_index = ops.cast(start_index, dtype="float32") - feature_axis = get_axis(self.feature_axis) - sequence_axis = get_axis(self.sequence_axis) + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 rotary_dim = ops.shape(inputs)[feature_axis] inverse_freq = self._get_inverse_freq(rotary_dim) - seq_len = ops.shape(inputs)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + seq_len = ops.shape(inputs)[sequence_axis] + tensor = ops.arange(seq_len, dtype="float32") + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i,j->ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=-1) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) # Reshape the embedding to be broadcastable with input shape. if feature_axis < sequence_axis: @@ -117,17 +129,16 @@ def get_axis(axis): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) - return ops.cos(embedding), ops.sin(embedding) + cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) + sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) + return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) + freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = freq_range / ops.cast(self.scaling_factor, "float32") inverse_freq = 1.0 / ( self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ** (freq_range / ops.cast(rotary_dim, "float32")) ) return inverse_freq diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index e01c1f8ce4..4b391264a2 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding from keras_nlp.utils.keras_utils import clone_initializer @@ -87,28 +88,23 @@ def build(self, inputs_shape): (None, None, self.num_query_heads, self.head_dim) ) self.softmax = keras.layers.Softmax(dtype="float32") + + self.rope_layer = RotaryEmbedding( + max_wavelength=10_000.0, dtype=self.dtype_policy + ) + self.built = True - def _apply_rope(self, x, positions): + def _apply_rope(self, x, start_index): """Rope rotate q or k.""" - # TODO: refactor to use RotaryEmbedding layer? - max_wavelength = 10000 - x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.arange( - x_shape[-1] // 2, dtype="float32" + x = self.rope_layer(x, start_index=start_index) + # Gemma uses a different layout for positional embeddings. + # The transformation below ensures the embeddings are numerically + # equivalent to the original gemma implementation. + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) ) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - sin = ops.cast(ops.sin(radians), self.compute_dtype) - cos = ops.cast(ops.cos(radians), self.compute_dtype) - x1, x2 = ops.split(x, 2, axis=-1) - # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA - # compilation on jax. We should be able to remove this once the - # following PR is in all jax releases we care about: - # https://github.com/openxla/xla/pull/7875 - output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - return ops.reshape(output, x_shape) + return x def _compute_attention( self, @@ -155,19 +151,14 @@ def call( cache_update_index=0, training=False, ): - seq_len = ops.shape(x)[1] - start_index = cache_update_index - positions = ops.arange(seq_len, dtype="float32") - - positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) - query = self._apply_rope(query, positions) + query = self._apply_rope(query, cache_update_index) if cache is not None: key_cache = cache[:, 0, ...] value_cache = cache[:, 1, ...] key_update = self.key_dense(x) - key_update = self._apply_rope(key_update, positions) + key_update = self._apply_rope(key_update, cache_update_index) value_update = self.value_dense(x) start = [0, cache_update_index, 0, 0] key = ops.slice_update(key_cache, start, key_update) @@ -175,7 +166,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, positions) + key = self._apply_rope(key, cache_update_index) value = self.value_dense(x) attention_vec = self._compute_attention( From 1cc8df590af92baa71f7f3186346977490fcfd31 Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:36:57 -0500 Subject: [PATCH 43/70] Remove install of Python 3.9 (#1514) --- .kokoro/github/ubuntu/gpu/build.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 87cd206495..b8d47dbe9c 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -14,11 +14,8 @@ if [[ -z "${KAGGLE_USERNAME}" ]]; then fi set -x - cd "${KOKORO_ROOT}/" -sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 - PYTHON_BINARY="/usr/bin/python3.9" "${PYTHON_BINARY}" -m venv venv From db855bc3252bc6cbcdaaebfd85ac6912b77add65 Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Thu, 14 Mar 2024 16:56:33 -0700 Subject: [PATCH 44/70] Update gemma_backbone.py for sharding config. (#1491) * Update gemma_backbone.py for sharding config. * Update unit test and fix format. * Update sharding spec for gemma based on gemma training. --- keras_nlp/models/gemma/gemma_backbone.py | 31 ++++++++++++++----- keras_nlp/models/gemma/gemma_backbone_test.py | 24 +++++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index c829aa948f..06f5b0f601 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -194,7 +194,11 @@ def presets(cls): return copy.deepcopy(backbone_presets) @staticmethod - def get_layout_map(device_mesh, model_parallel_dim_name="model"): + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): """Get a `keras.distribution.LayoutMap` for model parallel distribution. The returned `LayoutMap` contains the sharding spec for the gemma @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): distribution. model_parallel_dim_name: The axis name of the device mesh, where the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec of all the model weights. @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): f"{model_parallel_dim_name} is not found in the " f"device_mesh.axis_names. {device_mesh.axis_name=}" ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name model_dim = model_parallel_dim_name - # The sharding is partition for the hidden_dim of the model. + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 layout_map = keras.distribution.LayoutMap(device_mesh) - layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( - None, model_dim, + data_dim, None, ) layout_map["decoder_block.*attention_output.*kernel"] = ( - None, - None, model_dim, + None, + data_dim, ) - layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) - layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index 855d49658b..7b02de2b7a 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -106,26 +106,34 @@ def test_distribution(self): for w in model.weights: if "token_embedding/embeddings" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) if "attention/query/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/key/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/value/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/attention_output/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, None, "model") + tuple(w.value.sharding.spec), ("model", None, "batch") ) if "ffw_gating/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_gating_2/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_linearl" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) From d1031df1cedb77ffd72db5958fd8ffd5a9af00fc Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 19 Mar 2024 17:43:10 -0700 Subject: [PATCH 45/70] Unify docstring style --- STYLE_GUIDE.md | 2 +- examples/bert_pretraining/bert_create_pretraining_data.py | 2 +- examples/tools/split_sentences.py | 2 +- examples/tools/train_word_piece_vocab.py | 2 +- keras_nlp/layers/modeling/alibi_bias.py | 2 +- keras_nlp/layers/modeling/f_net_encoder.py | 2 +- keras_nlp/layers/modeling/masked_lm_head.py | 2 +- keras_nlp/layers/modeling/position_embedding.py | 2 +- keras_nlp/layers/modeling/reversible_embedding.py | 2 +- keras_nlp/layers/modeling/rotary_embedding.py | 2 +- keras_nlp/layers/modeling/sine_position_encoding.py | 2 +- keras_nlp/layers/modeling/token_and_position_embedding.py | 2 +- keras_nlp/layers/modeling/transformer_decoder.py | 2 +- keras_nlp/layers/modeling/transformer_encoder.py | 2 +- keras_nlp/layers/preprocessing/masked_lm_mask_generator.py | 2 +- keras_nlp/layers/preprocessing/multi_segment_packer.py | 2 +- keras_nlp/layers/preprocessing/random_deletion.py | 2 +- keras_nlp/layers/preprocessing/random_swap.py | 2 +- keras_nlp/layers/preprocessing/start_end_packer.py | 2 +- keras_nlp/metrics/edit_distance.py | 2 +- keras_nlp/metrics/perplexity.py | 2 +- keras_nlp/metrics/rouge_l.py | 2 +- keras_nlp/metrics/rouge_n.py | 2 +- keras_nlp/models/albert/albert_backbone.py | 2 +- keras_nlp/models/albert/albert_classifier.py | 2 +- keras_nlp/models/albert/albert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/albert/albert_preprocessor.py | 2 +- keras_nlp/models/albert/albert_tokenizer.py | 2 +- keras_nlp/models/backbone.py | 2 +- keras_nlp/models/bart/bart_backbone.py | 2 +- keras_nlp/models/bart/bart_preprocessor.py | 2 +- keras_nlp/models/bart/bart_seq_2_seq_lm.py | 2 +- keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py | 2 +- keras_nlp/models/bart/bart_tokenizer.py | 2 +- keras_nlp/models/bert/bert_backbone.py | 2 +- keras_nlp/models/bert/bert_classifier.py | 2 +- keras_nlp/models/bert/bert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/bert/bert_preprocessor.py | 2 +- keras_nlp/models/bert/bert_tokenizer.py | 2 +- keras_nlp/models/bloom/bloom_backbone.py | 2 +- keras_nlp/models/bloom/bloom_causal_lm.py | 2 +- keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py | 2 +- keras_nlp/models/bloom/bloom_preprocessor.py | 2 +- keras_nlp/models/bloom/bloom_tokenizer.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_classifier.py | 2 +- .../models/deberta_v3/deberta_v3_masked_lm_preprocessor.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py | 2 +- keras_nlp/models/distil_bert/distil_bert_backbone.py | 2 +- keras_nlp/models/distil_bert/distil_bert_classifier.py | 2 +- .../distil_bert/distil_bert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/distil_bert/distil_bert_preprocessor.py | 2 +- keras_nlp/models/distil_bert/distil_bert_tokenizer.py | 2 +- keras_nlp/models/electra/electra_backbone.py | 2 +- keras_nlp/models/electra/electra_tokenizer.py | 2 +- keras_nlp/models/f_net/f_net_backbone.py | 2 +- keras_nlp/models/f_net/f_net_classifier.py | 2 +- keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py | 2 +- keras_nlp/models/f_net/f_net_preprocessor.py | 2 +- keras_nlp/models/f_net/f_net_tokenizer.py | 2 +- keras_nlp/models/falcon/falcon_backbone.py | 2 +- keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py | 2 +- keras_nlp/models/falcon/falcon_preprocessor.py | 2 +- keras_nlp/models/falcon/falcon_tokenizer.py | 2 +- keras_nlp/models/gemma/gemma_causal_lm.py | 4 ++-- keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py | 2 +- keras_nlp/models/gemma/gemma_preprocessor.py | 2 +- keras_nlp/models/gemma/gemma_tokenizer.py | 2 +- keras_nlp/models/gpt2/gpt2_causal_lm.py | 2 +- keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py | 2 +- keras_nlp/models/gpt2/gpt2_preprocessor.py | 2 +- keras_nlp/models/gpt2/gpt2_tokenizer.py | 2 +- keras_nlp/models/llama/llama_backbone.py | 2 +- keras_nlp/models/llama/llama_causal_lm_preprocessor.py | 2 +- keras_nlp/models/llama/llama_preprocessor.py | 2 +- keras_nlp/models/llama/llama_tokenizer.py | 2 +- keras_nlp/models/mistral/mistral_backbone.py | 2 +- keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py | 2 +- keras_nlp/models/mistral/mistral_preprocessor.py | 2 +- keras_nlp/models/mistral/mistral_tokenizer.py | 2 +- keras_nlp/models/opt/opt_backbone.py | 2 +- keras_nlp/models/opt/opt_causal_lm.py | 2 +- keras_nlp/models/opt/opt_causal_lm_preprocessor.py | 2 +- keras_nlp/models/opt/opt_preprocessor.py | 2 +- keras_nlp/models/opt/opt_tokenizer.py | 2 +- keras_nlp/models/preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_backbone.py | 2 +- keras_nlp/models/roberta/roberta_classifier.py | 2 +- keras_nlp/models/roberta/roberta_masked_lm.py | 2 +- keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_tokenizer.py | 2 +- keras_nlp/models/t5/t5_tokenizer.py | 2 +- keras_nlp/models/task.py | 2 +- keras_nlp/models/whisper/whisper_audio_feature_extractor.py | 4 ++-- keras_nlp/models/whisper/whisper_backbone.py | 2 +- keras_nlp/models/whisper/whisper_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py | 2 +- .../xlm_roberta/xlm_roberta_masked_lm_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py | 2 +- keras_nlp/models/xlnet/xlnet_backbone.py | 2 +- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/contrastive_sampler.py | 2 +- keras_nlp/samplers/greedy_sampler.py | 2 +- keras_nlp/samplers/random_sampler.py | 2 +- keras_nlp/samplers/sampler.py | 2 +- keras_nlp/samplers/top_k_sampler.py | 2 +- keras_nlp/samplers/top_p_sampler.py | 2 +- keras_nlp/tokenizers/byte_pair_tokenizer.py | 6 +++--- keras_nlp/tokenizers/byte_tokenizer.py | 2 +- keras_nlp/tokenizers/sentence_piece_tokenizer.py | 4 ++-- keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py | 2 +- keras_nlp/tokenizers/tokenizer.py | 2 +- keras_nlp/tokenizers/unicode_codepoint_tokenizer.py | 2 +- keras_nlp/tokenizers/word_piece_tokenizer.py | 4 ++-- keras_nlp/tokenizers/word_piece_tokenizer_trainer.py | 2 +- pip_build.py | 2 +- tools/checkpoint_conversion/convert_gemma_checkpoints.py | 2 +- tools/count_preset_params.py | 2 +- tools/gemma/export_gemma_to_hf.py | 2 +- tools/gemma/export_gemma_to_torch_xla.py | 2 +- tools/gemma/run_gemma_xla.py | 2 +- 124 files changed, 130 insertions(+), 130 deletions(-) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 3db287de99..5d3466df69 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -116,7 +116,7 @@ class PositionEmbedding(keras.layers.Layer): Args: sequence_length: The maximum length of the dynamic sequence. - Examples: + Example usage: Direct call. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/examples/bert_pretraining/bert_create_pretraining_data.py b/examples/bert_pretraining/bert_create_pretraining_data.py index f7dcb54426..9e70d906de 100644 --- a/examples/bert_pretraining/bert_create_pretraining_data.py +++ b/examples/bert_pretraining/bert_create_pretraining_data.py @@ -27,7 +27,7 @@ This script is adapted from the original BERT respository: https://github.com/google-research/bert/blob/master/create_pretraining_data.py -Usage: +Example usage: python create_pretraining_data.py \ --input_files ~/datasets/bert-sentence-split-data/shard_0.txt \ --output_directory ~/datasets/bert-pretraining-data/shard_0.txt \ diff --git a/examples/tools/split_sentences.py b/examples/tools/split_sentences.py index 7606d0a070..ff14d714e2 100644 --- a/examples/tools/split_sentences.py +++ b/examples/tools/split_sentences.py @@ -21,7 +21,7 @@ This script will run muliprocessed, and the number of concurrent process and output file shards can be controlled with `--num_jobs` and `--num_shards`. -Usage: +Example usage: python examples/tools/create_sentence_split_data.py \ --input_files ~/datasets/wikipedia,~/datasets/bookscorpus \ --output_directory ~/datasets/bert-sentence-split-data diff --git a/examples/tools/train_word_piece_vocab.py b/examples/tools/train_word_piece_vocab.py index a9689aaf7f..a4a4489f1b 100644 --- a/examples/tools/train_word_piece_vocab.py +++ b/examples/tools/train_word_piece_vocab.py @@ -15,7 +15,7 @@ This script will create wordpiece vocabularies suitable for pretraining BERT. -Usage: +Example usage: python examples/tools/train_word_piece_vocabulary.py \ --input_files ~/datasets/bert-sentence-split-data/ \ --output_file vocab.txt diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index fdc956ae15..d1d9b97e70 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -40,7 +40,7 @@ class AlibiBias(keras.layers.Layer): multi-head attention layer of the transformer to add alibi bias to it. With shape `(batch_size, num_heads, query_length, key_length)`. - Examples: + Example usage: ```python query_length = 10 key_length = 10 diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index a5370d960e..d08e36da70 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -50,7 +50,7 @@ class FNetEncoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Examples: + Example usage: ```python # Create a single FNet encoder layer. diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index eacee7e8c0..383d7a2293 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -60,7 +60,7 @@ class MaskedLMHead(keras.layers.Layer): The bias initializer for the dense and multiheaded attention layers. Defaults to `"zeros"`. - Examples: + Example usage: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/position_embedding.py b/keras_nlp/layers/modeling/position_embedding.py index 6f9a44c29f..11086fbea2 100644 --- a/keras_nlp/layers/modeling/position_embedding.py +++ b/keras_nlp/layers/modeling/position_embedding.py @@ -43,7 +43,7 @@ class PositionEmbedding(keras.layers.Layer): compute the position embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example usage: Called directly on input. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index d115217687..574e6f9962 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -59,7 +59,7 @@ class ReversibleEmbedding(keras.layers.Embedding): from `output_dim` to `input_dim`, instead of a normal embedding call. Default to `False`. - Examples: + Example usage: ```python batch_size = 16 vocab_size = 100 diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index b494d559bd..3129cf5e30 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -47,7 +47,7 @@ class RotaryEmbedding(keras.layers.Layer): compute the rotary embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example usage: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/sine_position_encoding.py b/keras_nlp/layers/modeling/sine_position_encoding.py index 6e96a77e2c..da2f138ca4 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding.py +++ b/keras_nlp/layers/modeling/sine_position_encoding.py @@ -42,7 +42,7 @@ class SinePositionEncoding(keras.layers.Layer): compute the encoding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example usage: ```python # create a simple embedding layer with sinusoidal positional encoding seq_len = 100 diff --git a/keras_nlp/layers/modeling/token_and_position_embedding.py b/keras_nlp/layers/modeling/token_and_position_embedding.py index bb7107f96f..d215cad45e 100644 --- a/keras_nlp/layers/modeling/token_and_position_embedding.py +++ b/keras_nlp/layers/modeling/token_and_position_embedding.py @@ -44,7 +44,7 @@ class TokenAndPositionEmbedding(keras.layers.Layer): used in the vocabulary (input_dim should equal size of vocabulary + 1). - Examples: + Example usage: ```python inputs = np.ones(shape=(1, 50), dtype="int32") embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding( diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index d06a1948f5..0473ce1025 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -72,7 +72,7 @@ class TransformerDecoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Examples: + Example usage: ```python # Create a single transformer decoder layer. decoder = keras_nlp.layers.TransformerDecoder( diff --git a/keras_nlp/layers/modeling/transformer_encoder.py b/keras_nlp/layers/modeling/transformer_encoder.py index 32cdd35547..07538bf401 100644 --- a/keras_nlp/layers/modeling/transformer_encoder.py +++ b/keras_nlp/layers/modeling/transformer_encoder.py @@ -61,7 +61,7 @@ class TransformerEncoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Examples: + Example usage: ```python # Create a single transformer encoder layer. diff --git a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py index 74b2fd9811..f7b1413ae9 100644 --- a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py +++ b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py @@ -82,7 +82,7 @@ class MaskedLMMaskGenerator(PreprocessingLayer): 1 means the corresponding position in `mask_positions` is an actual mask, 0 means it is a pad. - Examples: + Example usage: Basic usage. ```python diff --git a/keras_nlp/layers/preprocessing/multi_segment_packer.py b/keras_nlp/layers/preprocessing/multi_segment_packer.py index 638d6f2b91..9caed450e5 100644 --- a/keras_nlp/layers/preprocessing/multi_segment_packer.py +++ b/keras_nlp/layers/preprocessing/multi_segment_packer.py @@ -86,7 +86,7 @@ class MultiSegmentPacker(PreprocessingLayer): sequence. The second is an integer tensor of the same shape, containing the segment ids. - Examples: + Example usage: *Pack a single input for classification.* >>> seq1 = [1, 2, 3, 4] diff --git a/keras_nlp/layers/preprocessing/random_deletion.py b/keras_nlp/layers/preprocessing/random_deletion.py index 061290ba56..fac8b958d4 100644 --- a/keras_nlp/layers/preprocessing/random_deletion.py +++ b/keras_nlp/layers/preprocessing/random_deletion.py @@ -58,7 +58,7 @@ class RandomDeletion(PreprocessingLayer): tracable--it can be any python function. seed: A seed for the random number generator. - Examples: + Example usage: Word level usage. >>> keras.utils.set_random_seed(1337) diff --git a/keras_nlp/layers/preprocessing/random_swap.py b/keras_nlp/layers/preprocessing/random_swap.py index 27873f0fe8..2fb20f9aa8 100644 --- a/keras_nlp/layers/preprocessing/random_swap.py +++ b/keras_nlp/layers/preprocessing/random_swap.py @@ -60,7 +60,7 @@ class RandomSwap(PreprocessingLayer): seed: A seed for the random number generator. - Examples: + Example usage: Word level usage. >>> keras.utils.set_random_seed(1337) diff --git a/keras_nlp/layers/preprocessing/start_end_packer.py b/keras_nlp/layers/preprocessing/start_end_packer.py index be3466a506..4755a9076e 100644 --- a/keras_nlp/layers/preprocessing/start_end_packer.py +++ b/keras_nlp/layers/preprocessing/start_end_packer.py @@ -59,7 +59,7 @@ class StartEndPacker(PreprocessingLayer): add_end_value: Pass `False` to not append an end value for this input. - Examples: + Example usage: Unbatched input (int). >>> inputs = [5, 6, 7] diff --git a/keras_nlp/metrics/edit_distance.py b/keras_nlp/metrics/edit_distance.py index 263ff8290b..662022d477 100644 --- a/keras_nlp/metrics/edit_distance.py +++ b/keras_nlp/metrics/edit_distance.py @@ -53,7 +53,7 @@ class EditDistance(keras.metrics.Metric): References: - [Morris et al.](https://www.researchgate.net/publication/221478089) - Examples: + Example usage: Various Input Types. diff --git a/keras_nlp/metrics/perplexity.py b/keras_nlp/metrics/perplexity.py index 4a7e626bc9..b82ed84846 100644 --- a/keras_nlp/metrics/perplexity.py +++ b/keras_nlp/metrics/perplexity.py @@ -40,7 +40,7 @@ class Perplexity(keras.metrics.Metric): name: string. Name of the metric instance. **kwargs: Other keyword arguments. - Examples: + Example usage: 1. Calculate perplexity by calling update_state() and result(). 1.1. `sample_weight`, and `mask_token_id` are not provided. diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py index 82fcdedade..329c0dccf2 100644 --- a/keras_nlp/metrics/rouge_l.py +++ b/keras_nlp/metrics/rouge_l.py @@ -40,7 +40,7 @@ class RougeL(RougeBase): References: - [Lin et al., 2004](https://aclanthology.org/W04-1013/) - Examples: + Example usage: 1. Python string. >>> rouge_l = keras_nlp.metrics.RougeL() diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py index 8b135d1dd0..2d6b1770dd 100644 --- a/keras_nlp/metrics/rouge_n.py +++ b/keras_nlp/metrics/rouge_n.py @@ -42,7 +42,7 @@ class RougeN(RougeBase): References: - [Lin et al., 2004](https://aclanthology.org/W04-1013/) - Examples: + Example usage: 1. Python string. >>> rouge_n = keras_nlp.metrics.RougeN(order=2) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 09053ff893..8a34dcc4c5 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -77,7 +77,7 @@ class AlbertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 32a4e0847d..224ae51540 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -54,7 +54,7 @@ class AlbertClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index 89cf134465..a3c846cfb2 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -69,7 +69,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index 19f4bd9a7b..cec26b079c 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -72,7 +72,7 @@ class AlbertPreprocessor(Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Examples: + Example usage: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.AlbertPreprocessor.from_preset( diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 44aed44cf5..04fdfa3b08 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -46,7 +46,7 @@ class AlbertTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index bfdc8207ad..475f89389f 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -115,7 +115,7 @@ def from_preset( load_weights: Whether to load pre-trained weights into model. Defaults to `True`. - Examples: + Example usage: ```python # Load architecture and weights from preset model = keras_nlp.models.{{model_name}}.from_preset( diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index f100133d25..2d509f1df4 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -65,7 +65,7 @@ class BartBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index 3310b1e532..33effe28af 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -52,7 +52,7 @@ class BartPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index c530555b3d..5ab4f4264d 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -59,7 +59,7 @@ class BartSeq2SeqLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Use `generate()` to do text generation, given an input context. ```python diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 1c72e6e935..e365f896dc 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -58,7 +58,7 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor): generates label weights by shifting the padding mask one step to the left. - Examples: + Example usage: Directly calling the layer on data ```python diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py index 17fb237b88..39a2c7d232 100644 --- a/keras_nlp/models/bart/bart_tokenizer.py +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -48,7 +48,7 @@ class BartTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 320dc1c2ee..7681622a65 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -66,7 +66,7 @@ class BertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 09d2b8810c..e6a4b050c3 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -55,7 +55,7 @@ class BertClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py index cdc61fbac3..75d8e5e6d3 100644 --- a/keras_nlp/models/bert/bert_masked_lm_preprocessor.py +++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py @@ -72,7 +72,7 @@ class BertMaskedLMPreprocessor(BertPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index 02f5a45985..7064761c49 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -69,7 +69,7 @@ class BertPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py index 1b634fe9b3..06e1cc2672 100644 --- a/keras_nlp/models/bert/bert_tokenizer.py +++ b/keras_nlp/models/bert/bert_tokenizer.py @@ -50,7 +50,7 @@ class BertTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.BertTokenizer.from_preset( diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 9b7c65a399..1bf18b5c25 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -58,7 +58,7 @@ class BloomBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py index 31eae30c6b..d259852243 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm.py +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -53,7 +53,7 @@ class BloomCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py index b56e1a3ef0..d8f503eaa7 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class BloomCausalLMPreprocessor(BloomPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.BloomCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index 8eb693cb50..8807caff40 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -62,7 +62,7 @@ class BloomPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index 0d7f74b163..e2f100ba82 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -46,7 +46,7 @@ class BloomTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index d6eea63601..ff973eb3c1 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -64,7 +64,7 @@ class DebertaV3Classifier(Task): dropout: float. Dropout probability applied to the pooled output. For the second dropout layer, `backbone.dropout` is used. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py index 519b0b4fca..2f2bca3b17 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py @@ -71,7 +71,7 @@ class DebertaV3MaskedLMPreprocessor(DebertaV3Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Examples: + Example usage: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 88fa08fd70..6c8902a86e 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -71,7 +71,7 @@ class DebertaV3Preprocessor(Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Examples: + Example usage: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.DebertaV3Preprocessor.from_preset( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py index e66c373e65..09da891b99 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py @@ -53,7 +53,7 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 73634b4216..af05dbed73 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -67,7 +67,7 @@ class DistilBertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index e82aaf2781..02391fa1b9 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -61,7 +61,7 @@ class DistilBertClassifier(Task): dropout: float. The dropout probability value, applied after the first dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py index f1360f58b7..cdd54954dd 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py @@ -74,7 +74,7 @@ class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 63f4e3637b..6ff3275a99 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -68,7 +68,7 @@ class DistilBertPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index 4a18398a1e..ba738fe76c 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -47,7 +47,7 @@ class DistilBertTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index f4f2a23b69..4fbfdf580b 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -63,7 +63,7 @@ class ElectraBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index acd665c2a3..1aaf74ed09 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -37,7 +37,7 @@ class ElectraTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Examples: + Example usage: ```python # Custom Vocabulary. vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index ab056c84c7..cd5fb604c2 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -71,7 +71,7 @@ class FNetBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index 512182d2cd..b42535c357 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -55,7 +55,7 @@ class FNetClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py index 51b4a4d1e7..6161d62520 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py @@ -68,7 +68,7 @@ class FNetMaskedLMPreprocessor(FNetPreprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index b4cb5836bb..06520acea0 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -66,7 +66,7 @@ class FNetPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/f_net/f_net_tokenizer.py b/keras_nlp/models/f_net/f_net_tokenizer.py index ae3f569b1d..8f785165c5 100644 --- a/keras_nlp/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/models/f_net/f_net_tokenizer.py @@ -46,7 +46,7 @@ class FNetTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.FNetTokenizer.from_preset( diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py index 5a3a0fccda..5412d1ec97 100644 --- a/keras_nlp/models/falcon/falcon_backbone.py +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -44,7 +44,7 @@ class FalconBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py index 61afb9b5a7..3930b04b90 100644 --- a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class FalconCausalLMPreprocessor(FalconPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/falcon/falcon_preprocessor.py b/keras_nlp/models/falcon/falcon_preprocessor.py index b37d641467..e5bec96350 100644 --- a/keras_nlp/models/falcon/falcon_preprocessor.py +++ b/keras_nlp/models/falcon/falcon_preprocessor.py @@ -67,7 +67,7 @@ class FalconPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py index 3201d27a63..6c60b4b544 100644 --- a/keras_nlp/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -46,7 +46,7 @@ class FalconTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 45c7c6abe0..17794c144e 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -53,7 +53,7 @@ class GemmaCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Use `generate()` to do text generation. ```python @@ -359,7 +359,7 @@ def score( [batch_size, num_tokens, vocab_size] in "logits" mode, or [batch_size, num_tokens] in "loss" mode. - Examples: + Example usage: Compute gradients between embeddings and loss scores with TensorFlow: ```python diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py index 04a067be82..6ef97c29bd 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class GemmaCausalLMPreprocessor(GemmaPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py index 8fc3beb48c..e4e7eeb51b 100644 --- a/keras_nlp/models/gemma/gemma_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -67,7 +67,7 @@ class GemmaPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py index 6a4bb76ea0..0a28661359 100644 --- a/keras_nlp/models/gemma/gemma_tokenizer.py +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -41,7 +41,7 @@ class GemmaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b0bd529da4..e49bd9d4b6 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -58,7 +58,7 @@ class GPT2CausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 3278b18a4f..6ccb03c146 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 82be34776f..0e64066a39 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -67,7 +67,7 @@ class GPT2Preprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 15b35bed87..4cc49237a1 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -46,7 +46,7 @@ class GPT2Tokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index b5383d528a..4776fb2e01 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -67,7 +67,7 @@ class LlamaBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py index a221185582..0aaffc5b64 100644 --- a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class LlamaCausalLMPreprocessor(LlamaPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py index 580557f50d..fe373d8d76 100644 --- a/keras_nlp/models/llama/llama_preprocessor.py +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -56,7 +56,7 @@ class LlamaPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py index 7acdf8687c..5e3a15cc28 100644 --- a/keras_nlp/models/llama/llama_tokenizer.py +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -41,7 +41,7 @@ class LlamaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.LlamaTokenizer.from_preset( diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 52de945760..28f264c444 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -73,7 +73,7 @@ class MistralBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py index 624c37c9a1..b56fbb40b9 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class MistralCausalLMPreprocessor(MistralPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.MistralCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index 38dc6da5b6..d53f23b138 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -59,7 +59,7 @@ class MistralPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index 59a00d302f..60e4fc32d2 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -45,7 +45,7 @@ class MistralTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.MistralTokenizer.from_preset( diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 16fe4a0218..acce5a04ab 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -62,7 +62,7 @@ class OPTBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 2ca8ee07b4..2e67c40670 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -58,7 +58,7 @@ class OPTCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py index 0a9ab86b00..387a23b7c5 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py @@ -57,7 +57,7 @@ class OPTCausalLMPreprocessor(OPTPreprocessor): return_labels: If `True`, the output `"token_ids"` will be offset by one and returned as labels. If `False` only features will be returned. - Examples: + Example usage: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.OPTCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index 8f52bb67e6..8ac57da934 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -67,7 +67,7 @@ class OPTPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index 4fb62ee73a..52ae0302b6 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -45,7 +45,7 @@ class OPTTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.OPTTokenizer.from_preset( diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 16a65e57c2..18eafda5e6 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -75,7 +75,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example usage: ```python # Load a preprocessor layer from a preset. preprocessor = keras_nlp.models.{{preprocessor_name}}.from_preset( diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 09fe753762..ca63c8c4fb 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -66,7 +66,7 @@ class RobertaBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 887bc657d4..3b7dd43fd8 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -56,7 +56,7 @@ class RobertaClassifier(Task): dropout: float. The dropout probability value, applied to the pooled output, and after the first dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index bf96189860..19462df13c 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -53,7 +53,7 @@ class RobertaMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py index c69c300dc8..b4df1a0b06 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py @@ -74,7 +74,7 @@ class RobertaMaskedLMPreprocessor(RobertaPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 57a421590f..2683397b5f 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -69,7 +69,7 @@ class RobertaPreprocessor(Preprocessor): sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py index 0cfabff754..068935a4d8 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/models/roberta/roberta_tokenizer.py @@ -47,7 +47,7 @@ class RobertaTokenizer(BytePairTokenizer): path. the file should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Examples: + Example usage: ```python # Unbatched input. tokenizer = keras_nlp.models.RobertaTokenizer.from_preset( diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py index b5dee49b85..d7084ba298 100644 --- a/keras_nlp/models/t5/t5_tokenizer.py +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -41,7 +41,7 @@ class T5Tokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python bytes_io = io.BytesIO() diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 9957f6546f..310d1c8585 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -199,7 +199,7 @@ def from_preset( load_weights: Whether to load pre-trained weights into model. Defaults to `True`. - Examples: + Example usage: ```python # Load architecture and weights from preset model = {{model_task_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index e41519bbc9..c9d29146a2 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -51,7 +51,7 @@ class WhisperAudioFeatureExtractor(PreprocessingLayer): seconds. The input audio tensor will be padded/trimmed to `max_audio_length * sampling_rate`. Defaults to `30`. - Examples: + Example usage: ```python audio_tensor = tf.ones((8000,), dtype="float32") @@ -281,7 +281,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example usage: ```python # Load a preset tokenizer. audio_feature_extractor = WhisperAudioFeatureExtractor.from_preset( diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index a2b685544e..3daac7c4fa 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -80,7 +80,7 @@ class WhisperBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index c21705a481..3e4d8081cf 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -70,7 +70,7 @@ class WhisperPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index c74a0fd6fc..a86c802210 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -57,7 +57,7 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example usage: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index fcd8bfe9b8..ca104b9135 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -58,7 +58,7 @@ class XLMRobertaClassifier(Task): dropout: float. The dropout probability value, applied to the pooled output, and after the first dense layer. - Examples: + Example usage: Raw string data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py index a26905e9e3..5fb0ef8aa3 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py @@ -76,7 +76,7 @@ class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index c94f5f2421..2be4e6b9af 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -71,7 +71,7 @@ class XLMRobertaPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Examples: + Example usage: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index 576f30bca1..3fb6f9dc24 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -50,7 +50,7 @@ class XLMRobertaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Examples: + Example usage: ```python tokenizer = keras_nlp.models.XLMRobertaTokenizer.from_preset( "xlm_roberta_base_multi", diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 45be1f74e7..6d971979e7 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -65,7 +65,7 @@ class XLNetBackbone(Backbone): padding_mask: Mask to avoid performing attention on padding token indices of shape `[batch_size, sequence_length]`. - Examples: + Example usage: ```python import numpy as np from keras_nlp.models import XLNetBackbone diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 297ec203de..5d5a2031f9 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -38,7 +38,7 @@ class BeamSampler(Sampler): Call arguments: {{call_args}} - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 36d10690d7..3f8bba2eda 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -38,7 +38,7 @@ class ContrastiveSampler(Sampler): Call arguments: {{call_args}} - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index ee8a6ecc2d..b3d989feee 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -24,7 +24,7 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index 1ff39c9f9b..fdef439455 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -32,7 +32,7 @@ class RandomSampler(Sampler): Call arguments: {{call_args}} - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 3ecf16ac28..7c9618a408 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -36,7 +36,7 @@ class Sampler: computes the next token based on a probability distribution over all possible vocab entries. - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 513dd738c7..df83102ff3 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -33,7 +33,7 @@ class TopKSampler(Sampler): Call arguments: {{call_args}} - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 326f5797a6..3585dfcb60 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -41,7 +41,7 @@ class TopPSampler(Sampler): Call arguments: {{call_args}} - Examples: + Example usage: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 2ac8832a76..6261caa83d 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -156,7 +156,7 @@ class BytePairTokenizerCache(tf.Module): The cache key is string tensor or python strings, and the value is split tokens joined by whitespace. For example, "dragonfly" => "dragon fly" - Examples: + Example usage: ``` cache = BytePairTokenizerCache() cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"]) @@ -252,7 +252,7 @@ class BytePairTokenizer(tokenizer.Tokenizer): contain splittable characters such as punctuation. Special tokens must still be included in `vocabulary`. Defaults to `None`. - Examples: + Example usage: Tokenize >>> vocab = {"butter": 1, "fly": 2} @@ -665,7 +665,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example usage: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index 3aefc4a01d..95bae533b5 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -76,7 +76,7 @@ class ByteTokenizer(tokenizer.Tokenizer): https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode). (U+FFFD) is `65533`. Defaults to `65533`. - Examples: + Example usage: Basic usage. >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index 64e169939c..93bded8cd4 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -68,7 +68,7 @@ class SentencePieceTokenizer(tokenizer.Tokenizer): References: - [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226) - Examples: + Example usage: From bytes. ```python @@ -275,7 +275,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example usage: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py index af9f2624b2..64b738480f 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py @@ -56,7 +56,7 @@ def compute_sentence_piece_proto( A `bytes` object with a serialized SentencePiece proto or `None` if proto_output_file if provided. - Examples: + Example usage: Basic Usage (from Dataset). >>> inputs = tf.data.Dataset.from_tensor_slices(["Drifting Along"]) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 7da1e9d7b1..e767f9749d 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -40,7 +40,7 @@ class Tokenizer(PreprocessingLayer): "vocab free" tokenizers, such as a whitespace splitter show below, these methods do not apply and can be skipped. - Examples: + Example usage: ```python class WhitespaceSplitterTokenizer(keras_nlp.tokenizers.Tokenizer): diff --git a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py index 5fe8f0144d..825fa944f4 100644 --- a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py @@ -79,7 +79,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer): Effectively this will make the `vocabulary_size - 1` id the the OOV value. - Examples: + Example usage: Basic Usage. >>> inputs = "Unicode Tokenizer" diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 75f956899f..0f5b265794 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -230,7 +230,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer): - [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/) - [Song et al., 2020](https://arxiv.org/abs/2012.15524) - Examples: + Example usage: Ragged outputs. >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."] @@ -480,7 +480,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example usage: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py b/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py index dc90075a5c..52b4f1f848 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py @@ -75,7 +75,7 @@ def compute_word_piece_vocabulary( Returns: Returns a list of vocabulary terms. - Examples: + Example usage: Basic Usage (from Dataset). >>> inputs = tf.data.Dataset.from_tensor_slices(["bat sat pat mat rat"]) diff --git a/pip_build.py b/pip_build.py index 0c83cbb436..8f74a931ab 100644 --- a/pip_build.py +++ b/pip_build.py @@ -13,7 +13,7 @@ # limitations under the License. """Script to create (and optionally install) a `.whl` archive for KerasNLP. -Usage: +Example usage: 1. Create a `.whl` file in `dist/`: diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py index ed81e023d4..a2e7acc4b8 100644 --- a/tools/checkpoint_conversion/convert_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -19,7 +19,7 @@ pip install git+https://github.com/google-deepmind/gemma.git python pip_build.py --install -Usage: +Example usage: cd tools/checkpoint_conversion python convert_gemma_checkpoints.py --preset gemma_2b_en """ diff --git a/tools/count_preset_params.py b/tools/count_preset_params.py index 3edcc6d09d..d069b8fe3c 100644 --- a/tools/count_preset_params.py +++ b/tools/count_preset_params.py @@ -14,7 +14,7 @@ """ Small utility script to count parameters in our preset checkpoints. -Usage: +Example usage: python tools/count_preset_params.py python tools/count_preset_params.py --model BertBackbone python tools/count_preset_params.py --preset bert_base_multi diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py index 6f1fdf24d2..0bf76aeeda 100644 --- a/tools/gemma/export_gemma_to_hf.py +++ b/tools/gemma/export_gemma_to_hf.py @@ -25,7 +25,7 @@ os.environ["KERAS_BACKEND"] = "torch" """ -Sample usage: +Example usage: For converting a keras model to HuggingFace format using a custom or fine-tuned checkpoint from Keras, make sure to pass the path for the Keras weights file diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py index 08d4b3ac98..d2b8e4aa86 100644 --- a/tools/gemma/export_gemma_to_torch_xla.py +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -38,7 +38,7 @@ os.environ["KERAS_BACKEND"] = "torch" """ -Sample usage: +Example usage: For converting a Keras model to PyTorch format using a custom or fine-tuned checkpoint from Keras, make sure to pass the path for the Keras weights file diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py index f212154c99..e71d727b95 100644 --- a/tools/gemma/run_gemma_xla.py +++ b/tools/gemma/run_gemma_xla.py @@ -47,7 +47,7 @@ from gemma.tokenizer import Tokenizer """ -Sample usage: +Example usage: Run the verification script supplying your model size, converted checkpoint file, vocabulary file, and test prompt. From 2acb4c9052fa68d1040d090cf5e364b04d7eb3de Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Tue, 19 Mar 2024 18:12:42 -0700 Subject: [PATCH 46/70] Revert "Unify docstring style" This reverts commit 97c3413d36cdb731ad4ce456ad6d899e1d2b37fb. --- STYLE_GUIDE.md | 2 +- examples/bert_pretraining/bert_create_pretraining_data.py | 2 +- examples/tools/split_sentences.py | 2 +- examples/tools/train_word_piece_vocab.py | 2 +- keras_nlp/layers/modeling/alibi_bias.py | 2 +- keras_nlp/layers/modeling/f_net_encoder.py | 2 +- keras_nlp/layers/modeling/masked_lm_head.py | 2 +- keras_nlp/layers/modeling/position_embedding.py | 2 +- keras_nlp/layers/modeling/reversible_embedding.py | 2 +- keras_nlp/layers/modeling/rotary_embedding.py | 2 +- keras_nlp/layers/modeling/sine_position_encoding.py | 2 +- keras_nlp/layers/modeling/token_and_position_embedding.py | 2 +- keras_nlp/layers/modeling/transformer_decoder.py | 2 +- keras_nlp/layers/modeling/transformer_encoder.py | 2 +- keras_nlp/layers/preprocessing/masked_lm_mask_generator.py | 2 +- keras_nlp/layers/preprocessing/multi_segment_packer.py | 2 +- keras_nlp/layers/preprocessing/random_deletion.py | 2 +- keras_nlp/layers/preprocessing/random_swap.py | 2 +- keras_nlp/layers/preprocessing/start_end_packer.py | 2 +- keras_nlp/metrics/edit_distance.py | 2 +- keras_nlp/metrics/perplexity.py | 2 +- keras_nlp/metrics/rouge_l.py | 2 +- keras_nlp/metrics/rouge_n.py | 2 +- keras_nlp/models/albert/albert_backbone.py | 2 +- keras_nlp/models/albert/albert_classifier.py | 2 +- keras_nlp/models/albert/albert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/albert/albert_preprocessor.py | 2 +- keras_nlp/models/albert/albert_tokenizer.py | 2 +- keras_nlp/models/backbone.py | 2 +- keras_nlp/models/bart/bart_backbone.py | 2 +- keras_nlp/models/bart/bart_preprocessor.py | 2 +- keras_nlp/models/bart/bart_seq_2_seq_lm.py | 2 +- keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py | 2 +- keras_nlp/models/bart/bart_tokenizer.py | 2 +- keras_nlp/models/bert/bert_backbone.py | 2 +- keras_nlp/models/bert/bert_classifier.py | 2 +- keras_nlp/models/bert/bert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/bert/bert_preprocessor.py | 2 +- keras_nlp/models/bert/bert_tokenizer.py | 2 +- keras_nlp/models/bloom/bloom_backbone.py | 2 +- keras_nlp/models/bloom/bloom_causal_lm.py | 2 +- keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py | 2 +- keras_nlp/models/bloom/bloom_preprocessor.py | 2 +- keras_nlp/models/bloom/bloom_tokenizer.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_classifier.py | 2 +- .../models/deberta_v3/deberta_v3_masked_lm_preprocessor.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py | 2 +- keras_nlp/models/distil_bert/distil_bert_backbone.py | 2 +- keras_nlp/models/distil_bert/distil_bert_classifier.py | 2 +- .../distil_bert/distil_bert_masked_lm_preprocessor.py | 2 +- keras_nlp/models/distil_bert/distil_bert_preprocessor.py | 2 +- keras_nlp/models/distil_bert/distil_bert_tokenizer.py | 2 +- keras_nlp/models/electra/electra_backbone.py | 2 +- keras_nlp/models/electra/electra_tokenizer.py | 2 +- keras_nlp/models/f_net/f_net_backbone.py | 2 +- keras_nlp/models/f_net/f_net_classifier.py | 2 +- keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py | 2 +- keras_nlp/models/f_net/f_net_preprocessor.py | 2 +- keras_nlp/models/f_net/f_net_tokenizer.py | 2 +- keras_nlp/models/falcon/falcon_backbone.py | 2 +- keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py | 2 +- keras_nlp/models/falcon/falcon_preprocessor.py | 2 +- keras_nlp/models/falcon/falcon_tokenizer.py | 2 +- keras_nlp/models/gemma/gemma_causal_lm.py | 4 ++-- keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py | 2 +- keras_nlp/models/gemma/gemma_preprocessor.py | 2 +- keras_nlp/models/gemma/gemma_tokenizer.py | 2 +- keras_nlp/models/gpt2/gpt2_causal_lm.py | 2 +- keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py | 2 +- keras_nlp/models/gpt2/gpt2_preprocessor.py | 2 +- keras_nlp/models/gpt2/gpt2_tokenizer.py | 2 +- keras_nlp/models/llama/llama_backbone.py | 2 +- keras_nlp/models/llama/llama_causal_lm_preprocessor.py | 2 +- keras_nlp/models/llama/llama_preprocessor.py | 2 +- keras_nlp/models/llama/llama_tokenizer.py | 2 +- keras_nlp/models/mistral/mistral_backbone.py | 2 +- keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py | 2 +- keras_nlp/models/mistral/mistral_preprocessor.py | 2 +- keras_nlp/models/mistral/mistral_tokenizer.py | 2 +- keras_nlp/models/opt/opt_backbone.py | 2 +- keras_nlp/models/opt/opt_causal_lm.py | 2 +- keras_nlp/models/opt/opt_causal_lm_preprocessor.py | 2 +- keras_nlp/models/opt/opt_preprocessor.py | 2 +- keras_nlp/models/opt/opt_tokenizer.py | 2 +- keras_nlp/models/preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_backbone.py | 2 +- keras_nlp/models/roberta/roberta_classifier.py | 2 +- keras_nlp/models/roberta/roberta_masked_lm.py | 2 +- keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_preprocessor.py | 2 +- keras_nlp/models/roberta/roberta_tokenizer.py | 2 +- keras_nlp/models/t5/t5_tokenizer.py | 2 +- keras_nlp/models/task.py | 2 +- keras_nlp/models/whisper/whisper_audio_feature_extractor.py | 4 ++-- keras_nlp/models/whisper/whisper_backbone.py | 2 +- keras_nlp/models/whisper/whisper_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py | 2 +- .../xlm_roberta/xlm_roberta_masked_lm_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py | 2 +- keras_nlp/models/xlnet/xlnet_backbone.py | 2 +- keras_nlp/samplers/beam_sampler.py | 2 +- keras_nlp/samplers/contrastive_sampler.py | 2 +- keras_nlp/samplers/greedy_sampler.py | 2 +- keras_nlp/samplers/random_sampler.py | 2 +- keras_nlp/samplers/sampler.py | 2 +- keras_nlp/samplers/top_k_sampler.py | 2 +- keras_nlp/samplers/top_p_sampler.py | 2 +- keras_nlp/tokenizers/byte_pair_tokenizer.py | 6 +++--- keras_nlp/tokenizers/byte_tokenizer.py | 2 +- keras_nlp/tokenizers/sentence_piece_tokenizer.py | 4 ++-- keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py | 2 +- keras_nlp/tokenizers/tokenizer.py | 2 +- keras_nlp/tokenizers/unicode_codepoint_tokenizer.py | 2 +- keras_nlp/tokenizers/word_piece_tokenizer.py | 4 ++-- keras_nlp/tokenizers/word_piece_tokenizer_trainer.py | 2 +- pip_build.py | 2 +- tools/checkpoint_conversion/convert_gemma_checkpoints.py | 2 +- tools/count_preset_params.py | 2 +- tools/gemma/export_gemma_to_hf.py | 2 +- tools/gemma/export_gemma_to_torch_xla.py | 2 +- tools/gemma/run_gemma_xla.py | 2 +- 124 files changed, 130 insertions(+), 130 deletions(-) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 5d3466df69..3db287de99 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -116,7 +116,7 @@ class PositionEmbedding(keras.layers.Layer): Args: sequence_length: The maximum length of the dynamic sequence. - Example usage: + Examples: Direct call. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/examples/bert_pretraining/bert_create_pretraining_data.py b/examples/bert_pretraining/bert_create_pretraining_data.py index 9e70d906de..f7dcb54426 100644 --- a/examples/bert_pretraining/bert_create_pretraining_data.py +++ b/examples/bert_pretraining/bert_create_pretraining_data.py @@ -27,7 +27,7 @@ This script is adapted from the original BERT respository: https://github.com/google-research/bert/blob/master/create_pretraining_data.py -Example usage: +Usage: python create_pretraining_data.py \ --input_files ~/datasets/bert-sentence-split-data/shard_0.txt \ --output_directory ~/datasets/bert-pretraining-data/shard_0.txt \ diff --git a/examples/tools/split_sentences.py b/examples/tools/split_sentences.py index ff14d714e2..7606d0a070 100644 --- a/examples/tools/split_sentences.py +++ b/examples/tools/split_sentences.py @@ -21,7 +21,7 @@ This script will run muliprocessed, and the number of concurrent process and output file shards can be controlled with `--num_jobs` and `--num_shards`. -Example usage: +Usage: python examples/tools/create_sentence_split_data.py \ --input_files ~/datasets/wikipedia,~/datasets/bookscorpus \ --output_directory ~/datasets/bert-sentence-split-data diff --git a/examples/tools/train_word_piece_vocab.py b/examples/tools/train_word_piece_vocab.py index a4a4489f1b..a9689aaf7f 100644 --- a/examples/tools/train_word_piece_vocab.py +++ b/examples/tools/train_word_piece_vocab.py @@ -15,7 +15,7 @@ This script will create wordpiece vocabularies suitable for pretraining BERT. -Example usage: +Usage: python examples/tools/train_word_piece_vocabulary.py \ --input_files ~/datasets/bert-sentence-split-data/ \ --output_file vocab.txt diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index d1d9b97e70..fdc956ae15 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -40,7 +40,7 @@ class AlibiBias(keras.layers.Layer): multi-head attention layer of the transformer to add alibi bias to it. With shape `(batch_size, num_heads, query_length, key_length)`. - Example usage: + Examples: ```python query_length = 10 key_length = 10 diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index d08e36da70..a5370d960e 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -50,7 +50,7 @@ class FNetEncoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Example usage: + Examples: ```python # Create a single FNet encoder layer. diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index 383d7a2293..eacee7e8c0 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -60,7 +60,7 @@ class MaskedLMHead(keras.layers.Layer): The bias initializer for the dense and multiheaded attention layers. Defaults to `"zeros"`. - Example usage: + Examples: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/position_embedding.py b/keras_nlp/layers/modeling/position_embedding.py index 11086fbea2..6f9a44c29f 100644 --- a/keras_nlp/layers/modeling/position_embedding.py +++ b/keras_nlp/layers/modeling/position_embedding.py @@ -43,7 +43,7 @@ class PositionEmbedding(keras.layers.Layer): compute the position embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Example usage: + Examples: Called directly on input. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index 574e6f9962..d115217687 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -59,7 +59,7 @@ class ReversibleEmbedding(keras.layers.Embedding): from `output_dim` to `input_dim`, instead of a normal embedding call. Default to `False`. - Example usage: + Examples: ```python batch_size = 16 vocab_size = 100 diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 3129cf5e30..b494d559bd 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -47,7 +47,7 @@ class RotaryEmbedding(keras.layers.Layer): compute the rotary embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Example usage: + Examples: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/sine_position_encoding.py b/keras_nlp/layers/modeling/sine_position_encoding.py index da2f138ca4..6e96a77e2c 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding.py +++ b/keras_nlp/layers/modeling/sine_position_encoding.py @@ -42,7 +42,7 @@ class SinePositionEncoding(keras.layers.Layer): compute the encoding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Example usage: + Examples: ```python # create a simple embedding layer with sinusoidal positional encoding seq_len = 100 diff --git a/keras_nlp/layers/modeling/token_and_position_embedding.py b/keras_nlp/layers/modeling/token_and_position_embedding.py index d215cad45e..bb7107f96f 100644 --- a/keras_nlp/layers/modeling/token_and_position_embedding.py +++ b/keras_nlp/layers/modeling/token_and_position_embedding.py @@ -44,7 +44,7 @@ class TokenAndPositionEmbedding(keras.layers.Layer): used in the vocabulary (input_dim should equal size of vocabulary + 1). - Example usage: + Examples: ```python inputs = np.ones(shape=(1, 50), dtype="int32") embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding( diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 0473ce1025..d06a1948f5 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -72,7 +72,7 @@ class TransformerDecoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Example usage: + Examples: ```python # Create a single transformer decoder layer. decoder = keras_nlp.layers.TransformerDecoder( diff --git a/keras_nlp/layers/modeling/transformer_encoder.py b/keras_nlp/layers/modeling/transformer_encoder.py index 07538bf401..32cdd35547 100644 --- a/keras_nlp/layers/modeling/transformer_encoder.py +++ b/keras_nlp/layers/modeling/transformer_encoder.py @@ -61,7 +61,7 @@ class TransformerEncoder(keras.layers.Layer): name: string. The name of the layer. Defaults to `None`. **kwargs: other keyword arguments. - Example usage: + Examples: ```python # Create a single transformer encoder layer. diff --git a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py index f7b1413ae9..74b2fd9811 100644 --- a/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py +++ b/keras_nlp/layers/preprocessing/masked_lm_mask_generator.py @@ -82,7 +82,7 @@ class MaskedLMMaskGenerator(PreprocessingLayer): 1 means the corresponding position in `mask_positions` is an actual mask, 0 means it is a pad. - Example usage: + Examples: Basic usage. ```python diff --git a/keras_nlp/layers/preprocessing/multi_segment_packer.py b/keras_nlp/layers/preprocessing/multi_segment_packer.py index 9caed450e5..638d6f2b91 100644 --- a/keras_nlp/layers/preprocessing/multi_segment_packer.py +++ b/keras_nlp/layers/preprocessing/multi_segment_packer.py @@ -86,7 +86,7 @@ class MultiSegmentPacker(PreprocessingLayer): sequence. The second is an integer tensor of the same shape, containing the segment ids. - Example usage: + Examples: *Pack a single input for classification.* >>> seq1 = [1, 2, 3, 4] diff --git a/keras_nlp/layers/preprocessing/random_deletion.py b/keras_nlp/layers/preprocessing/random_deletion.py index fac8b958d4..061290ba56 100644 --- a/keras_nlp/layers/preprocessing/random_deletion.py +++ b/keras_nlp/layers/preprocessing/random_deletion.py @@ -58,7 +58,7 @@ class RandomDeletion(PreprocessingLayer): tracable--it can be any python function. seed: A seed for the random number generator. - Example usage: + Examples: Word level usage. >>> keras.utils.set_random_seed(1337) diff --git a/keras_nlp/layers/preprocessing/random_swap.py b/keras_nlp/layers/preprocessing/random_swap.py index 2fb20f9aa8..27873f0fe8 100644 --- a/keras_nlp/layers/preprocessing/random_swap.py +++ b/keras_nlp/layers/preprocessing/random_swap.py @@ -60,7 +60,7 @@ class RandomSwap(PreprocessingLayer): seed: A seed for the random number generator. - Example usage: + Examples: Word level usage. >>> keras.utils.set_random_seed(1337) diff --git a/keras_nlp/layers/preprocessing/start_end_packer.py b/keras_nlp/layers/preprocessing/start_end_packer.py index 4755a9076e..be3466a506 100644 --- a/keras_nlp/layers/preprocessing/start_end_packer.py +++ b/keras_nlp/layers/preprocessing/start_end_packer.py @@ -59,7 +59,7 @@ class StartEndPacker(PreprocessingLayer): add_end_value: Pass `False` to not append an end value for this input. - Example usage: + Examples: Unbatched input (int). >>> inputs = [5, 6, 7] diff --git a/keras_nlp/metrics/edit_distance.py b/keras_nlp/metrics/edit_distance.py index 662022d477..263ff8290b 100644 --- a/keras_nlp/metrics/edit_distance.py +++ b/keras_nlp/metrics/edit_distance.py @@ -53,7 +53,7 @@ class EditDistance(keras.metrics.Metric): References: - [Morris et al.](https://www.researchgate.net/publication/221478089) - Example usage: + Examples: Various Input Types. diff --git a/keras_nlp/metrics/perplexity.py b/keras_nlp/metrics/perplexity.py index b82ed84846..4a7e626bc9 100644 --- a/keras_nlp/metrics/perplexity.py +++ b/keras_nlp/metrics/perplexity.py @@ -40,7 +40,7 @@ class Perplexity(keras.metrics.Metric): name: string. Name of the metric instance. **kwargs: Other keyword arguments. - Example usage: + Examples: 1. Calculate perplexity by calling update_state() and result(). 1.1. `sample_weight`, and `mask_token_id` are not provided. diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py index 329c0dccf2..82fcdedade 100644 --- a/keras_nlp/metrics/rouge_l.py +++ b/keras_nlp/metrics/rouge_l.py @@ -40,7 +40,7 @@ class RougeL(RougeBase): References: - [Lin et al., 2004](https://aclanthology.org/W04-1013/) - Example usage: + Examples: 1. Python string. >>> rouge_l = keras_nlp.metrics.RougeL() diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py index 2d6b1770dd..8b135d1dd0 100644 --- a/keras_nlp/metrics/rouge_n.py +++ b/keras_nlp/metrics/rouge_n.py @@ -42,7 +42,7 @@ class RougeN(RougeBase): References: - [Lin et al., 2004](https://aclanthology.org/W04-1013/) - Example usage: + Examples: 1. Python string. >>> rouge_n = keras_nlp.metrics.RougeN(order=2) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 8a34dcc4c5..09053ff893 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -77,7 +77,7 @@ class AlbertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 224ae51540..32a4e0847d 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -54,7 +54,7 @@ class AlbertClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py index a3c846cfb2..89cf134465 100644 --- a/keras_nlp/models/albert/albert_masked_lm_preprocessor.py +++ b/keras_nlp/models/albert/albert_masked_lm_preprocessor.py @@ -69,7 +69,7 @@ class AlbertMaskedLMPreprocessor(AlbertPreprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index cec26b079c..19f4bd9a7b 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -72,7 +72,7 @@ class AlbertPreprocessor(Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Example usage: + Examples: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.AlbertPreprocessor.from_preset( diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 04fdfa3b08..44aed44cf5 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -46,7 +46,7 @@ class AlbertTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 475f89389f..bfdc8207ad 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -115,7 +115,7 @@ def from_preset( load_weights: Whether to load pre-trained weights into model. Defaults to `True`. - Example usage: + Examples: ```python # Load architecture and weights from preset model = keras_nlp.models.{{model_name}}.from_preset( diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 2d509f1df4..f100133d25 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -65,7 +65,7 @@ class BartBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index 33effe28af..3310b1e532 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -52,7 +52,7 @@ class BartPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index 5ab4f4264d..c530555b3d 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -59,7 +59,7 @@ class BartSeq2SeqLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Use `generate()` to do text generation, given an input context. ```python diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index e365f896dc..1c72e6e935 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -58,7 +58,7 @@ class BartSeq2SeqLMPreprocessor(BartPreprocessor): generates label weights by shifting the padding mask one step to the left. - Example usage: + Examples: Directly calling the layer on data ```python diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py index 39a2c7d232..17fb237b88 100644 --- a/keras_nlp/models/bart/bart_tokenizer.py +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -48,7 +48,7 @@ class BartTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 7681622a65..320dc1c2ee 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -66,7 +66,7 @@ class BertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index e6a4b050c3..09d2b8810c 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -55,7 +55,7 @@ class BertClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/bert/bert_masked_lm_preprocessor.py b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py index 75d8e5e6d3..cdc61fbac3 100644 --- a/keras_nlp/models/bert/bert_masked_lm_preprocessor.py +++ b/keras_nlp/models/bert/bert_masked_lm_preprocessor.py @@ -72,7 +72,7 @@ class BertMaskedLMPreprocessor(BertPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index 7064761c49..02f5a45985 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -69,7 +69,7 @@ class BertPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py index 06e1cc2672..1b634fe9b3 100644 --- a/keras_nlp/models/bert/bert_tokenizer.py +++ b/keras_nlp/models/bert/bert_tokenizer.py @@ -50,7 +50,7 @@ class BertTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.BertTokenizer.from_preset( diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 1bf18b5c25..9b7c65a399 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -58,7 +58,7 @@ class BloomBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py index d259852243..31eae30c6b 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm.py +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -53,7 +53,7 @@ class BloomCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py index d8f503eaa7..b56e1a3ef0 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class BloomCausalLMPreprocessor(BloomPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.BloomCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index 8807caff40..8eb693cb50 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -62,7 +62,7 @@ class BloomPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index e2f100ba82..0d7f74b163 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -46,7 +46,7 @@ class BloomTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index ff973eb3c1..d6eea63601 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -64,7 +64,7 @@ class DebertaV3Classifier(Task): dropout: float. Dropout probability applied to the pooled output. For the second dropout layer, `backbone.dropout` is used. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py index 2f2bca3b17..519b0b4fca 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py @@ -71,7 +71,7 @@ class DebertaV3MaskedLMPreprocessor(DebertaV3Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Example usage: + Examples: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.DebertaV3MaskedLMPreprocessor.from_preset( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 6c8902a86e..88fa08fd70 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -71,7 +71,7 @@ class DebertaV3Preprocessor(Preprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Example usage: + Examples: Directly calling the layer on data. ```python preprocessor = keras_nlp.models.DebertaV3Preprocessor.from_preset( diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py index 09da891b99..e66c373e65 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py @@ -53,7 +53,7 @@ class DebertaV3Tokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index af05dbed73..73634b4216 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -67,7 +67,7 @@ class DistilBertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index 02391fa1b9..e82aaf2781 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -61,7 +61,7 @@ class DistilBertClassifier(Task): dropout: float. The dropout probability value, applied after the first dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py index cdd54954dd..f1360f58b7 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor.py @@ -74,7 +74,7 @@ class DistilBertMaskedLMPreprocessor(DistilBertPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 6ff3275a99..63f4e3637b 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -68,7 +68,7 @@ class DistilBertPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index ba738fe76c..4a18398a1e 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -47,7 +47,7 @@ class DistilBertTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 4fbfdf580b..f4f2a23b69 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -63,7 +63,7 @@ class ElectraBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index 1aaf74ed09..acd665c2a3 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -37,7 +37,7 @@ class ElectraTokenizer(WordPieceTokenizer): lowercase: If `True`, the input text will be first lowered before tokenization. - Example usage: + Examples: ```python # Custom Vocabulary. vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index cd5fb604c2..ab056c84c7 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -71,7 +71,7 @@ class FNetBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index b42535c357..512182d2cd 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -55,7 +55,7 @@ class FNetClassifier(Task): dropout: float. The dropout probability value, applied after the dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py index 6161d62520..51b4a4d1e7 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_preprocessor.py @@ -68,7 +68,7 @@ class FNetMaskedLMPreprocessor(FNetPreprocessor): left-to-right manner and fills up the buckets until we run out of budget. It supports an arbitrary number of segments. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index 06520acea0..b4cb5836bb 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -66,7 +66,7 @@ class FNetPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/f_net/f_net_tokenizer.py b/keras_nlp/models/f_net/f_net_tokenizer.py index 8f785165c5..ae3f569b1d 100644 --- a/keras_nlp/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/models/f_net/f_net_tokenizer.py @@ -46,7 +46,7 @@ class FNetTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.FNetTokenizer.from_preset( diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py index 5412d1ec97..5a3a0fccda 100644 --- a/keras_nlp/models/falcon/falcon_backbone.py +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -44,7 +44,7 @@ class FalconBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py index 3930b04b90..61afb9b5a7 100644 --- a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class FalconCausalLMPreprocessor(FalconPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/falcon/falcon_preprocessor.py b/keras_nlp/models/falcon/falcon_preprocessor.py index e5bec96350..b37d641467 100644 --- a/keras_nlp/models/falcon/falcon_preprocessor.py +++ b/keras_nlp/models/falcon/falcon_preprocessor.py @@ -67,7 +67,7 @@ class FalconPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py index 6c60b4b544..3201d27a63 100644 --- a/keras_nlp/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -46,7 +46,7 @@ class FalconTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 17794c144e..45c7c6abe0 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -53,7 +53,7 @@ class GemmaCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Use `generate()` to do text generation. ```python @@ -359,7 +359,7 @@ def score( [batch_size, num_tokens, vocab_size] in "logits" mode, or [batch_size, num_tokens] in "loss" mode. - Example usage: + Examples: Compute gradients between embeddings and loss scores with TensorFlow: ```python diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py index 6ef97c29bd..04a067be82 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class GemmaCausalLMPreprocessor(GemmaPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py index e4e7eeb51b..8fc3beb48c 100644 --- a/keras_nlp/models/gemma/gemma_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -67,7 +67,7 @@ class GemmaPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py index 0a28661359..6a4bb76ea0 100644 --- a/keras_nlp/models/gemma/gemma_tokenizer.py +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -41,7 +41,7 @@ class GemmaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index e49bd9d4b6..b0bd529da4 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -58,7 +58,7 @@ class GPT2CausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 6ccb03c146..3278b18a4f 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 0e64066a39..82be34776f 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -67,7 +67,7 @@ class GPT2Preprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 4cc49237a1..15b35bed87 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -46,7 +46,7 @@ class GPT2Tokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index 4776fb2e01..b5383d528a 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -67,7 +67,7 @@ class LlamaBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py index 0aaffc5b64..a221185582 100644 --- a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class LlamaCausalLMPreprocessor(LlamaPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py index fe373d8d76..580557f50d 100644 --- a/keras_nlp/models/llama/llama_preprocessor.py +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -56,7 +56,7 @@ class LlamaPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py index 5e3a15cc28..7acdf8687c 100644 --- a/keras_nlp/models/llama/llama_tokenizer.py +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -41,7 +41,7 @@ class LlamaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.LlamaTokenizer.from_preset( diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 28f264c444..52de945760 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -73,7 +73,7 @@ class MistralBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py index b56fbb40b9..624c37c9a1 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -56,7 +56,7 @@ class MistralCausalLMPreprocessor(MistralPreprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.MistralCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index d53f23b138..38dc6da5b6 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -59,7 +59,7 @@ class MistralPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the from_preset(). ```python diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index 60e4fc32d2..59a00d302f 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -45,7 +45,7 @@ class MistralTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.MistralTokenizer.from_preset( diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index acce5a04ab..16fe4a0218 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -62,7 +62,7 @@ class OPTBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 2e67c40670..2ca8ee07b4 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -58,7 +58,7 @@ class OPTCausalLM(GenerativeTask): If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Use `generate()` to do text generation. ```python diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py index 387a23b7c5..0a9ab86b00 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py @@ -57,7 +57,7 @@ class OPTCausalLMPreprocessor(OPTPreprocessor): return_labels: If `True`, the output `"token_ids"` will be offset by one and returned as labels. If `False` only features will be returned. - Example usage: + Examples: ```python # Load the preprocessor from a preset. preprocessor = keras_nlp.models.OPTCausalLMPreprocessor.from_preset( diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index 8ac57da934..8f52bb67e6 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -67,7 +67,7 @@ class OPTPreprocessor(Preprocessor): sequence_length: Pass to override the configured `sequence_length` of the layer. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index 52ae0302b6..4fb62ee73a 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -45,7 +45,7 @@ class OPTTokenizer(BytePairTokenizer): should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.OPTTokenizer.from_preset( diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 18eafda5e6..16a65e57c2 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -75,7 +75,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Example usage: + Examples: ```python # Load a preprocessor layer from a preset. preprocessor = keras_nlp.models.{{preprocessor_name}}.from_preset( diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index ca63c8c4fb..09fe753762 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -66,7 +66,7 @@ class RobertaBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 3b7dd43fd8..887bc657d4 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -56,7 +56,7 @@ class RobertaClassifier(Task): dropout: float. The dropout probability value, applied to the pooled output, and after the first dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index 19462df13c..bf96189860 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -53,7 +53,7 @@ class RobertaMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py index b4df1a0b06..c69c300dc8 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_preprocessor.py @@ -74,7 +74,7 @@ class RobertaMaskedLMPreprocessor(RobertaPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 2683397b5f..57a421590f 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -69,7 +69,7 @@ class RobertaPreprocessor(Preprocessor): sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py index 068935a4d8..0cfabff754 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/models/roberta/roberta_tokenizer.py @@ -47,7 +47,7 @@ class RobertaTokenizer(BytePairTokenizer): path. the file should have one merge rule per line. Every merge rule contains merge entities separated by a space. - Example usage: + Examples: ```python # Unbatched input. tokenizer = keras_nlp.models.RobertaTokenizer.from_preset( diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py index d7084ba298..b5dee49b85 100644 --- a/keras_nlp/models/t5/t5_tokenizer.py +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -41,7 +41,7 @@ class T5Tokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python bytes_io = io.BytesIO() diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 310d1c8585..9957f6546f 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -199,7 +199,7 @@ def from_preset( load_weights: Whether to load pre-trained weights into model. Defaults to `True`. - Example usage: + Examples: ```python # Load architecture and weights from preset model = {{model_task_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index c9d29146a2..e41519bbc9 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -51,7 +51,7 @@ class WhisperAudioFeatureExtractor(PreprocessingLayer): seconds. The input audio tensor will be padded/trimmed to `max_audio_length * sampling_rate`. Defaults to `30`. - Example usage: + Examples: ```python audio_tensor = tf.ones((8000,), dtype="float32") @@ -281,7 +281,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Example usage: + Examples: ```python # Load a preset tokenizer. audio_feature_extractor = WhisperAudioFeatureExtractor.from_preset( diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index 3daac7c4fa..a2b685544e 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -80,7 +80,7 @@ class WhisperBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index 3e4d8081cf..c21705a481 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -70,7 +70,7 @@ class WhisperPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index a86c802210..c74a0fd6fc 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -57,7 +57,7 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Example usage: + Examples: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index ca104b9135..fcd8bfe9b8 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -58,7 +58,7 @@ class XLMRobertaClassifier(Task): dropout: float. The dropout probability value, applied to the pooled output, and after the first dense layer. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py index 5fb0ef8aa3..a26905e9e3 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py @@ -76,7 +76,7 @@ class XLMRobertaMaskedLMPreprocessor(XLMRobertaPreprocessor): sample_weight: Label weights. Should always be `None` as the layer generates label weights. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index 2be4e6b9af..c94f5f2421 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -71,7 +71,7 @@ class XLMRobertaPreprocessor(Preprocessor): y: Any label data. Will be passed through unaltered. sample_weight: Any label weight data. Will be passed through unaltered. - Example usage: + Examples: Directly calling the layer on data. ```python diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index 3fb6f9dc24..576f30bca1 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -50,7 +50,7 @@ class XLMRobertaTokenizer(SentencePieceTokenizer): [SentencePiece repository](https://github.com/google/sentencepiece) for more details on the format. - Example usage: + Examples: ```python tokenizer = keras_nlp.models.XLMRobertaTokenizer.from_preset( "xlm_roberta_base_multi", diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 6d971979e7..45be1f74e7 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -65,7 +65,7 @@ class XLNetBackbone(Backbone): padding_mask: Mask to avoid performing attention on padding token indices of shape `[batch_size, sequence_length]`. - Example usage: + Examples: ```python import numpy as np from keras_nlp.models import XLNetBackbone diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 5d5a2031f9..297ec203de 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -38,7 +38,7 @@ class BeamSampler(Sampler): Call arguments: {{call_args}} - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 3f8bba2eda..36d10690d7 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -38,7 +38,7 @@ class ContrastiveSampler(Sampler): Call arguments: {{call_args}} - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index b3d989feee..ee8a6ecc2d 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -24,7 +24,7 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index fdef439455..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -32,7 +32,7 @@ class RandomSampler(Sampler): Call arguments: {{call_args}} - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 7c9618a408..3ecf16ac28 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -36,7 +36,7 @@ class Sampler: computes the next token based on a probability distribution over all possible vocab entries. - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index df83102ff3..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -33,7 +33,7 @@ class TopKSampler(Sampler): Call arguments: {{call_args}} - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index 3585dfcb60..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -41,7 +41,7 @@ class TopPSampler(Sampler): Call arguments: {{call_args}} - Example usage: + Examples: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 6261caa83d..2ac8832a76 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -156,7 +156,7 @@ class BytePairTokenizerCache(tf.Module): The cache key is string tensor or python strings, and the value is split tokens joined by whitespace. For example, "dragonfly" => "dragon fly" - Example usage: + Examples: ``` cache = BytePairTokenizerCache() cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"]) @@ -252,7 +252,7 @@ class BytePairTokenizer(tokenizer.Tokenizer): contain splittable characters such as punctuation. Special tokens must still be included in `vocabulary`. Defaults to `None`. - Example usage: + Examples: Tokenize >>> vocab = {"butter": 1, "fly": 2} @@ -665,7 +665,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Example usage: + Examples: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index 95bae533b5..3aefc4a01d 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -76,7 +76,7 @@ class ByteTokenizer(tokenizer.Tokenizer): https://www.tensorflow.org/api_docs/python/tf/strings/unicode_transcode). (U+FFFD) is `65533`. Defaults to `65533`. - Example usage: + Examples: Basic usage. >>> tokenizer = keras_nlp.tokenizers.ByteTokenizer() diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index 93bded8cd4..64e169939c 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -68,7 +68,7 @@ class SentencePieceTokenizer(tokenizer.Tokenizer): References: - [Kudo and Richardson, 2018](https://arxiv.org/abs/1808.06226) - Example usage: + Examples: From bytes. ```python @@ -275,7 +275,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Example usage: + Examples: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py index 64b738480f..af9f2624b2 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer_trainer.py @@ -56,7 +56,7 @@ def compute_sentence_piece_proto( A `bytes` object with a serialized SentencePiece proto or `None` if proto_output_file if provided. - Example usage: + Examples: Basic Usage (from Dataset). >>> inputs = tf.data.Dataset.from_tensor_slices(["Drifting Along"]) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index e767f9749d..7da1e9d7b1 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -40,7 +40,7 @@ class Tokenizer(PreprocessingLayer): "vocab free" tokenizers, such as a whitespace splitter show below, these methods do not apply and can be skipped. - Example usage: + Examples: ```python class WhitespaceSplitterTokenizer(keras_nlp.tokenizers.Tokenizer): diff --git a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py index 825fa944f4..5fe8f0144d 100644 --- a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py @@ -79,7 +79,7 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer): Effectively this will make the `vocabulary_size - 1` id the the OOV value. - Example usage: + Examples: Basic Usage. >>> inputs = "Unicode Tokenizer" diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 0f5b265794..75f956899f 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -230,7 +230,7 @@ class WordPieceTokenizer(tokenizer.Tokenizer): - [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/) - [Song et al., 2020](https://arxiv.org/abs/2012.15524) - Example usage: + Examples: Ragged outputs. >>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."] @@ -480,7 +480,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Example usage: + Examples: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py b/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py index 52b4f1f848..dc90075a5c 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_trainer.py @@ -75,7 +75,7 @@ def compute_word_piece_vocabulary( Returns: Returns a list of vocabulary terms. - Example usage: + Examples: Basic Usage (from Dataset). >>> inputs = tf.data.Dataset.from_tensor_slices(["bat sat pat mat rat"]) diff --git a/pip_build.py b/pip_build.py index 8f74a931ab..0c83cbb436 100644 --- a/pip_build.py +++ b/pip_build.py @@ -13,7 +13,7 @@ # limitations under the License. """Script to create (and optionally install) a `.whl` archive for KerasNLP. -Example usage: +Usage: 1. Create a `.whl` file in `dist/`: diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py index a2e7acc4b8..ed81e023d4 100644 --- a/tools/checkpoint_conversion/convert_gemma_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -19,7 +19,7 @@ pip install git+https://github.com/google-deepmind/gemma.git python pip_build.py --install -Example usage: +Usage: cd tools/checkpoint_conversion python convert_gemma_checkpoints.py --preset gemma_2b_en """ diff --git a/tools/count_preset_params.py b/tools/count_preset_params.py index d069b8fe3c..3edcc6d09d 100644 --- a/tools/count_preset_params.py +++ b/tools/count_preset_params.py @@ -14,7 +14,7 @@ """ Small utility script to count parameters in our preset checkpoints. -Example usage: +Usage: python tools/count_preset_params.py python tools/count_preset_params.py --model BertBackbone python tools/count_preset_params.py --preset bert_base_multi diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py index 0bf76aeeda..6f1fdf24d2 100644 --- a/tools/gemma/export_gemma_to_hf.py +++ b/tools/gemma/export_gemma_to_hf.py @@ -25,7 +25,7 @@ os.environ["KERAS_BACKEND"] = "torch" """ -Example usage: +Sample usage: For converting a keras model to HuggingFace format using a custom or fine-tuned checkpoint from Keras, make sure to pass the path for the Keras weights file diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py index d2b8e4aa86..08d4b3ac98 100644 --- a/tools/gemma/export_gemma_to_torch_xla.py +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -38,7 +38,7 @@ os.environ["KERAS_BACKEND"] = "torch" """ -Example usage: +Sample usage: For converting a Keras model to PyTorch format using a custom or fine-tuned checkpoint from Keras, make sure to pass the path for the Keras weights file diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py index e71d727b95..f212154c99 100644 --- a/tools/gemma/run_gemma_xla.py +++ b/tools/gemma/run_gemma_xla.py @@ -47,7 +47,7 @@ from gemma.tokenizer import Tokenizer """ -Example usage: +Sample usage: Run the verification script supplying your model size, converted checkpoint file, vocabulary file, and test prompt. From 898329fe33df1a92d3c01394af4c5685afff7891 Mon Sep 17 00:00:00 2001 From: mykolaskrynnyk <45297092+mykolaskrynnyk@users.noreply.github.com> Date: Wed, 20 Mar 2024 19:34:25 +0100 Subject: [PATCH 47/70] Docs/modelling layers (#1502) * Docs(layers): add a description for `tie_weights` argument * Refactor(layers): make `name` an explicit argument for Transformer layers * Refactor(layers): remove explicit usage of `name` in `__init__` calls * Docs(layers): remove references to `name` and consistently documents `**kwargs` --- keras_nlp/layers/modeling/alibi_bias.py | 3 +++ keras_nlp/layers/modeling/f_net_encoder.py | 7 +++---- keras_nlp/layers/modeling/masked_lm_head.py | 2 ++ keras_nlp/layers/modeling/position_embedding.py | 2 ++ keras_nlp/layers/modeling/reversible_embedding.py | 2 ++ keras_nlp/layers/modeling/rotary_embedding.py | 2 ++ keras_nlp/layers/modeling/sine_position_encoding.py | 2 ++ keras_nlp/layers/modeling/token_and_position_embedding.py | 5 +++++ keras_nlp/layers/modeling/transformer_decoder.py | 4 ++-- keras_nlp/layers/modeling/transformer_encoder.py | 4 ++-- 10 files changed, 25 insertions(+), 8 deletions(-) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index fdc956ae15..cc72be3f8c 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -35,6 +35,9 @@ class AlibiBias(keras.layers.Layer): each head. The heads' slopes are a geometric sequence that starts at `2**(-alibi_bias_max/num_heads)` and uses that same value as its ratio. Defaults to 8. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. + Call arguments: attention_scores: The result of multipying the query and the key of the multi-head attention layer of the transformer to add alibi bias to diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index a5370d960e..0732dee34c 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -47,8 +47,8 @@ class FNetEncoder(keras.layers.Layer): bias_initializer: "string" or `keras.initializers` initializer. The bias initializer for the dense layers. Defaults to `"zeros"`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Examples: @@ -79,10 +79,9 @@ def __init__( layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", bias_initializer="zeros", - name=None, **kwargs ): - super().__init__(name=name, **kwargs) + super().__init__(**kwargs) self.intermediate_dim = intermediate_dim self.dropout = dropout self.activation = keras.activations.get(activation) diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index eacee7e8c0..d51f0eb50b 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -59,6 +59,8 @@ class MaskedLMHead(keras.layers.Layer): bias_initializer: string or `keras.initializers` initializer. The bias initializer for the dense and multiheaded attention layers. Defaults to `"zeros"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Examples: diff --git a/keras_nlp/layers/modeling/position_embedding.py b/keras_nlp/layers/modeling/position_embedding.py index 6f9a44c29f..34597cb114 100644 --- a/keras_nlp/layers/modeling/position_embedding.py +++ b/keras_nlp/layers/modeling/position_embedding.py @@ -33,6 +33,8 @@ class PositionEmbedding(keras.layers.Layer): initializer: The initializer to use for the embedding weights. Defaults to `"glorot_uniform"`. seq_axis: The axis of the input tensor where we add the embeddings. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to compute an embedding for, with shape diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index d115217687..9266b6d28d 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -52,6 +52,8 @@ class ReversibleEmbedding(keras.layers.Embedding): reverse_dtype: The dtype for the reverse projection computation. For stability, it is usually best to use full precision even when working with half or mixed precision training. + **kwargs: other keyword arguments passed to `keras.layers.Embedding`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to the layer. diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index b494d559bd..1442548ea8 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -38,6 +38,8 @@ class RotaryEmbedding(keras.layers.Layer): scaling_factor: float. The scaling factor used to scale frequency range. sequence_axis: int. Sequence axis in the input tensor. feature_axis: int. Feature axis in the input tensor. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to apply the embedding to. This can have diff --git a/keras_nlp/layers/modeling/sine_position_encoding.py b/keras_nlp/layers/modeling/sine_position_encoding.py index 6e96a77e2c..5ab874c11d 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding.py +++ b/keras_nlp/layers/modeling/sine_position_encoding.py @@ -34,6 +34,8 @@ class SinePositionEncoding(keras.layers.Layer): max_wavelength: The maximum angular wavelength of the sine/cosine curves, as described in Attention is All You Need. Defaults to `10000`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to compute an embedding for, with shape diff --git a/keras_nlp/layers/modeling/token_and_position_embedding.py b/keras_nlp/layers/modeling/token_and_position_embedding.py index bb7107f96f..6266963bfd 100644 --- a/keras_nlp/layers/modeling/token_and_position_embedding.py +++ b/keras_nlp/layers/modeling/token_and_position_embedding.py @@ -33,6 +33,9 @@ class TokenAndPositionEmbedding(keras.layers.Layer): vocabulary_size: The size of the vocabulary. sequence_length: The maximum length of input sequence embedding_dim: The output dimension of the embedding layer + tie_weights: Boolean, whether or not the matrix for embedding and + the matrix for the `reverse` projection should share the same + weights. embeddings_initializer: The initializer to use for the Embedding Layers mask_zero: Boolean, whether or not the input value 0 is a special @@ -43,6 +46,8 @@ class TokenAndPositionEmbedding(keras.layers.Layer): If mask_zero` is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1). + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Examples: ```python diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index d06a1948f5..0de35da0b7 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -69,8 +69,8 @@ class TransformerDecoder(keras.layers.Layer): (similar to GPT-2). If set to False, outputs of attention layer and intermediate dense layer are normalized (similar to BERT). Defaults to `False`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Examples: ```python diff --git a/keras_nlp/layers/modeling/transformer_encoder.py b/keras_nlp/layers/modeling/transformer_encoder.py index 32cdd35547..cd45b6aebf 100644 --- a/keras_nlp/layers/modeling/transformer_encoder.py +++ b/keras_nlp/layers/modeling/transformer_encoder.py @@ -58,8 +58,8 @@ class TransformerEncoder(keras.layers.Layer): (similar to GPT-2). If set to False, outputs of attention layer and intermediate dense layer are normalized (similar to BERT). Defaults to `False`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Examples: From 5944635b1d192956a9b8224fe57466ca626ee3c4 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Wed, 20 Mar 2024 11:35:43 -0700 Subject: [PATCH 48/70] Standardize docstring (#1516) --- STYLE_GUIDE.md | 2 +- keras_nlp/layers/modeling/alibi_bias.py | 2 +- keras_nlp/layers/modeling/f_net_encoder.py | 2 +- keras_nlp/layers/modeling/masked_lm_head.py | 2 +- keras_nlp/layers/modeling/position_embedding.py | 2 +- keras_nlp/layers/modeling/reversible_embedding.py | 2 +- keras_nlp/layers/modeling/sine_position_encoding.py | 2 +- keras_nlp/layers/modeling/token_and_position_embedding.py | 2 +- keras_nlp/layers/modeling/transformer_decoder.py | 2 +- keras_nlp/layers/modeling/transformer_encoder.py | 2 +- keras_nlp/models/albert/albert_backbone.py | 2 +- keras_nlp/models/albert/albert_masked_lm.py | 2 +- keras_nlp/models/bert/bert_masked_lm.py | 2 +- keras_nlp/models/bloom/bloom_backbone.py | 2 +- keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py | 2 +- keras_nlp/models/distil_bert/distil_bert_masked_lm.py | 2 +- keras_nlp/models/electra/electra_backbone.py | 2 +- keras_nlp/models/f_net/f_net_masked_lm.py | 2 +- keras_nlp/models/gemma/gemma_backbone.py | 4 ++-- keras_nlp/models/gemma/gemma_causal_lm.py | 2 +- keras_nlp/models/preprocessor.py | 2 +- keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py | 2 +- keras_nlp/models/xlnet/xlnet_backbone.py | 2 +- keras_nlp/samplers/sampler.py | 2 +- keras_nlp/tokenizers/byte_pair_tokenizer.py | 4 ++-- keras_nlp/tokenizers/sentence_piece_tokenizer.py | 2 +- keras_nlp/tokenizers/tokenizer.py | 2 +- keras_nlp/tokenizers/word_piece_tokenizer.py | 2 +- 28 files changed, 30 insertions(+), 30 deletions(-) diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 3db287de99..335f7ade97 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -116,7 +116,7 @@ class PositionEmbedding(keras.layers.Layer): Args: sequence_length: The maximum length of the dynamic sequence. - Examples: + Example: Direct call. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index cc72be3f8c..c5f8706f9d 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -43,7 +43,7 @@ class AlibiBias(keras.layers.Layer): multi-head attention layer of the transformer to add alibi bias to it. With shape `(batch_size, num_heads, query_length, key_length)`. - Examples: + Example: ```python query_length = 10 key_length = 10 diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index 0732dee34c..919e3beb08 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -50,7 +50,7 @@ class FNetEncoder(keras.layers.Layer): **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single FNet encoder layer. diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index d51f0eb50b..9f9397cb70 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -62,7 +62,7 @@ class MaskedLMHead(keras.layers.Layer): **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/position_embedding.py b/keras_nlp/layers/modeling/position_embedding.py index 34597cb114..9f6b314b96 100644 --- a/keras_nlp/layers/modeling/position_embedding.py +++ b/keras_nlp/layers/modeling/position_embedding.py @@ -45,7 +45,7 @@ class PositionEmbedding(keras.layers.Layer): compute the position embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example: Called directly on input. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index 9266b6d28d..1fa5f5f903 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -61,7 +61,7 @@ class ReversibleEmbedding(keras.layers.Embedding): from `output_dim` to `input_dim`, instead of a normal embedding call. Default to `False`. - Examples: + Example: ```python batch_size = 16 vocab_size = 100 diff --git a/keras_nlp/layers/modeling/sine_position_encoding.py b/keras_nlp/layers/modeling/sine_position_encoding.py index 5ab874c11d..b1cd7fbf42 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding.py +++ b/keras_nlp/layers/modeling/sine_position_encoding.py @@ -44,7 +44,7 @@ class SinePositionEncoding(keras.layers.Layer): compute the encoding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example: ```python # create a simple embedding layer with sinusoidal positional encoding seq_len = 100 diff --git a/keras_nlp/layers/modeling/token_and_position_embedding.py b/keras_nlp/layers/modeling/token_and_position_embedding.py index 6266963bfd..8261cc7f34 100644 --- a/keras_nlp/layers/modeling/token_and_position_embedding.py +++ b/keras_nlp/layers/modeling/token_and_position_embedding.py @@ -49,7 +49,7 @@ class TokenAndPositionEmbedding(keras.layers.Layer): **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python inputs = np.ones(shape=(1, 50), dtype="int32") embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding( diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 0de35da0b7..b8f797f2e2 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -72,7 +72,7 @@ class TransformerDecoder(keras.layers.Layer): **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single transformer decoder layer. decoder = keras_nlp.layers.TransformerDecoder( diff --git a/keras_nlp/layers/modeling/transformer_encoder.py b/keras_nlp/layers/modeling/transformer_encoder.py index cd45b6aebf..20cec4ecf1 100644 --- a/keras_nlp/layers/modeling/transformer_encoder.py +++ b/keras_nlp/layers/modeling/transformer_encoder.py @@ -61,7 +61,7 @@ class TransformerEncoder(keras.layers.Layer): **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single transformer encoder layer. diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 09053ff893..0cc1d4d021 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -77,7 +77,7 @@ class AlbertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index 1958713b9f..e421ef524c 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -52,7 +52,7 @@ class AlbertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index 17b9669619..b915a99481 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -51,7 +51,7 @@ class BertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 9b7c65a399..eb686668d8 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -58,7 +58,7 @@ class BloomBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index d050dde6c0..a794c34374 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -55,7 +55,7 @@ class DebertaV3MaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index fcf54e014d..d99234a04f 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -55,7 +55,7 @@ class DistilBertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index f4f2a23b69..a116caa20d 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -63,7 +63,7 @@ class ElectraBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index c715a70843..4a0ec5e254 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -51,7 +51,7 @@ class FNetMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index 06f5b0f601..8e4bac126a 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -60,7 +60,7 @@ class GemmaBackbone(Backbone): computations, such as softmax and layer normalization will always be done a float32 precision regardless of dtype. - Example usage: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), @@ -205,7 +205,7 @@ def get_layout_map( backbone weights, so that you can use it to distribute weights across the accelerators. - Sample usage: + Example: ``` # Feel free to change the mesh shape to balance data and model parallel mesh = keras.distribution.DeviceMesh( diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 45c7c6abe0..58e2e302d5 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -359,7 +359,7 @@ def score( [batch_size, num_tokens, vocab_size] in "logits" mode, or [batch_size, num_tokens] in "loss" mode. - Examples: + Example: Compute gradients between embeddings and loss scores with TensorFlow: ```python diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 16a65e57c2..031a884e1b 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -75,7 +75,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example: ```python # Load a preprocessor layer from a preset. preprocessor = keras_nlp.models.{{preprocessor_name}}.from_preset( diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index e231f3dc7a..e6b5a45bb5 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -53,7 +53,7 @@ class XLMRobertaMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string inputs and pretrained backbone. ```python diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 45be1f74e7..03ea607d9e 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -65,7 +65,7 @@ class XLNetBackbone(Backbone): padding_mask: Mask to avoid performing attention on padding token indices of shape `[batch_size, sequence_length]`. - Examples: + Example: ```python import numpy as np from keras_nlp.models import XLNetBackbone diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 3ecf16ac28..a6b64b5324 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -36,7 +36,7 @@ class Sampler: computes the next token based on a probability distribution over all possible vocab entries. - Examples: + Example: ```python causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 2ac8832a76..a8dbc51361 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -156,7 +156,7 @@ class BytePairTokenizerCache(tf.Module): The cache key is string tensor or python strings, and the value is split tokens joined by whitespace. For example, "dragonfly" => "dragon fly" - Examples: + Example: ``` cache = BytePairTokenizerCache() cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"]) @@ -665,7 +665,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index 64e169939c..20a73d6af5 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -275,7 +275,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 7da1e9d7b1..4c26e45241 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -40,7 +40,7 @@ class Tokenizer(PreprocessingLayer): "vocab free" tokenizers, such as a whitespace splitter show below, these methods do not apply and can be skipped. - Examples: + Example: ```python class WhitespaceSplitterTokenizer(keras_nlp.tokenizers.Tokenizer): diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 75f956899f..c203d50c78 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -480,7 +480,7 @@ def from_preset( Args: preset: string. Must be one of "{{preset_names}}". - Examples: + Example: ```python # Load a preset tokenizer. tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") From 3ddfd883341a026d974a630046c5c43802162af9 Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Wed, 20 Mar 2024 21:57:15 +0200 Subject: [PATCH 49/70] Support tokenization of special tokens for word_piece_tokenizer (#1397) * Support tokenization of special tokens for word_piece_tokenizer * Add the feature to models tokenizers * Format the code * Fix Fromat * Small fixes * Add tests for bert * Add tests for distilbert * Small fix for bert test * Add tests for electra * Fix code format * Rename unsplittable to special * Edit special_tokens Arg * Format the code * Move special tokens checking into base class * Add special_tokens_in_strings Arg * Shorten comments * Shorten comments * Shorten the logic og splitting and add comments * Code format --- keras_nlp/models/bert/bert_tokenizer.py | 25 ++++--- keras_nlp/models/bert/bert_tokenizer_test.py | 10 +++ .../distil_bert/distil_bert_tokenizer.py | 25 ++++--- .../distil_bert/distil_bert_tokenizer_test.py | 10 +++ keras_nlp/models/electra/electra_tokenizer.py | 38 +++++++--- .../models/electra/electra_tokenizer_test.py | 10 +++ keras_nlp/tokenizers/word_piece_tokenizer.py | 70 ++++++++++++++++++- .../tokenizers/word_piece_tokenizer_test.py | 62 ++++++++++++++-- 8 files changed, 216 insertions(+), 34 deletions(-) diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py index 1b634fe9b3..4de433e43a 100644 --- a/keras_nlp/models/bert/bert_tokenizer.py +++ b/keras_nlp/models/bert/bert_tokenizer.py @@ -49,6 +49,9 @@ class BertTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -76,6 +79,7 @@ def __init__( self, vocabulary=None, lowercase=False, + special_tokens_in_strings=False, **kwargs, ): self.cls_token = "[CLS]" @@ -85,6 +89,13 @@ def __init__( super().__init__( vocabulary=vocabulary, lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -92,15 +103,6 @@ def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -114,3 +116,8 @@ def set_vocabulary(self, vocabulary): @classproperty def presets(cls): return copy.deepcopy({**backbone_presets, **classifier_presets}) + + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/bert/bert_tokenizer_test.py b/keras_nlp/models/bert/bert_tokenizer_test.py index e53419dab4..78b4417e13 100644 --- a/keras_nlp/models/bert/bert_tokenizer_test.py +++ b/keras_nlp/models/bert/bert_tokenizer_test.py @@ -39,6 +39,16 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = BertTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): BertTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index 4a18398a1e..29eb92e3ba 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -46,6 +46,9 @@ class DistilBertTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -74,6 +77,7 @@ def __init__( self, vocabulary, lowercase=False, + special_tokens_in_strings=False, **kwargs, ): self.cls_token = "[CLS]" @@ -83,6 +87,13 @@ def __init__( super().__init__( vocabulary=vocabulary, lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -90,15 +101,6 @@ def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -112,3 +114,8 @@ def set_vocabulary(self, vocabulary): @classproperty def presets(cls): return copy.deepcopy(backbone_presets) + + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py index e4bfba41d3..42bfde39e5 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py @@ -41,6 +41,16 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = DistilBertTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): DistilBertTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index acd665c2a3..583b756165 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -36,6 +36,9 @@ class ElectraTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -57,26 +60,34 @@ class ElectraTokenizer(WordPieceTokenizer): ``` """ - def __init__(self, vocabulary, lowercase=False, **kwargs): + def __init__( + self, + vocabulary, + lowercase=False, + special_tokens_in_strings=False, + **kwargs, + ): self.cls_token = "[CLS]" self.sep_token = "[SEP]" self.pad_token = "[PAD]" self.mask_token = "[MASK]" - super().__init__(vocabulary=vocabulary, lowercase=lowercase, **kwargs) + super().__init__( + vocabulary=vocabulary, + lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, + **kwargs, + ) def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -86,3 +97,8 @@ def set_vocabulary(self, vocabulary): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None + + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/electra/electra_tokenizer_test.py b/keras_nlp/models/electra/electra_tokenizer_test.py index 2e06fb900c..29c40a2f29 100644 --- a/keras_nlp/models/electra/electra_tokenizer_test.py +++ b/keras_nlp/models/electra/electra_tokenizer_test.py @@ -37,6 +37,16 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = ElectraTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): ElectraTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index c203d50c78..fc6f54e19d 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import re from typing import Iterable from typing import List @@ -101,12 +102,19 @@ ) +def get_special_tokens_pattern(special_tokens): + if special_tokens is None or len(special_tokens) == 0: + return None + return r"|".join([re.escape(token) for token in special_tokens]) + + def pretokenize( text, lowercase=False, strip_accents=True, split=True, split_on_cjk=True, + special_tokens_pattern=None, ): """Helper function that takes in a dataset element and pretokenizes it. @@ -124,7 +132,14 @@ def pretokenize( split_on_cjk: bool. If `True`, input will be split on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)). - Note that this is applicable only when `split` is `True`. Defaults to `True`. + Note that this is applicable only when `split` is `True`. Defaults + to `True`. + special_tokens_pattern: str. A regex pattern that contain the + special tokens that will never be split during the word-level + splitting applied before the word-peice encoding. This can be used + to ensure special tokens map to unique indices in the vocabulary, + even if these special tokens contain splittable characters such as + punctuation. Returns: A tensor containing the pre-processed and pre-tokenized `text`. @@ -154,6 +169,23 @@ def pretokenize( else: split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX keep_split_pattern = PUNCTUATION_REGEX + if special_tokens_pattern is not None: + # the idea here is to pass the special tokens regex to the split + # function as delimiter regex pattern, so the input will be splitted + # by them, but also the function will treat each on of them as one + # entity that shouldn't be splitted even if they have other + # delimiter regex pattern inside them. then pass the special tokens + # regex also as keep delimiter regex pattern, so they will + # not be removed. + split_pattern = r"|".join( + [ + special_tokens_pattern, + split_pattern, + ] + ) + keep_split_pattern = r"|".join( + [special_tokens_pattern, keep_split_pattern] + ) text = tf_text.regex_split( text, delim_regex_pattern=split_pattern, @@ -225,6 +257,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer): oov_token: str. The string value to substitute for an unknown token. It must be included in the vocab. Defaults to `"[UNK]"`. + special_tokens: list. A list of special tokens. when + `special_tokens_in_strings` is set to `True`, the tokenizer will map + every special token in the input strings to its id, even if these + special tokens contain characters that should be splitted before + tokenization such as punctuation. `special_tokens` must be included + in `vocabulary`. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. References: - [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/) @@ -303,6 +344,8 @@ def __init__( split_on_cjk: bool = True, suffix_indicator: str = "##", oov_token: str = "[UNK]", + special_tokens: List[str] = None, + special_tokens_in_strings: bool = False, dtype="int32", **kwargs, ) -> None: @@ -325,6 +368,19 @@ def __init__( self.split_on_cjk = split_on_cjk self.suffix_indicator = suffix_indicator self.oov_token = oov_token + self.special_tokens = special_tokens + self._special_tokens_pattern = None + if self.split and special_tokens_in_strings: + # the idea here is to pass the special tokens regex to the + # split function as delimiter regex pattern, so the input will + # be splitted by them, but also the function will treat each on + # of them as one entity that shouldn't be splitted even if they + # have other delimiter regex pattern inside them. then pass the + # special tokens regex also as keep delimiter regex + # pattern, so they will not be removed. + self._special_tokens_pattern = get_special_tokens_pattern( + self.special_tokens + ) self.set_vocabulary(vocabulary) def save_assets(self, dir_path): @@ -365,6 +421,16 @@ def set_vocabulary(self, vocabulary): "the `oov_token` argument when creating the tokenizer." ) + # Check for special tokens in the vocabulary + if self.special_tokens is not None: + for token in self.special_tokens: + if token not in self.vocabulary: + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self._fast_word_piece = tf_text.FastWordpieceTokenizer( vocab=self.vocabulary, token_out_type=self.compute_dtype, @@ -413,6 +479,7 @@ def get_config(self): "split": self.split, "suffix_indicator": self.suffix_indicator, "oov_token": self.oov_token, + "special_tokens": self.special_tokens, } ) return config @@ -436,6 +503,7 @@ def tokenize(self, inputs): self.strip_accents, self.split, self.split_on_cjk, + self._special_tokens_pattern, ) # Apply WordPiece and coerce shape for outputs. diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_test.py b/keras_nlp/tokenizers/word_piece_tokenizer_test.py index ead098c36c..5d80a32eee 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_test.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_test.py @@ -78,21 +78,38 @@ def test_error_id_out_of_vocabulary(self): with self.assertRaises(ValueError): tokenizer.id_to_token(-1) - def test_special_tokens(self): - input_data = ["quick brown whale"] - vocab_data = ["@UNK@", "qu", "@@ick", "br", "@@own", "fox"] + def test_special_tokens_string_dtype(self): + input_data = ["quick brown whale @MASK@"] + vocab_data = ["@UNK@", "qu", "@@ick", "br", "@@own", "fox", "@MASK@"] + special_tokens = ["@UNK@", "@MASK@"] tokenizer = WordPieceTokenizer( vocabulary=vocab_data, oov_token="@UNK@", suffix_indicator="@@", dtype="string", + special_tokens=special_tokens, + special_tokens_in_strings=True, ) call_output = tokenizer(input_data) self.assertAllEqual( call_output, - [["qu", "@@ick", "br", "@@own", "@UNK@"]], + [["qu", "@@ick", "br", "@@own", "@UNK@", "@MASK@"]], ) + def test_special_tokens_int_dtype(self): + input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."] + special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"] + vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."] + vocab_data = [*special_tokens, *vocab_data] + + tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, + special_tokens=special_tokens, + special_tokens_in_strings=True, + ) + output = tokenizer(input_data) + self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]) + def test_cjk_tokens(self): input_data = ["ah半推zz"] vocab_data = ["[UNK]", "推", "敐", "乐", "半", "偷", "匕", "ah", "zz"] @@ -217,3 +234,40 @@ def test_no_oov_token_in_vocabulary(self): with self.assertRaises(ValueError): WordPieceTokenizer(vocabulary=vocab_data, oov_token=None) + + def test_no_splitting_with_special_tokens(self): + # When `split` is `False`, no special tokens tokenization will be done. + input_data = [ + "[MASK] t o k e n", + "m i s s i n g", + "[MASK]", + "t o k e n", + ] + vocab_data = ["[UNK]", "[MASK]", "t o k e n"] + tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, split=False, special_tokens=["[MASK]"] + ) + output = tokenizer(input_data) + self.assertAllEqual(output, [0, 0, 1, 2]) + + def test_config_with_special_tokens(self): + input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."] + special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"] + vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."] + vocab_data = [*special_tokens, *vocab_data] + original_tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, + lowercase=False, + oov_token="[UNK]", + suffix_indicator="##", + dtype="string", + special_tokens=special_tokens, + ) + cloned_tokenizer = WordPieceTokenizer.from_config( + original_tokenizer.get_config() + ) + cloned_tokenizer.set_vocabulary(original_tokenizer.get_vocabulary()) + self.assertAllEqual( + original_tokenizer(input_data), + cloned_tokenizer(input_data), + ) From c3b2c09816d45eebae0c916859e1fc1ba56550a6 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Mon, 25 Mar 2024 10:11:44 -0700 Subject: [PATCH 50/70] Upload Model to Kaggle (#1512) * Initial Kaggle upload. * Address review comments. * Add upload valiations. * Address review comments. * Fix init. * Address review comments. * Improve error handling. * Address review comments. --- keras_nlp/__init__.py | 1 + keras_nlp/models/backbone.py | 9 +++ keras_nlp/tokenizers/tokenizer.py | 10 +++ keras_nlp/utils/preset_utils.py | 114 +++++++++++++++++++++++++++ keras_nlp/utils/preset_utils_test.py | 12 +++ 5 files changed, 146 insertions(+) diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 30f8a53b16..407a4b7a71 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -26,5 +26,6 @@ from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils +from keras_nlp.utils.preset_utils import upload_preset from keras_nlp.version_utils import __version__ from keras_nlp.version_utils import version diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index bfdc8207ad..08b9f86e96 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -17,6 +17,7 @@ from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.python_utils import format_docstring @@ -141,6 +142,14 @@ def from_preset( config_overrides=kwargs, ) + def save_to_preset(self, preset): + """Save backbone to a preset directory. + + Args: + preset: The path to the local model preset directory. + """ + save_to_preset(self, preset) + def __init_subclass__(cls, **kwargs): # Use __init_subclass__ to setup a correct docstring for from_preset. super().__init_subclass__(**kwargs) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 4c26e45241..834b99e5b1 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -18,6 +18,8 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) +from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE +from keras_nlp.utils.preset_utils import save_to_preset @keras_nlp_export("keras_nlp.tokenizers.Tokenizer") @@ -121,5 +123,13 @@ def token_to_id(self, token: str) -> int: f"{self.__class__.__name__}." ) + def save_to_preset(self, preset): + """Save tokenizer to a preset directory. + + Args: + preset: The path to the local model preset directory. + """ + save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) + def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 01c11a3db1..dcee9bc66f 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -16,6 +16,9 @@ import json import os +from absl import logging + +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config as backend_config from keras_nlp.backend import keras @@ -27,6 +30,8 @@ KAGGLE_PREFIX = "kaggle://" GS_PREFIX = "gs://" TOKENIZER_ASSET_DIR = "assets/tokenizer" +CONFIG_FILE = "config.json" +TOKENIZER_CONFIG_FILE = "tokenizer.json" def get_file(preset, path): @@ -155,6 +160,115 @@ def save_to_preset( metadata_file.write(json.dumps(metadata, indent=4)) +def _validate_tokenizer(preset, allow_incomplete=False): + config_path = get_file(preset, TOKENIZER_CONFIG_FILE) + if not os.path.exists(config_path): + if allow_incomplete: + logging.warning( + f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`." + ) + return + else: + raise FileNotFoundError( + f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. " + "To upload the model without a tokenizer, " + "set `allow_incomplete=True`." + ) + try: + with open(config_path) as config_file: + config = json.load(config_file) + except Exception as e: + raise ValueError( + f"Tokenizer config file `{config_path}` is an invalid json file. " + f"Error message: {e}" + ) + layer = keras.saving.deserialize_keras_object(config) + + if not config["assets"]: + raise ValueError( + f"Tokenizer config file {config_path} is missing `asset`." + ) + + for asset in config["assets"]: + asset_path = os.path.join(preset, asset) + if not os.path.exists(asset_path): + raise FileNotFoundError( + f"Asset `{asset}` doesn't exist in the preset direcotry `{preset}`." + ) + config_dir = os.path.dirname(config_path) + asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR) + + tokenizer = get_tokenizer(layer) + if not tokenizer: + raise ValueError(f"Model or layer `{layer}` is missing tokenizer.") + tokenizer.load_assets(asset_dir) + + +def _validate_backbone(preset): + config_path = os.path.join(preset, CONFIG_FILE) + if not os.path.exists(config_path): + raise FileNotFoundError( + f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`." + ) + try: + with open(config_path) as config_file: + config = json.load(config_file) + except Exception as e: + raise ValueError( + f"Config file `{config_path}` is an invalid json file. " + f"Error message: {e}" + ) + + if config["weights"]: + weights_path = os.path.join(preset, config["weights"]) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"The weights file doesn't exist in preset directory `{preset}`." + ) + else: + raise ValueError( + f"No weights listed in `{CONFIG_FILE}`. Make sure to use " + "`save_to_preset()` which adds additional data to a serialized " + "Keras object." + ) + + +@keras_nlp_export("keras_nlp.upload_preset") +def upload_preset( + uri, + preset, + allow_incomplete=False, +): + """Upload a preset directory to a model hub. + + Args: + uri: The URI identifying model to upload to. + URIs with format + `kaggle://///` + will be uploaded to Kaggle Hub. + preset: The path to the local model preset directory. + allow_incomplete: If True, allows the upload of presets without + a tokenizer configuration. Otherwise, a tokenizer + is required. + """ + + # Check if preset directory exists. + if not os.path.exists(preset): + raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") + + _validate_backbone(preset) + _validate_tokenizer(preset, allow_incomplete) + + if uri.startswith(KAGGLE_PREFIX): + kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) + kagglehub.model_upload(kaggle_handle, preset) + else: + raise ValueError( + f"Unexpected URI `'{uri}'`. Kaggle upload format should follow " + "`kaggle://///`." + ) + + def load_from_preset( preset, load_weights=True, diff --git a/keras_nlp/utils/preset_utils_test.py b/keras_nlp/utils/preset_utils_test.py index 44dc39f477..289e13b6ab 100644 --- a/keras_nlp/utils/preset_utils_test.py +++ b/keras_nlp/utils/preset_utils_test.py @@ -18,6 +18,7 @@ import pytest from absl.testing import parameterized +from keras_nlp import upload_preset from keras_nlp.models.albert.albert_classifier import AlbertClassifier from keras_nlp.models.backbone import Backbone from keras_nlp.models.bert.bert_classifier import BertClassifier @@ -105,3 +106,14 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "Unknown preset identifier"): AlbertClassifier.from_preset("snaggle://bort/bort/bort") + + def test_upload_empty_preset(self): + temp_dir = self.get_temp_dir() + empty_preset = os.path.join(temp_dir, "empty") + os.mkdir(empty_preset) + uri = "kaggle://test/test/test" + + with self.assertRaises(FileNotFoundError): + upload_preset(uri, empty_preset) + + # TODO: add more test to cover various invalid scenarios such as invalid json, missing files, etc. From eb4ef205cc98a441bb0edd43f4ce58b205776a3c Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Mon, 25 Mar 2024 13:24:29 -0400 Subject: [PATCH 51/70] Add scoring mode to MistralCausalLM (#1521) * Add scoring mode to MistralCausalLM * Fixing names in Docstring * Fix padding mask arg name * Fix embedded shape in test * Remove errant underscore in Docstring --- keras_nlp/models/mistral/mistral_causal_lm.py | 126 ++++++++++++++++++ .../models/mistral/mistral_causal_lm_test.py | 85 ++++++++++++ 2 files changed, 211 insertions(+) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index e9bd4e5616..800df1c8cb 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -215,6 +215,132 @@ def next(prompt, cache, index): "padding_mask": padding_mask, } + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `MistralCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the MistralBackbone and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + mistral_lm = keras_nlp.models.MistralCausalLM.from_preset( + "mistral_7b_en" + ) + generations = mistral_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = mistral_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = mistral_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss + @classproperty def presets(cls): return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_causal_lm_test.py b/keras_nlp/models/mistral/mistral_causal_lm_test.py index 3f9d7fab36..13f0dad907 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_test.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_test.py @@ -128,3 +128,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 8) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) From f1714e1cc6ab4a08b630952644351dd069c2a2c3 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 25 Mar 2024 16:02:36 -0700 Subject: [PATCH 52/70] Add Mistral Instruct V0.2 preset (#1520) --- keras_nlp/models/mistral/mistral_presets.py | 10 +++++++++ .../mistral/mistral_transformer_decoder.py | 21 +++++++++++-------- .../convert_mistral_checkpoints.py | 1 + 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/keras_nlp/models/mistral/mistral_presets.py b/keras_nlp/models/mistral/mistral_presets.py index 7fb4b4e0a6..fdee396300 100644 --- a/keras_nlp/models/mistral/mistral_presets.py +++ b/keras_nlp/models/mistral/mistral_presets.py @@ -35,4 +35,14 @@ }, "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6", }, + "mistral_0.2_instruct_7b_en": { + "metadata": { + "description": "Mistral 7B instruct Version 0.2 model", + "params": 7241732096, + "official_name": "Mistral", + "path": "mistral", + "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1", + }, } diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 7c90ab91b9..36b7f5944d 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -207,17 +207,20 @@ def _compute_self_attention_mask( else self_attention_cache_update_index ) - # Mistral uses a banded attention mask - causal_mask_lower = compute_causal_mask( + # The lower traingular attention mask + causal_mask = compute_causal_mask( batch_size, input_length, output_length, cache_update_index ) - # Below is a workaround for `ops.triu` for Keras 2. - # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. - # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) - i = ops.arange(output_length)[:, None] + cache_update_index - j = ops.arange(input_length)[None, :] - causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") - causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) + + # Mistral uses a banded attention mask if sliding window is not None + if self.sliding_window is not None: + # Below is a workaround for `ops.triu` for Keras 2. + # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. + # causal_mask = ops.triu(causal_mask, k=-self.sliding_window) + i = ops.arange(output_length)[:, None] + cache_update_index + j = ops.arange(input_length)[None, :] + causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") + causal_mask = ops.minimum(causal_mask, causal_mask_upper) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py index 7e13b9dd7a..ae3b1f5b83 100644 --- a/tools/checkpoint_conversion/convert_mistral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -33,6 +33,7 @@ PRESET_MAP = { "mistral_7b_en": "mistralai/Mistral-7B-v0.1", "mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1", + "mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2", } FLAGS = flags.FLAGS From 6703d76772cd19ea89c4290f7d03e98e910fae17 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Mon, 25 Mar 2024 17:08:20 -0700 Subject: [PATCH 53/70] Add Tests for Kaggle Upload Validation (#1524) * Add Kaggle upload validation tests. * Use bert_tiny as test model. --- keras_nlp/utils/preset_utils.py | 2 +- keras_nlp/utils/preset_utils_test.py | 63 +++++++++++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index dcee9bc66f..e2a4452714 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -223,7 +223,7 @@ def _validate_backbone(preset): weights_path = os.path.join(preset, config["weights"]) if not os.path.exists(weights_path): raise FileNotFoundError( - f"The weights file doesn't exist in preset directory `{preset}`." + f"The weights file is missing from the preset directory `{preset}`." ) else: raise ValueError( diff --git a/keras_nlp/utils/preset_utils_test.py b/keras_nlp/utils/preset_utils_test.py index 289e13b6ab..694b02c3f4 100644 --- a/keras_nlp/utils/preset_utils_test.py +++ b/keras_nlp/utils/preset_utils_test.py @@ -19,12 +19,16 @@ from absl.testing import parameterized from keras_nlp import upload_preset -from keras_nlp.models.albert.albert_classifier import AlbertClassifier -from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bert.bert_classifier import BertClassifier -from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier -from keras_nlp.models.task import Task +from keras_nlp.models import AlbertClassifier +from keras_nlp.models import Backbone +from keras_nlp.models import BertBackbone +from keras_nlp.models import BertClassifier +from keras_nlp.models import BertTokenizer +from keras_nlp.models import RobertaClassifier +from keras_nlp.models import Task from keras_nlp.tests.test_case import TestCase +from keras_nlp.utils.preset_utils import CONFIG_FILE +from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.utils.preset_utils import check_preset_class from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.preset_utils import save_to_preset @@ -116,4 +120,51 @@ def test_upload_empty_preset(self): with self.assertRaises(FileNotFoundError): upload_preset(uri, empty_preset) - # TODO: add more test to cover various invalid scenarios such as invalid json, missing files, etc. + @parameterized.parameters( + (TOKENIZER_CONFIG_FILE), (CONFIG_FILE), ("model.weights.h5") + ) + @pytest.mark.keras_3_only + @pytest.mark.large + def test_upload_with_missing_file(self, missing_file): + # Load a model from Kaggle to use as a test model. + preset = "bert_tiny_en_uncased" + backbone = BertBackbone.from_preset(preset) + tokenizer = BertTokenizer.from_preset(preset) + + # Save the model on a local directory. + temp_dir = self.get_temp_dir() + local_preset_dir = os.path.join(temp_dir, "bert_preset") + backbone.save_to_preset(local_preset_dir) + tokenizer.save_to_preset(local_preset_dir) + + # Delete the file that is supposed to be missing. + missing_path = os.path.join(local_preset_dir, missing_file) + os.remove(missing_path) + + # Verify error handling. + with self.assertRaisesRegex(FileNotFoundError, "is missing"): + upload_preset("kaggle://test/test/test", local_preset_dir) + + @parameterized.parameters((TOKENIZER_CONFIG_FILE), (CONFIG_FILE)) + @pytest.mark.keras_3_only + @pytest.mark.large + def test_upload_with_invalid_json(self, json_file): + # Load a model from Kaggle to use as a test model. + preset = "bert_tiny_en_uncased" + backbone = BertBackbone.from_preset(preset) + tokenizer = BertTokenizer.from_preset(preset) + + # Save the model on a local directory. + temp_dir = self.get_temp_dir() + local_preset_dir = os.path.join(temp_dir, "bert_preset") + backbone.save_to_preset(local_preset_dir) + tokenizer.save_to_preset(local_preset_dir) + + # Re-write json file content to an invalid format. + json_path = os.path.join(local_preset_dir, json_file) + with open(json_path, "w") as file: + file.write("Invalid!") + + # Verify error handling. + with self.assertRaisesRegex(ValueError, "is an invalid json"): + upload_preset("kaggle://test/test/test", local_preset_dir) From 6a8166eb07fb686d5fc3b58fd1b2c4d6987b13ea Mon Sep 17 00:00:00 2001 From: Pranav Prajapati <94780581+pranavvp16@users.noreply.github.com> Date: Wed, 27 Mar 2024 03:47:28 +0530 Subject: [PATCH 54/70] Add presets for Electra and checkpoint conversion script (#1384) * Added ElectraBackbone * Added backbone tests for ELECTRA * Fix config * Add model import to __init__ * add electra tokenizer * add tests for tokenizer * add __init__ file * add tokenizer and backbone to models __init__ * Fix Failing tokenization test * Add example on usage of the tokenizer with custom vocabulary * Add conversion script to convert weights from checkpoint * Add electra preprocessor * Add presets and tests * Add presets config with model weights * Add checkpoint conversion script * Name conversion for electra models * Update naming conventions according to preset names * Fix failing tokenizer tests * Update checkpoint conversion script according to kaggle * Add validate function * Kaggle preset * update preset link * Add electra presets * Complete run_small_preset test for electra * Add large variations of electra in presets * Fix case issues with electra presets * Fix format --------- Co-authored-by: Matt Watson --- keras_nlp/models/__init__.py | 1 + keras_nlp/models/electra/electra_backbone.py | 20 +- .../models/electra/electra_backbone_test.py | 34 +++ .../models/electra/electra_preprocessor.py | 163 ++++++++++ .../electra/electra_preprocessor_test.py | 67 +++++ keras_nlp/models/electra/electra_presets.py | 95 ++++++ keras_nlp/models/electra/electra_tokenizer.py | 8 + .../models/electra/electra_tokenizer_test.py | 20 ++ .../convert_electra_checkpoints.py | 278 ++++++++++++++++++ 9 files changed, 684 insertions(+), 2 deletions(-) create mode 100644 keras_nlp/models/electra/electra_preprocessor.py create mode 100644 keras_nlp/models/electra/electra_preprocessor_test.py create mode 100644 keras_nlp/models/electra/electra_presets.py create mode 100644 tools/checkpoint_conversion/convert_electra_checkpoints.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 033a9dc874..0a74cba869 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -72,6 +72,7 @@ DistilBertTokenizer, ) from keras_nlp.models.electra.electra_backbone import ElectraBackbone +from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_classifier import FNetClassifier diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index a116caa20d..7ecca892d9 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone +from keras_nlp.models.electra.electra_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate +from keras_nlp.utils.python_utils import classproperty def electra_kernel_initializer(stddev=0.02): @@ -36,8 +40,9 @@ class ElectraBackbone(Backbone): or classification task networks. The default constructor gives a fully customizable, randomly initialized - Electra encoder with any number of layers, heads, and embedding - dimensions. + ELECTRA encoder with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the + `from_preset()` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a @@ -70,6 +75,13 @@ class ElectraBackbone(Backbone): "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]), "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), } + + # Pre-trained ELECTRA encoder. + model = keras_nlp.models.ElectraBackbone.from_preset( + "electra_base_discriminator_en" + ) + model(input_data) + # Randomly initialized Electra encoder backbone = keras_nlp.models.ElectraBackbone( vocabulary_size=1000, @@ -234,3 +246,7 @@ def get_config(self): } ) return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/electra/electra_backbone_test.py b/keras_nlp/models/electra/electra_backbone_test.py index 09e6c53344..c51bb579e1 100644 --- a/keras_nlp/models/electra/electra_backbone_test.py +++ b/keras_nlp/models/electra/electra_backbone_test.py @@ -54,3 +54,37 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=ElectraBackbone, + preset="electra_small_discriminator_uncased_en", + input_data={ + "token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"), + "segment_ids": ops.zeros((1, 4), dtype="int32"), + "padding_mask": ops.ones((1, 4), dtype="int32"), + }, + expected_output_shape={ + "sequence_output": (1, 4, 256), + "pooled_output": (1, 256), + }, + # The forward pass from a preset should be stable! + expected_partial_output={ + "sequence_output": ( + ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977]) + ), + "pooled_output": ( + ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822]) + ), + }, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraBackbone.presets: + self.run_preset_test( + cls=ElectraBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/electra/electra_preprocessor.py b/keras_nlp/models/electra/electra_preprocessor.py new file mode 100644 index 0000000000..1e3ac2454c --- /dev/null +++ b/keras_nlp/models/electra/electra_preprocessor.py @@ -0,0 +1,163 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) +from keras_nlp.models.electra.electra_presets import backbone_presets +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.ElectraPreprocessor") +class ElectraPreprocessor(Preprocessor): + """A ELECTRA preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`. + with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens. + 3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`, + that can be passed directly to a ELECTRA model. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.ElectraTokenizer` instance. + sequence_length: The length of the packed inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset( + "electra_base_discriminator_en" + ) + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab += ["The", "quick", "brown", "fox", "jumped", "."] + tokenizer = keras_nlp.models.ElectraTokenizer(vocabulary=vocab) + preprocessor = keras_nlp.models.ElectraPreprocessor(tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset( + "electra_base_discriminator_en" + ) + + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.cls_token_id, + end_value=self.tokenizer.sep_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=truncate, + sequence_length=sequence_length, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.packer.sequence_length, + "truncate": self.packer.truncate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + x = convert_inputs_to_list_of_tensor_segments(x) + x = [self.tokenizer(segment) for segment in x] + token_ids, segment_ids = self.packer(x) + x = { + "token_ids": token_ids, + "segment_ids": segment_ids, + "padding_mask": token_ids != self.tokenizer.pad_token_id, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @classproperty + def tokenizer_cls(cls): + return ElectraTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy({**backbone_presets}) diff --git a/keras_nlp/models/electra/electra_preprocessor_test.py b/keras_nlp/models/electra/electra_preprocessor_test.py new file mode 100644 index 0000000000..5dd48fe40a --- /dev/null +++ b/keras_nlp/models/electra/electra_preprocessor_test.py @@ -0,0 +1,67 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.tests.test_case import TestCase + + +class ElectraPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.tokenizer = ElectraTokenizer(vocabulary=self.vocab) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["THE QUICK BROWN FOX."], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=ElectraPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[2, 5, 6, 7, 8, 1, 3, 0]], + "segment_ids": [[0, 0, 0, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = ElectraPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraPreprocessor.presets: + self.run_preset_test( + cls=ElectraPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/electra/electra_presets.py b/keras_nlp/models/electra/electra_presets.py new file mode 100644 index 0000000000..68d709b1fd --- /dev/null +++ b/keras_nlp/models/electra/electra_presets.py @@ -0,0 +1,95 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ELECTRA model preset configurations.""" + +backbone_presets = { + "electra_small_discriminator_uncased_en": { + "metadata": { + "description": ( + "12-layer small ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 13548800, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1", + }, + "electra_small_generator_uncased_en": { + "metadata": { + "description": ( + "12-layer small ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 13548800, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_small_generator_uncased_en/1", + }, + "electra_base_discriminator_uncased_en": { + "metadata": { + "description": ( + "12-layer base ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 109482240, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1", + }, + "electra_base_generator_uncased_en": { + "metadata": { + "description": ( + "12-layer base ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 33576960, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_base_generator_uncased_en/1", + }, + "electra_large_discriminator_uncased_en": { + "metadata": { + "description": ( + "24-layer large ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 335141888, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1", + }, + "electra_large_generator_uncased_en": { + "metadata": { + "description": ( + "24-layer large ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 51065344, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_large_generator_uncased_en/1", + }, +} diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index 583b756165..12f5ecfec6 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.electra.electra_presets import backbone_presets from keras_nlp.tokenizers import WordPieceTokenizer +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.ElectraTokenizer") @@ -102,3 +106,7 @@ def get_config(self): config = super().get_config() del config["special_tokens"] # Not configurable; set in __init__. return config + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/electra/electra_tokenizer_test.py b/keras_nlp/models/electra/electra_tokenizer_test.py index 29c40a2f29..9126ddc8a7 100644 --- a/keras_nlp/models/electra/electra_tokenizer_test.py +++ b/keras_nlp/models/electra/electra_tokenizer_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.tests.test_case import TestCase @@ -50,3 +52,21 @@ def test_tokenizer_special_tokens(self): def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): ElectraTokenizer(vocabulary=["a", "b", "c"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=ElectraTokenizer, + preset="electra_small_discriminator_uncased_en", + input_data=["the quick brown fox."], + expected_output=[[1996, 4248, 2829, 4419, 1012]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraTokenizer.presets: + self.run_preset_test( + cls=ElectraTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_electra_checkpoints.py b/tools/checkpoint_conversion/convert_electra_checkpoints.py new file mode 100644 index 0000000000..8bbafb1305 --- /dev/null +++ b/tools/checkpoint_conversion/convert_electra_checkpoints.py @@ -0,0 +1,278 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Electra weights conversion script. +""" + +import json +import os + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import tensorflow as tf # noqa: E402 +import transformers # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +import keras_nlp # noqa: E402 +from keras_nlp.utils.preset_utils import save_to_preset # noqa: E402 + +PRESET_MAP = { + "electra_base_generator_en": "google/electra-base-generator", + "electra_small_generator_en": "google/electra-small-generator", + "electra_base_discriminator_en": "google/electra-base-discriminator", + "electra_small_discriminator_en": "google/electra-small-discriminator", + "electra_large_discriminator_en": "google/electra-large-discriminator", + "electra_large_generator_en": "google/electra-large-generator", +} + +EXTRACT_DIR = "./model" + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + "electra_base_discriminator_en", + f'Must be one of {",".join(PRESET_MAP)}', +) +flags.mark_flag_as_required("preset") + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onx/*"], + local_dir=EXTRACT_DIR, + ) + return hf_model_dir + + +def convert_model(hf_model): + hf_config = hf_model.config.to_dict() + cfg = {} + cfg["vocab_size"] = hf_config["vocab_size"] + cfg["embedding_dim"] = hf_config["embedding_size"] + cfg["num_layers"] = hf_config["num_hidden_layers"] + cfg["num_heads"] = hf_config["num_attention_heads"] + cfg["hidden_dim"] = hf_config["hidden_size"] + cfg["intermediate_dim"] = hf_config["intermediate_size"] + cfg["dropout"] = hf_config["hidden_dropout_prob"] + cfg["max_sequence_length"] = hf_config["max_position_embeddings"] + return keras_nlp.models.ElectraBackbone(**cfg) + + +def convert_tokenizer(hf_model_dir): + tokenizer_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_path) as f: + hf_tokenizer = json.load(f) + vocab = hf_tokenizer["model"]["vocab"] + + return keras_nlp.models.ElectraTokenizer(vocabulary=vocab) + + +def convert_weights(keras_model, hf_model): + hf_model_dict = hf_model.state_dict() + + keras_model.get_layer("token_embedding").embeddings.assign( + hf_model_dict["embeddings.word_embeddings.weight"].numpy() + ) + keras_model.get_layer("position_embedding").position_embeddings.assign( + hf_model_dict["embeddings.position_embeddings.weight"].numpy() + ) + keras_model.get_layer("segment_embedding").embeddings.assign( + hf_model_dict["embeddings.token_type_embeddings.weight"].numpy() + ) + keras_model.get_layer("embeddings_layer_norm").gamma.assign( + hf_model_dict["embeddings.LayerNorm.weight"] + ) + keras_model.get_layer("embeddings_layer_norm").beta.assign( + hf_model_dict["embeddings.LayerNorm.bias"] + ) + + if any( + layer.name == "embeddings_projection" for layer in keras_model.layers + ): + keras_model.get_layer("embeddings_projection").kernel.assign( + hf_model_dict["embeddings_project.weight"].transpose(1, 0).numpy() + ) + keras_model.get_layer("embeddings_projection").bias.assign( + hf_model_dict["embeddings_project.bias"] + ) + + for i in range(keras_model.num_layers): + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.query.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.query.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.key.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.key.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.value.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.value.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.output.dense.weight"] + .transpose(1, 0) + .reshape((keras_model.num_heads, -1, keras_model.hidden_dim)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.bias.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.dense.bias" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.gamma.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.LayerNorm.weight" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.beta.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.LayerNorm.bias" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.intermediate.dense.weight"] + .transpose(1, 0) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.intermediate.dense.bias"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.output.dense.weight"] + .transpose(1, 0) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.output.dense.bias"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.gamma.assign( + hf_model_dict[f"encoder.layer.{i}.output.LayerNorm.weight"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.beta.assign( + hf_model_dict[f"encoder.layer.{i}.output.LayerNorm.bias"].numpy() + ) + + +def validate_output(keras_model, hf_model, keras_tokenizer, hf_tokenizer): + input_str = ["The quick brown fox jumps over the lazy dog."] + + keras_nlp_preprocessor = keras_nlp.models.ElectraPreprocessor( + keras_tokenizer + ) + keras_nlp_inputs = keras_nlp_preprocessor(tf.constant(input_str)) + keras_nlp_output = keras_model.predict(keras_nlp_inputs).get( + "sequence_output" + ) + + hf_inputs = hf_tokenizer( + input_str, padding="max_length", return_tensors="pt" + ) + hf_output = hf_model(**hf_inputs).last_hidden_state.detach().numpy() + print("🔶 KerasNLP output:", keras_nlp_output[0, 0, :10]) + print("🔶 HF output:", hf_output[0, 0, :10]) + print("Difference: ", np.mean(keras_nlp_output - hf_output)) + + +def main(_): + preset = FLAGS.preset + assert preset in PRESET_MAP.keys(), f"Invalid preset: {preset}" + print(f"✅ Converting {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("✅ Downloaded model from Hugging face hub") + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_name) + hf_model = transformers.AutoModel.from_pretrained(hf_model_name) + print(f"✅ Loaded {preset} from Hugging Face") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("✅ Keras model loaded") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted") + + validate_output(keras_model, hf_model, keras_tokenizer, hf_tokenizer) + print("✅ Validation complete") + + save_to_preset(keras_model, preset) + save_to_preset(keras_tokenizer, preset, config_filename="tokenizer.json") + + print("✅ Preset saved") + + +if __name__ == "__main__": + app.run(main) From 6ea1e63dd50d69ad7e6f5aa6d53ccda20ac58a0c Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 27 Mar 2024 17:41:01 +0100 Subject: [PATCH 55/70] Allow saving / loading from Huggingface Hub preset (#1510) * first draft * update upload_preset * lint * consistent error messages * lint --- keras_nlp/utils/preset_utils.py | 61 ++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index e2a4452714..5ddde4415a 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -27,8 +27,16 @@ except ImportError: kagglehub = None +try: + import huggingface_hub + from huggingface_hub.utils import HFValidationError +except ImportError: + huggingface_hub = None + KAGGLE_PREFIX = "kaggle://" GS_PREFIX = "gs://" +HF_PREFIX = "hf://" + TOKENIZER_ASSET_DIR = "assets/tokenizer" CONFIG_FILE = "config.json" TOKENIZER_CONFIG_FILE = "tokenizer.json" @@ -69,15 +77,33 @@ def get_file(preset, path): url, cache_subdir=os.path.join("models", subdir), ) + elif preset.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = preset.removeprefix(HF_PREFIX) + try: + return huggingface_hub.hf_hub_download( + repo_id=hf_handle, filename=path + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face preset. Hugging Face model handles " + "should have the form 'hf://{org}/{model}'. For example, " + f"'hf://username/bert_base_en'. Received: preset={preset}." + ) from e elif os.path.exists(preset): # Assume a local filepath. return os.path.join(preset, path) else: raise ValueError( "Unknown preset identifier. A preset must be a one of:\n" - "1) a built in preset identifier like `'bert_base_en'`\n" + "1) a built-in preset identifier like `'bert_base_en'`\n" "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n" - "3) a path to a local preset directory like `'./bert_base_en`\n" + "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n" + "4) a path to a local preset directory like `'./bert_base_en`\n" "Use `print(cls.presets.keys())` to view all built-in presets for " "API symbol `cls`.\n" f"Received: preset='{preset}'" @@ -245,7 +271,9 @@ def upload_preset( uri: The URI identifying model to upload to. URIs with format `kaggle://///` - will be uploaded to Kaggle Hub. + will be uploaded to Kaggle Hub while URIs with format + `hf://[/]` will be uploaded to the Hugging + Face Hub. preset: The path to the local model preset directory. allow_incomplete: If True, allows the upload of presets without a tokenizer configuration. Otherwise, a tokenizer @@ -262,10 +290,33 @@ def upload_preset( if uri.startswith(KAGGLE_PREFIX): kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) kagglehub.model_upload(kaggle_handle, preset) + elif uri.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = uri.removeprefix(HF_PREFIX) + try: + repo_url = huggingface_hub.create_repo( + repo_id=hf_handle, exist_ok=True + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face URI. Hugging Face model handles " + "should have the form 'hf://[{org}/]{model}'. For example, " + "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly" + f"upload to your user account. Received: URI={uri}." + ) from e + huggingface_hub.upload_folder( + repo_id=repo_url.repo_id, folder_path=preset + ) else: raise ValueError( - f"Unexpected URI `'{uri}'`. Kaggle upload format should follow " - "`kaggle://///`." + "Unknown URI. An URI must be a one of:\n" + "1) a Kaggle Model handle like `'kaggle://///'`\n" + "2) a Hugging Face handle like `'hf://[/]'`\n" + f"Received: uri='{uri}'." ) From 6e946e231e57852f08294115fa8fb0016fcd3637 Mon Sep 17 00:00:00 2001 From: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:45:08 -0400 Subject: [PATCH 56/70] Stop on multiple end tokens (#1518) * Add multitoken stopping * Update gemma_causal_lm.py * Add further multitoken support * Formatting * Revert tokenizer changes * Move multi token stop to generative task * None check * None check * Error message * Add stop_token_ids * Util testing * Fix sampler tests * All multitoken stop to all models * Sampler multi token * Formatting * Tuple required * Tuple docstring * Pytorch GPU fix * Numpy fix --- keras_nlp/models/bart/bart_seq_2_seq_lm.py | 18 ++++---- keras_nlp/models/bloom/bloom_causal_lm.py | 19 +++++---- keras_nlp/models/gemma/gemma_causal_lm.py | 19 +++++---- .../gemma/gemma_causal_lm_preprocessor.py | 5 +-- .../models/gemma/gemma_causal_lm_test.py | 20 +++++++++ keras_nlp/models/generative_task.py | 42 ++++++++++++++----- keras_nlp/models/gpt2/gpt2_causal_lm.py | 19 +++++---- .../models/gpt_neo_x/gpt_neo_x_causal_lm.py | 19 +++++---- keras_nlp/models/mistral/mistral_causal_lm.py | 19 +++++---- keras_nlp/models/opt/opt_causal_lm.py | 19 +++++---- keras_nlp/samplers/beam_sampler.py | 9 ++-- keras_nlp/samplers/beam_sampler_test.py | 2 +- keras_nlp/samplers/contrastive_sampler.py | 9 ++-- .../samplers/contrastive_sampler_test.py | 2 +- keras_nlp/samplers/greedy_sampler_test.py | 14 ++++++- keras_nlp/samplers/random_sampler_test.py | 2 +- keras_nlp/samplers/sampler.py | 9 ++-- keras_nlp/samplers/top_k_sampler_test.py | 2 +- keras_nlp/samplers/top_p_sampler_test.py | 2 +- keras_nlp/utils/tensor_utils.py | 24 +++++++++++ keras_nlp/utils/tensor_utils_test.py | 40 ++++++++++++++++++ 21 files changed, 219 insertions(+), 95 deletions(-) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index c530555b3d..2f0fa1104c 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -24,6 +24,7 @@ ) from keras_nlp.models.generative_task import GenerativeTask from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM") @@ -398,7 +399,7 @@ def _build_cache( def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a batch of inputs. @@ -412,8 +413,8 @@ def generate_step( inputs: A dictionary with four keys - `"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"` and `"decoder_padding_mask"`, with batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ ( @@ -477,17 +478,18 @@ def repeat_tensor(x): cache=self_attention_cache, index=index, mask=decoder_padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original # prompt (not in locations where `decoder_padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(decoder_token_ids, end_token_id), + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, ops.logical_not(decoder_padding_mask), ) end_locations = ops.cast(end_locations, "int32") diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py index 31eae30c6b..7d189d17e4 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm.py +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -24,6 +24,7 @@ from keras_nlp.models.bloom.bloom_presets import backbone_presets from keras_nlp.models.generative_task import GenerativeTask from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.BloomCausalLM") @@ -245,7 +246,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -256,8 +257,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -290,19 +291,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 58e2e302d5..346b6b362f 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -24,6 +24,7 @@ from keras_nlp.models.gemma.gemma_presets import backbone_presets from keras_nlp.models.generative_task import GenerativeTask from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GemmaCausalLM") @@ -238,7 +239,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -249,8 +250,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -283,19 +284,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py index 04a067be82..ca6b826abc 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -148,10 +148,7 @@ def generate_preprocess( "padding_mask": padding_mask, } - def generate_postprocess( - self, - x, - ): + def generate_postprocess(self, x): """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py index 5ed1ce015c..4a47d162ef 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -130,6 +130,26 @@ def wrapper(*args, **kwargs): # We should immediately abort and output the prompt. self.assertEqual(prompt, output) + def test_multitoken_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + + output = causal_lm.generate(prompt, stop_token_ids=(3,)) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + def test_generate_compilation(self): causal_lm = GemmaCausalLM(**self.init_kwargs) # Assert we do not recompile with successive calls. diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/generative_task.py index 99c447e8ef..30e11a4655 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/generative_task.py @@ -13,6 +13,7 @@ # limitations under the License. import itertools +from functools import partial import tensorflow as tf import tree @@ -64,10 +65,10 @@ def make_generate_function(self): def wrapped_generate_function( inputs, - end_token_id=None, + stop_token_ids=None, ): with torch.no_grad(): - return self.generate_step(inputs, end_token_id) + return self.generate_step(inputs, stop_token_ids) self.generate_function = wrapped_generate_function elif config.backend() == "tensorflow" and not self.run_eagerly: @@ -80,8 +81,8 @@ def wrapped_generate_function( elif config.backend() == "jax" and not self.run_eagerly: import jax - @jax.jit - def compiled_generate_function(inputs, end_token_id, state): + @partial(jax.jit, static_argnames=["stop_token_ids"]) + def compiled_generate_function(inputs, stop_token_ids, state): ( sampler_variables, trainable_variables, @@ -94,7 +95,7 @@ def compiled_generate_function(inputs, end_token_id, state): ) with keras.StatelessScope(state_mapping=mapping) as scope: - outputs = self.generate_step(inputs, end_token_id) + outputs = self.generate_step(inputs, stop_token_ids) # Get updated sampler variables from the stateless scope. sampler_variables = [] @@ -105,8 +106,11 @@ def compiled_generate_function(inputs, end_token_id, state): def wrapped_generate_function( inputs, - end_token_id=None, + stop_token_ids=None, ): + if isinstance(stop_token_ids, list): + stop_token_ids = tuple(stop_token_ids) + # Create an explicit tuple of all variable state. state = ( self._sampler.variables, @@ -118,7 +122,7 @@ def wrapped_generate_function( inputs = tree.map_structure(ops.convert_to_tensor, inputs) outputs, sampler_variables = compiled_generate_function( inputs, - end_token_id, + stop_token_ids, state, ) # Only assign the sampler variables (random seeds), as other @@ -206,6 +210,7 @@ def generate( self, inputs, max_length=None, + stop_token_ids=None, ): """Generate text given prompt `inputs`. @@ -234,15 +239,30 @@ def generate( `preprocessor`. If `preprocessor` is `None`, `inputs` should be should be padded to the desired maximum length and this argument will be ignored. + stop_token_ids: Optional. `None`, "auto", or tuple of token ids. Defaults + to "auto" which uses the `preprocessor.tokenizer.end_token_id`. + Not specifying a processor will produce an error. None stops + generation after generating `max_length` tokens. You may also + specify a list of token id's the model should stop on. Note that + sequences of tokens will each be interpreted as a stop token, + multi-token stop sequences are not supported. """ # Setup our three main passes. # 1. Optionally preprocessing strings to dense integer tensors. # 2. Generate new tokens via a compiled function on dense tensors. # 3. Optionally postprocess dense integer tensors back to string. generate_function = self.make_generate_function() - end_token_id = None - if self.preprocessor is not None: - end_token_id = self.preprocessor.tokenizer.end_token_id + + if self.preprocessor is None and stop_token_ids == "auto": + raise ValueError( + 'A `preprocessor` must be attached to the model if `stop_token_ids="auto"`. ' + "Currently `preprocessor=None`. To call `generate()` with preprocessing " + "detached, either pass `stop_tokens_ids=None` to always generate until " + "`max_length` or pass a tuple of token ids that should terminate generation " + "as `stop_tokens_ids`." + ) + elif stop_token_ids == "auto": + stop_token_ids = [self.preprocessor.tokenizer.end_token_id] def preprocess(x): return self.preprocessor.generate_preprocess( @@ -250,7 +270,7 @@ def preprocess(x): ) def generate(x): - return generate_function(x, end_token_id=end_token_id) + return generate_function(x, stop_token_ids=stop_token_ids) def postprocess(x): return self.preprocessor.generate_postprocess(x) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b0bd529da4..18e6ead7a2 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -24,6 +24,7 @@ ) from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPT2CausalLM") @@ -251,7 +252,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -262,8 +263,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -296,19 +297,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index b1df4a6706..797f3e2eb3 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -21,6 +21,7 @@ GPTNeoXCausalLMPreprocessor, ) from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPTNeoXCausalLM") @@ -141,7 +142,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -152,8 +153,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -186,19 +187,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop_tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index 800df1c8cb..20e19e6c31 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -23,6 +23,7 @@ ) from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.MistralCausalLM") @@ -143,7 +144,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -154,8 +155,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -188,19 +189,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop_tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 2ca8ee07b4..6133ef227e 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -24,6 +24,7 @@ ) from keras_nlp.models.opt.opt_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.OPTCausalLM") @@ -247,7 +248,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -258,8 +259,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -292,19 +293,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 297ec203de..3a34217952 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,6 +18,7 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.samplers.BeamSampler") @@ -70,7 +71,7 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, model=None, ): @@ -109,10 +110,10 @@ def unflatten_beams(x): log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0)) def cond(prompt, cache, index, log_probs): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* stop token. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index fca37cd85b..496e892e62 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -110,7 +110,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index 36d10690d7..24f983fd0b 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,6 +17,7 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") @@ -70,7 +71,7 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, model=None, ): @@ -106,10 +107,10 @@ def unflatten_beams(x): cache = cache if has_cache else () def cond(prompt, cache, index, logits, hidden_states): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* stop token. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 824d47a705..9cdd9de546 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -98,7 +98,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], index=0, hidden_states=self.hidden_states, ) diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 618c94d118..38c85c950b 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -86,10 +86,22 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) + def test_multitoken_early_stopping(self): + cache_chars = list("sequentially") + cache = ops.array([[self.char_lookup[c] for c in cache_chars]]) + prompt = ops.full((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + cache=cache, + stop_token_ids=[self.char_lookup["t"], self.char_lookup["n"]], + ) + self.assertEqual(self.join_as_string(output), ["sequenzzzzzz"]) + def test_is_greedy(self): def next(prompt, cache, index): # Dummy hidden states. diff --git a/keras_nlp/samplers/random_sampler_test.py b/keras_nlp/samplers/random_sampler_test.py index 3149d11dba..6f6fd53f51 100644 --- a/keras_nlp/samplers/random_sampler_test.py +++ b/keras_nlp/samplers/random_sampler_test.py @@ -98,7 +98,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index a6b64b5324..43950dea2f 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,6 +17,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.samplers.Sampler") @@ -90,7 +91,7 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, model=None, ): @@ -106,10 +107,10 @@ def __call__( cache = () if cache is None else cache def cond(prompt, cache, index): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* id from stop_token_ids. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index c0f516a214..c0ebbc85d0 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -86,7 +86,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index fea5ca4110..9bd91b4a1a 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -87,7 +87,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index a88d80a4da..8b5759b1cf 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -170,3 +170,27 @@ def is_int_dtype(dtype): def is_string_dtype(dtype): return "string" in standardize_dtype(dtype) + + +def any_equal(inputs, values, padding_mask): + """Return a mask that is True anywhere `inputs` has a value in `values`. + + Final mask has `padding_mask` applied. + + Args: + inputs: Input tensor. + values: List or iterable of tensors shaped like `inputs` or broadcastable + by bit operators. + padding_mask: Tensor with shape compatible with inputs that will condition + output. + + Returns: + A tensor with `inputs` shape where each position is True if it contains + a value from any `values`. Padding mask will be applied before + returning.""" + output = ops.equal(inputs, values[0]) + for value in values[1:]: + value_equality = ops.equal(inputs, value) + output = ops.logical_or(output, value_equality) + + return ops.logical_and(output, padding_mask) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py index 317e42ade8..ec27832ed9 100644 --- a/keras_nlp/utils/tensor_utils_test.py +++ b/keras_nlp/utils/tensor_utils_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import tensorflow as tf from keras_nlp.backend import ops from keras_nlp.tests.test_case import TestCase +from keras_nlp.utils.tensor_utils import any_equal from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import tensor_to_list @@ -97,3 +99,41 @@ def test_convert_ragged(self): self.assertAllEqual(outputs, [[1, 2], [1]]) self.assertFalse(unbatched) self.assertFalse(rectangular) + + +class MaskedAnyEqualTest(tf.test.TestCase): + def test_basic_equality(self): + inputs = ops.array([1, 2, 3, 5]) + values = [3, 5] + padding_mask = ops.array([True, True, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_multiple_values(self): + inputs = ops.array([2, 4, 7, 9]) + values = [5, 4, 9] + padding_mask = ops.array([True, True, True, True]) + expected_output = np.array([False, True, False, True]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_padding_mask(self): + inputs = ops.array([1, 5, 3, 2]) + values = [5, 3] + padding_mask = ops.array([True, False, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_input_shaped_values(self): + inputs = ops.array([1, 5, 3, 2]) + values = [[5, 5, 5, 5], [3, 3, 3, 3]] + padding_mask = ops.array([True, False, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) From e5b2833dff87dfdadea684dff4fc039b967f58e0 Mon Sep 17 00:00:00 2001 From: asmith26 Date: Wed, 27 Mar 2024 23:00:20 +0000 Subject: [PATCH 57/70] Update mistral_tokenizer.py (#1528) --- keras_nlp/models/mistral/mistral_tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index 59a00d302f..c7d27684f8 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -49,7 +49,7 @@ class MistralTokenizer(SentencePieceTokenizer): ```python # Unbatched input. tokenizer = keras_nlp.models.MistralTokenizer.from_preset( - "mistral_base_en", + "mistral_7b_en", ) tokenizer("The quick brown fox jumped.") From 2be333c6d308becd4b1d3bbc118649b0d6907333 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Wed, 27 Mar 2024 16:03:41 -0700 Subject: [PATCH 58/70] Add lora example to GemmaCausalLM docstring (#1527) * Add lora example to GemmaCausalLM docstring. * Address review. --- keras_nlp/models/gemma/gemma_causal_lm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 346b6b362f..30d0171844 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -98,6 +98,14 @@ class GemmaCausalLM(GenerativeTask): gemma_lm.fit(x=features, batch_size=2) ``` + Call `fit()` with LoRA fine-tuning enabled. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma.backbone.enable_lora(rank=4) + gemma_lm.fit(x=features, batch_size=2) + ``` + Call `fit()` without preprocessing. ```python x = { From 859b1bfd2103fdde65275545261c2ce4a3a9bc55 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 27 Mar 2024 18:36:04 -0700 Subject: [PATCH 59/70] Add LLaMA Causal LM with 7B presets (#1526) * Add LLaMA Causal LM * Add causal lm to the public API * Update preset names and fix checkpoint script * Fix discrepancies and add tests * Add tests for CausalLM * end_token -> stop_token_ids --- keras_nlp/models/__init__.py | 6 + keras_nlp/models/llama/llama_backbone.py | 15 +- keras_nlp/models/llama/llama_backbone_test.py | 31 ++ keras_nlp/models/llama/llama_causal_lm.py | 221 ++++++++++++ .../models/llama/llama_causal_lm_test.py | 130 +++++++ keras_nlp/models/llama/llama_preprocessor.py | 7 + .../models/llama/llama_preprocessor_test.py | 11 + keras_nlp/models/llama/llama_presets.py | 38 ++ keras_nlp/models/llama/llama_tokenizer.py | 8 + .../models/llama/llama_tokenizer_test.py | 20 + .../convert_llama_checkpoints.py | 341 +++++++++++++----- 11 files changed, 723 insertions(+), 105 deletions(-) create mode 100644 keras_nlp/models/llama/llama_causal_lm.py create mode 100644 keras_nlp/models/llama/llama_causal_lm_test.py create mode 100644 keras_nlp/models/llama/llama_presets.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 0a74cba869..2eb1f5c8dc 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -108,6 +108,12 @@ ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index b5383d528a..ec35989e01 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,20 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# import copy +import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone - -# from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm - -# from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty def _llama_kernel_initializer(stddev=0.02): @@ -191,6 +188,6 @@ def get_config(self): ) return config - # @classproperty - # def presets(cls): - # return copy.deepcopy(backbone_presets) + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index 56d8c44bd3..b641a0152e 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -49,3 +49,34 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + def test_num_parameters(self): + model = LlamaBackbone(**self.init_kwargs) + # Reference value calculated using the PyTorch model + self.assertEqual(model.count_params(), 968) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaBackbone, + preset="llama2_7b_en", + input_data={ + "token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 4096), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [0.0153, 1.1657, 2.2452, -2.0192, -0.5801] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaBackbone.presets: + self.run_preset_test( + cls=LlamaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py new file mode 100644 index 0000000000..7527766f01 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -0,0 +1,221 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLM") +class LlamaCausalLM(GenerativeTask): + """An end-to-end Llama model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a LLaMA model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.LlamaBackbone` instance. + preprocessor: A `keras_nlp.models.LlamaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + @classproperty + def backbone_cls(cls): + return LlamaBackbone + + @classproperty + def preprocessor_cls(cls): + return LlamaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `LlamaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + any_equal(token_ids, stop_token_ids), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_causal_lm_test.py b/keras_nlp/models/llama/llama_causal_lm_test.py new file mode 100644 index 0000000000..ff71a75b38 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_test.py @@ -0,0 +1,130 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = LlamaCausalLMPreprocessor( + LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "llama_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = LlamaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLM.presets: + self.run_preset_test( + cls=LlamaCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py index 580557f50d..a24c425082 100644 --- a/keras_nlp/models/llama/llama_preprocessor.py +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( @@ -189,3 +192,7 @@ def sequence_length(self, value): @classproperty def tokenizer_cls(cls): return LlamaTokenizer + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py index 6807886812..52a559aa2e 100644 --- a/keras_nlp/models/llama/llama_preprocessor_test.py +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.tests.test_case import TestCase @@ -55,3 +57,12 @@ def test_errors_for_2d_list_input(self): ambiguous_input = [["one", "two"], ["three", "four"]] with self.assertRaises(ValueError): preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaPreprocessor.presets: + self.run_preset_test( + cls=LlamaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_presets.py b/keras_nlp/models/llama/llama_presets.py new file mode 100644 index 0000000000..292848a11d --- /dev/null +++ b/keras_nlp/models/llama/llama_presets.py @@ -0,0 +1,38 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "llama2_7b_en": { + "metadata": { + "description": "LLaMA 2 7B Base model", + "params": 6738415616, + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", + }, + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1", + }, + "llama2_instruct_7b_en": { + "metadata": { + "description": "LLaMA 2 7B Chat model", + "params": 6738415616, + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", + }, + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py index 7acdf8687c..07b0f21037 100644 --- a/keras_nlp/models/llama/llama_tokenizer.py +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.LlamaTokenizer") @@ -79,3 +83,7 @@ def set_proto(self, proto): self.start_token_id = None self.end_token_id = None self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_tokenizer_test.py b/keras_nlp/models/llama/llama_tokenizer_test.py index 9a3c225456..51687731e5 100644 --- a/keras_nlp/models/llama/llama_tokenizer_test.py +++ b/keras_nlp/models/llama/llama_tokenizer_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.tests.test_case import TestCase @@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self): self.get_test_data_dir(), "no_special_token_vocab.spm" ) ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaTokenizer, + preset="llama2_7b_en", + input_data=["The quick brown fox."], + expected_output=[[450, 4996, 17354, 1701, 29916, 29889]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaTokenizer.presets: + self.run_preset_test( + cls=LlamaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index 5eb3973f36..4e127b2c7d 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -11,131 +11,280 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import shutil +import tempfile +import traceback -import torch -from transformers import AutoModel +import numpy as np +from absl import app +from absl import flags +from keras import ops +from transformers import AutoTokenizer +from transformers import LlamaForCausalLM -from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models import LlamaBackbone +from keras_nlp.models import LlamaCausalLMPreprocessor +from keras_nlp.models import LlamaTokenizer +from keras_nlp.utils.preset_utils import save_to_preset -os.environ["KERAS_BACKEND"] = "torch" +PRESET_MAP = { + "llama2_7b_en": "meta-llama/Llama-2-7b-hf", + "llama2_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf", +} -# from huggingface_hub import login -# llama weights as of now are on request access -# login(token=' Huggingface model and tokenizer loaded") - # MLP - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_intermediate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.up_proj.weight"].numpy().T - ) + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = LlamaBackbone(**backbone_kwargs) - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_gate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.gate_proj.weight"].numpy().T - ) + # === Get the tokenizer from the Huggingface model === + tokenizer_path = hf_tokenizer.vocab_file + keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_output_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.down_proj.weight"].numpy().T - ) + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") - # LAYERNORM - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._self_attention_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.input_layernorm.weight"] - ) + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.post_attention_layernorm.weight"] - ) + # === Save the model weights in float32 format === + keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + del keras_nlp_model, hf_model + gc.collect() -keras_model.get_layer("layer_norm").gamma.assign(hf_wts["norm.weight"]) + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_nlp_model = LlamaBackbone(**backbone_kwargs) + keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") -token_ids = [1, 2181, 8522, 338] -padding_mask = [1, 1, 1, 1] + # === Save the tokenizer === + save_to_preset( + keras_nlp_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) -keras_inputs = { - "token_ids": torch.tensor([token_ids]), - "padding_mask": torch.tensor([padding_mask]), -} -with torch.no_grad(): - keras_outputs = keras_model(keras_inputs) -print("Keras output = ", keras_outputs.numpy()) +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From f8aba3c2ca6b384b00885eb2107c6a82ce4b1fbc Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 28 Mar 2024 12:28:03 -0700 Subject: [PATCH 60/70] Add task base classes (#1517) This PR grew as I was writing it, and now adds a number of new features: 1. Exposed base classes. Sets us on a path for better documentation, a more "introspectable" library, and allow sub-classing. 2. Enable `from_preset()` on base classes for any subclass preset. This gives us similar functionality to "auto classes" in huggingface, without the extra overhead of needing a new symbol. 3. An ability to register new tasks/backbones/tokenizers from out of tree code with `keras.saving.register_keras_serializable()`. Try a colab: https://colab.research.google.com/gist/mattdangerw/da885f050fa8baef9b4f9a4ec68d6567/kerasnlp-base-classes.ipynb --- keras_nlp/models/__init__.py | 6 + keras_nlp/models/albert/__init__.py | 7 + keras_nlp/models/albert/albert_backbone.py | 8 - keras_nlp/models/albert/albert_classifier.py | 23 +-- keras_nlp/models/albert/albert_masked_lm.py | 23 +-- .../models/albert/albert_preprocessor.py | 14 +- keras_nlp/models/albert/albert_tokenizer.py | 8 - keras_nlp/models/backbone.py | 123 ++++++++++----- keras_nlp/models/backbone_test.py | 44 ++++++ keras_nlp/models/bart/__init__.py | 9 +- keras_nlp/models/bart/bart_backbone.py | 8 - keras_nlp/models/bart/bart_preprocessor.py | 13 +- keras_nlp/models/bart/bart_seq_2_seq_lm.py | 22 +-- .../bart/bart_seq_2_seq_lm_preprocessor.py | 7 - keras_nlp/models/bart/bart_tokenizer.py | 7 - keras_nlp/models/bert/__init__.py | 10 ++ keras_nlp/models/bert/bert_backbone.py | 8 - keras_nlp/models/bert/bert_classifier.py | 24 +-- keras_nlp/models/bert/bert_masked_lm.py | 23 +-- keras_nlp/models/bert/bert_preprocessor.py | 17 +- keras_nlp/models/bert/bert_tokenizer.py | 11 -- keras_nlp/models/bloom/__init__.py | 7 + keras_nlp/models/bloom/bloom_backbone.py | 7 - keras_nlp/models/bloom/bloom_causal_lm.py | 22 +-- keras_nlp/models/bloom/bloom_preprocessor.py | 13 +- keras_nlp/models/bloom/bloom_tokenizer.py | 7 - .../{generative_task.py => causal_lm.py} | 44 +++++- keras_nlp/models/classifier.py | 51 ++++++ keras_nlp/models/deberta_v3/__init__.py | 7 + .../models/deberta_v3/deberta_v3_backbone.py | 7 - .../deberta_v3/deberta_v3_classifier.py | 22 +-- .../models/deberta_v3/deberta_v3_masked_lm.py | 22 +-- .../deberta_v3/deberta_v3_preprocessor.py | 13 +- .../models/deberta_v3/deberta_v3_tokenizer.py | 7 - keras_nlp/models/distil_bert/__init__.py | 9 ++ .../distil_bert/distil_bert_backbone.py | 7 - .../distil_bert/distil_bert_classifier.py | 22 +-- .../distil_bert/distil_bert_masked_lm.py | 22 +-- .../distil_bert/distil_bert_preprocessor.py | 13 +- .../distil_bert/distil_bert_tokenizer.py | 7 - keras_nlp/models/electra/__init__.py | 7 + keras_nlp/models/electra/electra_backbone.py | 8 - .../models/electra/electra_preprocessor.py | 14 +- keras_nlp/models/electra/electra_tokenizer.py | 8 - keras_nlp/models/f_net/__init__.py | 7 + keras_nlp/models/f_net/f_net_backbone.py | 7 - keras_nlp/models/f_net/f_net_classifier.py | 22 +-- keras_nlp/models/f_net/f_net_masked_lm.py | 22 +-- keras_nlp/models/f_net/f_net_preprocessor.py | 13 +- keras_nlp/models/f_net/f_net_tokenizer.py | 7 - keras_nlp/models/falcon/__init__.py | 9 +- .../models/falcon/falcon_preprocessor.py | 13 +- keras_nlp/models/falcon/falcon_tokenizer.py | 7 - keras_nlp/models/gemma/__init__.py | 7 + keras_nlp/models/gemma/gemma_backbone.py | 7 - keras_nlp/models/gemma/gemma_causal_lm.py | 22 +-- keras_nlp/models/gemma/gemma_preprocessor.py | 13 +- keras_nlp/models/gemma/gemma_tokenizer.py | 7 - keras_nlp/models/gpt2/__init__.py | 7 + keras_nlp/models/gpt2/gpt2_backbone.py | 7 - keras_nlp/models/gpt2/gpt2_causal_lm.py | 22 +-- keras_nlp/models/gpt2/gpt2_preprocessor.py | 13 +- keras_nlp/models/gpt2/gpt2_tokenizer.py | 7 - .../models/gpt_neo_x/gpt_neo_x_causal_lm.py | 16 +- .../gpt_neo_x/gpt_neo_x_preprocessor.py | 7 +- keras_nlp/models/llama/__init__.py | 7 + keras_nlp/models/llama/llama_backbone.py | 8 - keras_nlp/models/llama/llama_causal_lm.py | 11 +- keras_nlp/models/llama/llama_preprocessor.py | 14 +- keras_nlp/models/masked_lm.py | 42 +++++ keras_nlp/models/mistral/__init__.py | 7 + keras_nlp/models/mistral/mistral_backbone.py | 7 - keras_nlp/models/mistral/mistral_causal_lm.py | 22 +-- .../models/mistral/mistral_preprocessor.py | 13 +- keras_nlp/models/mistral/mistral_tokenizer.py | 7 - keras_nlp/models/opt/__init__.py | 7 + keras_nlp/models/opt/opt_backbone.py | 7 - keras_nlp/models/opt/opt_causal_lm.py | 22 +-- keras_nlp/models/opt/opt_preprocessor.py | 13 +- keras_nlp/models/opt/opt_tokenizer.py | 7 - keras_nlp/models/preprocessor.py | 119 +++++++++----- keras_nlp/models/preprocessor_test.py | 37 +++++ keras_nlp/models/roberta/__init__.py | 7 + keras_nlp/models/roberta/roberta_backbone.py | 7 - .../models/roberta/roberta_classifier.py | 22 +-- keras_nlp/models/roberta/roberta_masked_lm.py | 22 +-- .../models/roberta/roberta_preprocessor.py | 13 +- keras_nlp/models/roberta/roberta_tokenizer.py | 7 - keras_nlp/models/seq_2_seq_lm.py | 54 +++++++ keras_nlp/models/t5/__init__.py | 7 + keras_nlp/models/t5/t5_backbone.py | 7 - keras_nlp/models/t5/t5_tokenizer.py | 7 - keras_nlp/models/task.py | 148 ++++++++++++------ keras_nlp/models/task_test.py | 20 +++ keras_nlp/models/whisper/__init__.py | 7 + .../whisper_audio_feature_extractor.py | 53 ------- keras_nlp/models/whisper/whisper_backbone.py | 7 - .../models/whisper/whisper_preprocessor.py | 17 +- keras_nlp/models/whisper/whisper_tokenizer.py | 7 - keras_nlp/models/xlm_roberta/__init__.py | 9 ++ .../xlm_roberta/xlm_roberta_backbone.py | 7 - .../xlm_roberta/xlm_roberta_classifier.py | 22 +-- .../xlm_roberta/xlm_roberta_masked_lm.py | 22 +-- .../xlm_roberta/xlm_roberta_preprocessor.py | 13 +- .../xlm_roberta/xlm_roberta_tokenizer.py | 7 - keras_nlp/tests/test_case.py | 2 - keras_nlp/tokenizers/byte_pair_tokenizer.py | 68 -------- .../tokenizers/sentence_piece_tokenizer.py | 68 -------- keras_nlp/tokenizers/tokenizer.py | 79 +++++++++- keras_nlp/tokenizers/tokenizer_test.py | 26 +++ keras_nlp/tokenizers/word_piece_tokenizer.py | 68 -------- keras_nlp/utils/preset_utils.py | 47 ++++-- keras_nlp/utils/preset_utils_test.py | 10 +- keras_nlp/utils/python_utils.py | 25 --- keras_nlp/utils/python_utils_test.py | 58 ------- 115 files changed, 965 insertions(+), 1331 deletions(-) create mode 100644 keras_nlp/models/backbone_test.py rename keras_nlp/models/{generative_task.py => causal_lm.py} (87%) create mode 100644 keras_nlp/models/classifier.py create mode 100644 keras_nlp/models/masked_lm.py create mode 100644 keras_nlp/models/preprocessor_test.py create mode 100644 keras_nlp/models/seq_2_seq_lm.py diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 2eb1f5c8dc..4139656fbf 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -43,6 +43,8 @@ ) from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.models.causal_lm import CausalLM +from keras_nlp.models.classifier import Classifier from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, @@ -114,6 +116,7 @@ ) from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( @@ -128,6 +131,7 @@ ) from keras_nlp.models.opt.opt_preprocessor import OPTPreprocessor from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_masked_lm import RobertaMaskedLM @@ -136,6 +140,7 @@ ) from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer from keras_nlp.models.task import Task @@ -162,3 +167,4 @@ XLMRobertaTokenizer, ) from keras_nlp.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_nlp.tokenizers.tokenizer import Tokenizer diff --git a/keras_nlp/models/albert/__init__.py b/keras_nlp/models/albert/__init__.py index ba0c2545e4..c0ae8e8fa1 100644 --- a/keras_nlp/models/albert/__init__.py +++ b/keras_nlp/models/albert/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_presets import backbone_presets +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (AlbertBackbone, AlbertTokenizer)) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 0cc1d4d021..e8acef81d5 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.backbone import Backbone from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def albert_kernel_initializer(stddev=0.02): @@ -266,7 +262,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 32a4e0847d..7471393cc7 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.albert.albert_backbone import AlbertBackbone from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor -from keras_nlp.models.albert.albert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.classifier import Classifier @keras_nlp_export("keras_nlp.models.AlbertClassifier") -class AlbertClassifier(Task): +class AlbertClassifier(Classifier): """An end-to-end ALBERT model for classification tasks This model attaches a classification head to a `keras_nlp.model.AlbertBackbone` @@ -146,6 +142,9 @@ class AlbertClassifier(Task): ``` """ + backbone_cls = AlbertBackbone + preprocessor_cls = AlbertPreprocessor + def __init__( self, backbone, @@ -209,15 +208,3 @@ def get_config(self): ) return config - - @classproperty - def backbone_cls(cls): - return AlbertBackbone - - @classproperty - def preprocessor_cls(cls): - return AlbertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets}) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index e421ef524c..01892bdcab 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead @@ -22,14 +20,12 @@ from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( AlbertMaskedLMPreprocessor, ) -from keras_nlp.models.albert.albert_presets import backbone_presets -from keras_nlp.models.task import Task +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertMaskedLM") -class AlbertMaskedLM(Task): +class AlbertMaskedLM(MaskedLM): """An end-to-end ALBERT model for the masked language modeling task. This model will train ALBERT on a masked language modeling task. @@ -96,6 +92,9 @@ class AlbertMaskedLM(Task): ``` """ + backbone_cls = AlbertBackbone + preprocessor_cls = AlbertMaskedLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -133,15 +132,3 @@ def __init__(self, backbone, preprocessor=None, **kwargs): weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return AlbertBackbone - - @classproperty - def preprocessor_cls(cls): - return AlbertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index 19f4bd9a7b..8e49d4e650 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertPreprocessor") @@ -149,6 +145,8 @@ class AlbertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = AlbertTokenizer + def __init__( self, tokenizer, @@ -205,11 +203,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return AlbertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 44aed44cf5..8887893afc 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertTokenizer") @@ -119,7 +115,3 @@ def set_proto(self, proto): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 08b9f86e96..e8ececa6f1 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -15,15 +15,51 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring @keras_nlp_export("keras_nlp.models.Backbone") class Backbone(keras.Model): + """Base class for all `Backbone` models. + + A `Backbone` is the basic architecture for a given NLP model. Unlike a + `keras_nlp.models.Task`, a `Backbone` is not tailored to any specific loss + function and training setup. A `Backbone` generally outputs the last hidden + states of an architecture before any output predictions. + + A `Backbone` can be used in one of two ways: + + 1. Through a `Task` class, which will wrap and extend a `Backbone` so it + can be used with high level Keras functions like `fit()`, `predict()` or + `evaluate()`. `Task` classes are built with a particular training + objective in mind (e.g. classification or language modeling). + 2. Directly, by extending underlying functional model with additional + outputs and training setup. This is the most flexible approach, and can + allow for any outputs, loss, or custom training loop. + + All backbones include a `from_preset()` constructor which can be used to + load a pre-trained config and weights. + + Example: + ```python + # Load a BERT backbone with pre-trained weights. + backbone = keras_nlp.models.Backbone.from_preset( + "bert_base_en", + ) + # Load a GPT2 backbone with pre-trained weights at bfloat16 precision. + backbone = keras_nlp.models.Backbone.from_preset( + "gpt2_base_en", + dtype="bfloat16", + trainable=False, + ) + ``` + """ + def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) self._functional_layer_ids = set( @@ -100,7 +136,11 @@ def from_config(cls, config): @classproperty def presets(cls): - return {} + """List builtin presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -109,33 +149,54 @@ def from_preset( load_weights=True, **kwargs, ): - """Instantiate {{model_name}} model from preset architecture and weights. + """Instantiate a `keras_nlp.models.Backbone` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.Backbone.from_preset()`, or from + a model class like `keras_nlp.models.GemmaBackbone.from_preset()`. + If calling from the base class, the subclass of the returning object + will be inferred from the config in the preset directory. + + For any `Backbone` subclass, you can run `cls.presets.keys()` to list + all built-in presets available on the class. Args: - preset: string. Must be one of "{{preset_names}}". - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. Examples: ```python - # Load architecture and weights from preset - model = keras_nlp.models.{{model_name}}.from_preset( - "{{example_preset_name}}" + # Load a Gemma backbone with pre-trained weights. + model = keras_nlp.models.Backbone.from_preset( + "gemma_2b_en", ) - # Load randomly initialized model from preset architecture - model = keras_nlp.models.{{model_name}}.from_preset( - "{{example_preset_name}}", - load_weights=False + # Load a Bert backbone with a pre-trained config and random weights. + model = keras_nlp.models.Backbone.from_preset( + "bert_base_en", + load_weights=False, ) ``` """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - check_preset_class(preset, cls) + preset_cls = check_config_class(preset) + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) return load_from_preset( preset, load_weights=load_weights, @@ -150,28 +211,6 @@ def save_to_preset(self, preset): """ save_to_preset(self, preset) - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) - def enable_lora(self, rank): """Enable Lora on the backbone. diff --git a/keras_nlp/models/backbone_test.py b/keras_nlp/models/backbone_test.py new file mode 100644 index 0000000000..7e4cd24f27 --- /dev/null +++ b/keras_nlp/models/backbone_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.bert.bert_backbone import BertBackbone +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.tests.test_case import TestCase + + +class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertBackbone.presets.keys()) + gpt2_presets = set(GPT2Backbone.presets.keys()) + all_presets = set(Backbone.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + Backbone.from_preset("bert_tiny_en_uncased", load_weights=False), + BertBackbone, + ) + self.assertIsInstance( + Backbone.from_preset("gpt2_base_en", load_weights=False), + GPT2Backbone, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False) diff --git a/keras_nlp/models/bart/__init__.py b/keras_nlp/models/bart/__init__.py index 6e4df4e727..b14f1f06b6 100644 --- a/keras_nlp/models/bart/__init__.py +++ b/keras_nlp/models/bart/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The KerasNLP Authors +# Copyright 2023 The KerasNLP Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bart.bart_backbone import BartBackbone +from keras_nlp.models.bart.bart_presets import backbone_presets +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BartBackbone, BartTokenizer)) diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index f100133d25..d10f1cc240 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding @@ -21,8 +19,6 @@ from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bart.bart_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def bart_kernel_initializer(stddev=0.02): @@ -260,7 +256,3 @@ def get_config(self): ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index 3310b1e532..276d73af04 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartPreprocessor") @@ -131,6 +128,8 @@ class BartPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BartTokenizer + def __init__( self, tokenizer, @@ -274,11 +273,3 @@ def sequence_length(self): @sequence_length.setter def sequence_length(self, value): self.decoder_sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return BartTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index 2f0fa1104c..e13e74b769 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( BartSeq2SeqLMPreprocessor, ) -from keras_nlp.models.generative_task import GenerativeTask -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM") -class BartSeq2SeqLM(GenerativeTask): +class BartSeq2SeqLM(Seq2SeqLM): """An end-to-end BART model for seq2seq language modeling. A seq2seq language model (LM) is an encoder-decoder model which is used for @@ -180,6 +177,9 @@ class BartSeq2SeqLM(GenerativeTask): ``` """ + backbone_cls = BartBackbone + preprocessor_cls = BartSeq2SeqLMPreprocessor + def __init__( self, backbone, @@ -208,18 +208,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return BartBackbone - - @classproperty - def preprocessor_cls(cls): - return BartSeq2SeqLMPreprocessor - def call_decoder_with_cache( self, encoder_hidden_states, diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 1c72e6e935..4d90fd87e8 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from absl import logging @@ -20,12 +19,10 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor") @@ -266,7 +263,3 @@ def generate_postprocess( decoder_token_ids, decoder_padding_mask ) return self.tokenizer.detokenize(decoder_token_ids) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py index 17fb237b88..c4e3d1204d 100644 --- a/keras_nlp/models/bart/bart_tokenizer.py +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartTokenizer") @@ -118,10 +115,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.pad_token_id = None self.end_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/bert/__init__.py b/keras_nlp/models/bert/__init__.py index ba0c2545e4..a34b85d0eb 100644 --- a/keras_nlp/models/bert/__init__.py +++ b/keras_nlp/models/bert/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bert.bert_backbone import BertBackbone +from keras_nlp.models.bert.bert_classifier import BertClassifier +from keras_nlp.models.bert.bert_presets import backbone_presets +from keras_nlp.models.bert.bert_presets import classifier_presets +from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BertBackbone, BertTokenizer)) +register_presets(classifier_presets, (BertClassifier, BertTokenizer)) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 320dc1c2ee..1f5a02dbb1 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def bert_kernel_initializer(stddev=0.02): @@ -226,7 +222,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 09d2b8810c..27bb076ea9 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -12,21 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_backbone import bert_kernel_initializer from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.classifier import Classifier @keras_nlp_export("keras_nlp.models.BertClassifier") -class BertClassifier(Task): +class BertClassifier(Classifier): """An end-to-end BERT model for classification tasks. This model attaches a classification head to a @@ -131,6 +126,9 @@ class BertClassifier(Task): ``` """ + backbone_cls = BertBackbone + preprocessor_cls = BertPreprocessor + def __init__( self, backbone, @@ -193,15 +191,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return BertBackbone - - @classproperty - def preprocessor_cls(cls): - return BertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index b915a99481..1166963625 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead @@ -22,13 +20,11 @@ from keras_nlp.models.bert.bert_masked_lm_preprocessor import ( BertMaskedLMPreprocessor, ) -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.BertMaskedLM") -class BertMaskedLM(Task): +class BertMaskedLM(MaskedLM): """An end-to-end BERT model for the masked language modeling task. This model will train BERT on a masked language modeling task. @@ -95,6 +91,9 @@ class BertMaskedLM(Task): ``` """ + backbone_cls = BertBackbone + preprocessor_cls = BertMaskedLMPreprocessor + def __init__( self, backbone, @@ -139,15 +138,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return BertBackbone - - @classproperty - def preprocessor_cls(cls): - return BertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index 02f5a45985..2975eae9c6 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -12,23 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty - -PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets)) @keras_nlp_export("keras_nlp.models.BertPreprocessor") @@ -130,6 +123,8 @@ class BertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BertTokenizer + def __init__( self, tokenizer, @@ -186,11 +181,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return BertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py index 4de433e43a..819bf2f63f 100644 --- a/keras_nlp/models/bert/bert_tokenizer.py +++ b/keras_nlp/models/bert/bert_tokenizer.py @@ -12,15 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_nlp.utils.python_utils import classproperty - -PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets)) @keras_nlp_export("keras_nlp.models.BertTokenizer") @@ -113,10 +106,6 @@ def set_vocabulary(self, vocabulary): self.pad_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) - def get_config(self): config = super().get_config() del config["special_tokens"] # Not configurable; set in __init__. diff --git a/keras_nlp/models/bloom/__init__.py b/keras_nlp/models/bloom/__init__.py index ba0c2545e4..2b7ad787b4 100644 --- a/keras_nlp/models/bloom/__init__.py +++ b/keras_nlp/models/bloom/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BloomBackbone, BloomTokenizer)) diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index eb686668d8..4e153b868d 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone from keras_nlp.models.bloom.bloom_decoder import BloomDecoder -from keras_nlp.models.bloom.bloom_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def _bloom_kernel_initializer(stddev=0.02): @@ -171,7 +168,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py index 7d189d17e4..914107f101 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm.py +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,14 +20,12 @@ from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( BloomCausalLMPreprocessor, ) -from keras_nlp.models.bloom.bloom_presets import backbone_presets -from keras_nlp.models.generative_task import GenerativeTask -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.BloomCausalLM") -class BloomCausalLM(GenerativeTask): +class BloomCausalLM(CausalLM): """An end-to-end BLOOM model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -147,6 +144,9 @@ class BloomCausalLM(GenerativeTask): ``` """ + backbone_cls = BloomBackbone + preprocessor_cls = BloomCausalLMPreprocessor + def __init__( self, backbone, @@ -176,18 +176,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return BloomBackbone - - @classproperty - def preprocessor_cls(cls): - return BloomCausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index 8eb693cb50..dfaac0332d 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.bloom.bloom_presets import backbone_presets from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BloomPreprocessor") @@ -107,6 +104,8 @@ class BloomPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BloomTokenizer + def __init__( self, tokenizer, @@ -183,11 +182,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return BloomTokenizer diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index 0d7f74b163..6c6097e4ce 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bloom.bloom_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BloomTokenizer") @@ -110,10 +107,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.end_token_id = None self.pad_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/causal_lm.py similarity index 87% rename from keras_nlp/models/generative_task.py rename to keras_nlp/models/causal_lm.py index 30e11a4655..98867e9ad2 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/causal_lm.py @@ -18,6 +18,7 @@ import tensorflow as tf import tree +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.backend import ops @@ -26,9 +27,46 @@ from keras_nlp.utils.tensor_utils import tensor_to_list -@keras.saving.register_keras_serializable(package="keras_nlp") -class GenerativeTask(Task): - """Base class for Generative Task models.""" +@keras_nlp_export("keras_nlp.models.CausalLM") +class CausalLM(Task): + """Base class for generative language modeling tasks. + + `CausalLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `CausalLM` tasks provide an additional, high-level `generate()` function + which can be used to auto-regressively sample a model token by token with a + string in, string out signature. The `compile()` method of all `CausalLM` + classes contains an additional `sampler` argument, which can be used to pass + a `keras_nlp.samplers.Sampler` to control how the predicted distribution + will be sampled. + + When calling `fit()`, the tokenized input will be predicted token-by-token + with a causal mask applied, which gives both a pre-training and supervised + fine-tuning setup for controlling inference-time generation. + + All `CausalLM` tasks include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. + + Example: + ```python + # Load a GPT2 backbone with pre-trained weights. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "gpt2_base_en", + ) + causal_lm.compile(sampler="top_k") + causal_lm.generate("Keras is a", max_length=64) + + # Load a Mistral instruction tuned checkpoint at bfloat16 precision. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "mistral_instruct_7b_en", + dtype="bfloat16", + ) + causal_lm.compile(sampler="greedy") + causal_lm.generate("Keras is a", max_length=64) + ``` + """ def compile( self, diff --git a/keras_nlp/models/classifier.py b/keras_nlp/models/classifier.py new file mode 100644 index 0000000000..f6c6a88720 --- /dev/null +++ b/keras_nlp/models/classifier.py @@ -0,0 +1,51 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.task import Task + + +@keras_nlp_export("keras_nlp.models.Classifier") +class Classifier(Task): + """Base class for all classification tasks. + + `Classifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + sequence classification. `Classifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `Classifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + # Load a BERT classifier with pre-trained weights. + classifier = keras_nlp.models.Classifier.from_preset( + "bert_base_en", + num_classes=2, + ) + # Fine-tune on IMDb movie reviews (or any dataset). + imdb_train, imdb_test = tfds.load( + "imdb_reviews", + split=["train", "test"], + as_supervised=True, + batch_size=16, + ) + classifier.fit(imdb_train, validation_data=imdb_test) + # Predict two new examples. + classifier.predict(["What an amazing movie!", "A total waste of my time."]) + ``` + """ diff --git a/keras_nlp/models/deberta_v3/__init__.py b/keras_nlp/models/deberta_v3/__init__.py index ba0c2545e4..4a5f6df5cd 100644 --- a/keras_nlp/models/deberta_v3/__init__.py +++ b/keras_nlp/models/deberta_v3/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone +from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets +from keras_nlp.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (DebertaV3Backbone, DebertaV3Tokenizer)) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index 9063b11df5..0029013098 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.deberta_v3.disentangled_attention_encoder import ( DisentangledAttentionEncoder, ) from keras_nlp.models.deberta_v3.relative_embedding import RelativeEmbedding -from keras_nlp.utils.python_utils import classproperty def deberta_kernel_initializer(stddev=0.02): @@ -208,7 +205,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index d6eea63601..e8cb7a60ed 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_backbone import ( deberta_kernel_initializer, @@ -23,13 +23,10 @@ from keras_nlp.models.deberta_v3.deberta_v3_preprocessor import ( DebertaV3Preprocessor, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Classifier") -class DebertaV3Classifier(Task): +class DebertaV3Classifier(Classifier): """An end-to-end DeBERTa model for classification tasks. This model attaches a classification head to a @@ -153,6 +150,9 @@ class DebertaV3Classifier(Task): ``` """ + backbone_cls = DebertaV3Backbone + preprocessor_cls = DebertaV3Preprocessor + def __init__( self, backbone, @@ -234,15 +234,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return DebertaV3Backbone - - @classproperty - def preprocessor_cls(cls): - return DebertaV3Preprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index a794c34374..7bb613b96c 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,13 +23,11 @@ from keras_nlp.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( DebertaV3MaskedLMPreprocessor, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.DebertaV3MaskedLM") -class DebertaV3MaskedLM(Task): +class DebertaV3MaskedLM(MaskedLM): """An end-to-end DeBERTaV3 model for the masked language modeling task. This model will train DeBERTaV3 on a masked language modeling task. @@ -98,6 +95,9 @@ class DebertaV3MaskedLM(Task): ``` """ + backbone_cls = DebertaV3Backbone + preprocessor_cls = DebertaV3MaskedLMPreprocessor + def __init__( self, backbone, @@ -138,15 +138,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return DebertaV3Backbone - - @classproperty - def preprocessor_cls(cls): - return DebertaV3MaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 88fa08fd70..67466ef948 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Preprocessor") @@ -147,6 +144,8 @@ class DebertaV3Preprocessor(Preprocessor): ``` """ + tokenizer_cls = DebertaV3Tokenizer + def __init__( self, tokenizer, @@ -202,11 +201,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return DebertaV3Tokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py index e66c373e65..9c14e3b618 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Tokenizer") @@ -151,7 +148,3 @@ def token_to_id(self, token): def detokenize(self, ids): ids = tf.ragged.boolean_mask(ids, tf.not_equal(ids, self.mask_token_id)) return super().detokenize(ids) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/__init__.py b/keras_nlp/models/distil_bert/__init__.py index ba0c2545e4..beb3742c9e 100644 --- a/keras_nlp/models/distil_bert/__init__.py +++ b/keras_nlp/models/distil_bert/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets +from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (DistilBertBackbone, DistilBertTokenizer)) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 73634b4216..4374707f13 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def distilbert_kernel_initializer(stddev=0.02): @@ -187,7 +184,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index e82aaf2781..f816e40a1b 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone from keras_nlp.models.distil_bert.distil_bert_backbone import ( distilbert_kernel_initializer, @@ -23,13 +23,10 @@ from keras_nlp.models.distil_bert.distil_bert_preprocessor import ( DistilBertPreprocessor, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertClassifier") -class DistilBertClassifier(Task): +class DistilBertClassifier(Classifier): """An end-to-end DistilBERT model for classification tasks. This model attaches a classification head to a @@ -140,6 +137,9 @@ class DistilBertClassifier(Task): ``` """ + backbone_cls = DistilBertBackbone + preprocessor_cls = DistilBertPreprocessor + def __init__( self, backbone, @@ -214,15 +214,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return DistilBertBackbone - - @classproperty - def preprocessor_cls(cls): - return DistilBertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index d99234a04f..80b4c17bb0 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,13 +23,11 @@ from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( DistilBertMaskedLMPreprocessor, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.DistilBertMaskedLM") -class DistilBertMaskedLM(Task): +class DistilBertMaskedLM(MaskedLM): """An end-to-end DistilBERT model for the masked language modeling task. This model will train DistilBERT on a masked language modeling task. @@ -98,6 +95,9 @@ class DistilBertMaskedLM(Task): ``` """ + backbone_cls = DistilBertBackbone + preprocessor_cls = DistilBertMaskedLMPreprocessor + def __init__( self, backbone, @@ -140,15 +140,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return DistilBertBackbone - - @classproperty - def preprocessor_cls(cls): - return DistilBertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 63f4e3637b..1163f7f029 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) @@ -27,7 +25,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertPreprocessor") @@ -118,6 +115,8 @@ class DistilBertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = DistilBertTokenizer + def __init__( self, tokenizer, @@ -173,11 +172,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return DistilBertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index 29eb92e3ba..c03792f761 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertTokenizer") @@ -111,10 +108,6 @@ def set_vocabulary(self, vocabulary): self.pad_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() del config["special_tokens"] # Not configurable; set in __init__. diff --git a/keras_nlp/models/electra/__init__.py b/keras_nlp/models/electra/__init__.py index ba0c2545e4..0717b97a72 100644 --- a/keras_nlp/models/electra/__init__.py +++ b/keras_nlp/models/electra/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.electra.electra_backbone import ElectraBackbone +from keras_nlp.models.electra.electra_presets import backbone_presets +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (ElectraBackbone, ElectraTokenizer)) diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 7ecca892d9..3ee88de826 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.electra.electra_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def electra_kernel_initializer(stddev=0.02): @@ -246,7 +242,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/electra/electra_preprocessor.py b/keras_nlp/models/electra/electra_preprocessor.py index 1e3ac2454c..2ee3e294d8 100644 --- a/keras_nlp/models/electra/electra_preprocessor.py +++ b/keras_nlp/models/electra/electra_preprocessor.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.electra.electra_presets import backbone_presets from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.ElectraPreprocessor") @@ -116,6 +112,8 @@ class ElectraPreprocessor(Preprocessor): ``` """ + tokenizer_cls = ElectraTokenizer + def __init__( self, tokenizer, @@ -153,11 +151,3 @@ def call(self, x, y=None, sample_weight=None): "padding_mask": token_ids != self.tokenizer.pad_token_id, } return pack_x_y_sample_weight(x, y, sample_weight) - - @classproperty - def tokenizer_cls(cls): - return ElectraTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets}) diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index 12f5ecfec6..583b756165 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.electra.electra_presets import backbone_presets from keras_nlp.tokenizers import WordPieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.ElectraTokenizer") @@ -106,7 +102,3 @@ def get_config(self): config = super().get_config() del config["special_tokens"] # Not configurable; set in __init__. return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/__init__.py b/keras_nlp/models/f_net/__init__.py index ba0c2545e4..7921ed6ca3 100644 --- a/keras_nlp/models/f_net/__init__.py +++ b/keras_nlp/models/f_net/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_presets import backbone_presets +from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (FNetBackbone, FNetTokenizer)) diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index ab056c84c7..91034a4359 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -20,9 +19,7 @@ from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def f_net_kernel_initializer(stddev=0.02): @@ -234,7 +231,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index 512182d2cd..a5c0bf6525 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor -from keras_nlp.models.f_net.f_net_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetClassifier") -class FNetClassifier(Task): +class FNetClassifier(Classifier): """An end-to-end f_net model for classification tasks. This model attaches a classification head to a @@ -100,6 +97,9 @@ class FNetClassifier(Task): ``` """ + backbone_cls = FNetBackbone + preprocessor_cls = FNetPreprocessor + def __init__( self, backbone, @@ -162,15 +162,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return FNetBackbone - - @classproperty - def preprocessor_cls(cls): - return FNetPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index 4a0ec5e254..83c7e62719 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -22,13 +21,11 @@ from keras_nlp.models.f_net.f_net_masked_lm_preprocessor import ( FNetMaskedLMPreprocessor, ) -from keras_nlp.models.f_net.f_net_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.FNetMaskedLM") -class FNetMaskedLM(Task): +class FNetMaskedLM(MaskedLM): """An end-to-end FNet model for the masked language modeling task. This model will train FNet on a masked language modeling task. @@ -95,6 +92,9 @@ class FNetMaskedLM(Task): ``` """ + backbone_cls = FNetBackbone + preprocessor_cls = FNetMaskedLMPreprocessor + def __init__( self, backbone, @@ -137,15 +137,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return FNetBackbone - - @classproperty - def preprocessor_cls(cls): - return FNetMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index b4cb5836bb..dfe1b71cfe 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetPreprocessor") @@ -120,6 +117,8 @@ class FNetPreprocessor(Preprocessor): ``` """ + tokenizer_cls = FNetTokenizer + def __init__( self, tokenizer, @@ -175,11 +174,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return FNetTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_tokenizer.py b/keras_nlp/models/f_net/f_net_tokenizer.py index ae3f569b1d..12a055b16c 100644 --- a/keras_nlp/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/models/f_net/f_net_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetTokenizer") @@ -94,7 +91,3 @@ def set_proto(self, proto): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/falcon/__init__.py b/keras_nlp/models/falcon/__init__.py index 3364a6bd16..cfc0b821cb 100644 --- a/keras_nlp/models/falcon/__init__.py +++ b/keras_nlp/models/falcon/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 The KerasNLP Authors +# Copyright 2023 The KerasNLP Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.models.falcon.falcon_presets import backbone_presets +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (FalconBackbone, FalconTokenizer)) diff --git a/keras_nlp/models/falcon/falcon_preprocessor.py b/keras_nlp/models/falcon/falcon_preprocessor.py index b37d641467..8a14f3c255 100644 --- a/keras_nlp/models/falcon/falcon_preprocessor.py +++ b/keras_nlp/models/falcon/falcon_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.falcon.falcon_presets import backbone_presets from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FalconPreprocessor") @@ -109,6 +106,8 @@ class FalconPreprocessor(Preprocessor): ``` """ + tokenizer_cls = FalconTokenizer + def __init__( self, tokenizer, @@ -185,11 +184,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return FalconTokenizer diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py index 3201d27a63..80d7334fe7 100644 --- a/keras_nlp/models/falcon/falcon_tokenizer.py +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.falcon.falcon_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FalconTokenizer") @@ -104,10 +101,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.start_token_id = None self.pad_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/gemma/__init__.py b/keras_nlp/models/gemma/__init__.py index ba0c2545e4..b390926a21 100644 --- a/keras_nlp/models/gemma/__init__.py +++ b/keras_nlp/models/gemma/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (GemmaBackbone, GemmaTokenizer)) diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index 8e4bac126a..a7973f9dec 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config @@ -21,9 +20,7 @@ from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone from keras_nlp.models.gemma.gemma_decoder_block import GemmaDecoderBlock -from keras_nlp.models.gemma.gemma_presets import backbone_presets from keras_nlp.models.gemma.rms_normalization import RMSNormalization -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GemmaBackbone") @@ -189,10 +186,6 @@ def get_config(self): ) return config - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - @staticmethod def get_layout_map( device_mesh, diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py index 30d0171844..34b0a43126 100644 --- a/keras_nlp/models/gemma/gemma_causal_lm.py +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( GemmaCausalLMPreprocessor, ) -from keras_nlp.models.gemma.gemma_presets import backbone_presets -from keras_nlp.models.generative_task import GenerativeTask -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GemmaCausalLM") -class GemmaCausalLM(GenerativeTask): +class GemmaCausalLM(CausalLM): """An end-to-end Gemma model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -148,6 +145,9 @@ class GemmaCausalLM(GenerativeTask): ``` """ + backbone_cls = GemmaBackbone + preprocessor_cls = GemmaCausalLMPreprocessor + def __init__( self, backbone, @@ -177,18 +177,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return GemmaBackbone - - @classproperty - def preprocessor_cls(cls): - return GemmaCausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py index 8fc3beb48c..86db9a4e81 100644 --- a/keras_nlp/models/gemma/gemma_preprocessor.py +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.gemma.gemma_presets import backbone_presets from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GemmaPreprocessor") @@ -124,6 +121,8 @@ class GemmaPreprocessor(Preprocessor): ``` """ + tokenizer_cls = GemmaTokenizer + def __init__( self, tokenizer, @@ -189,11 +188,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return GemmaTokenizer diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py index 6a4bb76ea0..7722d35f35 100644 --- a/keras_nlp/models/gemma/gemma_tokenizer.py +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.gemma.gemma_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GemmaTokenizer") @@ -102,7 +99,3 @@ def set_proto(self, proto): self.start_token_id = None self.end_token_id = None self.pad_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/__init__.py b/keras_nlp/models/gpt2/__init__.py index ba0c2545e4..ad86022a86 100644 --- a/keras_nlp/models/gpt2/__init__.py +++ b/keras_nlp/models/gpt2/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_presets import backbone_presets +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (GPT2Backbone, GPT2Tokenizer)) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index b7d2b10acf..49929e4091 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -20,9 +19,7 @@ from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def _gpt_2_kernel_initializer(stddev=0.02): @@ -197,7 +194,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 18e6ead7a2..41728f7433 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( GPT2CausalLMPreprocessor, ) -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPT2CausalLM") -class GPT2CausalLM(GenerativeTask): +class GPT2CausalLM(CausalLM): """An end-to-end GPT2 model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -150,6 +147,9 @@ class GPT2CausalLM(GenerativeTask): ``` """ + backbone_cls = GPT2Backbone + preprocessor_cls = GPT2CausalLMPreprocessor + def __init__( self, backbone, @@ -178,18 +178,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return GPT2Backbone - - @classproperty - def preprocessor_cls(cls): - return GPT2CausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 82be34776f..4641a32020 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPT2Preprocessor") @@ -109,6 +106,8 @@ class GPT2Preprocessor(Preprocessor): ``` """ + tokenizer_cls = GPT2Tokenizer + def __init__( self, tokenizer, @@ -185,11 +184,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return GPT2Tokenizer diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 15b35bed87..4a585c3176 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPT2Tokenizer") @@ -104,10 +101,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.start_token_id = None self.pad_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index 797f3e2eb3..119e51cc75 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -15,17 +15,16 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_nlp.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( GPTNeoXCausalLMPreprocessor, ) -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPTNeoXCausalLM") -class GPTNeoXCausalLM(GenerativeTask): +class GPTNeoXCausalLM(CausalLM): """An end-to-end GPTNeoX model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -47,6 +46,9 @@ class GPTNeoXCausalLM(GenerativeTask): should be preprocessed before calling the model. """ + backbone_cls = GPTNeoXBackbone + preprocessor_cls = GPTNeoXCausalLMPreprocessor + def __init__( self, backbone, @@ -75,14 +77,6 @@ def __init__( jit_compile=True, ) - @classproperty - def backbone_cls(cls): - return GPTNeoXBackbone - - @classproperty - def preprocessor_cls(cls): - return GPTNeoXCausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 8dc374332b..d5406a6a3d 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -20,7 +20,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPTNeoXPreprocessor") @@ -65,6 +64,8 @@ class GPTNeoXPreprocessor(Preprocessor): the layer. """ + tokenizer_cls = GPTNeoXTokenizer + def __init__( self, tokenizer, @@ -141,7 +142,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return GPTNeoXTokenizer diff --git a/keras_nlp/models/llama/__init__.py b/keras_nlp/models/llama/__init__.py index ba0c2545e4..3a57fccd42 100644 --- a/keras_nlp/models/llama/__init__.py +++ b/keras_nlp/models/llama/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (LlamaBackbone, LlamaTokenizer)) diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index ec35989e01..e586fa97f5 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops @@ -20,8 +18,6 @@ from keras_nlp.models.backbone import Backbone from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm -from keras_nlp.models.llama.llama_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def _llama_kernel_initializer(stddev=0.02): @@ -187,7 +183,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index 7527766f01..7f17645618 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -11,23 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.llama.llama_backbone import LlamaBackbone from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( LlamaCausalLMPreprocessor, ) -from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.LlamaCausalLM") -class LlamaCausalLM(GenerativeTask): +class LlamaCausalLM(CausalLM): """An end-to-end Llama model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -215,7 +212,3 @@ def next(prompt, cache, index): "token_ids": token_ids, "padding_mask": padding_mask, } - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py index a24c425082..f3aaa208a8 100644 --- a/keras_nlp/models/llama/llama_preprocessor.py +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.LlamaPreprocessor") @@ -113,6 +109,8 @@ class LlamaPreprocessor(Preprocessor): ``` """ + tokenizer_cls = LlamaTokenizer + def __init__( self, tokenizer, @@ -188,11 +186,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return LlamaTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/masked_lm.py b/keras_nlp/models/masked_lm.py new file mode 100644 index 0000000000..136dbf0b8e --- /dev/null +++ b/keras_nlp/models/masked_lm.py @@ -0,0 +1,42 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.task import Task + + +@keras_nlp_export("keras_nlp.models.MaskedLM") +class MaskedLM(Task): + """Base class for masked language modeling tasks. + + `MaskedLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + unsupervised fine-tuning with a masked language modeling loss. + + When calling `fit()`, all input will be tokenized, and random tokens in + the input sequence will be masked. These positions of these masked tokens + will be fed as an additional model input, and the original value of the + tokens predicted by the model outputs. + + All `MaskedLM` tasks include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. + + Example: + ```python + # Load a Bert MaskedLM with pre-trained weights. + masked_lm = keras_nlp.models.MaskedLM.from_preset( + "bert_base_en", + ) + masked_lm.fit(train_ds) + ``` + """ diff --git a/keras_nlp/models/mistral/__init__.py b/keras_nlp/models/mistral/__init__.py index ba0c2545e4..2593ab0201 100644 --- a/keras_nlp/models/mistral/__init__.py +++ b/keras_nlp/models/mistral/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_presets import backbone_presets +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (MistralBackbone, MistralTokenizer)) diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 52de945760..1ee2dfce66 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,11 +20,9 @@ from keras_nlp.models.mistral.mistral_layer_norm import ( MistralLayerNormalization, ) -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_transformer_decoder import ( MistralTransformerDecoder, ) -from keras_nlp.utils.python_utils import classproperty def _mistral_kernel_initializer(stddev=0.02): @@ -201,7 +198,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index 20e19e6c31..754c07d2a5 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -11,23 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( MistralCausalLMPreprocessor, ) -from keras_nlp.models.mistral.mistral_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.MistralCausalLM") -class MistralCausalLM(GenerativeTask): +class MistralCausalLM(CausalLM): """An end-to-end Mistral model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -49,6 +46,9 @@ class MistralCausalLM(GenerativeTask): should be preprocessed before calling the model. """ + backbone_cls = MistralBackbone + preprocessor_cls = MistralCausalLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -72,14 +72,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): jit_compile=True, ) - @classproperty - def backbone_cls(cls): - return MistralBackbone - - @classproperty - def preprocessor_cls(cls): - return MistralCausalLMPreprocessor - def call_with_cache( self, token_ids, @@ -341,7 +333,3 @@ def default_layer_intercept_fn(x, unused_i): ) per_token_loss = per_token_loss_fn(target_ids, logits) return per_token_loss - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index 38dc6da5b6..3df849fd06 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -11,18 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.MistralPreprocessor") @@ -113,6 +110,8 @@ class MistralPreprocessor(Preprocessor): ``` """ + tokenizer_cls = MistralTokenizer + def __init__( self, tokenizer, @@ -188,11 +187,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return MistralTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index c7d27684f8..e91a20df88 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.MistralTokenizer") @@ -81,7 +78,3 @@ def set_proto(self, proto): else: self.start_token_id = None self.end_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/opt/__init__.py b/keras_nlp/models/opt/__init__.py index ba0c2545e4..a81c5f3bc1 100644 --- a/keras_nlp/models/opt/__init__.py +++ b/keras_nlp/models/opt/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.opt.opt_backbone import OPTBackbone +from keras_nlp.models.opt.opt_presets import backbone_presets +from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (OPTBackbone, OPTTokenizer)) diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 16fe4a0218..16bd89c4d5 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.opt.opt_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def opt_kernel_initializer(stddev=0.02): @@ -169,7 +166,3 @@ def get_config(self): "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, } - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 6133ef227e..1bb5bd1e87 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( OPTCausalLMPreprocessor, ) -from keras_nlp.models.opt.opt_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.OPTCausalLM") -class OPTCausalLM(GenerativeTask): +class OPTCausalLM(CausalLM): """An end-to-end OPT model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -150,6 +147,9 @@ class OPTCausalLM(GenerativeTask): ``` """ + backbone_cls = OPTBackbone + preprocessor_cls = OPTCausalLMPreprocessor + def __init__( self, backbone, @@ -178,18 +178,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return OPTBackbone - - @classproperty - def preprocessor_cls(cls): - return OPTCausalLMPreprocessor - def call_with_cache( self, token_ids, diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index 8f52bb67e6..9eae445aec 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.opt.opt_presets import backbone_presets from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.OPTPreprocessor") @@ -109,6 +106,8 @@ class OPTPreprocessor(Preprocessor): ``` """ + tokenizer_cls = OPTTokenizer + def __init__( self, tokenizer, @@ -186,11 +185,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return OPTTokenizer diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index 4fb62ee73a..addcd0c01f 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.opt.opt_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.OPTTokenizer") @@ -110,10 +107,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.pad_token_id = None self.end_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 031a884e1b..a91bc6fb26 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -12,19 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Preprocessor") class Preprocessor(PreprocessingLayer): - """Base class for model preprocessors.""" + """Base class for preprocessing layers. + + A `Preprocessor` layer wraps a `keras_nlp.tokenizer.Tokenizer` to provide a + complete preprocessing setup for a given task. For example a masked language + modeling preprocessor will take in raw input strings, and output + `(x, y, sample_weight)` tuples. Where `x` contains token id sequences with + some + + This class can be subclassed to implement + """ + + tokenizer_cls = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -56,13 +69,15 @@ def from_config(cls, config): config["tokenizer"] = keras.layers.deserialize(config["tokenizer"]) return cls(**config) - @classproperty - def tokenizer_cls(cls): - return None - @classproperty def presets(cls): - return {} + presets = list_presets(cls) + # We can also load backbone presets. + if cls.tokenizer_cls is not None: + presets.update(cls.tokenizer_cls.presets) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -70,50 +85,68 @@ def from_preset( preset, **kwargs, ): - """Instantiate {{preprocessor_name}} from preset architecture. + """Instantiate a `keras_nlp.models.Preprocessor` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Preprocessor` subclass, you can run `cls.presets.keys()` to + list all built-in presets available on the class. + + As there are usually multiple preprocessing classes for a given model, + this method should be called on a specific subclass like + `keras_nlp.models.BertPreprocessor.from_preset()`. Args: - preset: string. Must be one of "{{preset_names}}". + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. - Example: + Examples: ```python - # Load a preprocessor layer from a preset. - preprocessor = keras_nlp.models.{{preprocessor_name}}.from_preset( - "{{example_preset_name}}", + # Load a preprocessor for Gemma generation. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en", + ) + + # Load a preprocessor for Bert classification. + preprocessor = keras_nlp.models.BertPreprocessor.from_preset( + "bert_base_en", ) ``` """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - + if cls == Preprocessor: + raise ValueError( + "Do not call `Preprocessor.from_preset()` directly. Instead call a " + "choose a particular task class, e.g. " + "`keras_nlp.models.BertPreprocessor.from_preset()`." + ) config_file = "tokenizer.json" - check_preset_class(preset, cls.tokenizer_cls, config_file=config_file) + preset_cls = check_config_class(preset, config_file=config_file) + subclasses = list_subclasses(cls) + subclasses = tuple( + filter(lambda x: x.tokenizer_cls == preset_cls, subclasses) + ) + if len(subclasses) == 0: + raise ValueError( + f"No registered subclass of `{cls.__name__}` can load " + f"a `{preset_cls.__name__}`." + ) + if len(subclasses) > 1: + names = ", ".join(f"`{x.__name__}`" for x in subclasses) + raise ValueError( + f"Ambiguous call to `{cls.__name__}.from_preset()`. " + f"Found multiple possible subclasses {names}. " + "Please call `from_preset` on a subclass directly." + ) + cls = subclasses[0] tokenizer = load_from_preset( preset, config_file=config_file, ) return cls(tokenizer=tokenizer, **kwargs) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Preprocessor.from_preset.__doc__ - format_docstring( - preprocessor_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/models/preprocessor_test.py b/keras_nlp/models/preprocessor_test.py new file mode 100644 index 0000000000..ea78720470 --- /dev/null +++ b/keras_nlp/models/preprocessor_test.py @@ -0,0 +1,37 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.tests.test_case import TestCase + + +class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertPreprocessor.presets.keys()) + gpt2_presets = set(GPT2Preprocessor.presets.keys()) + all_presets = set(Preprocessor.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + # No loading on a preprocessor directly (it is ambiguous). + Preprocessor.from_preset("bert_tiny_en_uncased", load_weights=False) + with self.assertRaises(ValueError): + # No loading on an incorrect class. + BertPreprocessor.from_preset("gpt2_base_en", load_weights=False) diff --git a/keras_nlp/models/roberta/__init__.py b/keras_nlp/models/roberta/__init__.py index ba0c2545e4..1fe69366c7 100644 --- a/keras_nlp/models/roberta/__init__.py +++ b/keras_nlp/models/roberta/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone +from keras_nlp.models.roberta.roberta_presets import backbone_presets +from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (RobertaBackbone, RobertaTokenizer)) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 09fe753762..6c25f2ae6f 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def roberta_kernel_initializer(stddev=0.02): @@ -184,7 +181,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 887bc657d4..57f50f4e94 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaClassifier") -class RobertaClassifier(Task): +class RobertaClassifier(Classifier): """An end-to-end RoBERTa model for classification tasks. This model attaches a classification head to a @@ -134,6 +131,9 @@ class RobertaClassifier(Task): ``` """ + backbone_cls = RobertaBackbone + preprocessor_cls = RobertaPreprocessor + def __init__( self, backbone, @@ -213,15 +213,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return RobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return RobertaPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index bf96189860..ef6660f777 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer from keras_nlp.models.roberta.roberta_masked_lm_preprocessor import ( RobertaMaskedLMPreprocessor, ) -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaMaskedLM") -class RobertaMaskedLM(Task): +class RobertaMaskedLM(MaskedLM): """An end-to-end RoBERTa model for the masked language modeling task. This model will train RoBERTa on a masked language modeling task. @@ -97,6 +94,9 @@ class RobertaMaskedLM(Task): ``` """ + backbone_cls = RobertaBackbone + preprocessor_cls = RobertaMaskedLMPreprocessor + def __init__( self, backbone, @@ -139,15 +139,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return RobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return RobertaMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 57a421590f..f0e027a910 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaPreprocessor") @@ -133,6 +130,8 @@ class RobertaPreprocessor(Preprocessor): ``` """ + tokenizer_cls = RobertaTokenizer + def __init__( self, tokenizer, @@ -190,11 +189,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return RobertaTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py index 0cfabff754..acf7f0aef9 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/models/roberta/roberta_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaTokenizer") @@ -126,10 +123,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.end_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/seq_2_seq_lm.py b/keras_nlp/models/seq_2_seq_lm.py new file mode 100644 index 0000000000..7c2ee9caa7 --- /dev/null +++ b/keras_nlp/models/seq_2_seq_lm.py @@ -0,0 +1,54 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.causal_lm import CausalLM + + +@keras_nlp_export("keras_nlp.models.Seq2SeqLM") +class Seq2SeqLM(CausalLM): + """Base class for sequence to sequence language modeling tasks. + + `Seq2SeqLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning, when generation is conditioned on + additional input sequence in a sequence-to-sequence setting. + + `Seq2SeqLM` tasks provide an additional, high-level `generate()` function + which can be used to auto-regressively sample an output sequence token by + token. The `compile()` method of `Seq2SeqLM` classes contains an additional + `sampler` argument, which can be used to pass a `keras_nlp.samplers.Sampler` + to control how the predicted distribution will be sampled. + + When calling `fit()`, each input should contain an input and output + sequence. The model will be trained to predict the output sequence + token-by-token using a causal mask, similar to a `keras_nlp.models.CausalLM` + task. Unlike the `CausalLM` task, an input sequence must be passed, and + can be attended to in full by all tokens in the output sequence. + + All `Seq2SeqLM` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + # Load a Bart backbone with pre-trained weights. + seq_2_seq_lm = keras_nlp.models.Seq2SeqLM.from_preset( + "bart_base_en", + ) + seq_2_seq_lm.compile(sampler="top_k") + # Generate conditioned on the `"The quick brown fox."` as an input sequence. + seq_2_seq_lm.generate("The quick brown fox.", max_length=30) + ``` + """ + + # TODO: fill in during https://github.com/keras-team/keras-nlp/pull/1425 diff --git a/keras_nlp/models/t5/__init__.py b/keras_nlp/models/t5/__init__.py index ba0c2545e4..10064e442b 100644 --- a/keras_nlp/models/t5/__init__.py +++ b/keras_nlp/models/t5/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.t5.t5_backbone import T5Backbone +from keras_nlp.models.t5.t5_presets import backbone_presets +from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (T5Backbone, T5Tokenizer)) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 862c4766f4..67d3b2a6a9 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm -from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.T5Backbone") @@ -259,7 +256,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py index b5dee49b85..ec5f0bf324 100644 --- a/keras_nlp/models/t5/t5_tokenizer.py +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.T5Tokenizer") @@ -99,7 +96,3 @@ def set_proto(self, proto): self.end_token_id = None self.pad_token_id = None self.start_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 9957f6546f..c0158b24d2 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -19,17 +19,38 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras +from keras_nlp.models.backbone import Backbone from keras_nlp.utils.keras_utils import print_msg from keras_nlp.utils.pipeline_model import PipelineModel -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring @keras_nlp_export("keras_nlp.models.Task") class Task(PipelineModel): - """Base class for Task models.""" + """Base class for all Task models. + + A `Task` wraps a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be directly + used for training, fine-tuning, and prediction for a given text problem. + + All `Task` models have `backbone` and `preprocessor` properties. By + default `fit()`, `predict()` and `evaluate()` will preprocess all inputs + automatically. To preprocess inputs separately or with a custom function, + you can set `task.preprocessor = None`, which disable any automatic + preprocessing on inputs. + + All `Task` classes include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. Calling `from_preset()` on a task + will automatically instantiate a `keras_nlp.models.Backbone` and + `keras_nlp.models.Preprocessor`. + """ + + backbone_cls = None + preprocessor_cls = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -173,17 +194,16 @@ def from_config(cls, config): ) return cls(**config) - @classproperty - def backbone_cls(cls): - return None - - @classproperty - def preprocessor_cls(cls): - return None - @classproperty def presets(cls): - return {} + """List builtin presets for a `Task` subclass.""" + presets = list_presets(cls) + # We can also load backbone presets. + if cls.backbone_cls is not None: + presets.update(cls.backbone_cls.presets) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -192,25 +212,54 @@ def from_preset( load_weights=True, **kwargs, ): - """Instantiate {{model_task_name}} model from preset architecture and weights. + """Instantiate a `keras_nlp.models.Task` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Task` subclass, you can run `cls.presets.keys()` to list all + built-in presets available on the class. + + This constructor can be called in one of two ways. Either from a task + specific base class like `keras_nlp.models.CausalLM.from_preset()`, or + from a model class like `keras_nlp.models.BertClassifier.from_preset()`. + If calling from the a base class, the subclass of the returning object + will be inferred from the config in the preset directory. Args: - preset: string. Must be one of "{{preset_names}}". - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. Examples: ```python - # Load architecture and weights from preset - model = {{model_task_name}}.from_preset("{{example_preset_name}}") + # Load a Gemma generative task. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "gemma_2b_en", + ) - # Load randomly initialized model from preset architecture - model = {{model_task_name}}.from_preset( - "{{example_preset_name}}", - load_weights=False + # Load a Bert classification task. + model = keras_nlp.models.Classifier.from_preset( + "bert_base_en", + num_classes=2, ) ``` """ + if cls == Task: + raise ValueError( + "Do not call `Task.from_preset()` directly. Instead call a " + "particular task class, e.g. " + "`keras_nlp.models.Classifier.from_preset()` or " + "`keras_nlp.models.BertClassifier.from_preset()`." + ) if "backbone" in kwargs: raise ValueError( "You cannot pass a `backbone` argument to the `from_preset` " @@ -218,15 +267,28 @@ def from_preset( "constructor with a `backbone` argument. " f"Received: backbone={kwargs['backbone']}." ) - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - preset_cls = check_preset_class(preset, (cls, cls.backbone_cls)) + preset_cls = check_config_class(preset) # Backbone case. - if preset_cls == cls.backbone_cls: + if issubclass(preset_cls, Backbone): + if preset_cls is not cls.backbone_cls: + subclasses = list_subclasses(cls) + subclasses = tuple( + filter(lambda x: x.backbone_cls == preset_cls, subclasses) + ) + if len(subclasses) == 0: + raise ValueError( + f"No registered subclass of `{cls.__name__}` can load " + f"a `{preset_cls.__name__}`." + ) + if len(subclasses) > 1: + names = ", ".join(f"`{x.__name__}`" for x in subclasses) + raise ValueError( + f"Ambiguous call to `{cls.__name__}.from_preset()`. " + f"Found multiple possible subclasses {names}. " + "Please call `from_preset` on a subclass directly." + ) + cls = subclasses[0] # Forward dtype to the backbone. config_overrides = {} if "dtype" in kwargs: @@ -247,34 +309,18 @@ def from_preset( return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) # Task case. + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) return load_from_preset( preset, load_weights=load_weights, config_overrides=kwargs, ) - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define `from_preset`, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Task.from_preset.__doc__ - format_docstring( - model_task_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) - @property def layers(self): # Remove preprocessor from layers so it does not show up in the summary. diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index bf82e4fa68..63fd189c12 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from keras_nlp.backend import keras +from keras_nlp.models.bert.bert_classifier import BertClassifier +from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.task import Task from keras_nlp.tests.test_case import TestCase @@ -40,6 +44,22 @@ def __init__(self, preprocessor=None, activation=None, **kwargs): class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertClassifier.presets.keys()) + gpt2_presets = set(GPT2CausalLM.presets.keys()) + all_presets = set(Task.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + # No loading on a task directly (it is ambiguous). + Task.from_preset("bert_tiny_en_uncased", load_weights=False) + with self.assertRaises(ValueError): + # No loading on an incorrect class. + BertClassifier.from_preset("gpt2_base_en", load_weights=False) + def test_summary_with_preprocessor(self): preprocessor = SimplePreprocessor() model = SimpleTask(preprocessor) diff --git a/keras_nlp/models/whisper/__init__.py b/keras_nlp/models/whisper/__init__.py index ba0c2545e4..328814046c 100644 --- a/keras_nlp/models/whisper/__init__.py +++ b/keras_nlp/models/whisper/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone +from keras_nlp.models.whisper.whisper_presets import backbone_presets +from keras_nlp.models.whisper.whisper_tokenizer import WhisperTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer)) diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index e41519bbc9..fb2c1a9cdd 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import numpy as np import tensorflow as tf @@ -21,9 +20,6 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_nlp.models.whisper.whisper_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring @keras_nlp_export("keras_nlp.models.WhisperAudioFeatureExtractor") @@ -265,52 +261,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate Whisper audio feature extractor from a given preset. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Examples: - ```python - # Load a preset tokenizer. - audio_feature_extractor = WhisperAudioFeatureExtractor.from_preset( - "{{example_preset_name}}" - ) - - # Compute the log-mel spectrogram. - audio_tensor = tf.ones((8000,), dtype=tf.float32) - audio_feature_extractor(audio_tensor) - ``` - """ - - if not cls.presets: - raise NotImplementedError( - "No presets have been created for this class" - ) - - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - - config = cls.presets[preset]["audio_feature_extractor_config"] - - return cls.from_config({**config, **kwargs}) - - -format_docstring( - example_preset_name=next(iter(backbone_presets), ""), - preset_names='", "'.join(backbone_presets), -)(WhisperAudioFeatureExtractor.from_preset.__func__) diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index a2b685544e..3f5a017f14 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,8 +23,6 @@ from keras_nlp.models.backbone import Backbone from keras_nlp.models.whisper.whisper_decoder import WhisperDecoder from keras_nlp.models.whisper.whisper_encoder import WhisperEncoder -from keras_nlp.models.whisper.whisper_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import assert_tf_backend @@ -305,7 +302,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index c21705a481..a78ea96639 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from absl import logging @@ -23,13 +22,11 @@ from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) -from keras_nlp.models.whisper.whisper_presets import backbone_presets from keras_nlp.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.WhisperPreprocessor") @@ -154,6 +151,8 @@ class WhisperPreprocessor(Preprocessor): ``` """ + tokenizer_cls = WhisperTokenizer + def __init__( self, tokenizer, @@ -326,15 +325,3 @@ def sequence_length(self): @sequence_length.setter def sequence_length(self, value): self.decoder_sequence_length = value - - @classproperty - def audio_feature_extractor_cls(cls): - return WhisperAudioFeatureExtractor - - @classproperty - def tokenizer_cls(cls): - return WhisperTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py index 7b68dfd790..f14fd1ee98 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/models/whisper/whisper_tokenizer.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import json from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.whisper.whisper_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty def _load_dict(dict_or_path): @@ -164,7 +161,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/__init__.py b/keras_nlp/models/xlm_roberta/__init__.py index ba0c2545e4..f8be3006ad 100644 --- a/keras_nlp/models/xlm_roberta/__init__.py +++ b/keras_nlp/models/xlm_roberta/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets +from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (XLMRobertaBackbone, XLMRobertaTokenizer)) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index c74a0fd6fc..28d7f18728 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.roberta import roberta_backbone -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaBackbone") @@ -82,7 +79,3 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): model(input_data) ``` """ - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index fcd8bfe9b8..14d41d233c 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -12,22 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer -from keras_nlp.models.task import Task from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import ( XLMRobertaPreprocessor, ) -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaClassifier") -class XLMRobertaClassifier(Task): +class XLMRobertaClassifier(Classifier): """An end-to-end XLM-RoBERTa model for classification tasks. This model attaches a classification head to a @@ -147,6 +144,9 @@ def train_sentencepiece(ds, vocab_size): ``` """ + backbone_cls = XLMRobertaBackbone + preprocessor_cls = XLMRobertaPreprocessor + def __init__( self, backbone, @@ -229,15 +229,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return XLMRobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return XLMRobertaPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index e6b5a45bb5..e687bac525 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer -from keras_nlp.models.task import Task from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( XLMRobertaMaskedLMPreprocessor, ) -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLM") -class XLMRobertaMaskedLM(Task): +class XLMRobertaMaskedLM(MaskedLM): """An end-to-end XLM-RoBERTa model for the masked language modeling task. This model will train XLM-RoBERTa on a masked language modeling task. @@ -100,6 +97,9 @@ class XLMRobertaMaskedLM(Task): ``` """ + backbone_cls = XLMRobertaBackbone + preprocessor_cls = XLMRobertaMaskedLMPreprocessor + def __init__( self, backbone, @@ -142,15 +142,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return XLMRobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return XLMRobertaMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index c94f5f2421..7901e35720 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, ) @@ -27,7 +25,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaPreprocessor") @@ -146,6 +143,8 @@ def train_sentencepiece(ds, vocab_size): ``` """ + tokenizer_cls = XLMRobertaTokenizer + def __init__( self, tokenizer, @@ -203,11 +202,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return XLMRobertaTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index 576f30bca1..8a6a592937 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import tensor_to_list @@ -187,7 +184,3 @@ def detokenize(self, inputs): # the `detokenize` method will return empty strings for these tokens. # This is a vagary of the `sentencepiece` library. return super().detokenize(tokens) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 6b88757c64..91ea6aba76 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -458,8 +458,6 @@ def run_preset_test( expected_partial_output=None, ): """Run instantiation and a forward pass for a preset.""" - self.assertRegex(cls.from_preset.__doc__, preset) - with self.assertRaises(Exception): cls.from_preset("clowntown", **init_kwargs) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index a8dbc51361..0142062d4f 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -29,10 +29,6 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -649,67 +645,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Example: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, - ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - BytePairTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index 20a73d6af5..e52b5633a9 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -21,10 +21,6 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -259,67 +255,3 @@ def detokenize(self, inputs): if unbatched: outputs = tf.squeeze(outputs, 0) return outputs - - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Example: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, - ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - SentencePieceTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 834b99e5b1..bfd18781ec 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -19,10 +19,20 @@ PreprocessingLayer, ) from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses +from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.preset_utils import save_to_preset +from keras_nlp.utils.python_utils import classproperty -@keras_nlp_export("keras_nlp.tokenizers.Tokenizer") +@keras_nlp_export( + [ + "keras_nlp.models.Tokenizer", + "keras_nlp.tokenizers.Tokenizer", + ] +) class Tokenizer(PreprocessingLayer): """A base class for tokenizer layers. @@ -133,3 +143,70 @@ def save_to_preset(self, preset): def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) + + @classproperty + def presets(cls): + """List builtin presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets + + @classmethod + def from_preset( + cls, + preset, + **kwargs, + ): + """Instantiate a `keras_nlp.models.Tokenizer` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Tokenizer` subclass, you can run `cls.presets.keys()` to list + all built-in presets available on the class. + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.Tokenizer.from_preset()`, or from + a model class like `keras_nlp.models.GemmaTokenizer.from_preset()`. + If calling from the base class, the subclass of the returning object + will be inferred from the config in the preset directory. + + Args: + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. + + Examples: + ```python + # Load a preset tokenizer. + tokenizer = keras_nlp.tokenizerTokenizer.from_preset("bert_base_en") + + # Tokenize some input. + tokenizer("The quick brown fox tripped.") + + # Detokenize some input. + tokenizer.detokenize([5, 6, 7, 8, 9]) + ``` + """ + config_file = "tokenizer.json" + preset_cls = check_config_class(preset, config_file=config_file) + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) + return load_from_preset( + preset, + config_file=config_file, + config_overrides=kwargs, + ) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index 509503d921..f9611e5eb8 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import tensorflow as tf +from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.tests.test_case import TestCase from keras_nlp.tokenizers.tokenizer import Tokenizer @@ -29,6 +32,29 @@ def detokenize(self, inputs): class TokenizerTest(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertTokenizer.presets.keys()) + gpt2_presets = set(GPT2Tokenizer.presets.keys()) + all_presets = set(Tokenizer.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + Tokenizer.from_preset("bert_tiny_en_uncased"), + BertTokenizer, + ) + self.assertIsInstance( + Tokenizer.from_preset("gpt2_base_en"), + GPT2Tokenizer, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + GPT2Tokenizer.from_preset("bert_tiny_en_uncased") + def test_tokenize(self): input_data = ["the quick brown fox"] tokenizer = SimpleTokenizer() diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index fc6f54e19d..0e04c5ff36 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -21,10 +21,6 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -532,67 +528,3 @@ def detokenize(self, inputs): if unbatched: outputs = tf.squeeze(outputs, 0) return outputs - - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Example: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, - ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - WordPieceTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 5ddde4415a..f843f4e7c4 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import datetime +import inspect import json import os @@ -41,6 +43,32 @@ CONFIG_FILE = "config.json" TOKENIZER_CONFIG_FILE = "tokenizer.json" +# Global state for preset registry. +BUILTIN_PRESETS = {} +BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict) + + +def register_presets(presets, classes): + for preset in presets: + BUILTIN_PRESETS[preset] = presets[preset] + for cls in classes: + BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset] + + +def list_presets(cls): + """Find all registered builtin presets for a class.""" + return dict(BUILTIN_PRESETS_FOR_CLASS[cls]) + + +def list_subclasses(cls): + """Find all registered subclasses of a class.""" + custom_objects = keras.saving.get_custom_objects().values() + subclasses = [] + for x in custom_objects: + if inspect.isclass(x) and x != cls and issubclass(x, cls): + subclasses.append(x) + return subclasses + def get_file(preset, path): """Download a preset file in necessary and return the local path.""" @@ -48,6 +76,8 @@ def get_file(preset, path): raise ValueError( f"A preset identifier must be a string. Received: preset={preset}" ) + if preset in BUILTIN_PRESETS: + preset = BUILTIN_PRESETS[preset]["kaggle_handle"] if preset.startswith(KAGGLE_PREFIX): if kagglehub is None: raise ImportError( @@ -359,25 +389,12 @@ def load_from_preset( return layer -def check_preset_class( +def check_config_class( preset, - classes, config_file="config.json", ): """Validate a preset is being loaded on the correct class.""" config_path = get_file(preset, config_file) with open(config_path) as config_file: config = json.load(config_file) - cls = keras.saving.get_registered_object(config["registered_name"]) - if not isinstance(classes, (tuple, list)): - classes = (classes,) - # Allow subclasses for testing a base class, e.g. - # `check_preset_class(preset, Backbone)` - if not any(issubclass(cls, x) for x in classes): - raise ValueError( - f"Unexpected class in preset `'{preset}'`. " - "When calling `from_preset()` on a class object, the preset class " - f"much match allowed classes. Allowed classes are `{classes}`. " - f"Received: `{cls}`." - ) - return cls + return keras.saving.get_registered_object(config["registered_name"]) diff --git a/keras_nlp/utils/preset_utils_test.py b/keras_nlp/utils/preset_utils_test.py index 694b02c3f4..32547b4ad0 100644 --- a/keras_nlp/utils/preset_utils_test.py +++ b/keras_nlp/utils/preset_utils_test.py @@ -20,16 +20,14 @@ from keras_nlp import upload_preset from keras_nlp.models import AlbertClassifier -from keras_nlp.models import Backbone from keras_nlp.models import BertBackbone from keras_nlp.models import BertClassifier from keras_nlp.models import BertTokenizer from keras_nlp.models import RobertaClassifier -from keras_nlp.models import Task from keras_nlp.tests.test_case import TestCase from keras_nlp.utils.preset_utils import CONFIG_FILE from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.preset_utils import save_to_preset @@ -81,11 +79,7 @@ def test_preset_saving(self, cls, preset_name, tokenizer_type): self.assertEqual(config["weights"], "model.weights.h5") # Try loading the model from preset directory - self.assertEqual(cls, check_preset_class(save_dir, cls)) - self.assertEqual(cls, check_preset_class(save_dir, Task)) - with self.assertRaises(ValueError): - # Preset is a subclass of Task, not Backbone. - check_preset_class(save_dir, Backbone) + self.assertEqual(cls, check_config_class(save_dir)) # Try loading the model from preset directory restored_model = load_from_preset(save_dir) diff --git a/keras_nlp/utils/python_utils.py b/keras_nlp/utils/python_utils.py index 82c4bcf1a8..d29c71dbbe 100644 --- a/keras_nlp/utils/python_utils.py +++ b/keras_nlp/utils/python_utils.py @@ -19,28 +19,3 @@ class classproperty(property): def __get__(self, _, owner_cls): return self.fget(owner_cls) - - -def format_docstring(**replacements): - """Format a python docstring using a dictionary of replacements. - - This decorator can be placed on a function, class or method to format it's - docstring with python variables. - - The decorator will replace any double bracketed variable with a kwargs - value passed to the decorator itself. For example - `@format_docstring(name="foo")` will replace any occurance of `{{name}}` in - the docstring with the string literal `foo`. - """ - - def decorate(obj): - doc = obj.__doc__ - # We use `str.format()` to replace variables in the docstring, but use - # double brackets, e.g. {{var}}, to mark format strings. So we need to - # to swap all double and single brackets in the source docstring. - doc = "{".join(part.replace("{", "{{") for part in doc.split("{{")) - doc = "}".join(part.replace("}", "}}") for part in doc.split("}}")) - obj.__doc__ = doc.format(**replacements) - return obj - - return decorate diff --git a/keras_nlp/utils/python_utils_test.py b/keras_nlp/utils/python_utils_test.py index 60590dd47b..7d08ca091f 100644 --- a/keras_nlp/utils/python_utils_test.py +++ b/keras_nlp/utils/python_utils_test.py @@ -14,7 +14,6 @@ from keras_nlp.tests.test_case import TestCase from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring class ClassPropertyTest(TestCase): @@ -25,60 +24,3 @@ def bar(cls): return "class property" self.assertAllEqual(Foo.bar, "class property") - - -class FormatDocstringTest(TestCase): - def test_function(self): - @format_docstring(adjective="salubrious") - def foo(): - """It was a {{adjective}} November day.""" - return "function" - - self.assertAllEqual(foo(), "function") - self.assertAllEqual(foo.__doc__, "It was a salubrious November day.") - - def test_class(self): - @format_docstring(adjective="smelly", name="Mortimer") - class Foo: - """I saw my {{adjective}} friend {{name}}.""" - - def __init__(self): - self.bar = "property" - - self.assertAllEqual(Foo().bar, "property") - self.assertAllEqual(Foo.__doc__, "I saw my smelly friend Mortimer.") - - def test_class_method(self): - @format_docstring(adjective="smelly", name="Mortimer") - class Foo: - """I saw my {{adjective}} friend {{name}}.""" - - def __init__(self): - self.bar = "property" - - @classmethod - @format_docstring(noun="cactus", bodypart="nostril") - def baz(cls): - """He was holding a {{noun}} in his {{bodypart}}.""" - return "class method" - - self.assertAllEqual(Foo.baz(), "class method") - self.assertAllEqual( - Foo.baz.__doc__, - "He was holding a cactus in his nostril.", - ) - self.assertAllEqual( - Foo.baz.__func__.__doc__, - "He was holding a cactus in his nostril.", - ) - - def test_brackets(self): - @format_docstring(nickname="dumdum") - def bar(): - """Use `{}` to create a dictionary, {{nickname}}.""" - return "function" - - self.assertAllEqual(bar(), "function") - self.assertAllEqual( - bar.__doc__, "Use `{}` to create a dictionary, dumdum." - ) From db831d7ce40965d9a04393bcfb5375c82ded3601 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:47:32 -0700 Subject: [PATCH 61/70] Doc fixes (#1530) --- keras_nlp/models/backbone.py | 2 +- keras_nlp/models/preprocessor.py | 4 +++- keras_nlp/models/task.py | 2 +- keras_nlp/tokenizers/tokenizer.py | 2 +- keras_nlp/utils/preset_utils.py | 7 ++++++- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index e8ececa6f1..c94fe3d68d 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -136,7 +136,7 @@ def from_config(cls, config): @classproperty def presets(cls): - """List builtin presets for a `Task` subclass.""" + """List built-in presets for a `Task` subclass.""" presets = list_presets(cls) for subclass in list_subclasses(cls): presets.update(subclass.presets) diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index a91bc6fb26..a4a9d6ee74 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -34,7 +34,9 @@ class Preprocessor(PreprocessingLayer): `(x, y, sample_weight)` tuples. Where `x` contains token id sequences with some - This class can be subclassed to implement + This class can be subclassed similar to any `keras.layers.Layer`, by + defining `build()`, `call()` and `get_config()` methods. All subclasses + should set the `tokenizer` property on construction. """ tokenizer_cls = None diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index c0158b24d2..7858b84709 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -196,7 +196,7 @@ def from_config(cls, config): @classproperty def presets(cls): - """List builtin presets for a `Task` subclass.""" + """List built-in presets for a `Task` subclass.""" presets = list_presets(cls) # We can also load backbone presets. if cls.backbone_cls is not None: diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index bfd18781ec..f522098fb2 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -146,7 +146,7 @@ def call(self, inputs, *args, training=None, **kwargs): @classproperty def presets(cls): - """List builtin presets for a `Task` subclass.""" + """List built-in presets for a `Task` subclass.""" presets = list_presets(cls) for subclass in list_subclasses(cls): presets.update(subclass.presets) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index f843f4e7c4..d64daeb2b2 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -49,6 +49,11 @@ def register_presets(presets, classes): + """Register built-in presets for a set of classes. + + Note that this is intended only for models and presets shipped in the + library itself. + """ for preset in presets: BUILTIN_PRESETS[preset] = presets[preset] for cls in classes: @@ -56,7 +61,7 @@ def register_presets(presets, classes): def list_presets(cls): - """Find all registered builtin presets for a class.""" + """Find all registered built-in presets for a class.""" return dict(BUILTIN_PRESETS_FOR_CLASS[cls]) From 8c7aa4d7a0d5c791ac987d4ab8fe2749d2d99063 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Fri, 29 Mar 2024 11:23:29 -0700 Subject: [PATCH 62/70] Run the LLaMA and Mistral RMS Layer Norm in float32 (#1532) * Run the LLaMA RMS Layer Norm in float32 * Also use float32 in Mistral Layer Norm * Address review comments - Change private variables to public vars - Change `self._weight` to `self.scale` - Don't persist the input dim - Move the var computation to its own line for readability * Change weights to scale in layer norm --- keras_nlp/models/llama/llama_layernorm.py | 31 +++++++++++++------ .../models/mistral/mistral_layer_norm.py | 21 +++++++------ 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/keras_nlp/models/llama/llama_layernorm.py b/keras_nlp/models/llama/llama_layernorm.py index 0e85a45625..8d19c9bc19 100644 --- a/keras_nlp/models/llama/llama_layernorm.py +++ b/keras_nlp/models/llama/llama_layernorm.py @@ -14,24 +14,35 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops -# TODO: Should be replaced with LayerNormalization with `rms_scaling` param -# https://github.com/keras-team/keras-core/pull/726 - +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. class LlamaLayerNorm(keras.layers.Layer): + """A normalization layer for Llama that implements RMS normalization.""" + def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): - self.weight = self.add_weight( - name="weight", - shape=(input_shape[-1],), + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), initializer="ones", + dtype=self.variable_dtype, ) self.built = True - def call(self, hidden_states): - variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True) - hidden_states = hidden_states * 1 / ops.sqrt(variance + self.epsilon) - return self.weight * hidden_states + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x, self.compute_dtype) * self.scale + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_nlp/models/mistral/mistral_layer_norm.py b/keras_nlp/models/mistral/mistral_layer_norm.py index e714a8540d..6801a4679b 100644 --- a/keras_nlp/models/mistral/mistral_layer_norm.py +++ b/keras_nlp/models/mistral/mistral_layer_norm.py @@ -23,25 +23,26 @@ class MistralLayerNormalization(keras.layers.Layer): def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) - self._epsilon = epsilon + self.epsilon = epsilon def build(self, input_shape): - self._dim = input_shape[-1] - self._weight = self.add_weight( - name="weight", + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", trainable=True, - shape=(self._dim,), + shape=(dim,), initializer="ones", + dtype=self.variable_dtype, ) self.built = True def call(self, x): - x = x * ops.rsqrt( - ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon - ) - return x * self._weight + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x, self.compute_dtype) * self.scale def get_config(self): config = super().get_config() - config.update({"epsilon": self._epsilon}) + config.update({"epsilon": self.epsilon}) return config From 1192db436c1d78438ed9148072847449cd86d507 Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 29 Mar 2024 14:25:47 -0400 Subject: [PATCH 63/70] Adds score API to GPT-2 (#1533) * Adds score API to GPT-2 * Addressing reviewer comments --- keras_nlp/models/gpt2/gpt2_causal_lm.py | 131 +++++++++++++++++++ keras_nlp/models/gpt2/gpt2_causal_lm_test.py | 85 ++++++++++++ 2 files changed, 216 insertions(+) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index 41728f7433..40d4787119 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -311,3 +311,134 @@ def next(prompt, cache, index): "token_ids": token_ids, "padding_mask": padding_mask, } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GPT2CausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `GPT2Backbone` and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. This index is not an index + into `self.backbone.layers`. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + generations = gpt2_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gpt2_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gpt2_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + position_embeddings = self.backbone.position_embedding(token_embeddings) + summed_embeddings = self.backbone.embeddings_add( + (token_embeddings, position_embeddings) + ) + x = layer_intercept_fn(summed_embeddings, -1) + x = self.backbone.embeddings_dropout(x) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index f34b6baa47..8999ebd9af 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -128,3 +128,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 7) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 7) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) From 035a776154a979781d19af813db24ab7006bef58 Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Fri, 29 Mar 2024 12:13:29 -0700 Subject: [PATCH 64/70] increase pip timeout to 1000s to avoid connection resets (#1535) --- .kokoro/github/ubuntu/gpu/build.sh | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index b8d47dbe9c..04c6c7c454 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -35,23 +35,25 @@ pip install -U pip setuptools psutil if [ "${KERAS2:-0}" == "1" ] then echo "Keras2 detected." - pip install -r requirements-common.txt --progress-bar off - pip install tensorflow-text==2.15 tensorflow[and-cuda]~=2.15 keras-core + pip install -r requirements-common.txt --progress-bar off --timeout 1000 + pip install tensorflow-text==2.15 tensorflow[and-cuda]~=2.15 keras-core \ + --timeout 1000 elif [ "$KERAS_BACKEND" == "tensorflow" ] then echo "TensorFlow backend detected." - pip install -r requirements-tensorflow-cuda.txt --progress-bar off + pip install -r requirements-tensorflow-cuda.txt --progress-bar off \ + --timeout 1000 elif [ "$KERAS_BACKEND" == "jax" ] then echo "JAX backend detected." - pip install -r requirements-jax-cuda.txt --progress-bar off + pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000 elif [ "$KERAS_BACKEND" == "torch" ] then echo "PyTorch backend detected." - pip install -r requirements-torch-cuda.txt --progress-bar off + pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000 fi pip install --no-deps -e "." --progress-bar off From 298e15c3a913862c45466bfa0834203e9ab39bec Mon Sep 17 00:00:00 2001 From: Ryan Mullins Date: Fri, 29 Mar 2024 15:15:08 -0400 Subject: [PATCH 65/70] Adds the score API to LlamaCausalLM (#1534) --- keras_nlp/models/llama/llama_causal_lm.py | 127 ++++++++++++++++++ .../models/llama/llama_causal_lm_test.py | 85 ++++++++++++ 2 files changed, 212 insertions(+) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py index 7f17645618..48b5fdb4c2 100644 --- a/keras_nlp/models/llama/llama_causal_lm.py +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -212,3 +212,130 @@ def next(prompt, cache, index): "token_ids": token_ids, "padding_mask": padding_mask, } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `LlamaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `LlamaBackbone` and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + llama_lm = keras_nlp.models.LlamaCausalLM.from_preset("llama2_7b_en") + generations = llama_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = llama_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = llama_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/llama/llama_causal_lm_test.py b/keras_nlp/models/llama/llama_causal_lm_test.py index ff71a75b38..c006f72783 100644 --- a/keras_nlp/models/llama/llama_causal_lm_test.py +++ b/keras_nlp/models/llama/llama_causal_lm_test.py @@ -128,3 +128,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 8) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) From 91aa6541eb1ffd07b1185f2a748337383f3a8d5e Mon Sep 17 00:00:00 2001 From: briango28 <72905199+briango28@users.noreply.github.com> Date: Sat, 30 Mar 2024 04:18:16 +0900 Subject: [PATCH 66/70] Implement compute_output_spec() for tokenizers with vocabulary. (#1523) * Implement compute_output_spec() for tokenizers with vocabulary. (restarted from new point in master branch) * Remove type annotation from compute_output_spec() in tokenizers --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 6 ++++++ keras_nlp/tokenizers/sentence_piece_tokenizer.py | 6 ++++++ keras_nlp/tokenizers/word_piece_tokenizer.py | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 0142062d4f..688d216a1c 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -24,6 +24,7 @@ from typing import Iterable from typing import List +import keras import regex as re import tensorflow as tf @@ -617,6 +618,11 @@ def detokenize(self, inputs): outputs = tf.squeeze(outputs, 0) return outputs + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype + ) + def _transform_bytes(self, tokens): """Map token bytes to unicode using `byte2unicode`.""" split_bytes = tf.strings.bytes_split(tokens) diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index e52b5633a9..da7b002454 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -17,6 +17,7 @@ import os from typing import List +import keras import tensorflow as tf from keras_nlp.api_export import keras_nlp_export @@ -255,3 +256,8 @@ def detokenize(self, inputs): if unbatched: outputs = tf.squeeze(outputs, 0) return outputs + + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype + ) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 0e04c5ff36..bcf3a7cb5e 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -17,6 +17,7 @@ from typing import Iterable from typing import List +import keras import tensorflow as tf from keras_nlp.api_export import keras_nlp_export @@ -528,3 +529,8 @@ def detokenize(self, inputs): if unbatched: outputs = tf.squeeze(outputs, 0) return outputs + + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype + ) From d95c271fb4fb8633e366629114d4866557ac1adb Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Fri, 29 Mar 2024 12:58:56 -0700 Subject: [PATCH 67/70] Remove staggler type annotiations (#1536) Currently Keras as a whole is not doing type annotiations, but we still have a few stragglers. Removing them as they occasionally cause confusion. --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 11 ++++--- keras_nlp/tokenizers/byte_tokenizer.py | 14 ++++----- .../tokenizers/sentence_piece_tokenizer.py | 13 ++++----- keras_nlp/tokenizers/tokenizer.py | 10 +++---- .../tokenizers/unicode_codepoint_tokenizer.py | 18 ++++++------ keras_nlp/tokenizers/word_piece_tokenizer.py | 29 +++++++++---------- 6 files changed, 45 insertions(+), 50 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 688d216a1c..d007845415 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -22,7 +22,6 @@ import json import os from typing import Iterable -from typing import List import keras import regex as re @@ -398,17 +397,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges): default=self.merge_ranks_lookup_default, ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" self._check_vocabulary() return self.vocabulary.keys() - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return len(self.vocabulary) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" # This will be slow, but keep memory usage down compared to building a # dict. Assuming the main use case is looking up a few special tokens @@ -421,7 +420,7 @@ def id_to_token(self, id: int) -> str: return token raise ValueError(f"`id` is out of the vocabulary. Received: {id}") - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" self._check_vocabulary() return self.vocabulary[token] diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index 3aefc4a01d..4d5c4a87ed 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -155,11 +155,11 @@ class ByteTokenizer(tokenizer.Tokenizer): def __init__( self, - lowercase: bool = True, - sequence_length: int = None, - normalization_form: str = None, - errors: str = "replace", - replacement_char: int = 65533, + lowercase=True, + sequence_length=None, + normalization_form=None, + errors="replace", + replacement_char=65533, dtype="int32", **kwargs, ): @@ -198,8 +198,8 @@ def __init__( [i.tobytes() for i in np.arange(256, dtype=np.uint8)] ) - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" return 256 def tokenize(self, inputs): diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index da7b002454..fb01828c6a 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -15,7 +15,6 @@ import base64 import binascii import os -from typing import List import keras import tensorflow as tf @@ -108,7 +107,7 @@ def train_sentence_piece_file(ds, path, size): def __init__( self, proto=None, - sequence_length: int = None, + sequence_length=None, dtype="int32", **kwargs, ) -> None: @@ -172,12 +171,12 @@ def set_proto(self, proto): # byte array as a string for saving. self.proto = proto_bytes - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return int(self._sentence_piece.vocab_size().numpy()) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary.""" self._check_vocabulary() return tensor_to_list( @@ -186,7 +185,7 @@ def get_vocabulary(self) -> List[str]: ) ) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" self._check_vocabulary() if id >= self.vocabulary_size() or id < 0: @@ -196,7 +195,7 @@ def id_to_token(self, id: int) -> str: ) return tensor_to_list(self._sentence_piece.id_to_string(id)) - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" self._check_vocabulary() return int(self._sentence_piece.string_to_id(token).numpy()) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index f522098fb2..9418741ea2 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, @@ -105,28 +103,28 @@ def detokenize(self, inputs, *args, **kwargs): f"{self.__class__.__name__}." ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings terms.""" raise NotImplementedError( "No implementation of `get_vocabulary()` was found for " f"{self.__class__.__name__}." ) - def vocabulary_size(self) -> int: + def vocabulary_size(self): """Returns the total size of the token id space.""" raise NotImplementedError( "No implementation of `vocabulary_size()` was found for " f"{self.__class__.__name__}." ) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" raise NotImplementedError( "No implementation of `id_to_token()` was found for " f"{self.__class__.__name__}." ) - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" raise NotImplementedError( "No implementation of `token_to_id()` was found for " diff --git a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py index 5fe8f0144d..578e03bca7 100644 --- a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py @@ -206,14 +206,14 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer): def __init__( self, - sequence_length: int = None, - lowercase: bool = True, - normalization_form: str = None, - errors: str = "replace", - replacement_char: int = 65533, - input_encoding: str = "UTF-8", - output_encoding: str = "UTF-8", - vocabulary_size: int = None, + sequence_length=None, + lowercase=True, + normalization_form=None, + errors="replace", + replacement_char=65533, + input_encoding="UTF-8", + output_encoding="UTF-8", + vocabulary_size=None, dtype="int32", **kwargs, ) -> None: @@ -275,7 +275,7 @@ def get_config(self): ) return config - def vocabulary_size(self) -> int: + def vocabulary_size(self): """Get the size of the tokenizer vocabulary. None implies no vocabulary size was provided""" return self._vocabulary_size diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index bcf3a7cb5e..4b9b90a943 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -15,7 +15,6 @@ import os import re from typing import Iterable -from typing import List import keras import tensorflow as tf @@ -334,15 +333,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer): def __init__( self, vocabulary=None, - sequence_length: int = None, - lowercase: bool = False, - strip_accents: bool = False, - split: bool = True, - split_on_cjk: bool = True, - suffix_indicator: str = "##", - oov_token: str = "[UNK]", - special_tokens: List[str] = None, - special_tokens_in_strings: bool = False, + sequence_length=None, + lowercase=False, + strip_accents=False, + split=True, + split_on_cjk=True, + suffix_indicator="##", + oov_token="[UNK]", + special_tokens=None, + special_tokens_in_strings=False, dtype="int32", **kwargs, ) -> None: @@ -437,17 +436,17 @@ def set_vocabulary(self, vocabulary): support_detokenization=True, ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" self._check_vocabulary() return self.vocabulary - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return len(self.vocabulary) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" self._check_vocabulary() if id >= self.vocabulary_size() or id < 0: @@ -457,7 +456,7 @@ def id_to_token(self, id: int) -> str: ) return self.vocabulary[id] - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" # This will be slow, but keep memory usage down compared to building a # . Assuming the main use case is looking up a few special tokens From dcebc7c97b4a2834496c4323306e1f7d983a0fbe Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Mon, 1 Apr 2024 13:50:41 -0700 Subject: [PATCH 68/70] Always run SiLU activation in float32 for LLaMA and Mistral (#1540) * Fix discrepency between HF LLaMA and our implementation * Fix Mistral transformer decoder --- keras_nlp/models/llama/llama_decoder.py | 12 +++++++++++- .../models/mistral/mistral_transformer_decoder.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 1ef247c575..7b4ad5f75d 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -97,7 +97,6 @@ def build(self, decoder_sequence_shape): self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, dtype=self.dtype_policy, @@ -167,6 +166,17 @@ def call( x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output)) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 36b7f5944d..3ef91d3066 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -102,7 +102,6 @@ def build(self, decoder_sequence_shape): self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, dtype=self.dtype_policy, @@ -172,6 +171,17 @@ def call( x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output)) From 29873a9676b98956600fb5cc4ecfb8d13af814ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 1 Apr 2024 14:23:40 -0700 Subject: [PATCH 69/70] Bump the python group with 2 updates (#1538) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.2.1+cu121 to 2.2.2+cu121 Updates `torchvision` from 0.17.1+cu121 to 0.17.2+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-torch-cuda.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 050dd85b1c..2ae593e057 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -4,8 +4,8 @@ tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 -torch==2.2.1+cu121 -torchvision==0.17.1+cu121 +torch==2.2.2+cu121 +torchvision==0.17.2+cu121 # Jax cpu-only version. jax[cpu] From d0ff82632b72bc414ff989efff9621ea2c48878a Mon Sep 17 00:00:00 2001 From: abuelnasr0 Date: Wed, 3 Apr 2024 00:25:47 +0200 Subject: [PATCH 70/70] Add special_tokens_in_strings to byte_pair_tokenizer --- keras_nlp/tokenizers/byte_pair_tokenizer.py | 84 +++++++++++-------- .../tokenizers/byte_pair_tokenizer_test.py | 18 +++- 2 files changed, 62 insertions(+), 40 deletions(-) diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index d007845415..89e39369cd 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -59,10 +59,10 @@ SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$""" -def get_unsplittable_tokens_pattern(unsplittable_tokens): - if unsplittable_tokens is None or len(unsplittable_tokens) == 0: +def get_special_tokens_pattern(special_tokens): + if special_tokens is None or len(special_tokens) == 0: return None - return r"|".join([re.escape(token) for token in unsplittable_tokens]) + return r"|".join([re.escape(token) for token in special_tokens]) def bytes_to_unicode(): @@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None): +def split_strings_for_bpe(inputs, special_tokens_pattern=None): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -110,26 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None): inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६" ) - if unsplittable_tokens_pattern is not None: - # First split the unsplittable tokens from the input. + if special_tokens_pattern is not None: + # First split the special tokens from the input. raw_tokens = tf_text.regex_split( - inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern + inputs, special_tokens_pattern, special_tokens_pattern ) - split_pattern_1_with_unsplittable_tokens = r"|".join( - [unsplittable_tokens_pattern, SPLIT_PATTERN_1] - ) - # Then split using both `unsplittable_tokens_pattern` and + # Then split using both `special_tokens_pattern` and # `SPLIT_PATTERN_1` to split inputs like original gpt2, while not - # affecting the unsplittable tokens. - # We split unsplittable tokens first then apply this split instead of + # affecting the special tokens. + # We split special tokens first then apply this split instead of # applying this split directly, because otherwise we will not split - # unsplittable tokens from inputs properly, because of this pattern + # special tokens from inputs properly, because of this pattern # ` ?[^\s\p{L}\p{N}{special_spaces}]+`. # e.g., [" "] will be [" "] instead of [" ", ""] raw_tokens = tf_text.regex_split( raw_tokens, - split_pattern_1_with_unsplittable_tokens, - split_pattern_1_with_unsplittable_tokens, + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), ) raw_tokens = raw_tokens.merge_dims(-2, -1) else: @@ -241,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer): a prefix space to the first word will cause it to be tokenized equivalently to all subsequent words in the sequence. Defaults to `False`. - unsplittable_tokens: list. A list of strings that will - never be split during the word-level splitting applied before the - byte-pair encoding. This can be used to ensure special tokens map to - unique indices in the vocabulary, even if these special tokens - contain splittable characters such as punctuation. Special tokens - must still be included in `vocabulary`. Defaults to `None`. + special_tokens: list. A list of special tokens. when + `special_tokens_in_strings` is set to `True`, special + tokens will never be split during the word-level splitting applied + before the byte-pair encoding. This can be used to ensure special + tokens map to unique indices in the vocabulary, even if these + special tokens contain splittable characters such as + punctuation. special tokens must still be included in + `vocabulary`. Defaults to `None`. + special_tokens_in_strings: bool. To indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -285,7 +287,8 @@ def __init__( merges=None, sequence_length=None, add_prefix_space=False, - unsplittable_tokens=None, + special_tokens=None, + special_tokens_in_strings=False, dtype="int32", **kwargs, ) -> None: @@ -300,10 +303,12 @@ def __init__( super().__init__(dtype=dtype, **kwargs) self.sequence_length = sequence_length self.add_prefix_space = add_prefix_space - self.unsplittable_tokens = unsplittable_tokens - self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern( - unsplittable_tokens - ) + self.special_tokens = special_tokens + self._special_tokens_pattern = None + if special_tokens_in_strings: + self._special_tokens_pattern = get_special_tokens_pattern( + special_tokens + ) # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -355,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges): "token to int ids. Received: " f"`type(vocabulary)={type(vocabulary)}`." ) + + # Check for special tokens in vocabulary. + if self.special_tokens is not None: + for token in self.special_tokens: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your" + "`vocabulary` or use a pretrained `vocabulary` name." + ) + if isinstance(merges, str): with open(merges, encoding="utf-8") as f: self.merges = [bp.rstrip() for bp in f] @@ -367,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges): ) self.cache = BytePairTokenizerCache() - if self.unsplittable_tokens: + if self.special_tokens and self._special_tokens_pattern is not None: # Put special tokens into cache, so it won't be further split and # merged. - self.cache.insert( - self.unsplittable_tokens, self.unsplittable_tokens - ) + self.cache.insert(self.special_tokens, self.special_tokens) # Create mapping between string tokens to int ids, and vice versa. byte_pairs = [x[0] for x in self.vocabulary.items()] @@ -550,9 +564,7 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe( - inputs, self._unsplittable_tokens_pattern - ) + raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values @@ -646,7 +658,7 @@ def get_config(self): { "sequence_length": self.sequence_length, "add_prefix_space": self.add_prefix_space, - "unsplittable_tokens": self.unsplittable_tokens, + "special_tokens": self.special_tokens, } ) - return config + return config \ No newline at end of file diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 9752966a17..542e872e1b 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -67,30 +67,40 @@ def test_tokenize_with_special_tokens(self): tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, - unsplittable_tokens=["s", "p"], + special_tokens=["s", "p"], + special_tokens_in_strings=True, ) output = tokenizer("sp") self.assertAllEqual(output, [1, 2]) - # If not setting special tokens, "sp" is one token. + # If not special_tokens_in_strings is `True`, "sp" is one token. tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, + special_tokens=["s", "p"], ) output = tokenizer("sp") self.assertAllEqual(output, [0]) + # test real wolrd special tokens. e. g. and vocab = {"": 0, "": 1, "a": 2, "Ġquick": 3, "Ġfox": 4} merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"] merges += ["Ġ f", "o x", "Ġf ox"] tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, - unsplittable_tokens=["", ""], + special_tokens=["", ""], + special_tokens_in_strings=True, ) output = tokenizer("a quick fox") self.assertAllEqual(output, [0, 2, 3, 4, 1]) + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + BytePairTokenizer( + vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"] + ) + def test_tokenize_prefix_space(self): input_data = ["brown.", "black."] tokenizer = BytePairTokenizer( @@ -181,4 +191,4 @@ def test_config(self): self.assertAllEqual( self.tokenizer(input_data), cloned_tokenizer(input_data), - ) + ) \ No newline at end of file