Skip to content

Commit

Permalink
Add FalconBackbone (keras-team#1475)
Browse files Browse the repository at this point in the history
* Add Falcon backbone.

* Add docstring.

* Add dtype.

* Add checkpoint conversion script.

* Fix tests.

* Random fixes.

* Add cache.

* Cast cumsum to int32.

* Make sublayers public.

* Address backbone comments.

* Update attention computation to use einsum.

* Falcon only works with Keras3.

* Fix tests.

* Remove falcon_causal_lm file.

* Remove commented/unused codes.
  • Loading branch information
SamanehSaadat authored and abuelnasr0 committed Apr 2, 2024
1 parent 414b4f4 commit 8590c22
Show file tree
Hide file tree
Showing 6 changed files with 870 additions and 0 deletions.
13 changes: 13 additions & 0 deletions keras_nlp/models/falcon/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
156 changes: 156 additions & 0 deletions keras_nlp/models/falcon/falcon_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

from keras_nlp.backend import keras
from keras_nlp.backend import ops


class FalconAttention(keras.layers.Layer):
def __init__(
self,
num_heads,
attention_dropout_rate,
**kwargs,
):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention_dropout_rate = attention_dropout_rate

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
# q = query length
# m = model dim
# n = num attention heads
# h = head dim
# k = key/value length

batch_size, seq_length, hidden_dim = inputs_shape

self.head_dim = hidden_dim // self.num_heads

# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,mnh->bqnh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="query_dense",
)
self.query_dense.build(inputs_shape)

self.key_dense = keras.layers.EinsumDense(
equation="bkm,mnh->bknh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="key_dense",
)
self.key_dense.build(inputs_shape)

self.value_dense = keras.layers.EinsumDense(
equation="bkm,mnh->bknh",
output_shape=(None, self.num_heads, self.head_dim),
bias_axes="nh",
dtype=self.dtype_policy,
name="value_dense",
)
self.value_dense.build(inputs_shape)

self.attention_dropout = keras.layers.Dropout(
rate=self.attention_dropout_rate,
dtype=self.dtype_policy,
name="attention_dropout",
)

self.output_dense = keras.layers.Dense(
hidden_dim,
dtype=self.dtype_policy,
name="output_dense",
)
self.output_dense.build(inputs_shape)

self.softmax = keras.layers.Softmax(dtype="float32", name="softmax")

self.built = True

def call(
self,
inputs,
alibi,
attention_mask=None,
cache=None,
cache_update_index=None,
):
batch_size, seq_length, hidden_dim = ops.shape(inputs)

query = self.query_dense(inputs)
key = self.key_dense(inputs)
value = self.value_dense(inputs)

if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key)
value = ops.slice_update(value_cache, start, value)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)

attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key)
attention_scores = ops.add(attention_scores, alibi)
attention_scores = (
attention_scores * self.inv_norm_factor
) # [batch_size, num_heads, query_length, kv_length]
attention_scores = self.softmax(
attention_scores, ops.expand_dims(attention_mask, 1)
)
attention_scores = self.attention_dropout(attention_scores)
attention_output = ops.einsum(
"bnqk,bknh->bqnh", attention_scores, value
)
attention_output = ops.reshape(
attention_output,
[batch_size, seq_length, self.num_heads * self.head_dim],
) # [batch_size, query_length, hidden_dim]

attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache

return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_heads": self.num_heads,
"attention_dropout_rate": self.attention_dropout_rate,
}
)
return config
160 changes: 160 additions & 0 deletions keras_nlp/models/falcon/falcon_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.falcon.falcon_transformer_decoder import (
FalconTransformerDecoder,
)


@keras_nlp_export("keras_nlp.models.FalconBackbone")
class FalconBackbone(Backbone):
"""The Falcon core architecure.
This network implements a Transformer-based decoder-only network,
[Falcon](https://arxiv.org/abs/2306.01116).
Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_attention_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The dimensionality of the embeddings and hidden states.
intermediate_dim: int. The output dimension of the first Dense layer in
the MLP network of each transformer.
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
the transformer decoder.
attention_dropout_rate: float. Dropout probability for the attention.
feedforward_dropout_rate: flaot. Dropout probability for the feedforward.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.
Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}
# Pretrained Falcon decoder.
# TODO: Update the preset.
model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset")
model(input_data)
# Randomly initialized Falcon decoder with a custom config.
model = keras_nlp.models.FalconBackbone(
vocabulary_size=10,
num_layers=2,
num_attention_heads=2,
hidden_dim=32,
intermediate_dim=32*4,
layer_norm_epsilon=1e-5,
attention_dropout_rate=0,
feedforward_dropout_rate=0,
dtype="float32",
)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_attention_heads,
hidden_dim,
intermediate_dim,
layer_norm_epsilon=1e-5,
attention_dropout_rate=0,
feedforward_dropout_rate=0,
dtype=None,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
dtype=dtype,
name="token_embedding",
)

self.transformer_layers = []
for i in range(num_layers):
layer = FalconTransformerDecoder(
num_attention_heads=num_attention_heads,
intermediate_dim=intermediate_dim,
attention_dropout_rate=attention_dropout_rate,
feedforward_dropout_rate=feedforward_dropout_rate,
dtype=dtype,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)

self.final_layernorm = keras.layers.LayerNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="final_layernorm",
)

# === Functional Model ===
token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids")
padding_mask = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
# Embed Tokens.
x = self.token_embedding(token_ids)

# Apply successive transformer decoder blocks.
for transformer_layer in self.transformer_layers:
x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask)
sequence_output = self.final_layernorm(x)

super().__init__(
inputs={
"token_ids": token_ids,
"padding_mask": padding_mask,
},
outputs=sequence_output,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.attention_dropout_rate = attention_dropout_rate
self.feedforward_dropout_rate = feedforward_dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"attention_dropout_rate": self.attention_dropout_rate,
"feedforward_dropout_rate": self.feedforward_dropout_rate,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config
49 changes: 49 additions & 0 deletions keras_nlp/models/falcon/falcon_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from keras_nlp.backend import ops
from keras_nlp.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.tests.test_case import TestCase


class FalconBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_attention_heads": 8,
"hidden_dim": 16,
"intermediate_dim": 32,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
"padding_mask": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=FalconBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 16),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=FalconBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading

0 comments on commit 8590c22

Please sign in to comment.