diff --git a/keras_nlp/models/falcon/__init__.py b/keras_nlp/models/falcon/__init__.py new file mode 100644 index 0000000000..3364a6bd16 --- /dev/null +++ b/keras_nlp/models/falcon/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/keras_nlp/models/falcon/falcon_attention.py b/keras_nlp/models/falcon/falcon_attention.py new file mode 100644 index 0000000000..0358ade54b --- /dev/null +++ b/keras_nlp/models/falcon/falcon_attention.py @@ -0,0 +1,156 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class FalconAttention(keras.layers.Layer): + def __init__( + self, + num_heads, + attention_dropout_rate, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.attention_dropout_rate = attention_dropout_rate + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # m = model dim + # n = num attention heads + # h = head dim + # k = key/value length + + batch_size, seq_length, hidden_dim = inputs_shape + + self.head_dim = hidden_dim // self.num_heads + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.query_dense = keras.layers.EinsumDense( + equation="bqm,mnh->bqnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="query_dense", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="key_dense", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="value_dense", + ) + self.value_dense.build(inputs_shape) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.output_dense = keras.layers.Dense( + hidden_dim, + dtype=self.dtype_policy, + name="output_dense", + ) + self.output_dense.build(inputs_shape) + + self.softmax = keras.layers.Softmax(dtype="float32", name="softmax") + + self.built = True + + def call( + self, + inputs, + alibi, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + batch_size, seq_length, hidden_dim = ops.shape(inputs) + + query = self.query_dense(inputs) + key = self.key_dense(inputs) + value = self.value_dense(inputs) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key) + attention_scores = ops.add(attention_scores, alibi) + attention_scores = ( + attention_scores * self.inv_norm_factor + ) # [batch_size, num_heads, query_length, kv_length] + attention_scores = self.softmax( + attention_scores, ops.expand_dims(attention_mask, 1) + ) + attention_scores = self.attention_dropout(attention_scores) + attention_output = ops.einsum( + "bnqk,bknh->bqnh", attention_scores, value + ) + attention_output = ops.reshape( + attention_output, + [batch_size, seq_length, self.num_heads * self.head_dim], + ) # [batch_size, query_length, hidden_dim] + + attention_output = self.output_dense(attention_output) + + if cache is not None: + return attention_output, cache + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "attention_dropout_rate": self.attention_dropout_rate, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py new file mode 100644 index 0000000000..4951189fe0 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -0,0 +1,160 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.falcon.falcon_transformer_decoder import ( + FalconTransformerDecoder, +) + + +@keras_nlp_export("keras_nlp.models.FalconBackbone") +class FalconBackbone(Backbone): + """The Falcon core architecure. + + This network implements a Transformer-based decoder-only network, + [Falcon](https://arxiv.org/abs/2306.01116). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_attention_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The dimensionality of the embeddings and hidden states. + intermediate_dim: int. The output dimension of the first Dense layer in + the MLP network of each transformer. + layer_norm_epsilon: float. Epsilon for the layer normalization layers in + the transformer decoder. + attention_dropout_rate: float. Dropout probability for the attention. + feedforward_dropout_rate: flaot. Dropout probability for the feedforward. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Falcon decoder. + # TODO: Update the preset. + model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset") + model(input_data) + + # Randomly initialized Falcon decoder with a custom config. + model = keras_nlp.models.FalconBackbone( + vocabulary_size=10, + num_layers=2, + num_attention_heads=2, + hidden_dim=32, + intermediate_dim=32*4, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype="float32", + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_attention_heads, + hidden_dim, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + dtype=dtype, + name="token_embedding", + ) + + self.transformer_layers = [] + for i in range(num_layers): + layer = FalconTransformerDecoder( + num_attention_heads=num_attention_heads, + intermediate_dim=intermediate_dim, + attention_dropout_rate=attention_dropout_rate, + feedforward_dropout_rate=feedforward_dropout_rate, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + self.final_layernorm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_layernorm", + ) + + # === Functional Model === + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed Tokens. + x = self.token_embedding(token_ids) + + # Apply successive transformer decoder blocks. + for transformer_layer in self.transformer_layers: + x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask) + sequence_output = self.final_layernorm(x) + + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone_test.py b/keras_nlp/models/falcon/falcon_backbone_test.py new file mode 100644 index 0000000000..140ce7e7bf --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone_test.py @@ -0,0 +1,49 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.tests.test_case import TestCase + + +class FalconBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_attention_heads": 8, + "hidden_dim": 16, + "intermediate_dim": 32, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_transformer_decoder.py b/keras_nlp/models/falcon/falcon_transformer_decoder.py new file mode 100644 index 0000000000..3b29cedd7e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_transformer_decoder.py @@ -0,0 +1,254 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.falcon.falcon_attention import FalconAttention + + +class FalconTransformerDecoder(keras.layers.Layer): + def __init__( + self, + num_attention_heads, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_attention_heads = num_attention_heads + self.intermediate_dim = intermediate_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + + def build(self, decoder_sequence_shape): + self.hidden_dim = decoder_sequence_shape[-1] + self.input_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="input_layernorm", + ) + self.input_layernorm.build(decoder_sequence_shape) + + # Attention layers. + self.key_dim = self.hidden_dim // self.num_attention_heads + self.attention_layer = FalconAttention( + num_heads=self.num_attention_heads, + attention_dropout_rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention", + ) + self.attention_layer.build( + decoder_sequence_shape, + ) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.post_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(decoder_sequence_shape) + + # Feedforward layers. + # TODO: use_bias should be an argument to the transformer to support + # other sizes of models, e.g. 7B, that don't use bias. + self.dense_h_to_4h = keras.layers.Dense( + self.intermediate_dim, + activation=keras.activations.gelu, + use_bias=True, + dtype=self.dtype_policy, + name="dense_h_to_4h", + ) + self.dense_h_to_4h.build(decoder_sequence_shape) + + self.dense_4h_to_h = keras.layers.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype_policy, + name="dense_4h_to_h", + ) + self.dense_4h_to_h.build( + ( + decoder_sequence_shape[0], + decoder_sequence_shape[1], + self.intermediate_dim, + ) + ) + + self.feedforward_dropout = keras.layers.Dropout( + rate=self.feedforward_dropout_rate, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + + self.built = True + + def call( + self, + inputs, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + training=None, + ): + attention_mask = self._compute_attention_mask( + decoder_sequence=inputs, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + + residual = inputs + + x = self.input_layernorm(inputs) + + alibi = self._build_alibi_tensor( + self.num_attention_heads, decoder_padding_mask + ) + + # Attention block. + attention_output = self.attention_layer( + inputs=x, + alibi=alibi, + attention_mask=attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + + if attention_cache is None: + x = attention_output + else: + x, attention_cache = attention_output + + x = self.attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self.post_attention_layernorm(x) + + x = self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + + x = self.feedforward_dropout(x, training=training) + + x = x + residual + + if attention_cache is not None: + return x, attention_cache + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_attention_heads": self.num_attention_heads, + "intermediate_dim": self.intermediate_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + } + ) + return config + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def _compute_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + attention_cache=None, + attention_cache_update_index=None, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def _build_alibi_tensor(self, num_heads, attention_mask): + batch_size, seq_length = attention_mask.shape + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), + dtype=self.compute_dtype, + ) # num_heads + arange_tensor = ( + ( + ops.cast(ops.cumsum(attention_mask, axis=-1) - 1, dtype="int32") + * attention_mask + ) + )[:, None, :] + alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype) + alibi = ops.expand_dims( + alibi, 0 + ) # [None, batch_size, num_heads, seq_length] + return ops.transpose(alibi, [1, 2, 0, 3]) + + def _get_slopes(self, num_heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + ) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py new file mode 100644 index 0000000000..90a06503dc --- /dev/null +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -0,0 +1,238 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +import keras +import numpy as np +import tensorflow as tf +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone + +keras.config.disable_traceback_filtering() + + +def convert_checkpoints(hf_model): + hf_config = hf_model.config.to_dict() + cfg = {} + cfg["vocabulary_size"] = hf_config["vocab_size"] + cfg["num_layers"] = hf_config["num_hidden_layers"] + cfg["num_attention_heads"] = hf_config["num_attention_heads"] + cfg["hidden_dim"] = hf_config["hidden_size"] + cfg["intermediate_dim"] = 4 * cfg["hidden_dim"] + cfg["feedforward_dropout_rate"] = hf_config["hidden_dropout"] + cfg["attention_dropout_rate"] = hf_config["attention_dropout"] + + keras_model = FalconBackbone(**cfg) + + hf_wts = hf_model.state_dict() + + # transformer.word_embeddings.weight + keras_model.get_layer("token_embedding").embeddings.assign( + hf_wts["transformer.word_embeddings.weight"] + ) + + for i in range(keras_model.num_layers): + # split key query value + fused_qkv = ( + hf_wts[f"transformer.h.{i}.self_attention.query_key_value.weight"] + .numpy() + .T + ) + seq_length, _ = fused_qkv.shape + head_dim = cfg["hidden_dim"] // cfg["num_attention_heads"] + fused_qkv = fused_qkv.reshape( + seq_length, cfg["num_attention_heads"], 3, head_dim + ) + query, key, value = ( + fused_qkv[..., 0, :], + fused_qkv[..., 1, :], + fused_qkv[..., 2, :], + ) + + fused_bias = hf_wts[ + f"transformer.h.{i}.self_attention.query_key_value.bias" + ].numpy() + fused_bias = fused_bias.reshape(cfg["num_attention_heads"], 3, head_dim) + query_bias, key_bias, value_bias = ( + fused_bias[..., 0, :], + fused_bias[..., 1, :], + fused_bias[..., 2, :], + ) + + # TODO: check if bias is true before assigning bias. + # transformer.h.0.self_attention.query_key_value.weight + # transformer.h.0.self_attention.query_key_value.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.query_dense.kernel.assign(query) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.query_dense.bias.assign(query_bias) + + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.key_dense.kernel.assign(key) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.key_dense.bias.assign(key_bias) + + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.value_dense.kernel.assign(value) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.value_dense.bias.assign(value_bias) + + # transformer.h.0.self_attention.dense.weight + # transformer.h.0.self_attention.dense.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.output_dense.kernel.assign( + hf_wts[f"transformer.h.{i}.self_attention.dense.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).attention_layer.output_dense.bias.assign( + hf_wts[f"transformer.h.{i}.self_attention.dense.bias"].numpy() + ) + + # transformer.h.0.mlp.dense_h_to_4h.weight + # transformer.h.0.mlp.dense_h_to_4h.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_h_to_4h.kernel.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_h_to_4h.bias.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_h_to_4h.bias"].numpy() + ) + + # transformer.h.0.mlp.dense_4h_to_h.weight + # transformer.h.0.mlp.dense_4h_to_h.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_4h_to_h.kernel.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).dense_4h_to_h.bias.assign( + hf_wts[f"transformer.h.{i}.mlp.dense_4h_to_h.bias"].numpy() + ) + + # transformer.h.0.input_layernorm.weight + # transformer.h.0.input_layernorm.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).input_layernorm.gamma.assign( + hf_wts[f"transformer.h.{i}.input_layernorm.weight"] + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).input_layernorm.beta.assign( + hf_wts[f"transformer.h.{i}.input_layernorm.bias"] + ) + + # transformer.h.0.post_attention_layernorm.weight + # transformer.h.0.post_attention_layernorm.bias + keras_model.get_layer( + f"transformer_layer_{i}" + ).post_attention_layernorm.gamma.assign( + hf_wts[f"transformer.h.{i}.post_attention_layernorm.weight"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + ).post_attention_layernorm.beta.assign( + hf_wts[f"transformer.h.{i}.post_attention_layernorm.bias"].numpy() + ) + + # transformer.ln_f.weight + # transformer.ln_f.bias + keras_model.get_layer("final_layernorm").gamma.assign( + hf_wts["transformer.ln_f.weight"].numpy() + ) + keras_model.get_layer("final_layernorm").beta.assign( + hf_wts["transformer.ln_f.bias"].numpy() + ) + + # TODO: Assign lm_head weights for CausalLM. + # # lm_head.weight + # keras_model.get_layer("lm_head").kernel.assign( + # hf_wts["lm_head.weight"].T.numpy() + # ) + + # Save the model. + print("Save KerasNLP model weights.") + temp_dir = tempfile.mkdtemp() + keras_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + + return keras_model + + +def check_output(keras_model, hf_model, hf_model_name): + sample_text = ["I am so happy today!"] + hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name) + hf_tokenizer.pad_token = hf_tokenizer.eos_token + hf_sample_input = hf_tokenizer( + sample_text, padding="max_length", return_tensors="pt" + ) + sample_input = { + "token_ids": tf.constant(hf_sample_input["input_ids"].numpy()), + "padding_mask": tf.constant(hf_sample_input["attention_mask"].numpy()), + } + print("token_ids: ", sample_input["token_ids"][0, :7]) + print("padding_mask", sample_input["padding_mask"][0, :7]) + + keras_output = keras_model.predict(sample_input) + + activation = {} + + def get_activation(name): + def hook(hf_model, input, output): + activation[name] = output[0].detach() + + return hook + + hf_model.transformer.register_forward_hook( + get_activation("transformer.ln_f") + ) + hf_model(**hf_sample_input) + hf_output = activation["transformer.ln_f"] + print("Keras shape: ", keras_output.shape) + print("HF shape: ", hf_output.shape) + + print("KerasNLP output:", keras_output[0, 1, :5]) + print("HF output:", hf_output[0, 1, :5]) + print( + "Difference:", + np.mean( + abs(keras_output[:, :6, :] - hf_output.detach().numpy()[:, :6, :]) + ), + ) + + +def main(): + hf_model_name = "tiiuae/falcon-rw-1b" + hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name) + keras_model = convert_checkpoints(hf_model) + check_output(keras_model, hf_model, hf_model_name) + + +if __name__ == "__main__": + main()