forked from keras-team/keras-hub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
FalconBackbone
(keras-team#1475)
* Add Falcon backbone. * Add docstring. * Add dtype. * Add checkpoint conversion script. * Fix tests. * Random fixes. * Add cache. * Cast cumsum to int32. * Make sublayers public. * Address backbone comments. * Update attention computation to use einsum. * Falcon only works with Keras3. * Fix tests. * Remove falcon_causal_lm file. * Remove commented/unused codes.
- Loading branch information
1 parent
414b4f4
commit 8590c22
Showing
6 changed files
with
870 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import math | ||
|
||
from keras_nlp.backend import keras | ||
from keras_nlp.backend import ops | ||
|
||
|
||
class FalconAttention(keras.layers.Layer): | ||
def __init__( | ||
self, | ||
num_heads, | ||
attention_dropout_rate, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.num_heads = num_heads | ||
self.attention_dropout_rate = attention_dropout_rate | ||
|
||
def build(self, inputs_shape): | ||
# Einsum variables: | ||
# b = batch size | ||
# q = query length | ||
# m = model dim | ||
# n = num attention heads | ||
# h = head dim | ||
# k = key/value length | ||
|
||
batch_size, seq_length, hidden_dim = inputs_shape | ||
|
||
self.head_dim = hidden_dim // self.num_heads | ||
|
||
# Layer-wise attention scaling | ||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) | ||
|
||
self.query_dense = keras.layers.EinsumDense( | ||
equation="bqm,mnh->bqnh", | ||
output_shape=(None, self.num_heads, self.head_dim), | ||
bias_axes="nh", | ||
dtype=self.dtype_policy, | ||
name="query_dense", | ||
) | ||
self.query_dense.build(inputs_shape) | ||
|
||
self.key_dense = keras.layers.EinsumDense( | ||
equation="bkm,mnh->bknh", | ||
output_shape=(None, self.num_heads, self.head_dim), | ||
bias_axes="nh", | ||
dtype=self.dtype_policy, | ||
name="key_dense", | ||
) | ||
self.key_dense.build(inputs_shape) | ||
|
||
self.value_dense = keras.layers.EinsumDense( | ||
equation="bkm,mnh->bknh", | ||
output_shape=(None, self.num_heads, self.head_dim), | ||
bias_axes="nh", | ||
dtype=self.dtype_policy, | ||
name="value_dense", | ||
) | ||
self.value_dense.build(inputs_shape) | ||
|
||
self.attention_dropout = keras.layers.Dropout( | ||
rate=self.attention_dropout_rate, | ||
dtype=self.dtype_policy, | ||
name="attention_dropout", | ||
) | ||
|
||
self.output_dense = keras.layers.Dense( | ||
hidden_dim, | ||
dtype=self.dtype_policy, | ||
name="output_dense", | ||
) | ||
self.output_dense.build(inputs_shape) | ||
|
||
self.softmax = keras.layers.Softmax(dtype="float32", name="softmax") | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, | ||
inputs, | ||
alibi, | ||
attention_mask=None, | ||
cache=None, | ||
cache_update_index=None, | ||
): | ||
batch_size, seq_length, hidden_dim = ops.shape(inputs) | ||
|
||
query = self.query_dense(inputs) | ||
key = self.key_dense(inputs) | ||
value = self.value_dense(inputs) | ||
|
||
if cache is not None: | ||
key_cache = cache[:, 0, ...] | ||
value_cache = cache[:, 1, ...] | ||
if cache_update_index is None: | ||
key = key_cache | ||
value = value_cache | ||
else: | ||
start = [0, cache_update_index, 0, 0] | ||
key = ops.slice_update(key_cache, start, key) | ||
value = ops.slice_update(value_cache, start, value) | ||
cache = ops.stack((key, value), axis=1) | ||
else: | ||
if cache_update_index is not None: | ||
raise ValueError( | ||
"`cache_update_index` should not be set if `cache` is " | ||
f"`None`. Received: cache={cache}, " | ||
f"cache_update_index={cache_update_index}" | ||
) | ||
|
||
attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key) | ||
attention_scores = ops.add(attention_scores, alibi) | ||
attention_scores = ( | ||
attention_scores * self.inv_norm_factor | ||
) # [batch_size, num_heads, query_length, kv_length] | ||
attention_scores = self.softmax( | ||
attention_scores, ops.expand_dims(attention_mask, 1) | ||
) | ||
attention_scores = self.attention_dropout(attention_scores) | ||
attention_output = ops.einsum( | ||
"bnqk,bknh->bqnh", attention_scores, value | ||
) | ||
attention_output = ops.reshape( | ||
attention_output, | ||
[batch_size, seq_length, self.num_heads * self.head_dim], | ||
) # [batch_size, query_length, hidden_dim] | ||
|
||
attention_output = self.output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
|
||
return attention_output | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_heads": self.num_heads, | ||
"attention_dropout_rate": self.attention_dropout_rate, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.backend import keras | ||
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding | ||
from keras_nlp.models.backbone import Backbone | ||
from keras_nlp.models.falcon.falcon_transformer_decoder import ( | ||
FalconTransformerDecoder, | ||
) | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.FalconBackbone") | ||
class FalconBackbone(Backbone): | ||
"""The Falcon core architecure. | ||
This network implements a Transformer-based decoder-only network, | ||
[Falcon](https://arxiv.org/abs/2306.01116). | ||
Args: | ||
vocabulary_size: int. The size of the token vocabulary. | ||
num_layers: int. The number of transformer layers. | ||
num_attention_heads: int. The number of attention heads for each transformer. | ||
The hidden size must be divisible by the number of attention heads. | ||
hidden_dim: int. The dimensionality of the embeddings and hidden states. | ||
intermediate_dim: int. The output dimension of the first Dense layer in | ||
the MLP network of each transformer. | ||
layer_norm_epsilon: float. Epsilon for the layer normalization layers in | ||
the transformer decoder. | ||
attention_dropout_rate: float. Dropout probability for the attention. | ||
feedforward_dropout_rate: flaot. Dropout probability for the feedforward. | ||
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
for model computations and weights. Note that some computations, | ||
such as softmax and layer normalization, will always be done at | ||
float32 precision regardless of dtype. | ||
Examples: | ||
```python | ||
input_data = { | ||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
} | ||
# Pretrained Falcon decoder. | ||
# TODO: Update the preset. | ||
model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset") | ||
model(input_data) | ||
# Randomly initialized Falcon decoder with a custom config. | ||
model = keras_nlp.models.FalconBackbone( | ||
vocabulary_size=10, | ||
num_layers=2, | ||
num_attention_heads=2, | ||
hidden_dim=32, | ||
intermediate_dim=32*4, | ||
layer_norm_epsilon=1e-5, | ||
attention_dropout_rate=0, | ||
feedforward_dropout_rate=0, | ||
dtype="float32", | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocabulary_size, | ||
num_layers, | ||
num_attention_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
layer_norm_epsilon=1e-5, | ||
attention_dropout_rate=0, | ||
feedforward_dropout_rate=0, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.token_embedding = ReversibleEmbedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
dtype=dtype, | ||
name="token_embedding", | ||
) | ||
|
||
self.transformer_layers = [] | ||
for i in range(num_layers): | ||
layer = FalconTransformerDecoder( | ||
num_attention_heads=num_attention_heads, | ||
intermediate_dim=intermediate_dim, | ||
attention_dropout_rate=attention_dropout_rate, | ||
feedforward_dropout_rate=feedforward_dropout_rate, | ||
dtype=dtype, | ||
name=f"transformer_layer_{i}", | ||
) | ||
self.transformer_layers.append(layer) | ||
|
||
self.final_layernorm = keras.layers.LayerNormalization( | ||
epsilon=layer_norm_epsilon, | ||
dtype=dtype, | ||
name="final_layernorm", | ||
) | ||
|
||
# === Functional Model === | ||
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") | ||
padding_mask = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
# Embed Tokens. | ||
x = self.token_embedding(token_ids) | ||
|
||
# Apply successive transformer decoder blocks. | ||
for transformer_layer in self.transformer_layers: | ||
x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask) | ||
sequence_output = self.final_layernorm(x) | ||
|
||
super().__init__( | ||
inputs={ | ||
"token_ids": token_ids, | ||
"padding_mask": padding_mask, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.num_layers = num_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.attention_dropout_rate = attention_dropout_rate | ||
self.feedforward_dropout_rate = feedforward_dropout_rate | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"num_layers": self.num_layers, | ||
"num_attention_heads": self.num_attention_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"attention_dropout_rate": self.attention_dropout_rate, | ||
"feedforward_dropout_rate": self.feedforward_dropout_rate, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pytest | ||
|
||
from keras_nlp.backend import ops | ||
from keras_nlp.models.falcon.falcon_backbone import FalconBackbone | ||
from keras_nlp.tests.test_case import TestCase | ||
|
||
|
||
class FalconBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"vocabulary_size": 10, | ||
"num_layers": 2, | ||
"num_attention_heads": 8, | ||
"hidden_dim": 16, | ||
"intermediate_dim": 32, | ||
} | ||
self.input_data = { | ||
"token_ids": ops.ones((2, 5), dtype="int32"), | ||
"padding_mask": ops.ones((2, 5), dtype="int32"), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=FalconBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 5, 16), | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=FalconBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.