diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0df37b1230..eb7a6ac0c5 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -21,3 +21,6 @@ updates: python: patterns: - "*" + ignore: + # ignore all updates for JAX GPU due to cuda version issue + - dependency-name: "jax[cuda12_pip]" diff --git a/.github/workflows/auto-assignment.yml b/.github/workflows/auto-assignment.yml new file mode 100644 index 0000000000..de72da8ba2 --- /dev/null +++ b/.github/workflows/auto-assignment.yml @@ -0,0 +1,21 @@ +name: auto-assignment +on: + issues: + types: + - opened + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/auto-assignment.js') + script({github, context}) diff --git a/.github/workflows/labeler.yaml b/.github/workflows/labeler.yaml new file mode 100644 index 0000000000..6832d9acb0 --- /dev/null +++ b/.github/workflows/labeler.yaml @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This workflow automatically identifies issues and pull requests (PRs) +# related to Gemma. It searches for the keyword "Gemma" (case-insensitive) +# in both the title and description of the issue/PR. If a match is found, +# the workflow adds the label 'Gemma' to the issue/PR. + +name: 'Labeler' +on: + issues: + types: [edited, opened] + pull_request_target: + types: [opened, edited] + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/labeler.js') + script({github, context}) diff --git a/.github/workflows/scripts/auto-assignment.js b/.github/workflows/scripts/auto-assignment.js new file mode 100644 index 0000000000..176b305f39 --- /dev/null +++ b/.github/workflows/scripts/auto-assignment.js @@ -0,0 +1,43 @@ +/** Automatically assign issues and PRs to users in the `assigneesList` + * on a rotating basis. + + @param {!object} + GitHub objects can call GitHub APIs using their built-in library functions. + The context object contains issue and PR details. +*/ + +module.exports = async ({ github, context }) => { + let issueNumber; + let assigneesList; + // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number. + if (context.payload.issue) { + //assignee List for issues. + assigneesList = ["SuryanarayanaY", "sachinprasadhs"]; + issueNumber = context.payload.issue.number; + } else { + //assignee List for PRs. + assigneesList = [mattdangerw]; + issueNumber = context.payload.number; + } + console.log("assignee list", assigneesList); + console.log("entered auto assignment for this issue: ", issueNumber); + if (!assigneesList.length) { + console.log("No assignees found for this repo."); + return; + } + let noOfAssignees = assigneesList.length; + let selection = issueNumber % noOfAssignees; + let assigneeForIssue = assigneesList[selection]; + + console.log( + "issue Number = ", + issueNumber + " , assigning to: ", + assigneeForIssue + ); + return github.rest.issues.addAssignees({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + assignees: [assigneeForIssue], + }); +}; diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js new file mode 100644 index 0000000000..aa4178645e --- /dev/null +++ b/.github/workflows/scripts/labeler.js @@ -0,0 +1,53 @@ +/* +Copyright 2024 Google LLC. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + + +/** + * Invoked from labeler.yaml file to add + * label 'Gemma' to the issue and PR for which have gemma keyword present. + * @param {!Object.} github contains pre defined functions. + * context Information about the workflow run. + */ + +module.exports = async ({ github, context }) => { + const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title + let issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body + const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number + const keyword_label = { + gemma:'Gemma' + } + const labelsToAdd = [] + console.log(issue_title,issue_description,issue_number) + if (issue_description==null) + { + issue_description = '' + } + + for(const [keyword, label] of Object.entries(keyword_label)){ + if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){ + console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) + labelsToAdd.push(label) + } + } + if(labelsToAdd.length > 0){ + console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`) + github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: labelsToAdd + }) + } +}; diff --git a/.github/workflows/stale-issue-pr.yml b/.github/workflows/stale-issue-pr.yml new file mode 100644 index 0000000000..034fb4c266 --- /dev/null +++ b/.github/workflows/stale-issue-pr.yml @@ -0,0 +1,50 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Awaiting response issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 14 + days-before-issue-close: 14 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: completed + only-labels: "stat:awaiting response from contributor" + stale-issue-message: > + This issue is stale because it has been open for 14 days with no activity. + It will be closed if no further activity occurs. Thank you. + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:awaiting response from contributor" + close-issue-message: > + This issue was closed because it has been inactive for 28 days. + Please reopen if you'd like to work on this further. + days-before-pr-stale: 14 + days-before-pr-close: 14 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you." + close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further." + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Contribution issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 180 + days-before-issue-close: 365 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: not_planned + any-of-labels: "stat:contributions welcome,good first issue" + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:contributions welcome,good first issue" + stale-issue-message: > + This issue is stale because it has been open for 180 days with no activity. + It will be closed if no further activity occurs. Thank you. + close-issue-message: > + This issue was closed because it has been inactive for more than 1 year. + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 2017b77c82..04c6c7c454 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -1,9 +1,20 @@ set -e -set -x -cd "${KOKORO_ROOT}/" +export KAGGLE_KEY="$(cat ${KOKORO_KEYSTORE_DIR}/73361_keras_kaggle_secret_key)" +export KAGGLE_USERNAME="$(cat ${KOKORO_KEYSTORE_DIR}/73361_keras_kaggle_username)" + +if [[ -z "${KAGGLE_KEY}" ]]; then + echo "KAGGLE_KEY is NOT set" + exit 1 +fi -sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 +if [[ -z "${KAGGLE_USERNAME}" ]]; then + echo "KAGGLE_USERNAME is NOT set" + exit 1 +fi + +set -x +cd "${KOKORO_ROOT}/" PYTHON_BINARY="/usr/bin/python3.9" @@ -24,23 +35,25 @@ pip install -U pip setuptools psutil if [ "${KERAS2:-0}" == "1" ] then echo "Keras2 detected." - pip install -r requirements-common.txt --progress-bar off - pip install tensorflow-text==2.15 tensorflow[and-cuda]~=2.15 keras-core + pip install -r requirements-common.txt --progress-bar off --timeout 1000 + pip install tensorflow-text==2.15 tensorflow[and-cuda]~=2.15 keras-core \ + --timeout 1000 elif [ "$KERAS_BACKEND" == "tensorflow" ] then echo "TensorFlow backend detected." - pip install -r requirements-tensorflow-cuda.txt --progress-bar off + pip install -r requirements-tensorflow-cuda.txt --progress-bar off \ + --timeout 1000 elif [ "$KERAS_BACKEND" == "jax" ] then echo "JAX backend detected." - pip install -r requirements-jax-cuda.txt --progress-bar off + pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000 elif [ "$KERAS_BACKEND" == "torch" ] then echo "PyTorch backend detected." - pip install -r requirements-torch-cuda.txt --progress-bar off + pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000 fi pip install --no-deps -e "." --progress-bar off diff --git a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg index 1b9ffb605a..63351a9f40 100644 --- a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "jax" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg index 1b9ffb605a..63351a9f40 100644 --- a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "jax" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg index 7e971ac96d..03fa92222a 100644 --- a/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "1" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg index 7e971ac96d..03fa92222a 100644 --- a/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "1" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg index b85ee6f4eb..c707593fcb 100644 --- a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "tensorflow" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg index b85ee6f4eb..c707593fcb 100644 --- a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "tensorflow" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg index 5d25106b3f..345159e802 100644 --- a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg +++ b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg @@ -12,5 +12,23 @@ env_vars: { value: "torch" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg index 5d25106b3f..345159e802 100644 --- a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg +++ b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg @@ -12,5 +12,23 @@ env_vars: { value: "torch" } +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_username" + } + } +} + +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73361 + keyname: "keras_kaggle_secret_key" + } + } +} + # Set timeout to 60 mins from default 180 mins timeout_mins: 60 \ No newline at end of file diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 3db287de99..335f7ade97 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -116,7 +116,7 @@ class PositionEmbedding(keras.layers.Layer): Args: sequence_length: The maximum length of the dynamic sequence. - Examples: + Example: Direct call. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/__init__.py b/keras_nlp/__init__.py index 30f8a53b16..407a4b7a71 100644 --- a/keras_nlp/__init__.py +++ b/keras_nlp/__init__.py @@ -26,5 +26,6 @@ from keras_nlp import samplers from keras_nlp import tokenizers from keras_nlp import utils +from keras_nlp.utils.preset_utils import upload_preset from keras_nlp.version_utils import __version__ from keras_nlp.version_utils import version diff --git a/keras_nlp/conftest.py b/keras_nlp/conftest.py index b876a7a0a8..d66cad5d42 100644 --- a/keras_nlp/conftest.py +++ b/keras_nlp/conftest.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest import tensorflow as tf @@ -83,6 +85,10 @@ def pytest_configure(config): "markers", "keras_3_only: mark test as a keras 3 only test", ) + config.addinivalue_line( + "markers", + "kaggle_key_required: mark test needing a kaggle key", + ) def pytest_collection_modifyitems(config, items): @@ -107,6 +113,16 @@ def pytest_collection_modifyitems(config, items): not backend_config.keras_3(), reason="tests only run on with multi-backend keras", ) + found_kaggle_key = all( + [ + os.environ.get("KAGGLE_USERNAME", None), + os.environ.get("KAGGLE_KEY", None), + ] + ) + kaggle_key_required = pytest.mark.skipif( + not found_kaggle_key, + reason="tests only run with a kaggle api key", + ) for item in items: if "large" in item.keywords: item.add_marker(skip_large) @@ -116,6 +132,8 @@ def pytest_collection_modifyitems(config, items): item.add_marker(tf_only) if "keras_3_only" in item.keywords: item.add_marker(keras_3_only) + if "kaggle_key_required" in item.keywords: + item.add_marker(kaggle_key_required) # Disable traceback filtering for quicker debugging of tests failures. diff --git a/keras_nlp/layers/modeling/alibi_bias.py b/keras_nlp/layers/modeling/alibi_bias.py index 8a66ad05af..c5f8706f9d 100644 --- a/keras_nlp/layers/modeling/alibi_bias.py +++ b/keras_nlp/layers/modeling/alibi_bias.py @@ -35,12 +35,15 @@ class AlibiBias(keras.layers.Layer): each head. The heads' slopes are a geometric sequence that starts at `2**(-alibi_bias_max/num_heads)` and uses that same value as its ratio. Defaults to 8. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. + Call arguments: attention_scores: The result of multipying the query and the key of the multi-head attention layer of the transformer to add alibi bias to it. With shape `(batch_size, num_heads, query_length, key_length)`. - Examples: + Example: ```python query_length = 10 key_length = 10 @@ -94,7 +97,9 @@ def _get_alibi_bias(self, num_heads, key_length): ) slopes = ops.expand_dims(slopes, 1) - seq_range = ops.expand_dims(ops.arange(1 - key_length, 1), 0) + seq_range = ops.expand_dims( + ops.arange(1 - key_length, 1, dtype="int32"), 0 + ) seq_range = ops.cast(seq_range, dtype=self.compute_dtype) alibi_bias = ops.multiply(slopes, seq_range) diff --git a/keras_nlp/layers/modeling/alibi_bias_test.py b/keras_nlp/layers/modeling/alibi_bias_test.py index 120c7622b1..69a48f29e1 100644 --- a/keras_nlp/layers/modeling/alibi_bias_test.py +++ b/keras_nlp/layers/modeling/alibi_bias_test.py @@ -99,7 +99,6 @@ def test_correct_output(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias() output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -127,7 +126,6 @@ def test_correct_output_num_heads_not_power_of_two(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias() output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -162,7 +160,6 @@ def test_correct_output_alibi_bias_max(self): input_tensor = ops.zeros(input_shape) layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( @@ -187,7 +184,6 @@ def test_correct_output_alibi_bias_max_num_heads_not_power_of_two( input_tensor = ops.zeros(input_shape) layer = AlibiBias(alibi_bias_max=alibi_bias_max) output_tensor = layer(input_tensor) - print(output_tensor) self.assertAllClose( output_tensor, ops.convert_to_tensor( diff --git a/keras_nlp/layers/modeling/f_net_encoder.py b/keras_nlp/layers/modeling/f_net_encoder.py index a5370d960e..919e3beb08 100644 --- a/keras_nlp/layers/modeling/f_net_encoder.py +++ b/keras_nlp/layers/modeling/f_net_encoder.py @@ -47,10 +47,10 @@ class FNetEncoder(keras.layers.Layer): bias_initializer: "string" or `keras.initializers` initializer. The bias initializer for the dense layers. Defaults to `"zeros"`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single FNet encoder layer. @@ -79,10 +79,9 @@ def __init__( layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", bias_initializer="zeros", - name=None, **kwargs ): - super().__init__(name=name, **kwargs) + super().__init__(**kwargs) self.intermediate_dim = intermediate_dim self.dropout = dropout self.activation = keras.activations.get(activation) diff --git a/keras_nlp/layers/modeling/masked_lm_head.py b/keras_nlp/layers/modeling/masked_lm_head.py index eacee7e8c0..9f9397cb70 100644 --- a/keras_nlp/layers/modeling/masked_lm_head.py +++ b/keras_nlp/layers/modeling/masked_lm_head.py @@ -59,8 +59,10 @@ class MaskedLMHead(keras.layers.Layer): bias_initializer: string or `keras.initializers` initializer. The bias initializer for the dense and multiheaded attention layers. Defaults to `"zeros"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python batch_size = 16 diff --git a/keras_nlp/layers/modeling/position_embedding.py b/keras_nlp/layers/modeling/position_embedding.py index 6f9a44c29f..9f6b314b96 100644 --- a/keras_nlp/layers/modeling/position_embedding.py +++ b/keras_nlp/layers/modeling/position_embedding.py @@ -33,6 +33,8 @@ class PositionEmbedding(keras.layers.Layer): initializer: The initializer to use for the embedding weights. Defaults to `"glorot_uniform"`. seq_axis: The axis of the input tensor where we add the embeddings. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to compute an embedding for, with shape @@ -43,7 +45,7 @@ class PositionEmbedding(keras.layers.Layer): compute the position embedding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example: Called directly on input. >>> layer = keras_nlp.layers.PositionEmbedding(sequence_length=10) diff --git a/keras_nlp/layers/modeling/reversible_embedding.py b/keras_nlp/layers/modeling/reversible_embedding.py index baa5fb7027..1fa5f5f903 100644 --- a/keras_nlp/layers/modeling/reversible_embedding.py +++ b/keras_nlp/layers/modeling/reversible_embedding.py @@ -52,6 +52,8 @@ class ReversibleEmbedding(keras.layers.Embedding): reverse_dtype: The dtype for the reverse projection computation. For stability, it is usually best to use full precision even when working with half or mixed precision training. + **kwargs: other keyword arguments passed to `keras.layers.Embedding`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to the layer. @@ -59,7 +61,7 @@ class ReversibleEmbedding(keras.layers.Embedding): from `output_dim` to `input_dim`, instead of a normal embedding call. Default to `False`. - Examples: + Example: ```python batch_size = 16 vocab_size = 100 @@ -73,7 +75,7 @@ class ReversibleEmbedding(keras.layers.Embedding): # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`. hidden_states = embedding(token_ids) # Project hidden states to shape `(batch_size, seq_length, vocab_size)`. - logits = embedding(hidden_state, reverse=True) + logits = embedding(hidden_states, reverse=True) ``` References: diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 45f77ce494..1442548ea8 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -38,6 +38,8 @@ class RotaryEmbedding(keras.layers.Layer): scaling_factor: float. The scaling factor used to scale frequency range. sequence_axis: int. Sequence axis in the input tensor. feature_axis: int. Feature axis in the input tensor. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to apply the embedding to. This can have @@ -85,30 +87,42 @@ def __init__( self.built = True def call(self, inputs, start_index=0): + inputs = ops.moveaxis( + inputs, (self.feature_axis, self.sequence_axis), (-1, 1) + ) cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) - return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + return ops.moveaxis( + output, (-1, 1), (self.feature_axis, self.sequence_axis) + ) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - x1, x2 = ops.split(tensor, 2, axis=self.feature_axis) - half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) + x1, x2 = ops.split(tensor, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + half_rot_tensor = ops.stack((-x2, x1), axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, inputs, start_index=0): - def get_axis(axis): - return axis if axis > 0 else len(inputs.shape) + axis + start_index = ops.cast(start_index, dtype="float32") - feature_axis = get_axis(self.feature_axis) - sequence_axis = get_axis(self.sequence_axis) + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 rotary_dim = ops.shape(inputs)[feature_axis] inverse_freq = self._get_inverse_freq(rotary_dim) - seq_len = ops.shape(inputs)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + seq_len = ops.shape(inputs)[sequence_axis] + tensor = ops.arange(seq_len, dtype="float32") + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i,j->ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=-1) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) # Reshape the embedding to be broadcastable with input shape. if feature_axis < sequence_axis: @@ -117,17 +131,16 @@ def get_axis(axis): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) - return ops.cos(embedding), ops.sin(embedding) + cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) + sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) + return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) + freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = freq_range / ops.cast(self.scaling_factor, "float32") inverse_freq = 1.0 / ( self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ** (freq_range / ops.cast(rotary_dim, "float32")) ) return inverse_freq diff --git a/keras_nlp/layers/modeling/sine_position_encoding.py b/keras_nlp/layers/modeling/sine_position_encoding.py index 6e96a77e2c..b1cd7fbf42 100644 --- a/keras_nlp/layers/modeling/sine_position_encoding.py +++ b/keras_nlp/layers/modeling/sine_position_encoding.py @@ -34,6 +34,8 @@ class SinePositionEncoding(keras.layers.Layer): max_wavelength: The maximum angular wavelength of the sine/cosine curves, as described in Attention is All You Need. Defaults to `10000`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. Call arguments: inputs: The tensor inputs to compute an embedding for, with shape @@ -42,7 +44,7 @@ class SinePositionEncoding(keras.layers.Layer): compute the encoding from. This is useful during cached decoding, where each position is predicted separately in a loop. - Examples: + Example: ```python # create a simple embedding layer with sinusoidal positional encoding seq_len = 100 diff --git a/keras_nlp/layers/modeling/token_and_position_embedding.py b/keras_nlp/layers/modeling/token_and_position_embedding.py index bb7107f96f..8261cc7f34 100644 --- a/keras_nlp/layers/modeling/token_and_position_embedding.py +++ b/keras_nlp/layers/modeling/token_and_position_embedding.py @@ -33,6 +33,9 @@ class TokenAndPositionEmbedding(keras.layers.Layer): vocabulary_size: The size of the vocabulary. sequence_length: The maximum length of input sequence embedding_dim: The output dimension of the embedding layer + tie_weights: Boolean, whether or not the matrix for embedding and + the matrix for the `reverse` projection should share the same + weights. embeddings_initializer: The initializer to use for the Embedding Layers mask_zero: Boolean, whether or not the input value 0 is a special @@ -43,8 +46,10 @@ class TokenAndPositionEmbedding(keras.layers.Layer): If mask_zero` is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1). + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python inputs = np.ones(shape=(1, 50), dtype="int32") embedding_layer = keras_nlp.layers.TokenAndPositionEmbedding( diff --git a/keras_nlp/layers/modeling/transformer_decoder.py b/keras_nlp/layers/modeling/transformer_decoder.py index 15c245768c..b8f797f2e2 100644 --- a/keras_nlp/layers/modeling/transformer_decoder.py +++ b/keras_nlp/layers/modeling/transformer_decoder.py @@ -34,12 +34,9 @@ class TransformerDecoder(keras.layers.Layer): paper [Attention is All You Need](https://arxiv.org/abs/1706.03762). Users can instantiate multiple instances of this class to stack up a decoder. - By default, this layer will apply a causal mask to the decoder attention layer. - This layer will correctly compute an attention mask from an implicit - Keras padding mask (for example, by passing `mask_zero=True` to a - `keras.layers.Embedding` layer). See the Masking and Padding - [guide](https://keras.io/guides/understanding_masking_and_padding/) - for more details. + By default, this layer will apply a causal mask to the decoder attention + layer. You can also pass padding or attention masks directly to the layer + during call, e.g. with `decoder_padding_mask` or `decoder_attention_mask`. This layer can be called with either one or two inputs. The number of inputs must be consistent across all calls. The options are as follows: @@ -72,10 +69,10 @@ class TransformerDecoder(keras.layers.Layer): (similar to GPT-2). If set to False, outputs of attention layer and intermediate dense layer are normalized (similar to BERT). Defaults to `False`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single transformer decoder layer. decoder = keras_nlp.layers.TransformerDecoder( diff --git a/keras_nlp/layers/modeling/transformer_encoder.py b/keras_nlp/layers/modeling/transformer_encoder.py index 32cdd35547..20cec4ecf1 100644 --- a/keras_nlp/layers/modeling/transformer_encoder.py +++ b/keras_nlp/layers/modeling/transformer_encoder.py @@ -58,10 +58,10 @@ class TransformerEncoder(keras.layers.Layer): (similar to GPT-2). If set to False, outputs of attention layer and intermediate dense layer are normalized (similar to BERT). Defaults to `False`. - name: string. The name of the layer. Defaults to `None`. - **kwargs: other keyword arguments. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `trainable`, `dtype` etc. - Examples: + Example: ```python # Create a single transformer encoder layer. diff --git a/keras_nlp/layers/modeling/transformer_layer_utils.py b/keras_nlp/layers/modeling/transformer_layer_utils.py index 863da59a36..f375bf1b9d 100644 --- a/keras_nlp/layers/modeling/transformer_layer_utils.py +++ b/keras_nlp/layers/modeling/transformer_layer_utils.py @@ -55,9 +55,12 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0): `(batch_size, output_length, input_length)` that can be passed to a attention layer. """ - i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index - j = ops.arange(input_length) - mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0) + i = ops.arange(output_length, dtype="float32") + i = i + ops.cast(cache_index, "float32") + i = ops.expand_dims(i, axis=1) + j = ops.arange(input_length, dtype="float32") + mask = ops.expand_dims(i >= j, axis=0) + return ops.broadcast_to(mask, (batch_size, output_length, input_length)) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 8fd6a70ac0..4139656fbf 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -20,6 +20,7 @@ ) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.models.backbone import Backbone from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM @@ -36,7 +37,14 @@ from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor from keras_nlp.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_preprocessor import BloomPreprocessor from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.models.causal_lm import CausalLM +from keras_nlp.models.classifier import Classifier from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_classifier import ( DebertaV3Classifier, @@ -66,6 +74,7 @@ DistilBertTokenizer, ) from keras_nlp.models.electra.electra_backbone import ElectraBackbone +from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_classifier import FNetClassifier @@ -75,6 +84,15 @@ ) from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( @@ -92,6 +110,13 @@ ) from keras_nlp.models.gpt_neo_x.gpt_neo_x_tokenizer import GPTNeoXTokenizer from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm import MistralCausalLM from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( @@ -106,6 +131,7 @@ ) from keras_nlp.models.opt.opt_preprocessor import OPTPreprocessor from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier from keras_nlp.models.roberta.roberta_masked_lm import RobertaMaskedLM @@ -114,8 +140,10 @@ ) from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.models.task import Task from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) @@ -139,3 +167,4 @@ XLMRobertaTokenizer, ) from keras_nlp.models.xlnet.xlnet_backbone import XLNetBackbone +from keras_nlp.tokenizers.tokenizer import Tokenizer diff --git a/keras_nlp/models/albert/__init__.py b/keras_nlp/models/albert/__init__.py index ba0c2545e4..c0ae8e8fa1 100644 --- a/keras_nlp/models/albert/__init__.py +++ b/keras_nlp/models/albert/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.albert.albert_backbone import AlbertBackbone +from keras_nlp.models.albert.albert_presets import backbone_presets +from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (AlbertBackbone, AlbertTokenizer)) diff --git a/keras_nlp/models/albert/albert_backbone.py b/keras_nlp/models/albert/albert_backbone.py index 1e342e791c..e8acef81d5 100644 --- a/keras_nlp/models/albert/albert_backbone.py +++ b/keras_nlp/models/albert/albert_backbone.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.backbone import Backbone from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def albert_kernel_initializer(stddev=0.02): @@ -77,7 +73,7 @@ class AlbertBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), @@ -230,6 +226,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) @@ -265,7 +262,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_classifier.py b/keras_nlp/models/albert/albert_classifier.py index 32a4e0847d..7471393cc7 100644 --- a/keras_nlp/models/albert/albert_classifier.py +++ b/keras_nlp/models/albert/albert_classifier.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.albert.albert_backbone import AlbertBackbone from keras_nlp.models.albert.albert_backbone import albert_kernel_initializer from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor -from keras_nlp.models.albert.albert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.classifier import Classifier @keras_nlp_export("keras_nlp.models.AlbertClassifier") -class AlbertClassifier(Task): +class AlbertClassifier(Classifier): """An end-to-end ALBERT model for classification tasks This model attaches a classification head to a `keras_nlp.model.AlbertBackbone` @@ -146,6 +142,9 @@ class AlbertClassifier(Task): ``` """ + backbone_cls = AlbertBackbone + preprocessor_cls = AlbertPreprocessor + def __init__( self, backbone, @@ -209,15 +208,3 @@ def get_config(self): ) return config - - @classproperty - def backbone_cls(cls): - return AlbertBackbone - - @classproperty - def preprocessor_cls(cls): - return AlbertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets}) diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index 1958713b9f..01892bdcab 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead @@ -22,14 +20,12 @@ from keras_nlp.models.albert.albert_masked_lm_preprocessor import ( AlbertMaskedLMPreprocessor, ) -from keras_nlp.models.albert.albert_presets import backbone_presets -from keras_nlp.models.task import Task +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertMaskedLM") -class AlbertMaskedLM(Task): +class AlbertMaskedLM(MaskedLM): """An end-to-end ALBERT model for the masked language modeling task. This model will train ALBERT on a masked language modeling task. @@ -52,7 +48,7 @@ class AlbertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python @@ -96,6 +92,9 @@ class AlbertMaskedLM(Task): ``` """ + backbone_cls = AlbertBackbone + preprocessor_cls = AlbertMaskedLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -133,15 +132,3 @@ def __init__(self, backbone, preprocessor=None, **kwargs): weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return AlbertBackbone - - @classproperty - def preprocessor_cls(cls): - return AlbertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_preprocessor.py b/keras_nlp/models/albert/albert_preprocessor.py index 19f4bd9a7b..8e49d4e650 100644 --- a/keras_nlp/models/albert/albert_preprocessor.py +++ b/keras_nlp/models/albert/albert_preprocessor.py @@ -12,20 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertPreprocessor") @@ -149,6 +145,8 @@ class AlbertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = AlbertTokenizer + def __init__( self, tokenizer, @@ -205,11 +203,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return AlbertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/albert/albert_tokenizer.py b/keras_nlp/models/albert/albert_tokenizer.py index 44aed44cf5..8887893afc 100644 --- a/keras_nlp/models/albert/albert_tokenizer.py +++ b/keras_nlp/models/albert/albert_tokenizer.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.AlbertTokenizer") @@ -119,7 +115,3 @@ def set_proto(self, proto): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 6fccf6013a..c94fe3d68d 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,22 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Backbone") class Backbone(keras.Model): + """Base class for all `Backbone` models. + + A `Backbone` is the basic architecture for a given NLP model. Unlike a + `keras_nlp.models.Task`, a `Backbone` is not tailored to any specific loss + function and training setup. A `Backbone` generally outputs the last hidden + states of an architecture before any output predictions. + + A `Backbone` can be used in one of two ways: + + 1. Through a `Task` class, which will wrap and extend a `Backbone` so it + can be used with high level Keras functions like `fit()`, `predict()` or + `evaluate()`. `Task` classes are built with a particular training + objective in mind (e.g. classification or language modeling). + 2. Directly, by extending underlying functional model with additional + outputs and training setup. This is the most flexible approach, and can + allow for any outputs, loss, or custom training loop. + + All backbones include a `from_preset()` constructor which can be used to + load a pre-trained config and weights. + + Example: + ```python + # Load a BERT backbone with pre-trained weights. + backbone = keras_nlp.models.Backbone.from_preset( + "bert_base_en", + ) + # Load a GPT2 backbone with pre-trained weights at bfloat16 precision. + backbone = keras_nlp.models.Backbone.from_preset( + "gpt2_base_en", + dtype="bfloat16", + trainable=False, + ) + ``` + """ + def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) self._functional_layer_ids = set( id(layer) for layer in self._flatten_layers() ) self._initialized = True + if dtype is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + if isinstance(dtype, keras.DTypePolicy): + self.dtype_policy = dtype + else: + self.dtype_policy = keras.DTypePolicy(dtype) + else: + self._set_dtype_policy(dtype) def __dir__(self): if config.keras_3(): @@ -67,7 +114,7 @@ def token_embedding(self): This layer embeds integer token ids to the hidden dim of the model. """ - return self._token_embedding + return getattr(self, "_token_embedding", None) @token_embedding.setter def token_embedding(self, value): @@ -89,7 +136,11 @@ def from_config(cls, config): @classproperty def presets(cls): - return {} + """List built-in presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -98,57 +149,141 @@ def from_preset( load_weights=True, **kwargs, ): - """Instantiate {{model_name}} model from preset architecture and weights. + """Instantiate a `keras_nlp.models.Backbone` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.Backbone.from_preset()`, or from + a model class like `keras_nlp.models.GemmaBackbone.from_preset()`. + If calling from the base class, the subclass of the returning object + will be inferred from the config in the preset directory. + + For any `Backbone` subclass, you can run `cls.presets.keys()` to list + all built-in presets available on the class. Args: - preset: string. Must be one of "{{preset_names}}". - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. Examples: ```python - # Load architecture and weights from preset - model = keras_nlp.models.{{model_name}}.from_preset( - "{{example_preset_name}}" + # Load a Gemma backbone with pre-trained weights. + model = keras_nlp.models.Backbone.from_preset( + "gemma_2b_en", ) - # Load randomly initialized model from preset architecture - model = keras_nlp.models.{{model_name}}.from_preset( - "{{example_preset_name}}", - load_weights=False + # Load a Bert backbone with a pre-trained config and random weights. + model = keras_nlp.models.Backbone.from_preset( + "bert_base_en", + load_weights=False, ) ``` """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - check_preset_class(preset, cls) + preset_cls = check_config_class(preset) + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) return load_from_preset( preset, load_weights=load_weights, config_overrides=kwargs, ) - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) + def save_to_preset(self, preset): + """Save backbone to a preset directory. + + Args: + preset: The path to the local model preset directory. + """ + save_to_preset(self, preset) + + def enable_lora(self, rank): + """Enable Lora on the backbone. + + Calling this method will freeze all weights on the backbone, + while enabling Lora on the query & value `EinsumDense` layers + of the attention layers. + """ + target_names = ["query_dense", "value_dense", "query", "value"] + self.trainable = True + self._lora_enabled_layers = [] + self._lora_rank = rank + for layer in self._flatten_layers(include_self=False): + layer.trainable = False + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for i, layer in enumerate(all_layers): + for name in target_names: + if layer.name == name: + if hasattr(layer, "enable_lora"): + layer.trainable = True + layer.enable_lora(rank) + self._lora_enabled_layers.append(i) - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: + def save_lora_weights(self, filepath): + if not getattr(self, "_lora_enabled_layers", []): + raise ValueError( + "There are no lora-enabled layers in this model. " + "Make sure to call `.enable_lora(rank)` first." + ) + if not str(filepath).endswith(".lora.h5"): + raise ValueError( + "The filename must end in `.lora.h5`. " + f"Received: filepath={filepath}" + ) - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="w") + lora_store = store.make("lora") + lora_store["rank"] = self._lora_rank + # We cannot identify layers by name since names are non-unique, + # so we identify them by index in the topologically sorted list + # of layers that have weights. + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + # We only lora the einsumdense layers, + # so the factored weights are always named `kernel` + layer = all_layers[layer_index] + inner_store = store.make(f"lora/{layer_index}") + inner_store["lora_kernel_a"] = layer.lora_kernel_a + inner_store["lora_kernel_b"] = layer.lora_kernel_b + store.close() - cls.from_preset = classmethod(from_preset) + def load_lora_weights(self, filepath): + store = keras.src.saving.saving_lib.H5IOStore(filepath, mode="r") + lora_store = store.get("lora") + rank = int(lora_store["rank"][()]) - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Backbone.from_preset.__doc__ - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) + if not getattr(self, "_lora_enabled_layers", []): + self.enable_lora(rank) + else: + if self._lora_rank != rank: + raise ValueError( + f"The Lora rank expected by file '{filepath}' " + f"is rank={rank}, but the model was called with " + f"`.enable_lora(rank={self._lora_rank})`. " + "Both ranks must match." + ) + all_layers = self._flatten_layers(include_self=False) + all_layers = [lyr for lyr in all_layers if lyr.weights] + for layer_index in self._lora_enabled_layers: + layer = all_layers[layer_index] + lora_kernel_a = store.get(f"lora/{layer_index}")["lora_kernel_a"] + lora_kernel_b = store.get(f"lora/{layer_index}")["lora_kernel_b"] + layer.lora_kernel_a.assign(lora_kernel_a) + layer.lora_kernel_b.assign(lora_kernel_b) + store.close() diff --git a/keras_nlp/models/backbone_test.py b/keras_nlp/models/backbone_test.py new file mode 100644 index 0000000000..7e4cd24f27 --- /dev/null +++ b/keras_nlp/models/backbone_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.bert.bert_backbone import BertBackbone +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.tests.test_case import TestCase + + +class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertBackbone.presets.keys()) + gpt2_presets = set(GPT2Backbone.presets.keys()) + all_presets = set(Backbone.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + Backbone.from_preset("bert_tiny_en_uncased", load_weights=False), + BertBackbone, + ) + self.assertIsInstance( + Backbone.from_preset("gpt2_base_en", load_weights=False), + GPT2Backbone, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False) diff --git a/keras_nlp/models/bart/__init__.py b/keras_nlp/models/bart/__init__.py index 6e4df4e727..b14f1f06b6 100644 --- a/keras_nlp/models/bart/__init__.py +++ b/keras_nlp/models/bart/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The KerasNLP Authors +# Copyright 2023 The KerasNLP Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bart.bart_backbone import BartBackbone +from keras_nlp.models.bart.bart_presets import backbone_presets +from keras_nlp.models.bart.bart_tokenizer import BartTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BartBackbone, BartTokenizer)) diff --git a/keras_nlp/models/bart/bart_backbone.py b/keras_nlp/models/bart/bart_backbone.py index 803d5a2a9f..d10f1cc240 100644 --- a/keras_nlp/models/bart/bart_backbone.py +++ b/keras_nlp/models/bart/bart_backbone.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding @@ -21,8 +19,6 @@ from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bart.bart_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def bart_kernel_initializer(stddev=0.02): @@ -232,6 +228,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) @@ -259,7 +256,3 @@ def get_config(self): ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_preprocessor.py b/keras_nlp/models/bart/bart_preprocessor.py index 3310b1e532..276d73af04 100644 --- a/keras_nlp/models/bart/bart_preprocessor.py +++ b/keras_nlp/models/bart/bart_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.models.bart.bart_tokenizer import BartTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartPreprocessor") @@ -131,6 +128,8 @@ class BartPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BartTokenizer + def __init__( self, tokenizer, @@ -274,11 +273,3 @@ def sequence_length(self): @sequence_length.setter def sequence_length(self, value): self.decoder_sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return BartTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm.py b/keras_nlp/models/bart/bart_seq_2_seq_lm.py index c17eafdb02..e13e74b769 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.models.bart.bart_backbone import BartBackbone -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.models.bart.bart_seq_2_seq_lm_preprocessor import ( BartSeq2SeqLMPreprocessor, ) -from keras_nlp.models.generative_task import GenerativeTask -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.seq_2_seq_lm import Seq2SeqLM +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.BartSeq2SeqLM") -class BartSeq2SeqLM(GenerativeTask): +class BartSeq2SeqLM(Seq2SeqLM): """An end-to-end BART model for seq2seq language modeling. A seq2seq language model (LM) is an encoder-decoder model which is used for @@ -179,6 +177,9 @@ class BartSeq2SeqLM(GenerativeTask): ``` """ + backbone_cls = BartBackbone + preprocessor_cls = BartSeq2SeqLMPreprocessor + def __init__( self, backbone, @@ -207,18 +208,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return BartBackbone - - @classproperty - def preprocessor_cls(cls): - return BartSeq2SeqLMPreprocessor - def call_decoder_with_cache( self, encoder_hidden_states, @@ -398,7 +387,7 @@ def _build_cache( def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a batch of inputs. @@ -412,8 +401,8 @@ def generate_step( inputs: A dictionary with four keys - `"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"` and `"decoder_padding_mask"`, with batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ ( @@ -477,16 +466,18 @@ def repeat_tensor(x): cache=self_attention_cache, index=index, mask=decoder_padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original # prompt (not in locations where `decoder_padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(decoder_token_ids, end_token_id), + end_locations = any_equal( + decoder_token_ids, + stop_token_ids, ops.logical_not(decoder_padding_mask), ) end_locations = ops.cast(end_locations, "int32") diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 1c72e6e935..4d90fd87e8 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from absl import logging @@ -20,12 +19,10 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor") @@ -266,7 +263,3 @@ def generate_postprocess( decoder_token_ids, decoder_padding_mask ) return self.tokenizer.detokenize(decoder_token_ids) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bart/bart_tokenizer.py b/keras_nlp/models/bart/bart_tokenizer.py index 17fb237b88..c4e3d1204d 100644 --- a/keras_nlp/models/bart/bart_tokenizer.py +++ b/keras_nlp/models/bart/bart_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartTokenizer") @@ -118,10 +115,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.pad_token_id = None self.end_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/bart/bart_tokenizer_test.py b/keras_nlp/models/bart/bart_tokenizer_test.py index 5a0015357b..b18629939c 100644 --- a/keras_nlp/models/bart/bart_tokenizer_test.py +++ b/keras_nlp/models/bart/bart_tokenizer_test.py @@ -37,10 +37,9 @@ def test_tokenizer_basics(self): cls=BartTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/models/bert/__init__.py b/keras_nlp/models/bert/__init__.py index ba0c2545e4..a34b85d0eb 100644 --- a/keras_nlp/models/bert/__init__.py +++ b/keras_nlp/models/bert/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bert.bert_backbone import BertBackbone +from keras_nlp.models.bert.bert_classifier import BertClassifier +from keras_nlp.models.bert.bert_presets import backbone_presets +from keras_nlp.models.bert.bert_presets import classifier_presets +from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BertBackbone, BertTokenizer)) +register_presets(classifier_presets, (BertClassifier, BertTokenizer)) diff --git a/keras_nlp/models/bert/bert_backbone.py b/keras_nlp/models/bert/bert_backbone.py index 2248260da7..1f5a02dbb1 100644 --- a/keras_nlp/models/bert/bert_backbone.py +++ b/keras_nlp/models/bert/bert_backbone.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bert.bert_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def bert_kernel_initializer(stddev=0.02): @@ -196,6 +192,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) @@ -225,7 +222,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bert/bert_classifier.py b/keras_nlp/models/bert/bert_classifier.py index 09d2b8810c..27bb076ea9 100644 --- a/keras_nlp/models/bert/bert_classifier.py +++ b/keras_nlp/models/bert/bert_classifier.py @@ -12,21 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.models.bert.bert_backbone import BertBackbone from keras_nlp.models.bert.bert_backbone import bert_kernel_initializer from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.classifier import Classifier @keras_nlp_export("keras_nlp.models.BertClassifier") -class BertClassifier(Task): +class BertClassifier(Classifier): """An end-to-end BERT model for classification tasks. This model attaches a classification head to a @@ -131,6 +126,9 @@ class BertClassifier(Task): ``` """ + backbone_cls = BertBackbone + preprocessor_cls = BertPreprocessor + def __init__( self, backbone, @@ -193,15 +191,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return BertBackbone - - @classproperty - def preprocessor_cls(cls): - return BertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) diff --git a/keras_nlp/models/bert/bert_masked_lm.py b/keras_nlp/models/bert/bert_masked_lm.py index 17b9669619..1166963625 100644 --- a/keras_nlp/models/bert/bert_masked_lm.py +++ b/keras_nlp/models/bert/bert_masked_lm.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead @@ -22,13 +20,11 @@ from keras_nlp.models.bert.bert_masked_lm_preprocessor import ( BertMaskedLMPreprocessor, ) -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.BertMaskedLM") -class BertMaskedLM(Task): +class BertMaskedLM(MaskedLM): """An end-to-end BERT model for the masked language modeling task. This model will train BERT on a masked language modeling task. @@ -51,7 +47,7 @@ class BertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python @@ -95,6 +91,9 @@ class BertMaskedLM(Task): ``` """ + backbone_cls = BertBackbone + preprocessor_cls = BertMaskedLMPreprocessor + def __init__( self, backbone, @@ -139,15 +138,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return BertBackbone - - @classproperty - def preprocessor_cls(cls): - return BertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bert/bert_preprocessor.py b/keras_nlp/models/bert/bert_preprocessor.py index 02f5a45985..2975eae9c6 100644 --- a/keras_nlp/models/bert/bert_preprocessor.py +++ b/keras_nlp/models/bert/bert_preprocessor.py @@ -12,23 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.models.bert.bert_tokenizer import BertTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty - -PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets)) @keras_nlp_export("keras_nlp.models.BertPreprocessor") @@ -130,6 +123,8 @@ class BertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BertTokenizer + def __init__( self, tokenizer, @@ -186,11 +181,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return BertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) diff --git a/keras_nlp/models/bert/bert_tokenizer.py b/keras_nlp/models/bert/bert_tokenizer.py index 1b634fe9b3..819bf2f63f 100644 --- a/keras_nlp/models/bert/bert_tokenizer.py +++ b/keras_nlp/models/bert/bert_tokenizer.py @@ -12,15 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bert.bert_presets import backbone_presets -from keras_nlp.models.bert.bert_presets import classifier_presets from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_nlp.utils.python_utils import classproperty - -PRESET_NAMES = ", ".join(list(backbone_presets) + list(classifier_presets)) @keras_nlp_export("keras_nlp.models.BertTokenizer") @@ -49,6 +42,9 @@ class BertTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -76,6 +72,7 @@ def __init__( self, vocabulary=None, lowercase=False, + special_tokens_in_strings=False, **kwargs, ): self.cls_token = "[CLS]" @@ -85,6 +82,13 @@ def __init__( super().__init__( vocabulary=vocabulary, lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -92,15 +96,6 @@ def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -111,6 +106,7 @@ def set_vocabulary(self, vocabulary): self.pad_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy({**backbone_presets, **classifier_presets}) + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/bert/bert_tokenizer_test.py b/keras_nlp/models/bert/bert_tokenizer_test.py index e53419dab4..78b4417e13 100644 --- a/keras_nlp/models/bert/bert_tokenizer_test.py +++ b/keras_nlp/models/bert/bert_tokenizer_test.py @@ -39,6 +39,16 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = BertTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): BertTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_nlp/models/bloom/__init__.py b/keras_nlp/models/bloom/__init__.py index ba0c2545e4..2b7ad787b4 100644 --- a/keras_nlp/models/bloom/__init__.py +++ b/keras_nlp/models/bloom/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_presets import backbone_presets +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (BloomBackbone, BloomTokenizer)) diff --git a/keras_nlp/models/bloom/bloom_backbone.py b/keras_nlp/models/bloom/bloom_backbone.py index 5737dcc889..4e153b868d 100644 --- a/keras_nlp/models/bloom/bloom_backbone.py +++ b/keras_nlp/models/bloom/bloom_backbone.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone from keras_nlp.models.bloom.bloom_decoder import BloomDecoder -from keras_nlp.models.bloom.bloom_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def _bloom_kernel_initializer(stddev=0.02): @@ -53,14 +50,12 @@ class BloomBackbone(Backbone): dropout: float. Dropout probability for the Transformer decoder. layer_norm_epsilon: float. Epsilon for the layer normalization layers in the transformer decoder. - max_sequence_length: int. The maximum sequence length that this decoder - can consume. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), @@ -80,7 +75,6 @@ class BloomBackbone(Backbone): intermediate_dim=32*4, dropout=0.0, layer_norm_epsilon=1e-5, - max_sequence_length=128, ) model(input_data) ``` @@ -96,7 +90,6 @@ def __init__( intermediate_dim, dropout=0.0, layer_norm_epsilon=1e-5, - max_sequence_length=2048, dtype=None, **kwargs, ): @@ -105,7 +98,6 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, embeddings_initializer=_bloom_kernel_initializer(stddev=0.02), - tie_weights=False, dtype=dtype, name="token_embedding", ) @@ -149,6 +141,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) @@ -160,7 +153,6 @@ def __init__( self.intermediate_dim = intermediate_dim self.dropout = dropout self.layer_norm_epsilon = layer_norm_epsilon - self.max_sequence_length = max_sequence_length def get_config(self): config = super().get_config() @@ -173,11 +165,6 @@ def get_config(self): "intermediate_dim": self.intermediate_dim, "dropout": self.dropout, "layer_norm_epsilon": self.layer_norm_epsilon, - "max_sequence_length": self.max_sequence_length, } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/bloom/bloom_backbone_test.py b/keras_nlp/models/bloom/bloom_backbone_test.py index 83732e4945..47ff7ec4cc 100644 --- a/keras_nlp/models/bloom/bloom_backbone_test.py +++ b/keras_nlp/models/bloom/bloom_backbone_test.py @@ -27,7 +27,6 @@ def setUp(self): "num_heads": 4, "hidden_dim": 8, "intermediate_dim": 32, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_nlp/models/bloom/bloom_causal_lm.py b/keras_nlp/models/bloom/bloom_causal_lm.py new file mode 100644 index 0000000000..914107f101 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm.py @@ -0,0 +1,307 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.causal_lm import CausalLM +from keras_nlp.utils.tensor_utils import any_equal + + +@keras_nlp_export("keras_nlp.models.BloomCausalLM") +class BloomCausalLM(CausalLM): + """An end-to-end BLOOM model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a BLOOM model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.BloomBackbone` instance. + preprocessor: A `keras_nlp.models.BloomCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + bloom_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.compile(sampler="top_k") + bloom_lm.generate("I want to say", max_length=30) + + bloom_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + bloom_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[1, 46, 15762, 632, 3, 3, 3, 3, 3]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0]] * 2), + } + + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset( + "bloom_560m_multi", + preprocessor=None, + ) + bloom_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset("bloom_560m_multi") + bloom_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + bloom_lm = keras_nlp.models.BloomCausalLM.from_preset( + "bloom_560m_multi", + preprocessor=None, + ) + bloom_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + features = [ + " airplane at airport", + " airplane airport", + ] + vocab = ["", "", "", ""] + vocab += ["!", "air", "Ä air", "plane", "Ä at", "port"] + vocab = dict([(token, i) for i, token in enumerate(vocab)]) + merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + merges += ["Ä ai r", "Ä a i", "pla ne"] + tokenizer = keras_nlp.models.BloomTokenizer(vocabulary=vocab, merges=merges) + preprocessor = keras_nlp.models.BloomCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.BloomBackbone( + vocabulary_size=tokenizer.vocabulary_size(), + num_layers=4, + num_heads=4, + hidden_dim=32, + intermediate_dim=128, + ) + bloom_lm = keras_nlp.models.BloomCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + bloom_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = BloomBackbone + preprocessor_cls = BloomCausalLMPreprocessor + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `BloomCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = self.backbone.embeddings_layer_norm(x) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_heads + head_dim = self.backbone.hidden_dim // num_heads + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } diff --git a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py index 01f3c88d30..b56e1a3ef0 100644 --- a/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence @@ -175,8 +175,8 @@ def generate_postprocess( # end markers). In the future we could make this configurable. padding_mask = ( padding_mask - & (token_ids != self.tokenizer.eos_token_id) - & (token_ids != self.tokenizer.bos_token_id) + & (token_ids != self.tokenizer.start_token_id) + & (token_ids != self.tokenizer.end_token_id) ) token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/bloom/bloom_causal_lm_test.py b/keras_nlp/models/bloom/bloom_causal_lm_test.py new file mode 100644 index 0000000000..70af6a2302 --- /dev/null +++ b/keras_nlp/models/bloom/bloom_causal_lm_test.py @@ -0,0 +1,188 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.bloom.bloom_backbone import BloomBackbone +from keras_nlp.models.bloom.bloom_causal_lm import BloomCausalLM +from keras_nlp.models.bloom.bloom_causal_lm_preprocessor import ( + BloomCausalLMPreprocessor, +) +from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer +from keras_nlp.tests.test_case import TestCase + + +class BloomCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["", "", "", ""] + self.vocab += ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.tokenizer = BloomTokenizer( + vocabulary=self.vocab, merges=self.merges + ) + self.preprocessor = BloomCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = BloomBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_heads=2, + hidden_dim=4, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = ( + [ + " airplane at airport", + " airplane airport", + ], + ) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + vocabulary_size = self.tokenizer.vocabulary_size() + self.run_task_test( + cls=BloomCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, vocabulary_size), + ) + + def test_generate(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + backbone = BloomBackbone.from_config( + {**self.backbone.get_config(), "dtype": "bfloat16"} + ) + causal_lm = BloomCausalLM( + backbone=backbone, preprocessor=self.preprocessor + ) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_mixed_float16(self): + backbone = BloomBackbone.from_config( + {**self.backbone.get_config(), "dtype": "mixed_float16"} + ) + causal_lm = BloomCausalLM( + backbone=backbone, preprocessor=self.preprocessor + ) + # String input. + prompt = "airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_early_stopping(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["airplane at", "airplane"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = BloomCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate("airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=BloomCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in BloomCausalLM.presets: + self.run_preset_test( + cls=BloomCausalLM, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/bloom/bloom_decoder.py b/keras_nlp/models/bloom/bloom_decoder.py index a0c62a2541..c6478e7270 100644 --- a/keras_nlp/models/bloom/bloom_decoder.py +++ b/keras_nlp/models/bloom/bloom_decoder.py @@ -110,8 +110,8 @@ def call( decoder_sequence, decoder_padding_mask=None, decoder_attention_mask=None, - attention_cache=None, - attention_cache_update_index=None, + cache=None, + cache_update_index=None, use_causal_mask=True, ): self_attention_mask = self._compute_attention_mask( @@ -119,8 +119,8 @@ def call( decoder_padding_mask=decoder_padding_mask, decoder_attention_mask=decoder_attention_mask, use_causal_mask=use_causal_mask, - attention_cache=attention_cache, - attention_cache_update_index=attention_cache_update_index, + cache=cache, + cache_update_index=cache_update_index, ) residual = decoder_sequence @@ -129,14 +129,14 @@ def call( attention_output = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, - cache=attention_cache, - cache_update_index=attention_cache_update_index, + cache=cache, + cache_update_index=cache_update_index, ) - if attention_cache is None: + if cache is None: x = attention_output else: - x, attention_cache = attention_output + x, cache = attention_output x = x + residual residual = x @@ -147,8 +147,8 @@ def call( x = self._dropout_layer(x) x = x + residual - if attention_cache is not None: - return x, attention_cache + if cache is not None: + return x, cache else: return x @@ -158,8 +158,8 @@ def _compute_attention_mask( decoder_padding_mask, decoder_attention_mask, use_causal_mask, - attention_cache, - attention_cache_update_index, + cache, + cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -167,18 +167,14 @@ def _compute_attention_mask( if use_causal_mask: batch_size = ops.shape(decoder_sequence)[0] input_length = output_length = ops.shape(decoder_sequence)[1] - if attention_cache is not None: - input_length = ops.shape(attention_cache)[2] + if cache is not None: + input_length = ops.shape(cache)[2] causal_mask = compute_causal_mask( batch_size, input_length, output_length, - ( - 0 - if attention_cache_update_index is None - else attention_cache_update_index - ), + (0 if cache_update_index is None else cache_update_index), ) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/bloom/bloom_preprocessor.py b/keras_nlp/models/bloom/bloom_preprocessor.py index 734c9f4bf8..dfaac0332d 100644 --- a/keras_nlp/models/bloom/bloom_preprocessor.py +++ b/keras_nlp/models/bloom/bloom_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.bloom.bloom_presets import backbone_presets from keras_nlp.models.bloom.bloom_tokenizer import BloomTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BloomPreprocessor") @@ -107,6 +104,8 @@ class BloomPreprocessor(Preprocessor): ``` """ + tokenizer_cls = BloomTokenizer + def __init__( self, tokenizer, @@ -126,8 +125,8 @@ def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer # assets have loaded when restoring a saved model. self.packer = StartEndPacker( - start_value=self.tokenizer.bos_token_id, - end_value=self.tokenizer.eos_token_id, + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, pad_value=self.tokenizer.pad_token_id, sequence_length=self.sequence_length, return_padding_mask=True, @@ -183,11 +182,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return BloomTokenizer diff --git a/keras_nlp/models/bloom/bloom_presets.py b/keras_nlp/models/bloom/bloom_presets.py index d3e9c780c0..134de5173d 100644 --- a/keras_nlp/models/bloom/bloom_presets.py +++ b/keras_nlp/models/bloom/bloom_presets.py @@ -17,14 +17,105 @@ "bloom_560m_multi": { "metadata": { "description": ( - "24-layer Bloom model. trained on 45 natural languages and " - "12 programming languages." + "24-layer Bloom model with hidden dimension of 1024. " + "trained on 45 natural languages and 12 programming languages." ), - "params": 816115712, + "params": 559214592, "official_name": "BLOOM", "path": "bloom", - "model_card": "https://huggingface.co/bigscience/bloom", + "model_card": "https://huggingface.co/bigscience/bloom-560m", }, - "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/1", + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/3", + }, + "bloom_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1065314304, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/1", + }, + "bloom_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1722408960, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/1", + }, + "bloom_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 3002557440, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/1", + }, + "bloomz_560m_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1024. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 559214592, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-560m", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/1", + }, + "bloomz_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1065314304, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/1", + }, + "bloomz_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1722408960, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/1", + }, + "bloomz_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 3002557440, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/1", }, } diff --git a/keras_nlp/models/bloom/bloom_tokenizer.py b/keras_nlp/models/bloom/bloom_tokenizer.py index cc3fcc2fc3..6c6097e4ce 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer.py +++ b/keras_nlp/models/bloom/bloom_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.bloom.bloom_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BloomTokenizer") @@ -74,16 +71,16 @@ def __init__( merges=None, **kwargs, ): - self.bos_token = "" - self.eos_token = "" + self.start_token = "" + self.end_token = "" self.pad_token = "" super().__init__( vocabulary=vocabulary, merges=merges, unsplittable_tokens=[ - self.bos_token, - self.eos_token, + self.start_token, + self.end_token, self.pad_token, ], **kwargs, @@ -94,7 +91,7 @@ def set_vocabulary_and_merges(self, vocabulary, merges): if vocabulary is not None: # Check for necessary special tokens. - for token in [self.bos_token, self.eos_token, self.pad_token]: + for token in [self.start_token, self.end_token, self.pad_token]: if token not in self.get_vocabulary(): raise ValueError( f"Cannot find token `'{token}'` in the provided " @@ -102,18 +99,14 @@ def set_vocabulary_and_merges(self, vocabulary, merges): "your `vocabulary` or use a pretrained `vocabulary` name." ) - self.bos_token_id = self.token_to_id(self.bos_token) - self.eos_token_id = self.token_to_id(self.eos_token) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) self.pad_token_id = self.token_to_id(self.pad_token) else: - self.bos_token_id = None - self.eos_token_id = None + self.start_token_id = None + self.end_token_id = None self.pad_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/bloom/bloom_tokenizer_test.py b/keras_nlp/models/bloom/bloom_tokenizer_test.py index 9ae9c0cc00..2fbebdf1ab 100644 --- a/keras_nlp/models/bloom/bloom_tokenizer_test.py +++ b/keras_nlp/models/bloom/bloom_tokenizer_test.py @@ -28,8 +28,8 @@ def setUp(self): self.merges += ["Ä ai r", "Ä a i", "pla ne"] self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} self.input_data = [ - "airplane at airport", - " airplane airport", + "airplane at airport", + " airplane airport", ] def test_tokenizer_basics(self): @@ -37,7 +37,7 @@ def test_tokenizer_basics(self): cls=BloomTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]], + expected_output=[[6, 1, 3, 4, 2, 5, 7, 8], [6, 2, 3, 2, 5, 7, 8]], ) def test_errors_missing_special_tokens(self): diff --git a/keras_nlp/models/generative_task.py b/keras_nlp/models/causal_lm.py similarity index 70% rename from keras_nlp/models/generative_task.py rename to keras_nlp/models/causal_lm.py index 9a461926e4..98867e9ad2 100644 --- a/keras_nlp/models/generative_task.py +++ b/keras_nlp/models/causal_lm.py @@ -13,10 +13,12 @@ # limitations under the License. import itertools +from functools import partial import tensorflow as tf import tree +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.backend import ops @@ -25,9 +27,46 @@ from keras_nlp.utils.tensor_utils import tensor_to_list -@keras.saving.register_keras_serializable(package="keras_nlp") -class GenerativeTask(Task): - """Base class for Generative Task models.""" +@keras_nlp_export("keras_nlp.models.CausalLM") +class CausalLM(Task): + """Base class for generative language modeling tasks. + + `CausalLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning. + + `CausalLM` tasks provide an additional, high-level `generate()` function + which can be used to auto-regressively sample a model token by token with a + string in, string out signature. The `compile()` method of all `CausalLM` + classes contains an additional `sampler` argument, which can be used to pass + a `keras_nlp.samplers.Sampler` to control how the predicted distribution + will be sampled. + + When calling `fit()`, the tokenized input will be predicted token-by-token + with a causal mask applied, which gives both a pre-training and supervised + fine-tuning setup for controlling inference-time generation. + + All `CausalLM` tasks include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. + + Example: + ```python + # Load a GPT2 backbone with pre-trained weights. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "gpt2_base_en", + ) + causal_lm.compile(sampler="top_k") + causal_lm.generate("Keras is a", max_length=64) + + # Load a Mistral instruction tuned checkpoint at bfloat16 precision. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "mistral_instruct_7b_en", + dtype="bfloat16", + ) + causal_lm.compile(sampler="greedy") + causal_lm.generate("Keras is a", max_length=64) + ``` + """ def compile( self, @@ -64,10 +103,10 @@ def make_generate_function(self): def wrapped_generate_function( inputs, - end_token_id=None, + stop_token_ids=None, ): with torch.no_grad(): - return self.generate_step(inputs, end_token_id) + return self.generate_step(inputs, stop_token_ids) self.generate_function = wrapped_generate_function elif config.backend() == "tensorflow" and not self.run_eagerly: @@ -80,8 +119,8 @@ def wrapped_generate_function( elif config.backend() == "jax" and not self.run_eagerly: import jax - @jax.jit - def compiled_generate_function(inputs, end_token_id, state): + @partial(jax.jit, static_argnames=["stop_token_ids"]) + def compiled_generate_function(inputs, stop_token_ids, state): ( sampler_variables, trainable_variables, @@ -94,39 +133,39 @@ def compiled_generate_function(inputs, end_token_id, state): ) with keras.StatelessScope(state_mapping=mapping) as scope: - outputs = self.generate_step(inputs, end_token_id) + outputs = self.generate_step(inputs, stop_token_ids) # Get updated sampler variables from the stateless scope. sampler_variables = [] for v in self._sampler.variables: new_v = scope.get_current_value(v) sampler_variables.append(new_v if new_v is not None else v) - state = ( - sampler_variables, - trainable_variables, - non_trainable_variables, - ) - return outputs, state + return outputs, sampler_variables def wrapped_generate_function( inputs, - end_token_id=None, + stop_token_ids=None, ): + if isinstance(stop_token_ids, list): + stop_token_ids = tuple(stop_token_ids) + # Create an explicit tuple of all variable state. state = ( self._sampler.variables, - self.trainable_variables, - self.non_trainable_variables, + # Use the explicit variable.value to preserve the + # sharding spec of distribution. + [v.value for v in self.trainable_variables], + [v.value for v in self.non_trainable_variables], ) inputs = tree.map_structure(ops.convert_to_tensor, inputs) - outputs, state = compiled_generate_function( + outputs, sampler_variables = compiled_generate_function( inputs, - end_token_id, + stop_token_ids, state, ) # Only assign the sampler variables (random seeds), as other # model variables should never be updated in generation. - for ref_v, v in zip(self._sampler.variables, state[0]): + for ref_v, v in zip(self._sampler.variables, sampler_variables): ref_v.assign(v) return outputs @@ -140,7 +179,7 @@ def _normalize_generate_inputs( ): """Normalize user input to the generate function. - This function coverts all inputs to tensors, adds a batch dimension if + This function converts all inputs to tensors, adds a batch dimension if necessary, and returns a iterable "dataset like" object (either an actual `tf.data.Dataset` or a list with a single batch element). """ @@ -209,6 +248,7 @@ def generate( self, inputs, max_length=None, + stop_token_ids=None, ): """Generate text given prompt `inputs`. @@ -237,15 +277,30 @@ def generate( `preprocessor`. If `preprocessor` is `None`, `inputs` should be should be padded to the desired maximum length and this argument will be ignored. + stop_token_ids: Optional. `None`, "auto", or tuple of token ids. Defaults + to "auto" which uses the `preprocessor.tokenizer.end_token_id`. + Not specifying a processor will produce an error. None stops + generation after generating `max_length` tokens. You may also + specify a list of token id's the model should stop on. Note that + sequences of tokens will each be interpreted as a stop token, + multi-token stop sequences are not supported. """ # Setup our three main passes. # 1. Optionally preprocessing strings to dense integer tensors. # 2. Generate new tokens via a compiled function on dense tensors. # 3. Optionally postprocess dense integer tensors back to string. generate_function = self.make_generate_function() - end_token_id = None - if self.preprocessor is not None: - end_token_id = self.preprocessor.tokenizer.end_token_id + + if self.preprocessor is None and stop_token_ids == "auto": + raise ValueError( + 'A `preprocessor` must be attached to the model if `stop_token_ids="auto"`. ' + "Currently `preprocessor=None`. To call `generate()` with preprocessing " + "detached, either pass `stop_tokens_ids=None` to always generate until " + "`max_length` or pass a tuple of token ids that should terminate generation " + "as `stop_tokens_ids`." + ) + elif stop_token_ids == "auto": + stop_token_ids = [self.preprocessor.tokenizer.end_token_id] def preprocess(x): return self.preprocessor.generate_preprocess( @@ -253,7 +308,7 @@ def preprocess(x): ) def generate(x): - return generate_function(x, end_token_id=end_token_id) + return generate_function(x, stop_token_ids=stop_token_ids) def postprocess(x): return self.preprocessor.generate_postprocess(x) diff --git a/keras_nlp/models/classifier.py b/keras_nlp/models/classifier.py new file mode 100644 index 0000000000..f6c6a88720 --- /dev/null +++ b/keras_nlp/models/classifier.py @@ -0,0 +1,51 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.task import Task + + +@keras_nlp_export("keras_nlp.models.Classifier") +class Classifier(Task): + """Base class for all classification tasks. + + `Classifier` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + sequence classification. `Classifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + All `Classifier` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + # Load a BERT classifier with pre-trained weights. + classifier = keras_nlp.models.Classifier.from_preset( + "bert_base_en", + num_classes=2, + ) + # Fine-tune on IMDb movie reviews (or any dataset). + imdb_train, imdb_test = tfds.load( + "imdb_reviews", + split=["train", "test"], + as_supervised=True, + batch_size=16, + ) + classifier.fit(imdb_train, validation_data=imdb_test) + # Predict two new examples. + classifier.predict(["What an amazing movie!", "A total waste of my time."]) + ``` + """ diff --git a/keras_nlp/models/deberta_v3/__init__.py b/keras_nlp/models/deberta_v3/__init__.py index ba0c2545e4..4a5f6df5cd 100644 --- a/keras_nlp/models/deberta_v3/__init__.py +++ b/keras_nlp/models/deberta_v3/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone +from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets +from keras_nlp.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (DebertaV3Backbone, DebertaV3Tokenizer)) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py index e7bd8ca20a..0029013098 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_backbone.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_backbone.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.deberta_v3.disentangled_attention_encoder import ( DisentangledAttentionEncoder, ) from keras_nlp.models.deberta_v3.relative_embedding import RelativeEmbedding -from keras_nlp.utils.python_utils import classproperty def deberta_kernel_initializer(stddev=0.02): @@ -178,6 +175,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) @@ -207,7 +205,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py index d6eea63601..e8cb7a60ed 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_classifier.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_classifier.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.deberta_v3.deberta_v3_backbone import DebertaV3Backbone from keras_nlp.models.deberta_v3.deberta_v3_backbone import ( deberta_kernel_initializer, @@ -23,13 +23,10 @@ from keras_nlp.models.deberta_v3.deberta_v3_preprocessor import ( DebertaV3Preprocessor, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Classifier") -class DebertaV3Classifier(Task): +class DebertaV3Classifier(Classifier): """An end-to-end DeBERTa model for classification tasks. This model attaches a classification head to a @@ -153,6 +150,9 @@ class DebertaV3Classifier(Task): ``` """ + backbone_cls = DebertaV3Backbone + preprocessor_cls = DebertaV3Preprocessor + def __init__( self, backbone, @@ -234,15 +234,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return DebertaV3Backbone - - @classproperty - def preprocessor_cls(cls): - return DebertaV3Preprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py index d050dde6c0..7bb613b96c 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,13 +23,11 @@ from keras_nlp.models.deberta_v3.deberta_v3_masked_lm_preprocessor import ( DebertaV3MaskedLMPreprocessor, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.DebertaV3MaskedLM") -class DebertaV3MaskedLM(Task): +class DebertaV3MaskedLM(MaskedLM): """An end-to-end DeBERTaV3 model for the masked language modeling task. This model will train DeBERTaV3 on a masked language modeling task. @@ -55,7 +52,7 @@ class DebertaV3MaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python @@ -98,6 +95,9 @@ class DebertaV3MaskedLM(Task): ``` """ + backbone_cls = DebertaV3Backbone + preprocessor_cls = DebertaV3MaskedLMPreprocessor + def __init__( self, backbone, @@ -138,15 +138,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return DebertaV3Backbone - - @classproperty - def preprocessor_cls(cls): - return DebertaV3MaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py index 88fa08fd70..67466ef948 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.models.deberta_v3.deberta_v3_tokenizer import DebertaV3Tokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Preprocessor") @@ -147,6 +144,8 @@ class DebertaV3Preprocessor(Preprocessor): ``` """ + tokenizer_cls = DebertaV3Tokenizer + def __init__( self, tokenizer, @@ -202,11 +201,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return DebertaV3Tokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py index e66c373e65..9c14e3b618 100644 --- a/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py +++ b/keras_nlp/models/deberta_v3/deberta_v3_tokenizer.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.deberta_v3.deberta_v3_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DebertaV3Tokenizer") @@ -151,7 +148,3 @@ def token_to_id(self, token): def detokenize(self, ids): ids = tf.ragged.boolean_mask(ids, tf.not_equal(ids, self.mask_token_id)) return super().detokenize(ids) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/__init__.py b/keras_nlp/models/distil_bert/__init__.py index ba0c2545e4..beb3742c9e 100644 --- a/keras_nlp/models/distil_bert/__init__.py +++ b/keras_nlp/models/distil_bert/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone +from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets +from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (DistilBertBackbone, DistilBertTokenizer)) diff --git a/keras_nlp/models/distil_bert/distil_bert_backbone.py b/keras_nlp/models/distil_bert/distil_bert_backbone.py index 1ae0840ea8..4374707f13 100644 --- a/keras_nlp/models/distil_bert/distil_bert_backbone.py +++ b/keras_nlp/models/distil_bert/distil_bert_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def distilbert_kernel_initializer(stddev=0.02): @@ -159,6 +156,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) @@ -186,7 +184,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_classifier.py b/keras_nlp/models/distil_bert/distil_bert_classifier.py index e82aaf2781..f816e40a1b 100644 --- a/keras_nlp/models/distil_bert/distil_bert_classifier.py +++ b/keras_nlp/models/distil_bert/distil_bert_classifier.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone from keras_nlp.models.distil_bert.distil_bert_backbone import ( distilbert_kernel_initializer, @@ -23,13 +23,10 @@ from keras_nlp.models.distil_bert.distil_bert_preprocessor import ( DistilBertPreprocessor, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertClassifier") -class DistilBertClassifier(Task): +class DistilBertClassifier(Classifier): """An end-to-end DistilBERT model for classification tasks. This model attaches a classification head to a @@ -140,6 +137,9 @@ class DistilBertClassifier(Task): ``` """ + backbone_cls = DistilBertBackbone + preprocessor_cls = DistilBertPreprocessor + def __init__( self, backbone, @@ -214,15 +214,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return DistilBertBackbone - - @classproperty - def preprocessor_cls(cls): - return DistilBertPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index fcf54e014d..80b4c17bb0 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,13 +23,11 @@ from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( DistilBertMaskedLMPreprocessor, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.DistilBertMaskedLM") -class DistilBertMaskedLM(Task): +class DistilBertMaskedLM(MaskedLM): """An end-to-end DistilBERT model for the masked language modeling task. This model will train DistilBERT on a masked language modeling task. @@ -55,7 +52,7 @@ class DistilBertMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python @@ -98,6 +95,9 @@ class DistilBertMaskedLM(Task): ``` """ + backbone_cls = DistilBertBackbone + preprocessor_cls = DistilBertMaskedLMPreprocessor + def __init__( self, backbone, @@ -140,15 +140,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return DistilBertBackbone - - @classproperty - def preprocessor_cls(cls): - return DistilBertMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py index 63f4e3637b..1163f7f029 100644 --- a/keras_nlp/models/distil_bert/distil_bert_preprocessor.py +++ b/keras_nlp/models/distil_bert/distil_bert_preprocessor.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( DistilBertTokenizer, ) @@ -27,7 +25,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertPreprocessor") @@ -118,6 +115,8 @@ class DistilBertPreprocessor(Preprocessor): ``` """ + tokenizer_cls = DistilBertTokenizer + def __init__( self, tokenizer, @@ -173,11 +172,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return DistilBertTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py index 4a18398a1e..c03792f761 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.tokenizers.word_piece_tokenizer import WordPieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.DistilBertTokenizer") @@ -46,6 +43,9 @@ class DistilBertTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -74,6 +74,7 @@ def __init__( self, vocabulary, lowercase=False, + special_tokens_in_strings=False, **kwargs, ): self.cls_token = "[CLS]" @@ -83,6 +84,13 @@ def __init__( super().__init__( vocabulary=vocabulary, lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, **kwargs, ) @@ -90,15 +98,6 @@ def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -109,6 +108,7 @@ def set_vocabulary(self, vocabulary): self.pad_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py index e4bfba41d3..42bfde39e5 100644 --- a/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py @@ -41,6 +41,16 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = DistilBertTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): DistilBertTokenizer(vocabulary=["a", "b", "c"]) diff --git a/keras_nlp/models/electra/__init__.py b/keras_nlp/models/electra/__init__.py index ba0c2545e4..0717b97a72 100644 --- a/keras_nlp/models/electra/__init__.py +++ b/keras_nlp/models/electra/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.electra.electra_backbone import ElectraBackbone +from keras_nlp.models.electra.electra_presets import backbone_presets +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (ElectraBackbone, ElectraTokenizer)) diff --git a/keras_nlp/models/electra/electra_backbone.py b/keras_nlp/models/electra/electra_backbone.py index 13be2d8eb8..3ee88de826 100644 --- a/keras_nlp/models/electra/electra_backbone.py +++ b/keras_nlp/models/electra/electra_backbone.py @@ -36,8 +36,9 @@ class ElectraBackbone(Backbone): or classification task networks. The default constructor gives a fully customizable, randomly initialized - Electra encoder with any number of layers, heads, and embedding - dimensions. + ELECTRA encoder with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the + `from_preset()` constructor. Disclaimer: Pre-trained models are provided on an "as is" basis, without warranties or conditions of any kind. The underlying model is provided by a @@ -63,13 +64,20 @@ class ElectraBackbone(Backbone): such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. - Examples: + Example: ```python input_data = { "token_ids": np.ones(shape=(1, 12), dtype="int32"), "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]]), "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), } + + # Pre-trained ELECTRA encoder. + model = keras_nlp.models.ElectraBackbone.from_preset( + "electra_base_discriminator_en" + ) + model(input_data) + # Randomly initialized Electra encoder backbone = keras_nlp.models.ElectraBackbone( vocabulary_size=1000, @@ -202,6 +210,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/electra/electra_backbone_test.py b/keras_nlp/models/electra/electra_backbone_test.py index 09e6c53344..c51bb579e1 100644 --- a/keras_nlp/models/electra/electra_backbone_test.py +++ b/keras_nlp/models/electra/electra_backbone_test.py @@ -54,3 +54,37 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=ElectraBackbone, + preset="electra_small_discriminator_uncased_en", + input_data={ + "token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"), + "segment_ids": ops.zeros((1, 4), dtype="int32"), + "padding_mask": ops.ones((1, 4), dtype="int32"), + }, + expected_output_shape={ + "sequence_output": (1, 4, 256), + "pooled_output": (1, 256), + }, + # The forward pass from a preset should be stable! + expected_partial_output={ + "sequence_output": ( + ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977]) + ), + "pooled_output": ( + ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822]) + ), + }, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraBackbone.presets: + self.run_preset_test( + cls=ElectraBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/electra/electra_preprocessor.py b/keras_nlp/models/electra/electra_preprocessor.py new file mode 100644 index 0000000000..2ee3e294d8 --- /dev/null +++ b/keras_nlp/models/electra/electra_preprocessor.py @@ -0,0 +1,153 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.multi_segment_packer import ( + MultiSegmentPacker, +) +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.ElectraPreprocessor") +class ElectraPreprocessor(Preprocessor): + """A ELECTRA preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`. + with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens. + 3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`, + that can be passed directly to a ELECTRA model. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.ElectraTokenizer` instance. + sequence_length: The length of the packed inputs. + truncate: string. The algorithm to truncate a list of batched segments + to fit within `sequence_length`. The value can be either + `round_robin` or `waterfall`: + - `"round_robin"`: Available space is assigned one token at a + time in a round-robin fashion to the inputs that still need + some, until the limit is reached. + - `"waterfall"`: The allocation of the budget is done using a + "waterfall" algorithm that allocates quota in a + left-to-right manner and fills up the buckets until we run + out of budget. It supports an arbitrary number of segments. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset( + "electra_base_discriminator_en" + ) + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] + vocab += ["The", "quick", "brown", "fox", "jumped", "."] + tokenizer = keras_nlp.models.ElectraTokenizer(vocabulary=vocab) + preprocessor = keras_nlp.models.ElectraPreprocessor(tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.ElectraPreprocessor.from_preset( + "electra_base_discriminator_en" + ) + + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + tokenizer_cls = ElectraTokenizer + + def __init__( + self, + tokenizer, + sequence_length=512, + truncate="round_robin", + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = MultiSegmentPacker( + start_value=self.tokenizer.cls_token_id, + end_value=self.tokenizer.sep_token_id, + pad_value=self.tokenizer.pad_token_id, + truncate=truncate, + sequence_length=sequence_length, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.packer.sequence_length, + "truncate": self.packer.truncate, + } + ) + return config + + def call(self, x, y=None, sample_weight=None): + x = convert_inputs_to_list_of_tensor_segments(x) + x = [self.tokenizer(segment) for segment in x] + token_ids, segment_ids = self.packer(x) + x = { + "token_ids": token_ids, + "segment_ids": segment_ids, + "padding_mask": token_ids != self.tokenizer.pad_token_id, + } + return pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_nlp/models/electra/electra_preprocessor_test.py b/keras_nlp/models/electra/electra_preprocessor_test.py new file mode 100644 index 0000000000..5dd48fe40a --- /dev/null +++ b/keras_nlp/models/electra/electra_preprocessor_test.py @@ -0,0 +1,67 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.electra.electra_preprocessor import ElectraPreprocessor +from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer +from keras_nlp.tests.test_case import TestCase + + +class ElectraPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + self.tokenizer = ElectraTokenizer(vocabulary=self.vocab) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["THE QUICK BROWN FOX."], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=ElectraPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[2, 5, 6, 7, 8, 1, 3, 0]], + "segment_ids": [[0, 0, 0, 0, 0, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = ElectraPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraPreprocessor.presets: + self.run_preset_test( + cls=ElectraPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/electra/electra_presets.py b/keras_nlp/models/electra/electra_presets.py new file mode 100644 index 0000000000..68d709b1fd --- /dev/null +++ b/keras_nlp/models/electra/electra_presets.py @@ -0,0 +1,95 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ELECTRA model preset configurations.""" + +backbone_presets = { + "electra_small_discriminator_uncased_en": { + "metadata": { + "description": ( + "12-layer small ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 13548800, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_small_discriminator_uncased_en/1", + }, + "electra_small_generator_uncased_en": { + "metadata": { + "description": ( + "12-layer small ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 13548800, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_small_generator_uncased_en/1", + }, + "electra_base_discriminator_uncased_en": { + "metadata": { + "description": ( + "12-layer base ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 109482240, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_base_discriminator_uncased_en/1", + }, + "electra_base_generator_uncased_en": { + "metadata": { + "description": ( + "12-layer base ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 33576960, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_base_generator_uncased_en/1", + }, + "electra_large_discriminator_uncased_en": { + "metadata": { + "description": ( + "24-layer large ELECTRA discriminator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 335141888, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_large_discriminator_uncased_en/1", + }, + "electra_large_generator_uncased_en": { + "metadata": { + "description": ( + "24-layer large ELECTRA generator model. All inputs are " + "lowercased. Trained on English Wikipedia + BooksCorpus." + ), + "params": 51065344, + "official_name": "ELECTRA", + "path": "electra", + "model_card": "https://github.com/google-research/electra", + }, + "kaggle_handle": "kaggle://keras/electra/keras/electra_large_generator_uncased_en/1", + }, +} diff --git a/keras_nlp/models/electra/electra_tokenizer.py b/keras_nlp/models/electra/electra_tokenizer.py index acd665c2a3..583b756165 100644 --- a/keras_nlp/models/electra/electra_tokenizer.py +++ b/keras_nlp/models/electra/electra_tokenizer.py @@ -36,6 +36,9 @@ class ElectraTokenizer(WordPieceTokenizer): plain text file containing a single word piece token per line. lowercase: If `True`, the input text will be first lowered before tokenization. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: ```python @@ -57,26 +60,34 @@ class ElectraTokenizer(WordPieceTokenizer): ``` """ - def __init__(self, vocabulary, lowercase=False, **kwargs): + def __init__( + self, + vocabulary, + lowercase=False, + special_tokens_in_strings=False, + **kwargs, + ): self.cls_token = "[CLS]" self.sep_token = "[SEP]" self.pad_token = "[PAD]" self.mask_token = "[MASK]" - super().__init__(vocabulary=vocabulary, lowercase=lowercase, **kwargs) + super().__init__( + vocabulary=vocabulary, + lowercase=lowercase, + special_tokens=[ + self.cls_token, + self.sep_token, + self.pad_token, + self.mask_token, + ], + special_tokens_in_strings=special_tokens_in_strings, + **kwargs, + ) def set_vocabulary(self, vocabulary): super().set_vocabulary(vocabulary) if vocabulary is not None: - # Check for necessary special tokens. - for token in [self.cls_token, self.pad_token, self.sep_token]: - if token not in self.vocabulary: - raise ValueError( - f"Cannot find token `'{token}'` in the provided " - f"`vocabulary`. Please provide `'{token}'` in your " - "`vocabulary` or use a pretrained `vocabulary` name." - ) - self.cls_token_id = self.token_to_id(self.cls_token) self.sep_token_id = self.token_to_id(self.sep_token) self.pad_token_id = self.token_to_id(self.pad_token) @@ -86,3 +97,8 @@ def set_vocabulary(self, vocabulary): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None + + def get_config(self): + config = super().get_config() + del config["special_tokens"] # Not configurable; set in __init__. + return config diff --git a/keras_nlp/models/electra/electra_tokenizer_test.py b/keras_nlp/models/electra/electra_tokenizer_test.py index 2e06fb900c..9126ddc8a7 100644 --- a/keras_nlp/models/electra/electra_tokenizer_test.py +++ b/keras_nlp/models/electra/electra_tokenizer_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from keras_nlp.models.electra.electra_tokenizer import ElectraTokenizer from keras_nlp.tests.test_case import TestCase @@ -37,6 +39,34 @@ def test_lowercase(self): output = tokenizer(self.input_data) self.assertAllEqual(output, [[9, 10, 11, 12], [9, 12]]) + def test_tokenizer_special_tokens(self): + input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"] + tokenizer = ElectraTokenizer( + **self.init_kwargs, special_tokens_in_strings=True + ) + output_data = tokenizer(input_data) + expected_output = [[2, 5, 4, 8, 3, 0]] + + self.assertAllEqual(output_data, expected_output) + def test_errors_missing_special_tokens(self): with self.assertRaises(ValueError): ElectraTokenizer(vocabulary=["a", "b", "c"]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=ElectraTokenizer, + preset="electra_small_discriminator_uncased_en", + input_data=["the quick brown fox."], + expected_output=[[1996, 4248, 2829, 4419, 1012]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in ElectraTokenizer.presets: + self.run_preset_test( + cls=ElectraTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/f_net/__init__.py b/keras_nlp/models/f_net/__init__.py index ba0c2545e4..7921ed6ca3 100644 --- a/keras_nlp/models/f_net/__init__.py +++ b/keras_nlp/models/f_net/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_presets import backbone_presets +from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (FNetBackbone, FNetTokenizer)) diff --git a/keras_nlp/models/f_net/f_net_backbone.py b/keras_nlp/models/f_net/f_net_backbone.py index 309f312a17..91034a4359 100644 --- a/keras_nlp/models/f_net/f_net_backbone.py +++ b/keras_nlp/models/f_net/f_net_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -20,9 +19,7 @@ from keras_nlp.layers.modeling.position_embedding import PositionEmbedding from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def f_net_kernel_initializer(stddev=0.02): @@ -206,6 +203,7 @@ def __init__( "sequence_output": sequence_output, "pooled_output": pooled_output, }, + dtype=dtype, **kwargs, ) @@ -233,7 +231,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index 512182d2cd..a5c0bf6525 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.f_net.f_net_backbone import FNetBackbone from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor -from keras_nlp.models.f_net.f_net_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetClassifier") -class FNetClassifier(Task): +class FNetClassifier(Classifier): """An end-to-end f_net model for classification tasks. This model attaches a classification head to a @@ -100,6 +97,9 @@ class FNetClassifier(Task): ``` """ + backbone_cls = FNetBackbone + preprocessor_cls = FNetPreprocessor + def __init__( self, backbone, @@ -162,15 +162,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return FNetBackbone - - @classproperty - def preprocessor_cls(cls): - return FNetPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index c715a70843..83c7e62719 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -22,13 +21,11 @@ from keras_nlp.models.f_net.f_net_masked_lm_preprocessor import ( FNetMaskedLMPreprocessor, ) -from keras_nlp.models.f_net.f_net_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.models.masked_lm import MaskedLM @keras_nlp_export("keras_nlp.models.FNetMaskedLM") -class FNetMaskedLM(Task): +class FNetMaskedLM(MaskedLM): """An end-to-end FNet model for the masked language modeling task. This model will train FNet on a masked language modeling task. @@ -51,7 +48,7 @@ class FNetMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string data. ```python @@ -95,6 +92,9 @@ class FNetMaskedLM(Task): ``` """ + backbone_cls = FNetBackbone + preprocessor_cls = FNetMaskedLMPreprocessor + def __init__( self, backbone, @@ -137,15 +137,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return FNetBackbone - - @classproperty - def preprocessor_cls(cls): - return FNetMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_preprocessor.py b/keras_nlp/models/f_net/f_net_preprocessor.py index b4cb5836bb..dfe1b71cfe 100644 --- a/keras_nlp/models/f_net/f_net_preprocessor.py +++ b/keras_nlp/models/f_net/f_net_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetPreprocessor") @@ -120,6 +117,8 @@ class FNetPreprocessor(Preprocessor): ``` """ + tokenizer_cls = FNetTokenizer + def __init__( self, tokenizer, @@ -175,11 +174,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return FNetTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_tokenizer.py b/keras_nlp/models/f_net/f_net_tokenizer.py index ae3f569b1d..12a055b16c 100644 --- a/keras_nlp/models/f_net/f_net_tokenizer.py +++ b/keras_nlp/models/f_net/f_net_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.FNetTokenizer") @@ -94,7 +91,3 @@ def set_proto(self, proto): self.sep_token_id = None self.pad_token_id = None self.mask_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/falcon/__init__.py b/keras_nlp/models/falcon/__init__.py new file mode 100644 index 0000000000..cfc0b821cb --- /dev/null +++ b/keras_nlp/models/falcon/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.models.falcon.falcon_presets import backbone_presets +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (FalconBackbone, FalconTokenizer)) diff --git a/keras_nlp/models/falcon/falcon_attention.py b/keras_nlp/models/falcon/falcon_attention.py new file mode 100644 index 0000000000..0358ade54b --- /dev/null +++ b/keras_nlp/models/falcon/falcon_attention.py @@ -0,0 +1,156 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class FalconAttention(keras.layers.Layer): + def __init__( + self, + num_heads, + attention_dropout_rate, + **kwargs, + ): + super().__init__(**kwargs) + self.num_heads = num_heads + self.attention_dropout_rate = attention_dropout_rate + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # m = model dim + # n = num attention heads + # h = head dim + # k = key/value length + + batch_size, seq_length, hidden_dim = inputs_shape + + self.head_dim = hidden_dim // self.num_heads + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.query_dense = keras.layers.EinsumDense( + equation="bqm,mnh->bqnh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="query_dense", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="key_dense", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + equation="bkm,mnh->bknh", + output_shape=(None, self.num_heads, self.head_dim), + bias_axes="nh", + dtype=self.dtype_policy, + name="value_dense", + ) + self.value_dense.build(inputs_shape) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.output_dense = keras.layers.Dense( + hidden_dim, + dtype=self.dtype_policy, + name="output_dense", + ) + self.output_dense.build(inputs_shape) + + self.softmax = keras.layers.Softmax(dtype="float32", name="softmax") + + self.built = True + + def call( + self, + inputs, + alibi, + attention_mask=None, + cache=None, + cache_update_index=None, + ): + batch_size, seq_length, hidden_dim = ops.shape(inputs) + + query = self.query_dense(inputs) + key = self.key_dense(inputs) + value = self.value_dense(inputs) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key) + value = ops.slice_update(value_cache, start, value) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + + attention_scores = ops.einsum("bqnh,bknh->bnqk", query, key) + attention_scores = ops.add(attention_scores, alibi) + attention_scores = ( + attention_scores * self.inv_norm_factor + ) # [batch_size, num_heads, query_length, kv_length] + attention_scores = self.softmax( + attention_scores, ops.expand_dims(attention_mask, 1) + ) + attention_scores = self.attention_dropout(attention_scores) + attention_output = ops.einsum( + "bnqk,bknh->bqnh", attention_scores, value + ) + attention_output = ops.reshape( + attention_output, + [batch_size, seq_length, self.num_heads * self.head_dim], + ) # [batch_size, query_length, hidden_dim] + + attention_output = self.output_dense(attention_output) + + if cache is not None: + return attention_output, cache + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "attention_dropout_rate": self.attention_dropout_rate, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone.py b/keras_nlp/models/falcon/falcon_backbone.py new file mode 100644 index 0000000000..5a3a0fccda --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone.py @@ -0,0 +1,161 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.falcon.falcon_transformer_decoder import ( + FalconTransformerDecoder, +) + + +@keras_nlp_export("keras_nlp.models.FalconBackbone") +class FalconBackbone(Backbone): + """The Falcon core architecure. + + This network implements a Transformer-based decoder-only network, + [Falcon](https://arxiv.org/abs/2306.01116). + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_attention_heads: int. The number of attention heads for each transformer. + The hidden size must be divisible by the number of attention heads. + hidden_dim: int. The dimensionality of the embeddings and hidden states. + intermediate_dim: int. The output dimension of the first Dense layer in + the MLP network of each transformer. + layer_norm_epsilon: float. Epsilon for the layer normalization layers in + the transformer decoder. + attention_dropout_rate: float. Dropout probability for the attention. + feedforward_dropout_rate: flaot. Dropout probability for the feedforward. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Falcon decoder. + # TODO: Update the preset. + model = keras_nlp.models.FalconBackbone.from_preset("falcon_preset") + model(input_data) + + # Randomly initialized Falcon decoder with a custom config. + model = keras_nlp.models.FalconBackbone( + vocabulary_size=10, + num_layers=2, + num_attention_heads=2, + hidden_dim=32, + intermediate_dim=32*4, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype="float32", + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_attention_heads, + hidden_dim, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + dtype=dtype, + name="token_embedding", + ) + + self.transformer_layers = [] + for i in range(num_layers): + layer = FalconTransformerDecoder( + num_attention_heads=num_attention_heads, + intermediate_dim=intermediate_dim, + attention_dropout_rate=attention_dropout_rate, + feedforward_dropout_rate=feedforward_dropout_rate, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + self.final_layernorm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_layernorm", + ) + + # === Functional Model === + token_ids = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + # Embed Tokens. + x = self.token_embedding(token_ids) + + # Apply successive transformer decoder blocks. + for transformer_layer in self.transformer_layers: + x = transformer_layer(inputs=x, decoder_padding_mask=padding_mask) + sequence_output = self.final_layernorm(x) + + super().__init__( + inputs={ + "token_ids": token_ids, + "padding_mask": padding_mask, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_nlp/models/falcon/falcon_backbone_test.py b/keras_nlp/models/falcon/falcon_backbone_test.py new file mode 100644 index 0000000000..140ce7e7bf --- /dev/null +++ b/keras_nlp/models/falcon/falcon_backbone_test.py @@ -0,0 +1,49 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.falcon.falcon_backbone import FalconBackbone +from keras_nlp.tests.test_case import TestCase + + +class FalconBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_attention_heads": 8, + "hidden_dim": 16, + "intermediate_dim": 32, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=FalconBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py new file mode 100644 index 0000000000..61afb9b5a7 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor.py @@ -0,0 +1,178 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.FalconCausalLMPreprocessor") +class FalconCausalLMPreprocessor(FalconPreprocessor): + """Falcon Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.FalconCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.FalconCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.FalconTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.FalconCausalLMPreprocessor.from_preset( + "falcon_refinedweb_1b_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`FalconCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + padding_mask = ops.convert_to_numpy(padding_mask) + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..5e812259e2 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_causal_lm_preprocessor_test.py @@ -0,0 +1,94 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.falcon.falcon_causal_lm_preprocessor import ( + FalconCausalLMPreprocessor, +) +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.tokenizer = FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=FalconCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = FalconCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "airplane at airport" + preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [6, 1, 3, 4, 2, 5, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "airplane at airport") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconCausalLMPreprocessor.presets: + self.run_preset_test( + cls=FalconCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_preprocessor.py b/keras_nlp/models/falcon/falcon_preprocessor.py new file mode 100644 index 0000000000..8a14f3c255 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_preprocessor.py @@ -0,0 +1,186 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.FalconPreprocessor") +class FalconPreprocessor(Preprocessor): + """Falcon preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.FalconBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `FalconPreprocessor` forces the input to have only one segment, as Falcon is + mainly used for generation tasks. For tasks having multi-segment inputs + like "glue/mnli", please use a model designed for classification purposes + such as BERT or RoBERTa. + + Args: + tokenizer: A `keras_nlp.models.FalconTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.FalconPreprocessor.from_preset("falcon_rw_1b") + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + features = ["a quick fox.", "a fox quick."] + vocab = {"<|endoftext|>": 0, "a": 4, "Ä quick": 5, "Ä fox": 6} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = keras_nlp.models.FalconTokenizer( + vocabulary=vocab, + merges=merges, + ) + preprocessor = keras_nlp.models.FalconPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.FalconPreprocessor.from_preset("falcon_rw_1b") + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + tokenizer_cls = FalconTokenizer + + def __init__( + self, + tokenizer, + sequence_length=2048, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Falcon requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Falcon " + "for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/models/falcon/falcon_preprocessor_test.py b/keras_nlp/models/falcon/falcon_preprocessor_test.py new file mode 100644 index 0000000000..7676062287 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_preprocessor_test.py @@ -0,0 +1,80 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from keras_nlp.models.falcon.falcon_preprocessor import FalconPreprocessor +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.tokenizer = FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["airplane at airport"] + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=FalconPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["airplane at airport"] * 4 + + preprocessor = FalconPreprocessor( + tokenizer=FalconTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ), + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "airplane at airport" + preprocessor = FalconPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 6]) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconPreprocessor.presets: + self.run_preset_test( + cls=FalconPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_presets.py b/keras_nlp/models/falcon/falcon_presets.py new file mode 100644 index 0000000000..b0bb6aa54e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_presets.py @@ -0,0 +1,30 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon model preset configurations.""" + +backbone_presets = { + "falcon_refinedweb_1b_en": { + "metadata": { + "description": ( + "24-layer Falcon model (Falcon with 1B parameters), trained on " + "350B tokens of RefinedWeb dataset." + ), + "params": 1311625216, + "official_name": "Falcon", + "path": "falcon", + "model_card": "https://huggingface.co/tiiuae/falcon-rw-1b", + }, + "kaggle_handle": "kaggle://keras/falcon/keras/falcon_refinedweb_1b_en/1", + }, +} diff --git a/keras_nlp/models/falcon/falcon_tokenizer.py b/keras_nlp/models/falcon/falcon_tokenizer.py new file mode 100644 index 0000000000..80d7334fe7 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer.py @@ -0,0 +1,110 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_nlp_export("keras_nlp.models.FalconTokenizer") +class FalconTokenizer(BytePairTokenizer): + """Falcon tokenizer based on BytePairTokenizer. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.BytePairTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by Falcon + models and provides a `from_preset()` method to automatically download + a matching vocabulary for a Falcon preset. + + This tokenizer does not provide truncation or padding of inputs. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + vocabulary: string or dict, maps token to integer ids. If it is a + string, it should be the file path to a json file. + merges: string or list, contains the merge rule. If it is a string, + it should be the file path to merge rules. The merge rule file + should have one merge rule per line. Every merge rule contains + merge entities separated by a space. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.FalconTokenizer.from_preset("falcon_refinedweb_1b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + vocab = {"<|endoftext|>": 0, "a": 4, "Ä quick": 5, "Ä fox": 6} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + tokenizer("a quick fox.") + ``` + """ + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Falcon uses the same start as end token, i.e., "<|endoftext|>". + self.end_token = self.start_token = "<|endoftext|>" + + super().__init__( + vocabulary=vocabulary, + merges=merges, + unsplittable_tokens=[self.end_token], + **kwargs, + ) + + def set_vocabulary_and_merges(self, vocabulary, merges): + super().set_vocabulary_and_merges(vocabulary, merges) + + if vocabulary is not None: + # Check for necessary special tokens. + if self.end_token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{self.end_token}'` in the provided " + f"`vocabulary`. Please provide `'{self.end_token}'` in " + "your `vocabulary` or use a pretrained `vocabulary` name." + ) + + self.end_token_id = self.token_to_id(self.end_token) + self.start_token_id = self.end_token_id + self.pad_token_id = 0 + else: + self.end_token_id = None + self.start_token_id = None + self.pad_token_id = None + + def get_config(self): + config = super().get_config() + # In the constructor, we pass the list of special tokens to the + # `unsplittable_tokens` arg of the superclass' constructor. Hence, we + # delete it from the config here. + del config["unsplittable_tokens"] + return config diff --git a/keras_nlp/models/falcon/falcon_tokenizer_test.py b/keras_nlp/models/falcon/falcon_tokenizer_test.py new file mode 100644 index 0000000000..735bcac4b6 --- /dev/null +++ b/keras_nlp/models/falcon/falcon_tokenizer_test.py @@ -0,0 +1,62 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.falcon.falcon_tokenizer import FalconTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FalconTokenizerTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ä air", "plane", "Ä at", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ä  a", "Ä  t", "Ä  i", "Ä  b", "a i", "p l", "n e"] + self.merges += ["Ä a t", "p o", "r t", "Ä t h", "ai r", "pl a", "po rt"] + self.merges += ["Ä ai r", "Ä a i", "pla ne"] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + " airplane at airport<|endoftext|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=FalconTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + FalconTokenizer(vocabulary=["a", "b", "c"], merges=[]) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=FalconTokenizer, + preset="falcon_refinedweb_1b_en", + input_data=["The quick brown fox."], + expected_output=[[464, 2068, 7586, 21831, 13]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in FalconTokenizer.presets: + self.run_preset_test( + cls=FalconTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/falcon/falcon_transformer_decoder.py b/keras_nlp/models/falcon/falcon_transformer_decoder.py new file mode 100644 index 0000000000..3b29cedd7e --- /dev/null +++ b/keras_nlp/models/falcon/falcon_transformer_decoder.py @@ -0,0 +1,254 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.falcon.falcon_attention import FalconAttention + + +class FalconTransformerDecoder(keras.layers.Layer): + def __init__( + self, + num_attention_heads, + intermediate_dim, + layer_norm_epsilon=1e-5, + attention_dropout_rate=0, + feedforward_dropout_rate=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_attention_heads = num_attention_heads + self.intermediate_dim = intermediate_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.attention_dropout_rate = attention_dropout_rate + self.feedforward_dropout_rate = feedforward_dropout_rate + + def build(self, decoder_sequence_shape): + self.hidden_dim = decoder_sequence_shape[-1] + self.input_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="input_layernorm", + ) + self.input_layernorm.build(decoder_sequence_shape) + + # Attention layers. + self.key_dim = self.hidden_dim // self.num_attention_heads + self.attention_layer = FalconAttention( + num_heads=self.num_attention_heads, + attention_dropout_rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention", + ) + self.attention_layer.build( + decoder_sequence_shape, + ) + + self.attention_dropout = keras.layers.Dropout( + rate=self.attention_dropout_rate, + dtype=self.dtype_policy, + name="attention_dropout", + ) + + self.post_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="post_attention_layernorm", + ) + self.post_attention_layernorm.build(decoder_sequence_shape) + + # Feedforward layers. + # TODO: use_bias should be an argument to the transformer to support + # other sizes of models, e.g. 7B, that don't use bias. + self.dense_h_to_4h = keras.layers.Dense( + self.intermediate_dim, + activation=keras.activations.gelu, + use_bias=True, + dtype=self.dtype_policy, + name="dense_h_to_4h", + ) + self.dense_h_to_4h.build(decoder_sequence_shape) + + self.dense_4h_to_h = keras.layers.Dense( + self.hidden_dim, + use_bias=True, + dtype=self.dtype_policy, + name="dense_4h_to_h", + ) + self.dense_4h_to_h.build( + ( + decoder_sequence_shape[0], + decoder_sequence_shape[1], + self.intermediate_dim, + ) + ) + + self.feedforward_dropout = keras.layers.Dropout( + rate=self.feedforward_dropout_rate, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + + self.built = True + + def call( + self, + inputs, + decoder_padding_mask=None, + decoder_attention_mask=None, + attention_cache=None, + attention_cache_update_index=None, + training=None, + ): + attention_mask = self._compute_attention_mask( + decoder_sequence=inputs, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + attention_cache=attention_cache, + attention_cache_update_index=attention_cache_update_index, + ) + + residual = inputs + + x = self.input_layernorm(inputs) + + alibi = self._build_alibi_tensor( + self.num_attention_heads, decoder_padding_mask + ) + + # Attention block. + attention_output = self.attention_layer( + inputs=x, + alibi=alibi, + attention_mask=attention_mask, + cache=attention_cache, + cache_update_index=attention_cache_update_index, + ) + + if attention_cache is None: + x = attention_output + else: + x, attention_cache = attention_output + + x = self.attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self.post_attention_layernorm(x) + + x = self.dense_h_to_4h(x) + x = self.dense_4h_to_h(x) + + x = self.feedforward_dropout(x, training=training) + + x = x + residual + + if attention_cache is not None: + return x, attention_cache + else: + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_attention_heads": self.num_attention_heads, + "intermediate_dim": self.intermediate_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "attention_dropout_rate": self.attention_dropout_rate, + "feedforward_dropout_rate": self.feedforward_dropout_rate, + } + ) + return config + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def _compute_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + attention_cache=None, + attention_cache_update_index=None, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if attention_cache is not None: + input_length = ops.shape(attention_cache)[2] + + causal_mask = compute_causal_mask( + batch_size, + input_length, + output_length, + ( + 0 + if attention_cache_update_index is None + else attention_cache_update_index + ), + ) + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def _build_alibi_tensor(self, num_heads, attention_mask): + batch_size, seq_length = attention_mask.shape + slopes = ops.convert_to_tensor( + self._get_slopes(num_heads), + dtype=self.compute_dtype, + ) # num_heads + arange_tensor = ( + ( + ops.cast(ops.cumsum(attention_mask, axis=-1) - 1, dtype="int32") + * attention_mask + ) + )[:, None, :] + alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype) + alibi = ops.expand_dims( + alibi, 0 + ) # [None, batch_size, num_heads, seq_length] + return ops.transpose(alibi, [1, 2, 0, 3]) + + def _get_slopes(self, num_heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + self._get_slopes(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + ) diff --git a/keras_nlp/models/gemma/__init__.py b/keras_nlp/models/gemma/__init__.py new file mode 100644 index 0000000000..b390926a21 --- /dev/null +++ b/keras_nlp/models/gemma/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_presets import backbone_presets +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (GemmaBackbone, GemmaTokenizer)) diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py new file mode 100644 index 0000000000..4b391264a2 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -0,0 +1,188 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_nlp.utils.keras_utils import clone_initializer + + +class CachedGemmaAttention(keras.layers.Layer): + """A cached grouped query attention layer.""" + + def __init__( + self, + head_dim, + num_query_heads, + num_key_value_heads, + kernel_initializer="glorot_uniform", + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.dropout = dropout + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + self.num_key_value_groups = num_query_heads // num_key_value_heads + + def build(self, inputs_shape): + self.hidden_dim = inputs_shape[-1] + + self.query_dense = keras.layers.EinsumDense( + "btd,ndh->btnh", + output_shape=(None, self.num_query_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self.query_dense.build(inputs_shape) + + self.key_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self.key_dense.build(inputs_shape) + + self.value_dense = keras.layers.EinsumDense( + "bsd,kdh->bskh", + output_shape=(None, self.num_key_value_heads, self.head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self.value_dense.build(inputs_shape) + + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + self.output_dense = keras.layers.EinsumDense( + equation="btnh,nhd->btd", + output_shape=(None, self.hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self.output_dense.build( + (None, None, self.num_query_heads, self.head_dim) + ) + self.softmax = keras.layers.Softmax(dtype="float32") + + self.rope_layer = RotaryEmbedding( + max_wavelength=10_000.0, dtype=self.dtype_policy + ) + + self.built = True + + def _apply_rope(self, x, start_index): + """Rope rotate q or k.""" + x = self.rope_layer(x, start_index=start_index) + # Gemma uses a different layout for positional embeddings. + # The transformation below ensures the embeddings are numerically + # equivalent to the original gemma implementation. + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) + ) + return x + + def _compute_attention( + self, + q, + k, + v, + attention_mask, + training=False, + ): + query_normalization = 1 / np.sqrt(self.head_dim) + + q *= ops.cast(query_normalization, dtype=q.dtype) + q_shape = ops.shape(q) + q = ops.reshape( + q, + ( + *q_shape[:-2], + self.num_key_value_heads, + self.num_query_heads // self.num_key_value_heads, + q_shape[-1], + ), + ) + b, q_len, _, _, h = ops.shape(q) + + attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k) + attention_mask = attention_mask[:, None, None, :, :] + orig_dtype = attention_logits.dtype + attention_softmax = self.softmax(attention_logits, mask=attention_mask) + attention_softmax = ops.cast(attention_softmax, orig_dtype) + + if self.dropout: + attention_softmax = self.dropout_layer( + attention_softmax, training=training + ) + + results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) + return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + + def call( + self, + x, + attention_mask=None, + cache=None, + cache_update_index=0, + training=False, + ): + query = self.query_dense(x) + query = self._apply_rope(query, cache_update_index) + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + key_update = self.key_dense(x) + key_update = self._apply_rope(key_update, cache_update_index) + value_update = self.value_dense(x) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + key = self.key_dense(x) + key = self._apply_rope(key, cache_update_index) + value = self.value_dense(x) + + attention_vec = self._compute_attention( + query, key, value, attention_mask, training=training + ) + + # Wipe attn vec if there are no attended tokens. + no_attended_tokens = ops.all( + ops.equal(attention_mask, 0), axis=-1, keepdims=True + )[..., None] + attention_vec = ops.where( + no_attended_tokens, ops.zeros_like(attention_vec), attention_vec + ) + + attention_output = self.output_dense(attention_vec) + + if cache is not None: + return attention_output, cache + return attention_output diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py new file mode 100644 index 0000000000..a7973f9dec --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -0,0 +1,276 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_nlp.models.backbone import Backbone +from keras_nlp.models.gemma.gemma_decoder_block import GemmaDecoderBlock +from keras_nlp.models.gemma.rms_normalization import RMSNormalization + + +@keras_nlp_export("keras_nlp.models.GemmaBackbone") +class GemmaBackbone(Backbone): + """Gemma core network with hyperparameters. + + This backbone implements the base Transformer network for the Gemma model. + It includes the embedding lookups and transformer layers. This backbone + will output the final hidden states for each token, not generative + predictions over the vocabulary space. For a higher-level object for text + generation, see `keras_nlp.models.GemmaCausalLM`. + + The default constructor gives a fully customizable, randomly initialized + Gemma model with any number of layers, heads, and embedding dimensions. To + load preset architectures and weights, use the `from_preset` constructor. + + Args: + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers. + num_query_heads: int. The number of heads for the query projections in + the attention layer. + num_key_value_heads: int. The number of heads for the key and value + projections in the attention layer. + hidden_dim: int. The size of the transformer hidden state at the end + of each transformer layer. + intermediate_dim: int. The output dimension of the first Dense layer in + a two-layer feedforward network for each transformer. + head_dim: int. The size of each attention head. + layer_norm_epsilon: float. The epsilon value user for every layer norm + in the transformer model. + dropout: float. Dropout probability for the Transformer encoder. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example: + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Gemma decoder. + model = keras_nlp.models.GemmaBackbone.from_preset("gemma_2b_en") + model(input_data) + + # Randomly initialized Gemma decoder with custom config. + model = keras_nlp.models.GemmaBackbone( + vocabulary_size=50257, + num_layers=12, + num_query_heads=12, + num_key_value_heads=1, + hidden_dim=768, + intermediate_dim=3072, + head_dim=64, + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + num_key_value_heads, + hidden_dim, + intermediate_dim, + head_dim, + layer_norm_epsilon=1e-6, + dropout=0, + dtype=None, + **kwargs, + ): + if not config.keras_3(): + raise ValueError( + "`GemmaBackbone` requires Keras 3. Run `pip install -U keras` " + "upgrade your Keras version, or see https://keras.io/getting_started/ " + "for more info on Keras versions and installation." + ) + + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=True, + embeddings_initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="untruncated_normal", + seed=None, + ), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = GemmaDecoderBlock( + intermediate_dim=intermediate_dim, + hidden_dim=hidden_dim, + num_query_heads=num_query_heads, + head_dim=head_dim, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=dtype, + name=f"decoder_block_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = RMSNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="final_normalization", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="float32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="float32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the gemma + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model parallel + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), axis_names=('batch', 'model'), + devices=keras.distribution.list_devices()) + layout_map = GemmaBackbone.get_layout_map( + mesh, model_parallel_dim_name="model") + + distribution = keras.distribution.ModelParallel( + mesh, layout_map, batch_dim_name='batch') + with distribution.scope(): + gemma_model = keras_nlp.models.GemmaCausalLM.from_preset() + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + of all the model weights. + """ + # The weight path and shape of the Gemma backbone is like below (for 2G) + # token_embedding/embeddings, (256128, 2048), 524550144 + # repeat block for decoder + # ... + # decoder_block_17/pre_attention_norm/scale, (2048,), 2048 + # decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304 + # decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304 + # decoder_block_17/pre_ffw_norm/scale, (2048,), 2048 + # decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432 + # decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432 + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected `keras.distribution.Device`," + f" got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( + model_dim, + data_dim, + None, + ) + layout_map["decoder_block.*attention_output.*kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) + + return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py new file mode 100644 index 0000000000..7b02de2b7a --- /dev/null +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -0,0 +1,139 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 256128, + "num_layers": 2, + "num_query_heads": 4, + "num_key_value_heads": 4, + "hidden_dim": 128, + "intermediate_dim": 256, + "head_dim": 128, + "layer_norm_epsilon": 1e-6, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 128), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_smallest_preset(self): + # TODO: Fails with OOM on current GPU CI + self.run_preset_test( + cls=GemmaBackbone, + preset="gemma_2b_en", + input_data={ + "token_ids": ops.array([[651, 4320, 8426, 25341, 235265]]), + "padding_mask": ops.ones((1, 5), dtype="int32"), + }, + expected_output_shape=(1, 5, 2048), + # The forward pass from a preset should be stable! + expected_partial_output=ops.array( + [1.073359, 0.262374, 0.170238, 0.605402, 2.336161] + ), + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaBackbone.presets: + self.run_preset_test( + cls=GemmaBackbone, + preset=preset, + input_data=self.input_data, + ) + + def test_architecture_characteristics(self): + model = GemmaBackbone(**self.init_kwargs) + self.assertEqual(model.count_params(), 33407616) + self.assertEqual(len(model.layers), 6) + + def test_distribution(self): + if keras.backend.backend() != "jax": + return + devices = keras.distribution.list_devices("CPU") + if len(devices) == 1: + # Need more than 1 device for distribution testing. + return + device_mesh = keras.distribution.DeviceMesh( + shape=(1, len(devices)), + axis_names=("batch", "model"), + devices=devices, + ) + + layout_map = GemmaBackbone.get_layout_map(device_mesh) + distribution = keras.distribution.ModelParallel(device_mesh, layout_map) + with distribution.scope(): + model = GemmaBackbone(**self.init_kwargs) + + for w in model.weights: + if "token_embedding/embeddings" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) + if "attention/query/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "attention/key/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "attention/value/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch", None) + ) + if "attention/attention_output/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", None, "batch") + ) + if "ffw_gating/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "ffw_gating_2/kernel" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) + if "ffw_linearl" in w.path: + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm.py b/keras_nlp/models/gemma/gemma_causal_lm.py new file mode 100644 index 0000000000..34b0a43126 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm.py @@ -0,0 +1,438 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.causal_lm import CausalLM +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.utils.tensor_utils import any_equal + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLM") +class GemmaCausalLM(CausalLM): + """An end-to-end Gemma model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a Gemma model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"greedy"` sampling will be used. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to string inputs during + `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default + when creating the model with `from_preset()`. + + Args: + backbone: A `keras_nlp.models.GemmaBackbone` instance. + preprocessor: A `keras_nlp.models.GemmaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + + Examples: + + Use `generate()` to do text generation. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.generate("I want to say", max_length=30) + + # Generate with batched prompts. + gemma_lm.generate(["This is a", "Where are you"], max_length=30) + ``` + + Compile the `generate()` function with a custom sampler. + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.compile(sampler="top_k") + gemma_lm.generate("I want to say", max_length=30) + + gemma_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) + gemma_lm.generate("I want to say", max_length=30) + ``` + + Use `generate()` without preprocessing. + ```python + prompt = { + # Token ids for " Keras is". + "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2), + # Use `"padding_mask"` to indicate values that should not be overridden. + "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2), + } + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.generate(prompt) + ``` + + Call `fit()` on a single batch. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` with LoRA fine-tuning enabled. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en") + gemma.backbone.enable_lora(rank=4) + gemma_lm.fit(x=features, batch_size=2) + ``` + + Call `fit()` without preprocessing. + ```python + x = { + # Token ids for " Keras is deep learning library" + "token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2), + } + y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2) + sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2) + + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en", + preprocessor=None, + ) + gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2) + ``` + + Custom backbone and vocabulary. + ```python + tokenizer = keras_nlp.models.GemmaTokenizer( + proto="proto.spm", + ) + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor( + tokenizer=tokenizer, + sequence_length=128, + ) + backbone = keras_nlp.models.GemmaBackbone( + vocabulary_size=30552, + num_layers=4, + num_heads=4, + hidden_dim=256, + intermediate_dim=512, + max_sequence_length=128, + ) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=backbone, + preprocessor=preprocessor, + ) + gemma_lm.fit(x=features, batch_size=2) + ``` + """ + + backbone_cls = GemmaBackbone + preprocessor_cls = GemmaCausalLMPreprocessor + + def __init__( + self, + backbone, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + sampler="greedy", + jit_compile=True, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `GemmaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs in the + whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype) + # Each decoder layer has a cache; we update them separately. + caches = [] + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = transformer_layer( + x, + cache=current_cache, + cache_update_index=cache_update_index, + ) + caches.append(next_cache) + cache = ops.stack(caches, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.head_dim + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of `stop_token_ids` locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GemmaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the GemmaBackbone and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset( + "gemma_2b_en" + ) + generations = gemma_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gemma_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gemma_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + x = token_embeddings * ops.cast( + ops.sqrt(self.backbone.hidden_dim), dtype=self.compute_dtype + ) + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py new file mode 100644 index 0000000000..ca6b826abc --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor.py @@ -0,0 +1,170 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.GemmaCausalLMPreprocessor") +class GemmaCausalLMPreprocessor(GemmaPreprocessor): + """Gemma Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.GemmaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.GemmaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Apply tokenization to a `tf.data.Dataset`. + features = tf.constant(["The quick brown fox.", "Call me Ishmael."]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Prepare tokens for generation (no end token). + preprocessor.generate_preprocess(["The quick brown fox jumped."]) + + # Map generation outputs back to strings. + preprocessor.generate_postprocess({ + 'token_ids': np.array([[2, 714, 4320, 8426, 25341, 32292, 235265, 0]]), + 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]), + }) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`GemmaCausalLMPreprocessor` generates `y` and `sample_weight` " + "based on your input data, but your data already contains `y` " + "or `sample_weight`. Your `y` and `sample_weight` will be " + "ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess(self, x): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + if not self.built: + self.build(None) + + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + token_ids = ops.convert_to_numpy(token_ids) + mask = ops.convert_to_numpy(padding_mask) + # Also strip any special tokens during detokenization (e.g. the start + # and end markers). In the future we could make this configurable. + mask = mask & (token_ids != self.tokenizer.start_token_id) + mask = mask & (token_ids != self.tokenizer.pad_token_id) + mask = mask & (token_ids != self.tokenizer.end_token_id) + token_ids = tf.ragged.boolean_mask(token_ids, mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..c3305afe91 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_preprocessor_test.py @@ -0,0 +1,93 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[4, 9, 5, 7, 2, 0, 0, 0]], # Labels shifted. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Zero out unlabeled examples. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = GemmaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[9, 5, 7, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 5, 7, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 4, 9, 5, 7, 2, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0], + } + preprocessor = GemmaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=GemmaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_causal_lm_test.py b/keras_nlp/models/gemma/gemma_causal_lm_test.py new file mode 100644 index 0000000000..4a47d162ef --- /dev/null +++ b/keras_nlp/models/gemma/gemma_causal_lm_test.py @@ -0,0 +1,267 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_nlp.models.gemma.gemma_causal_lm_preprocessor import ( + GemmaCausalLMPreprocessor, +) +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaCausalLMTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.preprocessor = GemmaCausalLMPreprocessor( + self.tokenizer, + sequence_length=8, + ) + self.backbone = GemmaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=2, + num_key_value_heads=1, + hidden_dim=4, + intermediate_dim=8, + head_dim=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the quick brown fox"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 11), + ) + + def test_generate(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + + def test_generate_with_bfloat16(self): + original_floatx = keras.config.floatx() + keras.config.set_floatx("float16") + try: + causal_lm = GemmaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate("the quick brown fox") + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :4], + prompt_ids["token_ids"][:, :4], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :4], + prompt_ids["padding_mask"][:, :4], + ) + finally: + # Restore floatx to the original value to prevent impact on other + # tests even if there is an exception. + keras.config.set_floatx(original_floatx) + + def test_early_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_multitoken_stopping(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the quick"] + + output = causal_lm.generate(prompt, stop_token_ids=(3,)) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = GemmaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.kaggle_key_required + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=GemmaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaCausalLM.presets: + self.run_preset_test( + cls=GemmaCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = keras.ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = GemmaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 11) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/gemma/gemma_decoder_block.py b/keras_nlp/models/gemma/gemma_decoder_block.py new file mode 100644 index 0000000000..0a91655fc4 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_decoder_block.py @@ -0,0 +1,189 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_nlp.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_nlp.models.gemma.gemma_attention import CachedGemmaAttention +from keras_nlp.models.gemma.rms_normalization import RMSNormalization + + +class GemmaDecoderBlock(keras.layers.Layer): + def __init__( + self, + hidden_dim, + intermediate_dim, + head_dim, + num_query_heads, + num_key_value_heads, + layer_norm_epsilon=1e-6, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + self.pre_attention_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_attention_norm", + ) + + self.attention = CachedGemmaAttention( + head_dim=head_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + dropout=dropout, + dtype=self.dtype_policy, + name="attention", + ) + + if self.dropout > 0: + self.attention_dropout = keras.layers.Dropout(rate=dropout) + self.feedforward_dropout = keras.layers.Dropout(rate=dropout) + + self.pre_ffw_norm = RMSNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="pre_ffw_norm", + ) + + self.gating_ffw = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating", + ) + + self.gating_ffw_2 = keras.layers.EinsumDense( + equation="btd,df->btf", + output_shape=(None, self.intermediate_dim // 2), + dtype=self.dtype_policy, + name="ffw_gating_2", + ) + + self.ffw_linear = keras.layers.EinsumDense( + equation="btf,fd->btd", + output_shape=(None, self.hidden_dim), + dtype=self.dtype_policy, + name="ffw_linear", + ) + + def build(self, input_shape): + self.pre_attention_norm.build(input_shape) + self.attention.build(input_shape) + + shape = input_shape + self.pre_ffw_norm.build(shape) + self.gating_ffw.build(shape) + self.gating_ffw_2.build(shape) + + shape = self.gating_ffw.compute_output_shape(shape) + self.ffw_linear.build(shape) + self.built = True + + def compute_output_shape(self, input_shape): + # Isometric + return input_shape + + def _compute_attention_mask( + self, x, padding_mask, cache, cache_update_index + ): + decoder_mask = merge_padding_and_attention_mask( + inputs=x, padding_mask=padding_mask, attention_mask=None + ) + batch_size = ops.shape(x)[0] + input_length = output_length = ops.shape(x)[1] + if cache is not None: + input_length = ops.shape(cache)[2] + + causal_mask = compute_causal_mask( + batch_size=batch_size, + input_length=input_length, + output_length=output_length, + cache_index=cache_update_index, + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def call( + self, + x, + padding_mask=None, + cache=None, + cache_update_index=0, + ): + normalized_x = self.pre_attention_norm(x) + attention_mask = self._compute_attention_mask( + normalized_x, padding_mask, cache, cache_update_index + ) + if cache is not None: + attention, new_cache = self.attention( + normalized_x, + attention_mask=attention_mask, + cache=cache, + cache_update_index=cache_update_index, + ) + else: + attention = self.attention( + normalized_x, + attention_mask=attention_mask, + ) + + if self.dropout: + attention = self.attention_dropout(attention) + + attention_x = x + attention + normalized_x = self.pre_ffw_norm(attention_x) + + x1 = self.gating_ffw(normalized_x) + x2 = self.gating_ffw_2(normalized_x) + x = keras.activations.gelu(x1, approximate=True) * x2 + x = self.ffw_linear(x) + + x = x + attention_x + + if cache is not None: + return x, new_cache + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "head_dim": self.head_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_nlp/models/gemma/gemma_lora_test.py b/keras_nlp/models/gemma/gemma_lora_test.py new file mode 100644 index 0000000000..1cbbdfa67f --- /dev/null +++ b/keras_nlp/models/gemma/gemma_lora_test.py @@ -0,0 +1,102 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest + +from keras_nlp.models.gemma.gemma_backbone import GemmaBackbone +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaLoraTest(TestCase): + def setUp(self): + self._init_kwargs = { + "vocabulary_size": 50, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 2, + "hidden_dim": 32, + "intermediate_dim": 16, + "head_dim": 16, + "layer_norm_epsilon": 1e-6, + } + + def test_lora_fine_tuning(self): + # Set up backbone and preprocessor. + backbone = GemmaBackbone(**self._init_kwargs) + backbone.enable_lora(4) + # 4 layers, 2 weights per layer + self.assertLen(backbone.trainable_weights, 4 * 2) + self.assertLen(backbone.non_trainable_weights, 20) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + + # Test fine-tuning + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + # Test saving and reloading. + temp_filepath = os.path.join( + self.get_temp_dir(), "lora_model.weights.h5" + ) + backbone.save_weights(temp_filepath) + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(temp_filepath) + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + def test_lora_saving_and_reloading(self): + backbone = GemmaBackbone(**self._init_kwargs) + initial_model_filepath = os.path.join( + self.get_temp_dir(), "base.weights.h5" + ) + backbone.save_weights(initial_model_filepath) + + backbone.enable_lora(4) + input_data = { + "token_ids": np.ones((2, 5), dtype="int32"), + "padding_mask": np.ones((2, 5), dtype="int32"), + } + targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"])) + backbone.compile(optimizer="sgd", loss="mse") + backbone.fit(input_data, targets, epochs=1) + + lora_filepath = os.path.join(self.get_temp_dir(), "lora_model.lora.h5") + backbone.save_lora_weights(lora_filepath) + + # New backbone with same initial weights + new_backbone = GemmaBackbone(**self._init_kwargs) + new_backbone.load_weights(initial_model_filepath) + new_backbone.enable_lora(4) + new_backbone.load_lora_weights(lora_filepath) + + ref_out = backbone(input_data) + new_out = new_backbone(input_data) + self.assertAllClose(ref_out, new_out) + + # Test exceptions + backbone = GemmaBackbone(**self._init_kwargs) + with self.assertRaisesRegex(ValueError, "no lora-enabled layers"): + backbone.save_lora_weights(lora_filepath) + backbone.enable_lora(5) + with self.assertRaisesRegex(ValueError, "ranks must match"): + backbone.load_lora_weights(lora_filepath) + with self.assertRaisesRegex(ValueError, "filename must end in"): + backbone.save_lora_weights("bad_filepath") diff --git a/keras_nlp/models/gemma/gemma_preprocessor.py b/keras_nlp/models/gemma/gemma_preprocessor.py new file mode 100644 index 0000000000..86db9a4e81 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor.py @@ -0,0 +1,190 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.GemmaPreprocessor") +class GemmaPreprocessor(Preprocessor): + """Gemma preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do 2 things: + + - Tokenize the inputs using the `tokenizer`. + - Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can + be passed directly to a `keras_nlp.models.GemmaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + The call method of this layer accepts three arguments, `x`, `y`, and + `sample_weight`. `x` can be a python string or tensor representing a single + segment, a list of python strings representing a batch of single segments, + or a list of tensors representing multiple segments to be packed together. + `y` and `sample_weight` are both optional, can have any format, and will be + passed through unaltered. + + `GemmaPreprocessor` expects the input to have only one segment, as Gemma is + mainly used for generation tasks. For tasks having multi-segment inputs + please combine inputs into a single string input before passing to the + preprocessor layer. + + Args: + tokenizer: A `keras_nlp.models.GemmaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the layer on data. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize a batch of sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + preprocessor = keras_nlp.models.GemmaPreprocessor(tokenizer=tokenizer) + preprocessor("The quick brown fox jumped.") + ``` + + Apply preprocessing to a `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.GemmaPreprocessor.from_preset( + "gemma_2b_en" + ) + + text = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((text, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(text) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + tokenizer_cls = GemmaTokenizer + + def __init__( + self, + tokenizer, + sequence_length=8192, + add_start_token=True, + add_end_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.tokenizer = tokenizer + self.sequence_length = sequence_length + self.add_start_token = add_start_token + self.add_end_token = add_end_token + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + pad_value=self.tokenizer.pad_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "GemmaPreprocessor requires each input to contain only " + f"one segment, but received {len(x)}. If you are using Gemma " + "for a multi-segment classification task, please combine your " + "input into a single string." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config diff --git a/keras_nlp/models/gemma/gemma_preprocessor_test.py b/keras_nlp/models/gemma/gemma_preprocessor_test.py new file mode 100644 index 0000000000..0cb427af03 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_preprocessor_test.py @@ -0,0 +1,75 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_preprocessor import GemmaPreprocessor +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = GemmaTokenizer( + proto=os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ), + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["the quick brown fox"] + + def test_preprocessor_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output={ + "token_ids": [[1, 4, 9, 5, 7, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + preprocessor = GemmaPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[4, 9, 5, 7, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + + def test_sequence_length_override(self): + input_data = "the quick brown fox" + preprocessor = GemmaPreprocessor(**self.init_kwargs) + x = preprocessor(input_data, sequence_length=4) + self.assertAllEqual(x["token_ids"], [1, 4, 9, 2]) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaPreprocessor.presets: + self.run_preset_test( + cls=GemmaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/gemma_presets.py b/keras_nlp/models/gemma/gemma_presets.py new file mode 100644 index 0000000000..72360f72dc --- /dev/null +++ b/keras_nlp/models/gemma/gemma_presets.py @@ -0,0 +1,66 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "gemma_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_2b_en/2", + }, + "gemma_instruct_2b_en": { + "metadata": { + "description": ( + "18-layer Gemma model (Gemma with 2B parameters). " + ), + "params": 2506172416, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_2b_en/2", + }, + "gemma_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_7b_en/2", + }, + "gemma_instruct_7b_en": { + "metadata": { + "description": ( + "28-layer Gemma model (Gemma with 7B parameters). " + ), + "params": 8537680896, + "official_name": "Gemma", + "path": "gemma", + "model_card": "https://www.kaggle.com/models/google/gemma", + }, + "kaggle_handle": "kaggle://keras/gemma/keras/gemma_instruct_7b_en/2", + }, +} diff --git a/keras_nlp/models/gemma/gemma_tokenizer.py b/keras_nlp/models/gemma/gemma_tokenizer.py new file mode 100644 index 0000000000..7722d35f35 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer.py @@ -0,0 +1,101 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer + + +@keras_nlp_export("keras_nlp.models.GemmaTokenizer") +class GemmaTokenizer(SentencePieceTokenizer): + """Gemma tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Gemma models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Gemma preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + + ```python + # Unbatched input. + tokenizer = keras_nlp.models.GemmaTokenizer.from_preset("gemma_2b_en") + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + + # Custom vocabulary. + bytes_io = io.BytesIO() + ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=ds.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=8, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + tokenizer = keras_nlp.models.GemmaTokenizer( + proto=bytes_io.getvalue(), + ) + tokenizer("The quick brown fox jumped.") + ``` + """ + + def __init__(self, proto, **kwargs): + self.start_token = "" + self.end_token = "" + self.pad_token = "" + + super().__init__(proto=proto, **kwargs) + + def set_proto(self, proto): + super().set_proto(proto) + if proto is not None: + for token in [self.end_token, self.pad_token]: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self.start_token_id = self.token_to_id(self.start_token) + self.end_token_id = self.token_to_id(self.end_token) + self.pad_token_id = self.token_to_id(self.pad_token) + else: + self.start_token_id = None + self.end_token_id = None + self.pad_token_id = None diff --git a/keras_nlp/models/gemma/gemma_tokenizer_test.py b/keras_nlp/models/gemma/gemma_tokenizer_test.py new file mode 100644 index 0000000000..65569c8174 --- /dev/null +++ b/keras_nlp/models/gemma/gemma_tokenizer_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.gemma.gemma_tokenizer import GemmaTokenizer +from keras_nlp.tests.test_case import TestCase + + +@pytest.mark.keras_3_only +class GemmaTokenizerTest(TestCase): + def setUp(self): + self.init_kwargs = { + # Generated using create_gemma_test_proto.py + "proto": os.path.join( + self.get_test_data_dir(), "gemma_test_vocab.spm" + ) + } + self.input_data = ["the quick brown fox", "the earth is round"] + + def test_tokenizer_basics(self): + self.run_preprocessing_layer_test( + cls=GemmaTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=[[4, 9, 5, 7], [4, 6, 8, 10]], + ) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + GemmaTokenizer( + # Generated using create_no_special_token_proto.py + proto=os.path.join( + self.get_test_data_dir(), "no_special_token_vocab.spm" + ) + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=GemmaTokenizer, + preset="gemma_2b_en", + input_data=["The quick brown fox."], + expected_output=[[651, 4320, 8426, 25341, 235265]], + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + for preset in GemmaTokenizer.presets: + self.run_preset_test( + cls=GemmaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/gemma/rms_normalization.py b/keras_nlp/models/gemma/rms_normalization.py new file mode 100644 index 0000000000..c3e4296020 --- /dev/null +++ b/keras_nlp/models/gemma/rms_normalization.py @@ -0,0 +1,40 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_nlp.backend import keras +from keras_nlp.backend import ops + + +class RMSNormalization(keras.layers.Layer): + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(input_shape[-1],), + initializer="zeros", + ) + self.built = True + + def call(self, x): + # Always compute normalization in float32. + x = ops.cast(x, "float32") + scale = ops.cast(self.scale, "float32") + var = ops.mean(ops.square(x), axis=-1, keepdims=True) + normed_inputs = x * ops.reciprocal(ops.sqrt(var + self.epsilon)) + normed_inputs = normed_inputs * (1 + scale) + return ops.cast(normed_inputs, self.compute_dtype) diff --git a/keras_nlp/models/gpt2/__init__.py b/keras_nlp/models/gpt2/__init__.py index ba0c2545e4..ad86022a86 100644 --- a/keras_nlp/models/gpt2/__init__.py +++ b/keras_nlp/models/gpt2/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone +from keras_nlp.models.gpt2.gpt2_presets import backbone_presets +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (GPT2Backbone, GPT2Tokenizer)) diff --git a/keras_nlp/models/gpt2/gpt2_backbone.py b/keras_nlp/models/gpt2/gpt2_backbone.py index d93b2199b0..49929e4091 100644 --- a/keras_nlp/models/gpt2/gpt2_backbone.py +++ b/keras_nlp/models/gpt2/gpt2_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -20,9 +19,7 @@ from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.utils.keras_utils import gelu_approximate -from keras_nlp.utils.python_utils import classproperty def _gpt_2_kernel_initializer(stddev=0.02): @@ -170,6 +167,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) @@ -196,7 +194,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index e154c88bb1..40d4787119 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( GPT2CausalLMPreprocessor, ) -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPT2CausalLM") -class GPT2CausalLM(GenerativeTask): +class GPT2CausalLM(CausalLM): """An end-to-end GPT2 model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -149,6 +147,9 @@ class GPT2CausalLM(GenerativeTask): ``` """ + backbone_cls = GPT2Backbone + preprocessor_cls = GPT2CausalLMPreprocessor + def __init__( self, backbone, @@ -177,18 +178,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return GPT2Backbone - - @classproperty - def preprocessor_cls(cls): - return GPT2CausalLMPreprocessor - def call_with_cache( self, token_ids, @@ -251,7 +240,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -262,8 +251,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -296,18 +285,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") @@ -321,3 +311,134 @@ def next(prompt, cache, index): "token_ids": token_ids, "padding_mask": padding_mask, } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `GPT2CausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `GPT2Backbone` and isn't influential on + the computation of this function. If omitted, this function uses + `keras.ops.ones()` to create a tensor of the appropriate shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. This index is not an index + into `self.backbone.layers`. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + generations = gpt2_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = gpt2_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = gpt2_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + position_embeddings = self.backbone.position_embedding(token_embeddings) + summed_embeddings = self.backbone.embeddings_add( + (token_embeddings, position_embeddings) + ) + x = layer_intercept_fn(summed_embeddings, -1) + x = self.backbone.embeddings_dropout(x) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py index 97d0b42d97..3278b18a4f 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py index f34b6baa47..8999ebd9af 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm_test.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm_test.py @@ -128,3 +128,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 7) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = GPT2CausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 7) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/gpt2/gpt2_preprocessor.py b/keras_nlp/models/gpt2/gpt2_preprocessor.py index 82be34776f..4641a32020 100644 --- a/keras_nlp/models/gpt2/gpt2_preprocessor.py +++ b/keras_nlp/models/gpt2/gpt2_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPT2Preprocessor") @@ -109,6 +106,8 @@ class GPT2Preprocessor(Preprocessor): ``` """ + tokenizer_cls = GPT2Tokenizer + def __init__( self, tokenizer, @@ -185,11 +184,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return GPT2Tokenizer diff --git a/keras_nlp/models/gpt2/gpt2_tokenizer.py b/keras_nlp/models/gpt2/gpt2_tokenizer.py index 15b35bed87..4a585c3176 100644 --- a/keras_nlp/models/gpt2/gpt2_tokenizer.py +++ b/keras_nlp/models/gpt2/gpt2_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.gpt2.gpt2_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPT2Tokenizer") @@ -104,10 +101,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.start_token_id = None self.pad_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py index 1955ed5801..415fa56af2 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_backbone.py @@ -137,6 +137,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py index bef32017ea..119e51cc75 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm.py @@ -15,16 +15,16 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.gpt_neo_x.gpt_neo_x_backbone import GPTNeoXBackbone from keras_nlp.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import ( GPTNeoXCausalLMPreprocessor, ) -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.GPTNeoXCausalLM") -class GPTNeoXCausalLM(GenerativeTask): +class GPTNeoXCausalLM(CausalLM): """An end-to-end GPTNeoX model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -46,6 +46,9 @@ class GPTNeoXCausalLM(GenerativeTask): should be preprocessed before calling the model. """ + backbone_cls = GPTNeoXBackbone + preprocessor_cls = GPTNeoXCausalLMPreprocessor + def __init__( self, backbone, @@ -74,14 +77,6 @@ def __init__( jit_compile=True, ) - @classproperty - def backbone_cls(cls): - return GPTNeoXBackbone - - @classproperty - def preprocessor_cls(cls): - return GPTNeoXCausalLMPreprocessor - def call_with_cache( self, token_ids, @@ -141,7 +136,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -152,8 +147,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -186,18 +181,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop_tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py index 92ff9bbb03..71b3d8ce04 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py @@ -99,7 +99,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -127,7 +127,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py index 8dc374332b..d5406a6a3d 100644 --- a/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py +++ b/keras_nlp/models/gpt_neo_x/gpt_neo_x_preprocessor.py @@ -20,7 +20,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.GPTNeoXPreprocessor") @@ -65,6 +64,8 @@ class GPTNeoXPreprocessor(Preprocessor): the layer. """ + tokenizer_cls = GPTNeoXTokenizer + def __init__( self, tokenizer, @@ -141,7 +142,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return GPTNeoXTokenizer diff --git a/keras_nlp/models/llama/__init__.py b/keras_nlp/models/llama/__init__.py index ba0c2545e4..3a57fccd42 100644 --- a/keras_nlp/models/llama/__init__.py +++ b/keras_nlp/models/llama/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (LlamaBackbone, LlamaTokenizer)) diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py index 529e73b009..33ffcef209 100644 --- a/keras_nlp/models/llama/llama_attention.py +++ b/keras_nlp/models/llama/llama_attention.py @@ -18,34 +18,33 @@ class LlamaAttention(keras.layers.Layer): - """Grouped query attention for Llama models""" + """A cached grounded query attention layer with sliding window.""" def __init__( self, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads + self.dropout = dropout self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength - self.kernel_initializer = keras.initializers.get(kernel_initializer) - self.max_sequence_length = max_sequence_length + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) self.rope_scaling_factor = rope_scaling_factor - self.rope_max_wavelength = rope_max_wavelength def build(self, inputs_shape): - self.hidden_dim = inputs_shape[-1] - self.attn_head_size = self.hidden_dim // self.num_query_heads - # Einsum variables: # b = batch size # q = query length @@ -54,18 +53,27 @@ def build(self, inputs_shape): # u = num query heads # v = num key/value heads # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self.num_query_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) + self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="key", ) @@ -73,8 +81,12 @@ def build(self, inputs_shape): self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="value", ) @@ -86,21 +98,28 @@ def build(self, inputs_shape): name="attention_softmax", ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + self._output_dense = keras.layers.EinsumDense( - equation="bqm,mh->bqh", - output_shape=(None, self.hidden_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build(inputs_shape) + self._output_dense.build((None, None, self.num_query_heads, head_dim)) - self._rotary_embedding_layer = RotaryEmbedding( + self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) - self._rotary_embedding_layer.build(inputs_shape) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" self.built = True @@ -110,6 +129,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + training=None, ): query = self._query_dense(hidden_states) @@ -136,75 +156,61 @@ def call( key = self._key_dense(hidden_states) value = self._value_dense(hidden_states) - query = self._rotary_embedding_layer(query) - key = self._rotary_embedding_layer(key) + query = self.rotary_embedding_layer(query) + key = self.rotary_embedding_layer(key) - key = ops.tile(key, [1, 1, self.num_key_value_groups, 1]) - value = ops.tile(value, [1, 1, self.num_key_value_groups, 1]) + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - attention_output, attention_scores = self._compute_attention( + attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output_shape = ops.shape(attention_output) - - attention_output = ops.reshape( - attention_output, - [ - attention_output_shape[0], - attention_output_shape[1], - self.hidden_dim, - ], + attention_output = self._dropout_layer( + attention_output, training=training ) attention_output = self._output_dense(attention_output) if cache is not None: - return (attention_output, cache) + return attention_output, cache return attention_output def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - mask_expansion_axis = -3 - for _ in range( - len(attention_scores.shape) - len(attention_mask.shape) - ): - attention_mask = ops.expand_dims( - attention_mask, axis=mask_expansion_axis - ) - return self._softmax(attention_scores, attention_mask) + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum("aecd,abcd->acbe", key, query) - - norm_factor = ops.sqrt( - ops.convert_to_tensor(self.attn_head_size, self.compute_dtype) - ) + attention_scores = ops.einsum(self._dot_product_equation, query, key) - attention_scores /= norm_factor + attention_scores = attention_scores / self._norm_factor attention_scores = self._masked_softmax( attention_scores, attention_mask ) attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( - "acbe,aecd->abcd", attention_scores, value + self._combine_equation, attention_scores, value ) - return attention_output, attention_scores + return attention_output def get_config(self): config = super().get_config() config.update( { "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_factor": self.rope_scaling_factor, - "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index cc628ad7a5..e586fa97f5 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -16,7 +16,7 @@ from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.llama.llama_decoder import LlamaDecoder +from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm @@ -27,41 +27,64 @@ def _llama_kernel_initializer(stddev=0.02): @keras_nlp_export("keras_nlp.models.LlamaBackbone") class LlamaBackbone(Backbone): """ - LLaMA core network with hyperparameters. + The Llama Transformer core architecture with hyperparameters. This network implements a Transformer-based decoder network, - LLaMA, as described in ["LLaMA: Open Foundation and Fine-Tuned Language Models"](https://arxiv.org/abs/2302.13971). + Llama, as described in + ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. The default constructor gives a fully customizable, randomly initialized - LLaMA model with any number of layers, heads, and embedding - dimensions. This backbone also supports LLaMA2 checkpoints. + Llama model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. Args: - vocabulary_size: int. The size of the token vocabulary. - num_layers: int. The number of transformer layers. - num_query_heads: int. The number of attention heads for each transformer. - The hidden size must be divisible by the number of attention heads. - hidden_dim: int. The size of the transformer encoding and pooler layers. - intermediate_dim: int. The output dimension of the first Dense layer in - a two-layer feedforward network for each transformer. - num_key_value_heads: int. This is the number of key_value heads that - should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, - the model will use Multi Head Attention (MHA), if num_key_value_heads=1 - the model will use Multi Query Attention (MQA) - rope_scaling_factor: float. The scaling factor for calculation of rotary - embedding - rope_max_wavelength: int. The maximum angular wavelength of the - sine/cosine curves, for rotary embeddings. - layer_norm_epsilon: float. a value added to the denominator for - numerical stability. - max_sequence_length: int. The maximum sequence length that this encoder - can consume. If `None`, `max_sequence_length` uses the value from - sequence length. This determines the variable shape for positional - embeddings. + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Llama decoder. + model = keras_nlp.models.LlamaBackbone.from_preset("llama7b_base_en") + model(input_data) + + # Randomly initialized Llama decoder with custom config. + model = keras_nlp.models.LlamaBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` """ def __init__( @@ -72,10 +95,10 @@ def __init__( hidden_dim, intermediate_dim, num_key_value_heads, - rope_scaling_factor=1.0, rope_max_wavelength=10000, - layer_norm_epsilon=1e-5, - max_sequence_length=4096, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, dtype=None, **kwargs, ): @@ -83,31 +106,31 @@ def __init__( self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, - embeddings_initializer=_llama_kernel_initializer(stddev=0.01), tie_weights=False, + embeddings_initializer=_llama_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) self.transformer_layers = [] for i in range(num_layers): - layer = LlamaDecoder( + layer = LlamaTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, - rope_scaling_factor=rope_scaling_factor, - max_sequence_length=max_sequence_length, rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, kernel_initializer=_llama_kernel_initializer(stddev=0.02), + dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) self.layer_norm = LlamaLayerNorm( - dtype=dtype, epsilon=layer_norm_epsilon, - name="layer_norm", + dtype=dtype, + name="sequence_output_layernorm", ) # === Functional Model === @@ -127,6 +150,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) @@ -139,8 +163,8 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout def get_config(self): config = super().get_config() @@ -154,8 +178,8 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index efff972c6b..b641a0152e 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -28,7 +28,6 @@ def setUp(self): "num_key_value_heads": 2, "hidden_dim": 8, "intermediate_dim": 8, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), @@ -50,3 +49,34 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) + + def test_num_parameters(self): + model = LlamaBackbone(**self.init_kwargs) + # Reference value calculated using the PyTorch model + self.assertEqual(model.count_params(), 968) + + @pytest.mark.extra_large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaBackbone, + preset="llama2_7b_en", + input_data={ + "token_ids": ops.array([[1, 1824, 349, 524, 11234, 28804]]), + "padding_mask": ops.ones((1, 6), dtype="int32"), + }, + expected_output_shape=(1, 6, 4096), + # The forward pass from a preset should be stable! + # Reference values computed using PyTorch HF model. + expected_partial_output=ops.array( + [0.0153, 1.1657, 2.2452, -2.0192, -0.5801] + ), + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaBackbone.presets: + self.run_preset_test( + cls=LlamaBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_causal_lm.py b/keras_nlp/models/llama/llama_causal_lm.py new file mode 100644 index 0000000000..48b5fdb4c2 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm.py @@ -0,0 +1,341 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.causal_lm import CausalLM +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLM") +class LlamaCausalLM(CausalLM): + """An end-to-end Llama model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a LLaMA model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_nlp.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_nlp.models.LlamaBackbone` instance. + preprocessor: A `keras_nlp.models.LlamaCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + inputs = backbone.inputs + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Default compilation === + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(2e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + @classproperty + def backbone_cls(cls): + return LlamaBackbone + + @classproperty + def preprocessor_cls(cls): + return LlamaCausalLMPreprocessor + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `LlamaCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self._sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = ops.logical_and( + any_equal(token_ids, stop_token_ids), + ops.logical_not(padding_mask), + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `LlamaCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `LlamaBackbone` and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + llama_lm = keras_nlp.models.LlamaCausalLM.from_preset("llama2_7b_en") + generations = llama_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = llama_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = llama_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a221185582 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") +class LlamaCausalLMPreprocessor(LlamaPreprocessor): + """Llama Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`LlamaCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..aa4d155c8c --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = LlamaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=LlamaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_causal_lm_test.py b/keras_nlp/models/llama/llama_causal_lm_test.py new file mode 100644 index 0000000000..c006f72783 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_test.py @@ -0,0 +1,215 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import patch + +import pytest + +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models.llama.llama_causal_lm import LlamaCausalLM +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = LlamaCausalLMPreprocessor( + LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "llama_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = LlamaBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = LlamaCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=LlamaCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLM.presets: + self.run_preset_test( + cls=LlamaCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = LlamaCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 8) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 3b9d6906b8..7b4ad5f75d 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -24,20 +24,20 @@ from keras_nlp.utils.keras_utils import clone_initializer -class LlamaDecoder(keras.layers.Layer): - """Llama decoder block.""" +class LlamaTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Llama backbone.""" def __init__( self, intermediate_dim, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) @@ -48,61 +48,80 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length + self.dropout = dropout + self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.supports_masking = True + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape self.hidden_dim = decoder_sequence_shape[-1] - # Self attention layers. + # Self attention layer. self._self_attention_layer = LlamaAttention( num_query_heads=self.num_query_heads, num_key_value_heads=self.num_key_value_heads, rope_max_wavelength=self.rope_max_wavelength, - max_sequence_length=self.max_sequence_length, rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, dtype=self.dtype_policy, + name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="self_attention_layernorm", ) self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) # Feedforward layers. self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_gate_dense", ) self._feedforward_gate_dense.build(decoder_sequence_shape) self._feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_output_dense", ) - intermediate_shape = list(decoder_sequence_shape) - intermediate_shape[-1] = self.intermediate_dim - self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) self._feedforward_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="feedforward_layernorm", ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -115,6 +134,7 @@ def call( decoder_attention_mask=None, self_attention_cache=None, self_attention_cache_update_index=None, + training=None, ): self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, @@ -125,10 +145,9 @@ def call( ) residual = decoder_sequence - x = self._self_attention_layernorm( - decoder_sequence, - ) + x = self._self_attention_layernorm(decoder_sequence) + # Self attention block. x = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, @@ -139,12 +158,25 @@ def call( if self_attention_cache is not None: x, self_attention_cache = x + x = self._self_attention_dropout(x, training=training) + x = x + residual residual = x x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output)) @@ -152,7 +184,7 @@ def call( decoder_output = x + residual if self_attention_cache is not None: - return (decoder_output, self_attention_cache) + return decoder_output, self_attention_cache return decoder_output def _compute_self_attention_mask( @@ -160,8 +192,8 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, - self_attention_cache=None, - self_attention_cache_update_index=None, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -174,16 +206,16 @@ def _compute_self_attention_mask( if self_attention_cache is not None: input_length = ops.shape(self_attention_cache)[2] + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + causal_mask = compute_causal_mask( - batch_size, - input_length, - output_length, - ( - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index - ), + batch_size, input_length, output_length, cache_update_index ) + return ( ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None @@ -198,17 +230,16 @@ def get_config(self): config.update( { "intermediate_dim": self.intermediate_dim, - "hidden_dim": self.hidden_dim, "num_query_heads": self.num_query_heads, "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_layernorm.py b/keras_nlp/models/llama/llama_layernorm.py index 0e85a45625..8d19c9bc19 100644 --- a/keras_nlp/models/llama/llama_layernorm.py +++ b/keras_nlp/models/llama/llama_layernorm.py @@ -14,24 +14,35 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops -# TODO: Should be replaced with LayerNormalization with `rms_scaling` param -# https://github.com/keras-team/keras-core/pull/726 - +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. class LlamaLayerNorm(keras.layers.Layer): + """A normalization layer for Llama that implements RMS normalization.""" + def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) self.epsilon = epsilon def build(self, input_shape): - self.weight = self.add_weight( - name="weight", - shape=(input_shape[-1],), + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), initializer="ones", + dtype=self.variable_dtype, ) self.built = True - def call(self, hidden_states): - variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True) - hidden_states = hidden_states * 1 / ops.sqrt(variance + self.epsilon) - return self.weight * hidden_states + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x, self.compute_dtype) * self.scale + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py new file mode 100644 index 0000000000..f3aaa208a8 --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -0,0 +1,188 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.LlamaPreprocessor") +class LlamaPreprocessor(Preprocessor): + """A Llama preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. + with the appropriate tokens. + 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` + that can be passed directly to `keras_nlp.models.LlamaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the from_preset(). + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + tokenizer_cls = LlamaTokenizer + + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=False, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Llama requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Llama" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py new file mode 100644 index 0000000000..52a559aa2e --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -0,0 +1,68 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["the quick brown fox"], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = LlamaPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaPreprocessor.presets: + self.run_preset_test( + cls=LlamaPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_presets.py b/keras_nlp/models/llama/llama_presets.py new file mode 100644 index 0000000000..292848a11d --- /dev/null +++ b/keras_nlp/models/llama/llama_presets.py @@ -0,0 +1,38 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "llama2_7b_en": { + "metadata": { + "description": "LLaMA 2 7B Base model", + "params": 6738415616, + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", + }, + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1", + }, + "llama2_instruct_7b_en": { + "metadata": { + "description": "LLaMA 2 7B Chat model", + "params": 6738415616, + "official_name": "LLaMA 2", + "path": "llama2", + "model_card": "https://github.com/meta-llama/llama", + }, + "kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1", + }, +} diff --git a/keras_nlp/models/llama/llama_tokenizer.py b/keras_nlp/models/llama/llama_tokenizer.py index 7acdf8687c..07b0f21037 100644 --- a/keras_nlp/models/llama/llama_tokenizer.py +++ b/keras_nlp/models/llama/llama_tokenizer.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy + from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.llama.llama_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.LlamaTokenizer") @@ -79,3 +83,7 @@ def set_proto(self, proto): self.start_token_id = None self.end_token_id = None self.pad_token_id = None + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_tokenizer_test.py b/keras_nlp/models/llama/llama_tokenizer_test.py index 9a3c225456..51687731e5 100644 --- a/keras_nlp/models/llama/llama_tokenizer_test.py +++ b/keras_nlp/models/llama/llama_tokenizer_test.py @@ -14,6 +14,8 @@ import os +import pytest + from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer from keras_nlp.tests.test_case import TestCase @@ -44,3 +46,21 @@ def test_errors_missing_special_tokens(self): self.get_test_data_dir(), "no_special_token_vocab.spm" ) ) + + @pytest.mark.large + def test_smallest_preset(self): + self.run_preset_test( + cls=LlamaTokenizer, + preset="llama2_7b_en", + input_data=["The quick brown fox."], + expected_output=[[450, 4996, 17354, 1701, 29916, 29889]], + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaTokenizer.presets: + self.run_preset_test( + cls=LlamaTokenizer, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/masked_lm.py b/keras_nlp/models/masked_lm.py new file mode 100644 index 0000000000..136dbf0b8e --- /dev/null +++ b/keras_nlp/models/masked_lm.py @@ -0,0 +1,42 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.task import Task + + +@keras_nlp_export("keras_nlp.models.MaskedLM") +class MaskedLM(Task): + """Base class for masked language modeling tasks. + + `MaskedLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + unsupervised fine-tuning with a masked language modeling loss. + + When calling `fit()`, all input will be tokenized, and random tokens in + the input sequence will be masked. These positions of these masked tokens + will be fed as an additional model input, and the original value of the + tokens predicted by the model outputs. + + All `MaskedLM` tasks include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. + + Example: + ```python + # Load a Bert MaskedLM with pre-trained weights. + masked_lm = keras_nlp.models.MaskedLM.from_preset( + "bert_base_en", + ) + masked_lm.fit(train_ds) + ``` + """ diff --git a/keras_nlp/models/mistral/__init__.py b/keras_nlp/models/mistral/__init__.py index ba0c2545e4..2593ab0201 100644 --- a/keras_nlp/models/mistral/__init__.py +++ b/keras_nlp/models/mistral/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.mistral.mistral_backbone import MistralBackbone +from keras_nlp.models.mistral.mistral_presets import backbone_presets +from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (MistralBackbone, MistralTokenizer)) diff --git a/keras_nlp/models/mistral/mistral_backbone.py b/keras_nlp/models/mistral/mistral_backbone.py index 3e2cfae148..1ee2dfce66 100644 --- a/keras_nlp/models/mistral/mistral_backbone.py +++ b/keras_nlp/models/mistral/mistral_backbone.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,11 +20,9 @@ from keras_nlp.models.mistral.mistral_layer_norm import ( MistralLayerNormalization, ) -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_transformer_decoder import ( MistralTransformerDecoder, ) -from keras_nlp.utils.python_utils import classproperty def _mistral_kernel_initializer(stddev=0.02): @@ -166,6 +163,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=sequence_output, + dtype=dtype, **kwargs, ) @@ -200,7 +198,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_causal_lm.py b/keras_nlp/models/mistral/mistral_causal_lm.py index 3296bb9495..754c07d2a5 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm.py +++ b/keras_nlp/models/mistral/mistral_causal_lm.py @@ -11,22 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.mistral.mistral_backbone import MistralBackbone from keras_nlp.models.mistral.mistral_causal_lm_preprocessor import ( MistralCausalLMPreprocessor, ) -from keras_nlp.models.mistral.mistral_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.MistralCausalLM") -class MistralCausalLM(GenerativeTask): +class MistralCausalLM(CausalLM): """An end-to-end Mistral model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -48,6 +46,9 @@ class MistralCausalLM(GenerativeTask): should be preprocessed before calling the model. """ + backbone_cls = MistralBackbone + preprocessor_cls = MistralCausalLMPreprocessor + def __init__(self, backbone, preprocessor=None, **kwargs): # === Layers === self.backbone = backbone @@ -71,14 +72,6 @@ def __init__(self, backbone, preprocessor=None, **kwargs): jit_compile=True, ) - @classproperty - def backbone_cls(cls): - return MistralBackbone - - @classproperty - def preprocessor_cls(cls): - return MistralCausalLMPreprocessor - def call_with_cache( self, token_ids, @@ -143,7 +136,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -154,8 +147,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -188,18 +181,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop_tokens locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") @@ -214,6 +208,128 @@ def next(prompt, cache, index): "padding_mask": padding_mask, } - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `MistralCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the MistralBackbone and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + mistral_lm = keras_nlp.models.MistralCausalLM.from_preset( + "mistral_7b_en" + ) + generations = mistral_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = mistral_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = mistral_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py index 893036cd58..624c37c9a1 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py @@ -131,7 +131,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -159,7 +159,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/mistral/mistral_causal_lm_test.py b/keras_nlp/models/mistral/mistral_causal_lm_test.py index 3f9d7fab36..13f0dad907 100644 --- a/keras_nlp/models/mistral/mistral_causal_lm_test.py +++ b/keras_nlp/models/mistral/mistral_causal_lm_test.py @@ -128,3 +128,88 @@ def test_all_presets(self): preset=preset, input_data=self.input_data, ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MistralCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 8) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_nlp/models/mistral/mistral_layer_norm.py b/keras_nlp/models/mistral/mistral_layer_norm.py index e714a8540d..6801a4679b 100644 --- a/keras_nlp/models/mistral/mistral_layer_norm.py +++ b/keras_nlp/models/mistral/mistral_layer_norm.py @@ -23,25 +23,26 @@ class MistralLayerNormalization(keras.layers.Layer): def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) - self._epsilon = epsilon + self.epsilon = epsilon def build(self, input_shape): - self._dim = input_shape[-1] - self._weight = self.add_weight( - name="weight", + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", trainable=True, - shape=(self._dim,), + shape=(dim,), initializer="ones", + dtype=self.variable_dtype, ) self.built = True def call(self, x): - x = x * ops.rsqrt( - ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + self._epsilon - ) - return x * self._weight + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x, self.compute_dtype) * self.scale def get_config(self): config = super().get_config() - config.update({"epsilon": self._epsilon}) + config.update({"epsilon": self.epsilon}) return config diff --git a/keras_nlp/models/mistral/mistral_preprocessor.py b/keras_nlp/models/mistral/mistral_preprocessor.py index 38dc6da5b6..3df849fd06 100644 --- a/keras_nlp/models/mistral/mistral_preprocessor.py +++ b/keras_nlp/models/mistral/mistral_preprocessor.py @@ -11,18 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.models.mistral.mistral_tokenizer import MistralTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.MistralPreprocessor") @@ -113,6 +110,8 @@ class MistralPreprocessor(Preprocessor): ``` """ + tokenizer_cls = MistralTokenizer + def __init__( self, tokenizer, @@ -188,11 +187,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return MistralTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_presets.py b/keras_nlp/models/mistral/mistral_presets.py index 82a2ec44f6..fdee396300 100644 --- a/keras_nlp/models/mistral/mistral_presets.py +++ b/keras_nlp/models/mistral/mistral_presets.py @@ -23,7 +23,7 @@ "path": "mistral", "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", }, - "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3", + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6", }, "mistral_instruct_7b_en": { "metadata": { @@ -33,6 +33,16 @@ "path": "mistral", "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", }, - "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3", + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6", + }, + "mistral_0.2_instruct_7b_en": { + "metadata": { + "description": "Mistral 7B instruct Version 0.2 model", + "params": 7241732096, + "official_name": "Mistral", + "path": "mistral", + "model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md", + }, + "kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/1", }, } diff --git a/keras_nlp/models/mistral/mistral_tokenizer.py b/keras_nlp/models/mistral/mistral_tokenizer.py index 59a00d302f..e91a20df88 100644 --- a/keras_nlp/models/mistral/mistral_tokenizer.py +++ b/keras_nlp/models/mistral/mistral_tokenizer.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.mistral.mistral_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.MistralTokenizer") @@ -49,7 +46,7 @@ class MistralTokenizer(SentencePieceTokenizer): ```python # Unbatched input. tokenizer = keras_nlp.models.MistralTokenizer.from_preset( - "mistral_base_en", + "mistral_7b_en", ) tokenizer("The quick brown fox jumped.") @@ -81,7 +78,3 @@ def set_proto(self, proto): else: self.start_token_id = None self.end_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/mistral/mistral_transformer_decoder.py b/keras_nlp/models/mistral/mistral_transformer_decoder.py index 7c90ab91b9..3ef91d3066 100644 --- a/keras_nlp/models/mistral/mistral_transformer_decoder.py +++ b/keras_nlp/models/mistral/mistral_transformer_decoder.py @@ -102,7 +102,6 @@ def build(self, decoder_sequence_shape): self._feedforward_gate_dense = keras.layers.Dense( self.intermediate_dim, - activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), use_bias=False, dtype=self.dtype_policy, @@ -172,6 +171,17 @@ def call( x = self._feedforward_layernorm(x) gate_output = self._feedforward_gate_dense(x) + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + x = self._feedforward_intermediate_dense(x) x = self._feedforward_output_dense(ops.multiply(x, gate_output)) @@ -207,17 +217,20 @@ def _compute_self_attention_mask( else self_attention_cache_update_index ) - # Mistral uses a banded attention mask - causal_mask_lower = compute_causal_mask( + # The lower traingular attention mask + causal_mask = compute_causal_mask( batch_size, input_length, output_length, cache_update_index ) - # Below is a workaround for `ops.triu` for Keras 2. - # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. - # causal_mask = ops.triu(causal_mask_lower, k=-self.sliding_window) - i = ops.arange(output_length)[:, None] + cache_update_index - j = ops.arange(input_length)[None, :] - causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") - causal_mask = ops.minimum(causal_mask_lower, causal_mask_upper) + + # Mistral uses a banded attention mask if sliding window is not None + if self.sliding_window is not None: + # Below is a workaround for `ops.triu` for Keras 2. + # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is removed. + # causal_mask = ops.triu(causal_mask, k=-self.sliding_window) + i = ops.arange(output_length)[:, None] + cache_update_index + j = ops.arange(input_length)[None, :] + causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") + causal_mask = ops.minimum(causal_mask, causal_mask_upper) return ( ops.minimum(decoder_mask, causal_mask) diff --git a/keras_nlp/models/opt/__init__.py b/keras_nlp/models/opt/__init__.py index ba0c2545e4..a81c5f3bc1 100644 --- a/keras_nlp/models/opt/__init__.py +++ b/keras_nlp/models/opt/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.opt.opt_backbone import OPTBackbone +from keras_nlp.models.opt.opt_presets import backbone_presets +from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (OPTBackbone, OPTTokenizer)) diff --git a/keras_nlp/models/opt/opt_backbone.py b/keras_nlp/models/opt/opt_backbone.py index 0b98a6c64e..16bd89c4d5 100644 --- a/keras_nlp/models/opt/opt_backbone.py +++ b/keras_nlp/models/opt/opt_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.opt.opt_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def opt_kernel_initializer(stddev=0.02): @@ -146,6 +143,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) @@ -168,7 +166,3 @@ def get_config(self): "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, } - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/opt/opt_causal_lm.py b/keras_nlp/models/opt/opt_causal_lm.py index 9715bc6b75..1bb5bd1e87 100644 --- a/keras_nlp/models/opt/opt_causal_lm.py +++ b/keras_nlp/models/opt/opt_causal_lm.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops -from keras_nlp.models.generative_task import GenerativeTask +from keras_nlp.models.causal_lm import CausalLM from keras_nlp.models.opt.opt_backbone import OPTBackbone from keras_nlp.models.opt.opt_causal_lm_preprocessor import ( OPTCausalLMPreprocessor, ) -from keras_nlp.models.opt.opt_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty +from keras_nlp.utils.tensor_utils import any_equal @keras_nlp_export("keras_nlp.models.OPTCausalLM") -class OPTCausalLM(GenerativeTask): +class OPTCausalLM(CausalLM): """An end-to-end OPT model for causal language modeling. A causal language model (LM) predicts the next token based on previous @@ -149,6 +147,9 @@ class OPTCausalLM(GenerativeTask): ``` """ + backbone_cls = OPTBackbone + preprocessor_cls = OPTCausalLMPreprocessor + def __init__( self, backbone, @@ -177,18 +178,6 @@ def __init__( jit_compile=True, ) - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def backbone_cls(cls): - return OPTBackbone - - @classproperty - def preprocessor_cls(cls): - return OPTCausalLMPreprocessor - def call_with_cache( self, token_ids, @@ -247,7 +236,7 @@ def _build_cache(self, token_ids): def generate_step( self, inputs, - end_token_id=None, + stop_token_ids=None, ): """A compilable generation function for a single batch of inputs. @@ -258,8 +247,8 @@ def generate_step( Args: inputs: A dictionary with two keys `"token_ids"` and `"padding_mask"` and batched tensor values. - end_token_id: The id of the end token to stop on. If all - sequences have produced a new `end_token_id`, generation + stop_token_ids: Tuple of id's of end token's to stop on. If all + sequences have produced a new stop token, generation will stop. """ token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] @@ -292,18 +281,19 @@ def next(prompt, cache, index): cache=cache, index=index, mask=padding_mask, - end_token_id=end_token_id, + stop_token_ids=stop_token_ids, hidden_states=hidden_states, + model=self, ) # Compute an output padding mask with the token ids we updated. - if end_token_id is not None: - # Build a mask of `end_token_id` locations not in the original + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original # prompt (not in locations where `padding_mask` is True). - end_locations = ops.logical_and( - ops.equal(token_ids, end_token_id), - ops.logical_not(padding_mask), + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) ) + end_locations = ops.cast(end_locations, "int32") # Use cumsum to get ones in all locations after end_locations. cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") diff --git a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py index 1895854e41..0a9ab86b00 100644 --- a/keras_nlp/models/opt/opt_causal_lm_preprocessor.py +++ b/keras_nlp/models/opt/opt_causal_lm_preprocessor.py @@ -132,7 +132,7 @@ def generate_preprocess( x, sequence_length=None, ): - """Covert strings to integer token input for generation. + """Convert strings to integer token input for generation. Similar to calling the layer for training, this method takes in strings or tensor strings, tokenizes and packs the input, and computes a padding @@ -160,7 +160,7 @@ def generate_postprocess( self, x, ): - """Covert integer token output to strings for generation. + """Convert integer token output to strings for generation. This method reverses `generate_preprocess()`, by first removing all padding and start/end tokens, and then converting the integer sequence diff --git a/keras_nlp/models/opt/opt_preprocessor.py b/keras_nlp/models/opt/opt_preprocessor.py index 8f52bb67e6..9eae445aec 100644 --- a/keras_nlp/models/opt/opt_preprocessor.py +++ b/keras_nlp/models/opt/opt_preprocessor.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker -from keras_nlp.models.opt.opt_presets import backbone_presets from keras_nlp.models.opt.opt_tokenizer import OPTTokenizer from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.OPTPreprocessor") @@ -109,6 +106,8 @@ class OPTPreprocessor(Preprocessor): ``` """ + tokenizer_cls = OPTTokenizer + def __init__( self, tokenizer, @@ -186,11 +185,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classproperty - def tokenizer_cls(cls): - return OPTTokenizer diff --git a/keras_nlp/models/opt/opt_tokenizer.py b/keras_nlp/models/opt/opt_tokenizer.py index 4fb62ee73a..addcd0c01f 100644 --- a/keras_nlp/models/opt/opt_tokenizer.py +++ b/keras_nlp/models/opt/opt_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.opt.opt_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.OPTTokenizer") @@ -110,10 +107,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.pad_token_id = None self.end_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index 16a65e57c2..a4a9d6ee74 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -12,19 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Preprocessor") class Preprocessor(PreprocessingLayer): - """Base class for model preprocessors.""" + """Base class for preprocessing layers. + + A `Preprocessor` layer wraps a `keras_nlp.tokenizer.Tokenizer` to provide a + complete preprocessing setup for a given task. For example a masked language + modeling preprocessor will take in raw input strings, and output + `(x, y, sample_weight)` tuples. Where `x` contains token id sequences with + some + + This class can be subclassed similar to any `keras.layers.Layer`, by + defining `build()`, `call()` and `get_config()` methods. All subclasses + should set the `tokenizer` property on construction. + """ + + tokenizer_cls = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -56,13 +71,15 @@ def from_config(cls, config): config["tokenizer"] = keras.layers.deserialize(config["tokenizer"]) return cls(**config) - @classproperty - def tokenizer_cls(cls): - return None - @classproperty def presets(cls): - return {} + presets = list_presets(cls) + # We can also load backbone presets. + if cls.tokenizer_cls is not None: + presets.update(cls.tokenizer_cls.presets) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -70,50 +87,68 @@ def from_preset( preset, **kwargs, ): - """Instantiate {{preprocessor_name}} from preset architecture. + """Instantiate a `keras_nlp.models.Preprocessor` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Preprocessor` subclass, you can run `cls.presets.keys()` to + list all built-in presets available on the class. + + As there are usually multiple preprocessing classes for a given model, + this method should be called on a specific subclass like + `keras_nlp.models.BertPreprocessor.from_preset()`. Args: - preset: string. Must be one of "{{preset_names}}". + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. Examples: ```python - # Load a preprocessor layer from a preset. - preprocessor = keras_nlp.models.{{preprocessor_name}}.from_preset( - "{{example_preset_name}}", + # Load a preprocessor for Gemma generation. + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor.from_preset( + "gemma_2b_en", + ) + + # Load a preprocessor for Bert classification. + preprocessor = keras_nlp.models.BertPreprocessor.from_preset( + "bert_base_en", ) ``` """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - + if cls == Preprocessor: + raise ValueError( + "Do not call `Preprocessor.from_preset()` directly. Instead call a " + "choose a particular task class, e.g. " + "`keras_nlp.models.BertPreprocessor.from_preset()`." + ) config_file = "tokenizer.json" - check_preset_class(preset, cls.tokenizer_cls, config_file=config_file) + preset_cls = check_config_class(preset, config_file=config_file) + subclasses = list_subclasses(cls) + subclasses = tuple( + filter(lambda x: x.tokenizer_cls == preset_cls, subclasses) + ) + if len(subclasses) == 0: + raise ValueError( + f"No registered subclass of `{cls.__name__}` can load " + f"a `{preset_cls.__name__}`." + ) + if len(subclasses) > 1: + names = ", ".join(f"`{x.__name__}`" for x in subclasses) + raise ValueError( + f"Ambiguous call to `{cls.__name__}.from_preset()`. " + f"Found multiple possible subclasses {names}. " + "Please call `from_preset` on a subclass directly." + ) + cls = subclasses[0] tokenizer = load_from_preset( preset, config_file=config_file, ) return cls(tokenizer=tokenizer, **kwargs) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Preprocessor.from_preset.__doc__ - format_docstring( - preprocessor_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/models/preprocessor_test.py b/keras_nlp/models/preprocessor_test.py new file mode 100644 index 0000000000..ea78720470 --- /dev/null +++ b/keras_nlp/models/preprocessor_test.py @@ -0,0 +1,37 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor +from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.tests.test_case import TestCase + + +class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertPreprocessor.presets.keys()) + gpt2_presets = set(GPT2Preprocessor.presets.keys()) + all_presets = set(Preprocessor.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + # No loading on a preprocessor directly (it is ambiguous). + Preprocessor.from_preset("bert_tiny_en_uncased", load_weights=False) + with self.assertRaises(ValueError): + # No loading on an incorrect class. + BertPreprocessor.from_preset("gpt2_base_en", load_weights=False) diff --git a/keras_nlp/models/roberta/__init__.py b/keras_nlp/models/roberta/__init__.py index ba0c2545e4..1fe69366c7 100644 --- a/keras_nlp/models/roberta/__init__.py +++ b/keras_nlp/models/roberta/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone +from keras_nlp.models.roberta.roberta_presets import backbone_presets +from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (RobertaBackbone, RobertaTokenizer)) diff --git a/keras_nlp/models/roberta/roberta_backbone.py b/keras_nlp/models/roberta/roberta_backbone.py index 1ab61eeeb7..6c25f2ae6f 100644 --- a/keras_nlp/models/roberta/roberta_backbone.py +++ b/keras_nlp/models/roberta/roberta_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -21,8 +20,6 @@ ) from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder from keras_nlp.models.backbone import Backbone -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty def roberta_kernel_initializer(stddev=0.02): @@ -156,6 +153,7 @@ def __init__( "padding_mask": padding_mask_input, }, outputs=x, + dtype=dtype, **kwargs, ) @@ -183,7 +181,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_classifier.py b/keras_nlp/models/roberta/roberta_classifier.py index 887bc657d4..57f50f4e94 100644 --- a/keras_nlp/models/roberta/roberta_classifier.py +++ b/keras_nlp/models/roberta/roberta_classifier.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaClassifier") -class RobertaClassifier(Task): +class RobertaClassifier(Classifier): """An end-to-end RoBERTa model for classification tasks. This model attaches a classification head to a @@ -134,6 +131,9 @@ class RobertaClassifier(Task): ``` """ + backbone_cls = RobertaBackbone + preprocessor_cls = RobertaPreprocessor + def __init__( self, backbone, @@ -213,15 +213,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return RobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return RobertaPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index bf96189860..ef6660f777 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer from keras_nlp.models.roberta.roberta_masked_lm_preprocessor import ( RobertaMaskedLMPreprocessor, ) -from keras_nlp.models.roberta.roberta_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaMaskedLM") -class RobertaMaskedLM(Task): +class RobertaMaskedLM(MaskedLM): """An end-to-end RoBERTa model for the masked language modeling task. This model will train RoBERTa on a masked language modeling task. @@ -97,6 +94,9 @@ class RobertaMaskedLM(Task): ``` """ + backbone_cls = RobertaBackbone + preprocessor_cls = RobertaMaskedLMPreprocessor + def __init__( self, backbone, @@ -139,15 +139,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return RobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return RobertaMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_preprocessor.py b/keras_nlp/models/roberta/roberta_preprocessor.py index 57a421590f..f0e027a910 100644 --- a/keras_nlp/models/roberta/roberta_preprocessor.py +++ b/keras_nlp/models/roberta/roberta_preprocessor.py @@ -12,20 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaPreprocessor") @@ -133,6 +130,8 @@ class RobertaPreprocessor(Preprocessor): ``` """ + tokenizer_cls = RobertaTokenizer + def __init__( self, tokenizer, @@ -190,11 +189,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return RobertaTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/roberta/roberta_tokenizer.py b/keras_nlp/models/roberta/roberta_tokenizer.py index 0cfabff754..acf7f0aef9 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer.py +++ b/keras_nlp/models/roberta/roberta_tokenizer.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.RobertaTokenizer") @@ -126,10 +123,6 @@ def set_vocabulary_and_merges(self, vocabulary, merges): self.end_token_id = None self.mask_token_id = None - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - def get_config(self): config = super().get_config() # In the constructor, we pass the list of special tokens to the diff --git a/keras_nlp/models/roberta/roberta_tokenizer_test.py b/keras_nlp/models/roberta/roberta_tokenizer_test.py index 3b2305608d..572bc03151 100644 --- a/keras_nlp/models/roberta/roberta_tokenizer_test.py +++ b/keras_nlp/models/roberta/roberta_tokenizer_test.py @@ -28,7 +28,7 @@ def setUp(self): self.merges += ["Ä ai r", "Ä a i", "pla ne"] self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} self.input_data = [ - " airplane at airport", + " airplane at airport", " airplane airport", ] @@ -37,10 +37,9 @@ def test_tokenizer_basics(self): cls=RobertaTokenizer, init_kwargs=self.init_kwargs, input_data=self.input_data, - # TODO: should not get tokenized as - expected_output=[[0, 4, 5, 6, 4, 7, 0, 1], [4, 5, 4, 7]], + expected_output=[[0, 4, 5, 6, 4, 7, 8, 2, 1], [4, 5, 4, 7]], expected_detokenize_output=[ - " airplane at airport", + " airplane at airport", " airplane airport", ], ) diff --git a/keras_nlp/models/seq_2_seq_lm.py b/keras_nlp/models/seq_2_seq_lm.py new file mode 100644 index 0000000000..7c2ee9caa7 --- /dev/null +++ b/keras_nlp/models/seq_2_seq_lm.py @@ -0,0 +1,54 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.models.causal_lm import CausalLM + + +@keras_nlp_export("keras_nlp.models.Seq2SeqLM") +class Seq2SeqLM(CausalLM): + """Base class for sequence to sequence language modeling tasks. + + `Seq2SeqLM` tasks wrap a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be used for + generation and generative fine-tuning, when generation is conditioned on + additional input sequence in a sequence-to-sequence setting. + + `Seq2SeqLM` tasks provide an additional, high-level `generate()` function + which can be used to auto-regressively sample an output sequence token by + token. The `compile()` method of `Seq2SeqLM` classes contains an additional + `sampler` argument, which can be used to pass a `keras_nlp.samplers.Sampler` + to control how the predicted distribution will be sampled. + + When calling `fit()`, each input should contain an input and output + sequence. The model will be trained to predict the output sequence + token-by-token using a causal mask, similar to a `keras_nlp.models.CausalLM` + task. Unlike the `CausalLM` task, an input sequence must be passed, and + can be attended to in full by all tokens in the output sequence. + + All `Seq2SeqLM` tasks include a `from_preset()` constructor which can be + used to load a pre-trained config and weights. + + Example: + ```python + # Load a Bart backbone with pre-trained weights. + seq_2_seq_lm = keras_nlp.models.Seq2SeqLM.from_preset( + "bart_base_en", + ) + seq_2_seq_lm.compile(sampler="top_k") + # Generate conditioned on the `"The quick brown fox."` as an input sequence. + seq_2_seq_lm.generate("The quick brown fox.", max_length=30) + ``` + """ + + # TODO: fill in during https://github.com/keras-team/keras-nlp/pull/1425 diff --git a/keras_nlp/models/t5/__init__.py b/keras_nlp/models/t5/__init__.py index ba0c2545e4..10064e442b 100644 --- a/keras_nlp/models/t5/__init__.py +++ b/keras_nlp/models/t5/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.t5.t5_backbone import T5Backbone +from keras_nlp.models.t5.t5_presets import backbone_presets +from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (T5Backbone, T5Tokenizer)) diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index cf747c503c..67d3b2a6a9 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm -from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.T5Backbone") @@ -224,6 +221,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) @@ -258,7 +256,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/t5/t5_tokenizer.py b/keras_nlp/models/t5/t5_tokenizer.py index b5dee49b85..ec5f0bf324 100644 --- a/keras_nlp/models/t5/t5_tokenizer.py +++ b/keras_nlp/models/t5/t5_tokenizer.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.t5.t5_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.T5Tokenizer") @@ -99,7 +96,3 @@ def set_proto(self, proto): self.end_token_id = None self.pad_token_id = None self.start_token_id = None - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index 697af20899..ddff7a164a 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -131,8 +131,7 @@ def call( shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = ops.cast(attention_mask, "int32") - attention_mask = causal_mask & attention_mask + attention_mask = causal_mask & ops.cast(attention_mask, "bool") x = hidden_states # Intermediate result. diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 783cc0b41b..7858b84709 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -16,19 +16,41 @@ from rich import markup from rich import table as rich_table +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras +from keras_nlp.models.backbone import Backbone from keras_nlp.utils.keras_utils import print_msg from keras_nlp.utils.pipeline_model import PipelineModel -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Task") class Task(PipelineModel): - """Base class for Task models.""" + """Base class for all Task models. + + A `Task` wraps a `keras_nlp.models.Backbone` and + a `keras_nlp.models.Preprocessor` to create a model that can be directly + used for training, fine-tuning, and prediction for a given text problem. + + All `Task` models have `backbone` and `preprocessor` properties. By + default `fit()`, `predict()` and `evaluate()` will preprocess all inputs + automatically. To preprocess inputs separately or with a custom function, + you can set `task.preprocessor = None`, which disable any automatic + preprocessing on inputs. + + All `Task` classes include a `from_preset()` constructor which can be used + to load a pre-trained config and weights. Calling `from_preset()` on a task + will automatically instantiate a `keras_nlp.models.Backbone` and + `keras_nlp.models.Preprocessor`. + """ + + backbone_cls = None + preprocessor_cls = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -36,6 +58,12 @@ def __init__(self, *args, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True + if self.backbone is not None: + # Keras 2 and Keras 3 handle setting policy differently. + if config.keras_3(): + self.dtype_policy = self._backbone.dtype_policy + else: + self._set_dtype_policy(self._backbone.dtype_policy) def __dir__(self): if config.keras_3(): @@ -128,7 +156,7 @@ def __setattr__(self, name, value): @property def backbone(self): """A `keras.Model` instance providing the backbone sub-model.""" - return self._backbone + return getattr(self, "_backbone", None) @backbone.setter def backbone(self, value): @@ -137,7 +165,7 @@ def backbone(self, value): @property def preprocessor(self): """A `keras.layers.Layer` instance used to preprocess inputs.""" - return self._preprocessor + return getattr(self, "_preprocessor", None) @preprocessor.setter def preprocessor(self, value): @@ -166,17 +194,16 @@ def from_config(cls, config): ) return cls(**config) - @classproperty - def backbone_cls(cls): - return None - - @classproperty - def preprocessor_cls(cls): - return None - @classproperty def presets(cls): - return {} + """List built-in presets for a `Task` subclass.""" + presets = list_presets(cls) + # We can also load backbone presets. + if cls.backbone_cls is not None: + presets.update(cls.backbone_cls.presets) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets @classmethod def from_preset( @@ -185,25 +212,54 @@ def from_preset( load_weights=True, **kwargs, ): - """Instantiate {{model_task_name}} model from preset architecture and weights. + """Instantiate a `keras_nlp.models.Task` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Task` subclass, you can run `cls.presets.keys()` to list all + built-in presets available on the class. + + This constructor can be called in one of two ways. Either from a task + specific base class like `keras_nlp.models.CausalLM.from_preset()`, or + from a model class like `keras_nlp.models.BertClassifier.from_preset()`. + If calling from the a base class, the subclass of the returning object + will be inferred from the config in the preset directory. Args: - preset: string. Must be one of "{{preset_names}}". - load_weights: Whether to load pre-trained weights into model. - Defaults to `True`. + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. Examples: ```python - # Load architecture and weights from preset - model = {{model_task_name}}.from_preset("{{example_preset_name}}") + # Load a Gemma generative task. + causal_lm = keras_nlp.models.CausalLM.from_preset( + "gemma_2b_en", + ) - # Load randomly initialized model from preset architecture - model = {{model_task_name}}.from_preset( - "{{example_preset_name}}", - load_weights=False + # Load a Bert classification task. + model = keras_nlp.models.Classifier.from_preset( + "bert_base_en", + num_classes=2, ) ``` """ + if cls == Task: + raise ValueError( + "Do not call `Task.from_preset()` directly. Instead call a " + "particular task class, e.g. " + "`keras_nlp.models.Classifier.from_preset()` or " + "`keras_nlp.models.BertClassifier.from_preset()`." + ) if "backbone" in kwargs: raise ValueError( "You cannot pass a `backbone` argument to the `from_preset` " @@ -211,15 +267,28 @@ def from_preset( "constructor with a `backbone` argument. " f"Received: backbone={kwargs['backbone']}." ) - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - preset_cls = check_preset_class(preset, (cls, cls.backbone_cls)) + preset_cls = check_config_class(preset) # Backbone case. - if preset_cls == cls.backbone_cls: + if issubclass(preset_cls, Backbone): + if preset_cls is not cls.backbone_cls: + subclasses = list_subclasses(cls) + subclasses = tuple( + filter(lambda x: x.backbone_cls == preset_cls, subclasses) + ) + if len(subclasses) == 0: + raise ValueError( + f"No registered subclass of `{cls.__name__}` can load " + f"a `{preset_cls.__name__}`." + ) + if len(subclasses) > 1: + names = ", ".join(f"`{x.__name__}`" for x in subclasses) + raise ValueError( + f"Ambiguous call to `{cls.__name__}.from_preset()`. " + f"Found multiple possible subclasses {names}. " + "Please call `from_preset` on a subclass directly." + ) + cls = subclasses[0] # Forward dtype to the backbone. config_overrides = {} if "dtype" in kwargs: @@ -240,34 +309,18 @@ def from_preset( return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) # Task case. + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) return load_from_preset( preset, load_weights=load_weights, config_overrides=kwargs, ) - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define `from_preset`, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = Task.from_preset.__doc__ - format_docstring( - model_task_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) - @property def layers(self): # Remove preprocessor from layers so it does not show up in the summary. diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index bf82e4fa68..63fd189c12 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from keras_nlp.backend import keras +from keras_nlp.models.bert.bert_classifier import BertClassifier +from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.task import Task from keras_nlp.tests.test_case import TestCase @@ -40,6 +44,22 @@ def __init__(self, preprocessor=None, activation=None, **kwargs): class TestTask(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertClassifier.presets.keys()) + gpt2_presets = set(GPT2CausalLM.presets.keys()) + all_presets = set(Task.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + # No loading on a task directly (it is ambiguous). + Task.from_preset("bert_tiny_en_uncased", load_weights=False) + with self.assertRaises(ValueError): + # No loading on an incorrect class. + BertClassifier.from_preset("gpt2_base_en", load_weights=False) + def test_summary_with_preprocessor(self): preprocessor = SimplePreprocessor() model = SimpleTask(preprocessor) diff --git a/keras_nlp/models/whisper/__init__.py b/keras_nlp/models/whisper/__init__.py index ba0c2545e4..328814046c 100644 --- a/keras_nlp/models/whisper/__init__.py +++ b/keras_nlp/models/whisper/__init__.py @@ -11,3 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.whisper.whisper_backbone import WhisperBackbone +from keras_nlp.models.whisper.whisper_presets import backbone_presets +from keras_nlp.models.whisper.whisper_tokenizer import WhisperTokenizer +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer)) diff --git a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py index e41519bbc9..fb2c1a9cdd 100644 --- a/keras_nlp/models/whisper/whisper_audio_feature_extractor.py +++ b/keras_nlp/models/whisper/whisper_audio_feature_extractor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import numpy as np import tensorflow as tf @@ -21,9 +20,6 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_nlp.models.whisper.whisper_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring @keras_nlp_export("keras_nlp.models.WhisperAudioFeatureExtractor") @@ -265,52 +261,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate Whisper audio feature extractor from a given preset. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Examples: - ```python - # Load a preset tokenizer. - audio_feature_extractor = WhisperAudioFeatureExtractor.from_preset( - "{{example_preset_name}}" - ) - - # Compute the log-mel spectrogram. - audio_tensor = tf.ones((8000,), dtype=tf.float32) - audio_feature_extractor(audio_tensor) - ``` - """ - - if not cls.presets: - raise NotImplementedError( - "No presets have been created for this class" - ) - - if preset not in cls.presets: - raise ValueError( - "`preset` must be one of " - f"""{", ".join(cls.presets)}. Received: {preset}.""" - ) - - config = cls.presets[preset]["audio_feature_extractor_config"] - - return cls.from_config({**config, **kwargs}) - - -format_docstring( - example_preset_name=next(iter(backbone_presets), ""), - preset_names='", "'.join(backbone_presets), -)(WhisperAudioFeatureExtractor.from_preset.__func__) diff --git a/keras_nlp/models/whisper/whisper_backbone.py b/keras_nlp/models/whisper/whisper_backbone.py index c66a61d4e5..3f5a017f14 100644 --- a/keras_nlp/models/whisper/whisper_backbone.py +++ b/keras_nlp/models/whisper/whisper_backbone.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras @@ -24,8 +23,6 @@ from keras_nlp.models.backbone import Backbone from keras_nlp.models.whisper.whisper_decoder import WhisperDecoder from keras_nlp.models.whisper.whisper_encoder import WhisperEncoder -from keras_nlp.models.whisper.whisper_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import assert_tf_backend @@ -274,6 +271,7 @@ def __init__( "encoder_sequence_output": encoder_output, "decoder_sequence_output": decoder_output, }, + dtype=dtype, **kwargs, ) @@ -304,7 +302,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/whisper/whisper_preprocessor.py b/keras_nlp/models/whisper/whisper_preprocessor.py index c21705a481..a78ea96639 100644 --- a/keras_nlp/models/whisper/whisper_preprocessor.py +++ b/keras_nlp/models/whisper/whisper_preprocessor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from absl import logging @@ -23,13 +22,11 @@ from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) -from keras_nlp.models.whisper.whisper_presets import backbone_presets from keras_nlp.models.whisper.whisper_tokenizer import WhisperTokenizer from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.WhisperPreprocessor") @@ -154,6 +151,8 @@ class WhisperPreprocessor(Preprocessor): ``` """ + tokenizer_cls = WhisperTokenizer + def __init__( self, tokenizer, @@ -326,15 +325,3 @@ def sequence_length(self): @sequence_length.setter def sequence_length(self, value): self.decoder_sequence_length = value - - @classproperty - def audio_feature_extractor_cls(cls): - return WhisperAudioFeatureExtractor - - @classproperty - def tokenizer_cls(cls): - return WhisperTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/whisper/whisper_tokenizer.py b/keras_nlp/models/whisper/whisper_tokenizer.py index 7b68dfd790..f14fd1ee98 100644 --- a/keras_nlp/models/whisper/whisper_tokenizer.py +++ b/keras_nlp/models/whisper/whisper_tokenizer.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import json from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.whisper.whisper_presets import backbone_presets from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer -from keras_nlp.utils.python_utils import classproperty def _load_dict(dict_or_path): @@ -164,7 +161,3 @@ def get_config(self): } ) return config - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/__init__.py b/keras_nlp/models/xlm_roberta/__init__.py index ba0c2545e4..f8be3006ad 100644 --- a/keras_nlp/models/xlm_roberta/__init__.py +++ b/keras_nlp/models/xlm_roberta/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone +from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets +from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( + XLMRobertaTokenizer, +) +from keras_nlp.utils.preset_utils import register_presets + +register_presets(backbone_presets, (XLMRobertaBackbone, XLMRobertaTokenizer)) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py index c74a0fd6fc..28d7f18728 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_backbone.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.roberta import roberta_backbone -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaBackbone") @@ -82,7 +79,3 @@ class XLMRobertaBackbone(roberta_backbone.RobertaBackbone): model(input_data) ``` """ - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py index fcd8bfe9b8..14d41d233c 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_classifier.py @@ -12,22 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras +from keras_nlp.models.classifier import Classifier from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer -from keras_nlp.models.task import Task from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_preprocessor import ( XLMRobertaPreprocessor, ) -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaClassifier") -class XLMRobertaClassifier(Task): +class XLMRobertaClassifier(Classifier): """An end-to-end XLM-RoBERTa model for classification tasks. This model attaches a classification head to a @@ -147,6 +144,9 @@ def train_sentencepiece(ds, vocab_size): ``` """ + backbone_cls = XLMRobertaBackbone + preprocessor_cls = XLMRobertaPreprocessor + def __init__( self, backbone, @@ -229,15 +229,3 @@ def get_config(self): } ) return config - - @classproperty - def backbone_cls(cls): - return XLMRobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return XLMRobertaPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py index e231f3dc7a..e687bac525 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_masked_lm.py @@ -12,23 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.layers.modeling.masked_lm_head import MaskedLMHead +from keras_nlp.models.masked_lm import MaskedLM from keras_nlp.models.roberta.roberta_backbone import roberta_kernel_initializer -from keras_nlp.models.task import Task from keras_nlp.models.xlm_roberta.xlm_roberta_backbone import XLMRobertaBackbone from keras_nlp.models.xlm_roberta.xlm_roberta_masked_lm_preprocessor import ( XLMRobertaMaskedLMPreprocessor, ) -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaMaskedLM") -class XLMRobertaMaskedLM(Task): +class XLMRobertaMaskedLM(MaskedLM): """An end-to-end XLM-RoBERTa model for the masked language modeling task. This model will train XLM-RoBERTa on a masked language modeling task. @@ -53,7 +50,7 @@ class XLMRobertaMaskedLM(Task): `None`. If `None`, this model will not apply preprocessing, and inputs should be preprocessed before calling the model. - Example usage: + Examples: Raw string inputs and pretrained backbone. ```python @@ -100,6 +97,9 @@ class XLMRobertaMaskedLM(Task): ``` """ + backbone_cls = XLMRobertaBackbone + preprocessor_cls = XLMRobertaMaskedLMPreprocessor + def __init__( self, backbone, @@ -142,15 +142,3 @@ def __init__( weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], jit_compile=True, ) - - @classproperty - def backbone_cls(cls): - return XLMRobertaBackbone - - @classproperty - def preprocessor_cls(cls): - return XLMRobertaMaskedLMPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py index c94f5f2421..7901e35720 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_preprocessor.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.multi_segment_packer import ( MultiSegmentPacker, ) from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.models.xlm_roberta.xlm_roberta_tokenizer import ( XLMRobertaTokenizer, ) @@ -27,7 +25,6 @@ convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight -from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.XLMRobertaPreprocessor") @@ -146,6 +143,8 @@ def train_sentencepiece(ds, vocab_size): ``` """ + tokenizer_cls = XLMRobertaTokenizer + def __init__( self, tokenizer, @@ -203,11 +202,3 @@ def sequence_length(self, value): self._sequence_length = value if self.packer is not None: self.packer.sequence_length = value - - @classproperty - def tokenizer_cls(cls): - return XLMRobertaTokenizer - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index 576f30bca1..8a6a592937 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tensorflow as tf from keras_nlp.api_export import keras_nlp_export -from keras_nlp.models.xlm_roberta.xlm_roberta_presets import backbone_presets from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer -from keras_nlp.utils.python_utils import classproperty from keras_nlp.utils.tensor_utils import tensor_to_list @@ -187,7 +184,3 @@ def detokenize(self, inputs): # the `detokenize` method will return empty strings for these tokens. # This is a vagary of the `sentencepiece` library. return super().detokenize(tokens) - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/xlnet/xlnet_backbone.py b/keras_nlp/models/xlnet/xlnet_backbone.py index 0d660bead9..03ea607d9e 100644 --- a/keras_nlp/models/xlnet/xlnet_backbone.py +++ b/keras_nlp/models/xlnet/xlnet_backbone.py @@ -65,7 +65,7 @@ class XLNetBackbone(Backbone): padding_mask: Mask to avoid performing attention on padding token indices of shape `[batch_size, sequence_length]`. - Examples: + Example: ```python import numpy as np from keras_nlp.models import XLNetBackbone @@ -184,6 +184,7 @@ def __init__( "segment_ids": segment_id_input, }, outputs=output, + dtype=dtype, **kwargs, ) diff --git a/keras_nlp/samplers/beam_sampler.py b/keras_nlp/samplers/beam_sampler.py index 9562f95d14..3a34217952 100644 --- a/keras_nlp/samplers/beam_sampler.py +++ b/keras_nlp/samplers/beam_sampler.py @@ -18,11 +18,9 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring +from keras_nlp.utils.tensor_utils import any_equal -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.BeamSampler") class BeamSampler(Sampler): """Beam Sampler class. @@ -42,55 +40,17 @@ class BeamSampler(Sampler): {{call_args}} Examples: - Return only the beam with the highest accumulated probability. ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.BeamSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzeeeeeee'] - ``` + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - Return all beams and their probabilities. - ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 8, len(int_lookup) - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - beams, probs = keras_nlp.samplers.BeamSampler(return_all_beams=True)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - - print(beams.shape) - # >>> (1, 5, 8) - print(probs.shape) - # >>> (1, 5) - print(["".join([int_lookup[i] for i in s]) for s in beams[0].numpy()]) - # >>> ['zzzzzeee', 'zzzzzeed', 'zzzzzeec', 'zzzzzeea', 'zzzzzeeb'] + # Pass by name to compile. + causal_lm.compile(sampler="beam") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.BeamSampler(num_beams=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ @@ -111,8 +71,9 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, + model=None, ): batch_size, max_length = ops.shape(prompt)[0], ops.shape(prompt)[1] index = ops.cast(index, "int32") @@ -149,10 +110,10 @@ def unflatten_beams(x): log_probs = flatten_beams(ops.repeat(log_probs, batch_size, axis=0)) def cond(prompt, cache, index, log_probs): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* stop token. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) @@ -208,6 +169,7 @@ def gather_beams(x): body=body, loop_vars=(prompt, cache, index, log_probs), maximum_iterations=(max_length - index), + model=model, ) all_prompts = unflatten_beams(prompt) diff --git a/keras_nlp/samplers/beam_sampler_test.py b/keras_nlp/samplers/beam_sampler_test.py index fca37cd85b..496e892e62 100644 --- a/keras_nlp/samplers/beam_sampler_test.py +++ b/keras_nlp/samplers/beam_sampler_test.py @@ -110,7 +110,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index bac65bcfbe..24f983fd0b 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -17,11 +17,9 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring +from keras_nlp.utils.tensor_utils import any_equal -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.ContrastiveSampler") class ContrastiveSampler(Sampler): """Contrastive Sampler class. @@ -37,35 +35,22 @@ class ContrastiveSampler(Sampler): alpha: float, the weight of minus max similarity in joint score computation. The larger the value of `alpha`, the score relies more on the similarity than the token probability. - seed: int. The random seed. Defaults to `None`. Call arguments: {{call_args}} Examples: ```python - # Use a simple alphabet of lowercase characters to [0, 26). - int_lookup = {i: chr(i + ord("a")) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - hidden_size = 5 - index = 5 - - def next(prompt, cache, index): - prompt_batch_size = tf.shape(prompt)[0] - hidden_states = np.ones((prompt_batch_size, hidden_size)) - # A uniform distribution over our alphabet. - logits = np.ones((prompt_batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.ContrastiveSampler()( - next=next, - prompt=np.full((batch_size, length), char_lookup["z"], dtype="int32"), - index=index, - hidden_states=np.ones([batch_size, index, hidden_size]), - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> "zzzzzeeeeeee" + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="contrastive") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.ContrastiveSampler(k=5) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ @@ -86,8 +71,9 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, + model=None, ): if hidden_states is None: raise ValueError( @@ -121,10 +107,10 @@ def unflatten_beams(x): cache = cache if has_cache else () def cond(prompt, cache, index, logits, hidden_states): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* stop token. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) @@ -224,6 +210,7 @@ def gather_best_token(beams): body=body, loop_vars=(prompt, cache, index, logits, hidden_states), maximum_iterations=(max_length - index), + model=model, ) return prompt diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 824d47a705..9cdd9de546 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -98,7 +98,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], index=0, hidden_states=self.hidden_states, ) diff --git a/keras_nlp/samplers/greedy_sampler.py b/keras_nlp/samplers/greedy_sampler.py index 8e178b7468..ee8a6ecc2d 100644 --- a/keras_nlp/samplers/greedy_sampler.py +++ b/keras_nlp/samplers/greedy_sampler.py @@ -15,11 +15,8 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import ops from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.GreedySampler") class GreedySampler(Sampler): """Greedy sampler class. @@ -27,29 +24,18 @@ class GreedySampler(Sampler): This sampler is implemented on greedy search, i.e., always picking up the token of the largest probability as the next token. - Call arguments: - {{call_args}} - Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Pass by name to compile. + causal_lm.compile(sampler="greedy") + causal_lm.generate(["Keras is a"]) + + # Pass by object to compile. + sampler = keras_nlp.samplers.GreedySampler() + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/greedy_sampler_test.py b/keras_nlp/samplers/greedy_sampler_test.py index 618c94d118..38c85c950b 100644 --- a/keras_nlp/samplers/greedy_sampler_test.py +++ b/keras_nlp/samplers/greedy_sampler_test.py @@ -86,10 +86,22 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) + def test_multitoken_early_stopping(self): + cache_chars = list("sequentially") + cache = ops.array([[self.char_lookup[c] for c in cache_chars]]) + prompt = ops.full((self.batch_size, self.length), self.char_lookup["z"]) + output = self.sampler( + next=self.next, + prompt=prompt, + cache=cache, + stop_token_ids=[self.char_lookup["t"], self.char_lookup["n"]], + ) + self.assertEqual(self.join_as_string(output), ["sequenzzzzzz"]) + def test_is_greedy(self): def next(prompt, cache, index): # Dummy hidden states. diff --git a/keras_nlp/samplers/random_sampler.py b/keras_nlp/samplers/random_sampler.py index b922d29b2a..1ff39c9f9b 100644 --- a/keras_nlp/samplers/random_sampler.py +++ b/keras_nlp/samplers/random_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.RandomSampler") class RandomSampler(Sampler): """Random Sampler class. @@ -37,24 +34,16 @@ class RandomSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, state, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, state + # Pass by name to compile. + causal_lm.compile(sampler="random") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.RandomSampler()( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzcpnjqij'] + # Pass by object to compile. + sampler = keras_nlp.samplers.RandomSampler(temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/random_sampler_test.py b/keras_nlp/samplers/random_sampler_test.py index 3149d11dba..6f6fd53f51 100644 --- a/keras_nlp/samplers/random_sampler_test.py +++ b/keras_nlp/samplers/random_sampler_test.py @@ -98,7 +98,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index e28fbe9d6e..43950dea2f 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -17,33 +17,9 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.backend import random -from keras_nlp.utils.python_utils import format_docstring - -call_args_docstring = """next: A function which takes in the - `prompt, cache, index` of the current generation loop, and outputs - a tuple `(logits, hidden_states, cache)` with `logits` being the - logits of next token, `hidden_states` being the representation of - the next token, and `cache` for next iteration. - prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This - tensor will be iteratively updated column by column with new sampled - values, starting at `index`. - cache: Optional. A tensor or nested structure of tensors that will be - updated by each call to `next`. This can be used to cache - computations from early iterations of the generative loop. - index: Optional. The first index of `prompt` to start sampling at. - Usually this is set as the length of the shortest non-padded - sequence in `prompt`. - mask: Optional. A 2D integer tensor with the same shape as `prompt`. - Locations which are `True` in the mask are never updated during - sampling. Usually used to mark all locations in the dense prompt - tensor which were present in a user input. - end_token_id: Optional. The token marking the end of the sequence. If - specified, sampling will stop as soon as all sequences in the prompt - produce a `end_token_id` in a location where `mask` is `False`. -""" - - -@format_docstring(call_args=call_args_docstring) +from keras_nlp.utils.tensor_utils import any_equal + + @keras_nlp_export("keras_nlp.samplers.Sampler") class Sampler: """Base sampler class. @@ -57,35 +33,32 @@ class Sampler: {{call_args}} This base class can be extended to implement different auto-regressive - sampling methods. Subclasses can either: - - - Override the `get_next_token()` method, which computes the next token - based on a probability distribution over all possible vocab entries. - - Override `__call__`, if the sampling method needs additional information - beyond the next tokens probability distribution to sample a sequence. - - Please check available subclass samplers for examples. + sampling methods. To do so, override the `get_next_token()` method, which + computes the next token based on a probability distribution over all + possible vocab entries. - Examples: + Example: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) - - def next(prompt, cache, index): - # return a uniform distribution over our alphabet. - logits = ops.ones((batch_size, vocab_size)) - return logits, None, cache - - output = keras_nlp.samplers.GreedySampler()( - next=next, - prompt=ops.fill((batch_size, length,), char_lookup['z']), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzaaaaaaa'] + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") + + # Greedy search with some tokens forbidden. + class CustomSampler(keras_nlp.samplers.Sampler): + def __init__(self, forbidden_tokens, **kwargs): + super().__init__(**kwargs) + self.forbidden_tokens = forbidden_tokens + + def get_next_token(self, probs): + batch_size, vocab_size = keras.ops.shape(probs) + for id in self.forbidden_tokens: + update = keras.ops.zeros((batch_size, 1)) + probs = keras.ops.slice_update(probs, (0, id), update) + return keras.ops.argmax(probs, axis=-1) + + # 257 = "a" with a leading space, 262 = "the" with a leading space. + causal_lm.compile(sampler=CustomSampler(forbidden_tokens=[257, 262])) + causal_lm.summary() + causal_lm.generate(["That's strange"]) ``` """ @@ -118,8 +91,9 @@ def __call__( cache=None, index=0, mask=None, - end_token_id=None, + stop_token_ids=None, hidden_states=None, + model=None, ): max_length = ops.shape(prompt)[-1] # Make sure `max_length` and `index` are the same dtype. @@ -133,10 +107,10 @@ def __call__( cache = () if cache is None else cache def cond(prompt, cache, index): - if end_token_id is None: + if stop_token_ids is None: return True - # Stop if all sequences have produced a *new* end_token_id. - end_tokens = (prompt == end_token_id) & (~mask) + # Stop if all sequences have produced a *new* id from stop_token_ids. + end_tokens = any_equal(prompt, stop_token_ids, ~mask) prompt_done = ops.any(end_tokens, axis=-1) return ops.logical_not(ops.all(prompt_done)) @@ -161,6 +135,7 @@ def body(prompt, cache, index): body, loop_vars=(prompt, cache, index), maximum_iterations=(max_length - index), + model=model, ) return prompt @@ -175,32 +150,68 @@ def compute_probabilities(self, logits): probs = keras.activations.softmax(logits / self.temperature) return ops.cast(probs, logits_dtype) - def run_loop(self, cond, body, loop_vars=None, maximum_iterations=None): + def run_loop( + self, cond, body, model=None, loop_vars=None, maximum_iterations=None + ): """Run ops.while_loops with a `StatelessScope` if necessary.""" if config.backend() == "jax": + import itertools + + if model: + model_trainable_variables = model.trainable_variables + model_non_trainable_variables = model.non_trainable_variables + else: + model_trainable_variables = [] + model_non_trainable_variables = [] - def stateless_cond(variables, *loop_vars): + def stateless_cond(state, *loop_vars): return cond(*loop_vars) - def stateless_body(variables, *loop_vars): - mapping = zip(self.variables, variables) + def stateless_body(state, *loop_vars): + ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) = state + mapping = itertools.chain( + zip(self.variables, sampler_variables), + zip(model_trainable_variables, trainable_variables), + zip(model_non_trainable_variables, non_trainable_variables), + ) with keras.StatelessScope(state_mapping=mapping) as scope: loop_vars = body(*loop_vars) - variables = [] + sampler_variables = [] for v in self.variables: new_v = scope.get_current_value(v) - variables.append(new_v if new_v is not None else v) - return variables, *loop_vars + sampler_variables.append(new_v if new_v is not None else v) + state = ( + sampler_variables, + trainable_variables, + non_trainable_variables, + ) + return state, *loop_vars variables = [ops.convert_to_tensor(v) for v in self.variables] - variables, *loop_vars = ops.while_loop( + trainable_variables = [ + ops.convert_to_tensor(v) for v in model_trainable_variables + ] + non_trainable_variables = [ + ops.convert_to_tensor(v) for v in model_non_trainable_variables + ] + state = ( + variables, + trainable_variables, + non_trainable_variables, + ) + state, *loop_vars = ops.while_loop( cond=stateless_cond, body=stateless_body, - loop_vars=(variables, *loop_vars), + loop_vars=(state, *loop_vars), maximum_iterations=maximum_iterations, ) - [ref_v.assign(v) for ref_v, v in zip(self.variables, variables)] + for ref_v, v in zip(self.variables, state[0]): + ref_v.assign(v) else: loop_vars = ops.while_loop( cond=cond, diff --git a/keras_nlp/samplers/top_k_sampler.py b/keras_nlp/samplers/top_k_sampler.py index 3456694848..513dd738c7 100644 --- a/keras_nlp/samplers/top_k_sampler.py +++ b/keras_nlp/samplers/top_k_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopKSampler") class TopKSampler(Sampler): """Top-K Sampler class. @@ -38,24 +35,16 @@ class TopKSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_k") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopKSampler(k=3)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtypes="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzacbbcaa'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopKSampler(k=5, temperature=0.7) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_k_sampler_test.py b/keras_nlp/samplers/top_k_sampler_test.py index c0f516a214..c0ebbc85d0 100644 --- a/keras_nlp/samplers/top_k_sampler_test.py +++ b/keras_nlp/samplers/top_k_sampler_test.py @@ -86,7 +86,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/samplers/top_p_sampler.py b/keras_nlp/samplers/top_p_sampler.py index a04b39aa2b..326f5797a6 100644 --- a/keras_nlp/samplers/top_p_sampler.py +++ b/keras_nlp/samplers/top_p_sampler.py @@ -16,11 +16,8 @@ from keras_nlp.backend import ops from keras_nlp.backend import random from keras_nlp.samplers.sampler import Sampler -from keras_nlp.samplers.sampler import call_args_docstring -from keras_nlp.utils.python_utils import format_docstring -@format_docstring(call_args=call_args_docstring) @keras_nlp_export("keras_nlp.samplers.TopPSampler") class TopPSampler(Sampler): """Top-P Sampler class. @@ -46,24 +43,16 @@ class TopPSampler(Sampler): Examples: ```python - # Use a simple alphabet of lowercase characters with ids in range [0, 25]. - int_lookup = {i: chr(i + ord('a')) for i in range(26)} - char_lookup = {v: k for k, v in int_lookup.items()} - batch_size, length, vocab_size = 1, 12, len(int_lookup) + causal_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") - def next(prompt, cache, index): - hidden_states = np.ones((batch_size, 10)) - # A uniform distribution over our alphabet. - logits = np.ones((batch_size, vocab_size)) - return logits, hidden_states, cache + # Pass by name to compile. + causal_lm.compile(sampler="top_p") + causal_lm.generate(["Keras is a"]) - output = keras_nlp.samplers.TopPSampler(p=0.1)( - next=next, - prompt=np.full((batch_size, length,), char_lookup['z'], dtype="int32"), - index=5, - ) - print(["".join([int_lookup[i] for i in s]) for s in output.numpy()]) - # >>> ['zzzzzbabcccb'] + # Pass by object to compile. + sampler = keras_nlp.samplers.TopPSampler(p=0.1, k=1_000) + causal_lm.compile(sampler=sampler) + causal_lm.generate(["Keras is a"]) ``` """ diff --git a/keras_nlp/samplers/top_p_sampler_test.py b/keras_nlp/samplers/top_p_sampler_test.py index fea5ca4110..9bd91b4a1a 100644 --- a/keras_nlp/samplers/top_p_sampler_test.py +++ b/keras_nlp/samplers/top_p_sampler_test.py @@ -87,7 +87,7 @@ def test_early_stopping(self): next=self.next, prompt=prompt, cache=cache, - end_token_id=self.char_lookup["t"], + stop_token_ids=[self.char_lookup["t"]], ) self.assertEqual(self.join_as_string(output), ["sequentzzzzz"]) diff --git a/keras_nlp/tests/test_case.py b/keras_nlp/tests/test_case.py index 0541ae6451..91ea6aba76 100644 --- a/keras_nlp/tests/test_case.py +++ b/keras_nlp/tests/test_case.py @@ -329,10 +329,10 @@ def run_precision_test(self, cls, init_kwargs, input_data): for weight in layer.weights: if is_float_dtype(weight.dtype): self.assertDTypeEqual(weight, policy.variable_dtype) - for sublayer in layer._flatten_layers(include_self=False): - if isinstance( - sublayer, (keras.layers.Softmax, keras.layers.InputLayer) - ): + for sublayer in layer._flatten_layers(): + if isinstance(sublayer, keras.layers.Softmax): + continue + if isinstance(sublayer, keras.layers.InputLayer): continue self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) @@ -458,8 +458,6 @@ def run_preset_test( expected_partial_output=None, ): """Run instantiation and a forward pass for a preset.""" - self.assertRegex(cls.from_preset.__doc__, preset) - with self.assertRaises(Exception): cls.from_preset("clowntown", **init_kwargs) diff --git a/keras_nlp/tests/test_data/gemma_test_vocab.spm b/keras_nlp/tests/test_data/gemma_test_vocab.spm new file mode 100644 index 0000000000..a049c032c2 Binary files /dev/null and b/keras_nlp/tests/test_data/gemma_test_vocab.spm differ diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer.py b/keras_nlp/tokenizers/byte_pair_tokenizer.py index 902af812e9..89e39369cd 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer.py @@ -22,17 +22,13 @@ import json import os from typing import Iterable -from typing import List +import keras import regex as re import tensorflow as tf from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -63,17 +59,10 @@ SPLIT_PATTERN_2 = rf"""[\są„¬{SPECIAL_WHITESPACES}]$""" -def create_alts_for_unsplittable_tokens(unsplittable_tokens): - # Create alternates for all special tokens that will be not split during - # tokenization. - alts = [] - prefix = "Ä“" - # Trim out splitters. - replace_pattern = r"'|\s+|[^\p{L}\p{N}]+" - for token in unsplittable_tokens: - token = re.sub(replace_pattern, "", token) - alts.append(prefix + token) - return alts +def get_special_tokens_pattern(special_tokens): + if special_tokens is None or len(special_tokens) == 0: + return None + return r"|".join([re.escape(token) for token in special_tokens]) def bytes_to_unicode(): @@ -108,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -def split_strings_for_bpe(inputs, unsplittable_tokens=None): +def split_strings_for_bpe(inputs, special_tokens_pattern=None): # We need to recreate the exact behavior of token presplitting in the # original gpt2 tokenizer which uses a lookahead. As re2 does not # support lookahead match, we are using an alternative insert a special @@ -120,24 +109,35 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None): inputs = tf.strings.regex_replace( inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1ą„¬" ) - if unsplittable_tokens: - alts = create_alts_for_unsplittable_tokens(unsplittable_tokens) - for token, alt in zip(unsplittable_tokens, alts): - escaped_token = re.escape(token) - inputs = tf_text.regex_split(inputs, escaped_token, escaped_token) - inputs = tf.strings.regex_replace(inputs, escaped_token, alt) - raw_tokens = tf_text.regex_split(inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1) + + if special_tokens_pattern is not None: + # First split the special tokens from the input. + raw_tokens = tf_text.regex_split( + inputs, special_tokens_pattern, special_tokens_pattern + ) + # Then split using both `special_tokens_pattern` and + # `SPLIT_PATTERN_1` to split inputs like original gpt2, while not + # affecting the special tokens. + # We split special tokens first then apply this split instead of + # applying this split directly, because otherwise we will not split + # special tokens from inputs properly, because of this pattern + # ` ?[^\s\p{L}\p{N}{special_spaces}]+`. + # e.g., [" "] will be [" "] instead of [" ", ""] + raw_tokens = tf_text.regex_split( + raw_tokens, + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), + r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]), + ) + raw_tokens = raw_tokens.merge_dims(-2, -1) + else: + raw_tokens = tf_text.regex_split( + inputs, SPLIT_PATTERN_1, SPLIT_PATTERN_1 + ) + # Second pass splits out the last whilespace char or "ą„¬". raw_tokens = tf_text.regex_split( raw_tokens, SPLIT_PATTERN_2, SPLIT_PATTERN_2 ) - if unsplittable_tokens: - # Replace special tokens alternate with originals. - for token, alt in zip(unsplittable_tokens, alts): - escaped_alt = re.escape(alt) - raw_tokens = tf.strings.regex_replace( - raw_tokens, escaped_alt, token - ) while raw_tokens.shape.rank > 2: raw_tokens = raw_tokens.merge_dims(1, 2) return remove_strings_from_inputs(raw_tokens, "ą„¬") @@ -149,7 +149,7 @@ class BytePairTokenizerCache(tf.Module): The cache key is string tensor or python strings, and the value is split tokens joined by whitespace. For example, "dragonfly" => "dragon fly" - Examples: + Example: ``` cache = BytePairTokenizerCache() cache.insert(["butterfly", "dragonfly"], ["but ter fly", "dragon fly"]) @@ -238,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer): a prefix space to the first word will cause it to be tokenized equivalently to all subsequent words in the sequence. Defaults to `False`. - unsplittable_tokens: list. A list of strings that will - never be split during the word-level splitting applied before the - byte-pair encoding. This can be used to ensure special tokens map to - unique indices in the vocabulary, even if these special tokens - contain splittable characters such as punctuation. Special tokens - must still be included in `vocabulary`. Defaults to `None`. + special_tokens: list. A list of special tokens. when + `special_tokens_in_strings` is set to `True`, special + tokens will never be split during the word-level splitting applied + before the byte-pair encoding. This can be used to ensure special + tokens map to unique indices in the vocabulary, even if these + special tokens contain splittable characters such as + punctuation. special tokens must still be included in + `vocabulary`. Defaults to `None`. + special_tokens_in_strings: bool. To indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. Examples: @@ -282,7 +287,8 @@ def __init__( merges=None, sequence_length=None, add_prefix_space=False, - unsplittable_tokens=None, + special_tokens=None, + special_tokens_in_strings=False, dtype="int32", **kwargs, ) -> None: @@ -297,7 +303,12 @@ def __init__( super().__init__(dtype=dtype, **kwargs) self.sequence_length = sequence_length self.add_prefix_space = add_prefix_space - self.unsplittable_tokens = unsplittable_tokens + self.special_tokens = special_tokens + self._special_tokens_pattern = None + if special_tokens_in_strings: + self._special_tokens_pattern = get_special_tokens_pattern( + special_tokens + ) # Create byte <=> unicode mapping. This is useful for handling # whitespace tokens. @@ -349,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges): "token to int ids. Received: " f"`type(vocabulary)={type(vocabulary)}`." ) + + # Check for special tokens in vocabulary. + if self.special_tokens is not None: + for token in self.special_tokens: + if token not in self.get_vocabulary(): + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your" + "`vocabulary` or use a pretrained `vocabulary` name." + ) + if isinstance(merges, str): with open(merges, encoding="utf-8") as f: self.merges = [bp.rstrip() for bp in f] @@ -361,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges): ) self.cache = BytePairTokenizerCache() - if self.unsplittable_tokens: + if self.special_tokens and self._special_tokens_pattern is not None: # Put special tokens into cache, so it won't be further split and # merged. - self.cache.insert( - self.unsplittable_tokens, self.unsplittable_tokens - ) + self.cache.insert(self.special_tokens, self.special_tokens) # Create mapping between string tokens to int ids, and vice versa. byte_pairs = [x[0] for x in self.vocabulary.items()] @@ -391,17 +411,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges): default=self.merge_ranks_lookup_default, ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" self._check_vocabulary() return self.vocabulary.keys() - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return len(self.vocabulary) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" # This will be slow, but keep memory usage down compared to building a # dict. Assuming the main use case is looking up a few special tokens @@ -414,7 +434,7 @@ def id_to_token(self, id: int) -> str: return token raise ValueError(f"`id` is out of the vocabulary. Received: {id}") - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" self._check_vocabulary() return self.vocabulary[token] @@ -544,7 +564,7 @@ def tokenize(self, inputs): if scalar_input: inputs = tf.expand_dims(inputs, 0) - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern) token_row_splits = raw_tokens.row_splits flat_tokens = raw_tokens.flat_values @@ -609,6 +629,11 @@ def detokenize(self, inputs): outputs = tf.squeeze(outputs, 0) return outputs + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype + ) + def _transform_bytes(self, tokens): """Map token bytes to unicode using `byte2unicode`.""" split_bytes = tf.strings.bytes_split(tokens) @@ -633,71 +658,7 @@ def get_config(self): { "sequence_length": self.sequence_length, "add_prefix_space": self.add_prefix_space, - "unsplittable_tokens": self.unsplittable_tokens, + "special_tokens": self.special_tokens, } ) - return config - - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Examples: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, - ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - BytePairTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) + return config \ No newline at end of file diff --git a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py index 00f8f9b87f..542e872e1b 100644 --- a/keras_nlp/tokenizers/byte_pair_tokenizer_test.py +++ b/keras_nlp/tokenizers/byte_pair_tokenizer_test.py @@ -67,19 +67,40 @@ def test_tokenize_with_special_tokens(self): tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, - unsplittable_tokens=["s", "p"], + special_tokens=["s", "p"], + special_tokens_in_strings=True, ) output = tokenizer("sp") self.assertAllEqual(output, [1, 2]) - # If not setting special tokens, "sp" is one token. + # If not special_tokens_in_strings is `True`, "sp" is one token. tokenizer = BytePairTokenizer( vocabulary=vocab, merges=merges, + special_tokens=["s", "p"], ) output = tokenizer("sp") self.assertAllEqual(output, [0]) + # test real wolrd special tokens. e. g. and + vocab = {"": 0, "": 1, "a": 2, "Ä quick": 3, "Ä fox": 4} + merges = ["Ä  q", "u i", "c k", "ui ck", "Ä q uick"] + merges += ["Ä  f", "o x", "Ä f ox"] + tokenizer = BytePairTokenizer( + vocabulary=vocab, + merges=merges, + special_tokens=["", ""], + special_tokens_in_strings=True, + ) + output = tokenizer("a quick fox") + self.assertAllEqual(output, [0, 2, 3, 4, 1]) + + def test_errors_missing_special_tokens(self): + with self.assertRaises(ValueError): + BytePairTokenizer( + vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"] + ) + def test_tokenize_prefix_space(self): input_data = ["brown.", "black."] tokenizer = BytePairTokenizer( @@ -170,4 +191,4 @@ def test_config(self): self.assertAllEqual( self.tokenizer(input_data), cloned_tokenizer(input_data), - ) + ) \ No newline at end of file diff --git a/keras_nlp/tokenizers/byte_tokenizer.py b/keras_nlp/tokenizers/byte_tokenizer.py index 3aefc4a01d..4d5c4a87ed 100644 --- a/keras_nlp/tokenizers/byte_tokenizer.py +++ b/keras_nlp/tokenizers/byte_tokenizer.py @@ -155,11 +155,11 @@ class ByteTokenizer(tokenizer.Tokenizer): def __init__( self, - lowercase: bool = True, - sequence_length: int = None, - normalization_form: str = None, - errors: str = "replace", - replacement_char: int = 65533, + lowercase=True, + sequence_length=None, + normalization_form=None, + errors="replace", + replacement_char=65533, dtype="int32", **kwargs, ): @@ -198,8 +198,8 @@ def __init__( [i.tobytes() for i in np.arange(256, dtype=np.uint8)] ) - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" return 256 def tokenize(self, inputs): diff --git a/keras_nlp/tokenizers/sentence_piece_tokenizer.py b/keras_nlp/tokenizers/sentence_piece_tokenizer.py index ae655aceb6..fb01828c6a 100644 --- a/keras_nlp/tokenizers/sentence_piece_tokenizer.py +++ b/keras_nlp/tokenizers/sentence_piece_tokenizer.py @@ -15,16 +15,12 @@ import base64 import binascii import os -from typing import List +import keras import tensorflow as tf from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -111,7 +107,7 @@ def train_sentence_piece_file(ds, path, size): def __init__( self, proto=None, - sequence_length: int = None, + sequence_length=None, dtype="int32", **kwargs, ) -> None: @@ -175,12 +171,12 @@ def set_proto(self, proto): # byte array as a string for saving. self.proto = proto_bytes - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return int(self._sentence_piece.vocab_size().numpy()) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary.""" self._check_vocabulary() return tensor_to_list( @@ -189,7 +185,7 @@ def get_vocabulary(self) -> List[str]: ) ) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" self._check_vocabulary() if id >= self.vocabulary_size() or id < 0: @@ -199,7 +195,7 @@ def id_to_token(self, id: int) -> str: ) return tensor_to_list(self._sentence_piece.id_to_string(id)) - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" self._check_vocabulary() return int(self._sentence_piece.string_to_id(token).numpy()) @@ -253,71 +249,14 @@ def tokenize(self, inputs): def detokenize(self, inputs): self._check_vocabulary() inputs, unbatched, _ = convert_to_ragged_batch(inputs) + # tf-text sentencepiece does not handle int64. + inputs = tf.cast(inputs, "int32") outputs = self._sentence_piece.detokenize(inputs) if unbatched: outputs = tf.squeeze(outputs, 0) return outputs - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Examples: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - SentencePieceTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 7da1e9d7b1..9418741ea2 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -12,15 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - from keras_nlp.api_export import keras_nlp_export from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) - - -@keras_nlp_export("keras_nlp.tokenizers.Tokenizer") +from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE +from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import list_presets +from keras_nlp.utils.preset_utils import list_subclasses +from keras_nlp.utils.preset_utils import load_from_preset +from keras_nlp.utils.preset_utils import save_to_preset +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export( + [ + "keras_nlp.models.Tokenizer", + "keras_nlp.tokenizers.Tokenizer", + ] +) class Tokenizer(PreprocessingLayer): """A base class for tokenizer layers. @@ -40,7 +50,7 @@ class Tokenizer(PreprocessingLayer): "vocab free" tokenizers, such as a whitespace splitter show below, these methods do not apply and can be skipped. - Examples: + Example: ```python class WhitespaceSplitterTokenizer(keras_nlp.tokenizers.Tokenizer): @@ -93,33 +103,108 @@ def detokenize(self, inputs, *args, **kwargs): f"{self.__class__.__name__}." ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings terms.""" raise NotImplementedError( "No implementation of `get_vocabulary()` was found for " f"{self.__class__.__name__}." ) - def vocabulary_size(self) -> int: + def vocabulary_size(self): """Returns the total size of the token id space.""" raise NotImplementedError( "No implementation of `vocabulary_size()` was found for " f"{self.__class__.__name__}." ) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" raise NotImplementedError( "No implementation of `id_to_token()` was found for " f"{self.__class__.__name__}." ) - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" raise NotImplementedError( "No implementation of `token_to_id()` was found for " f"{self.__class__.__name__}." ) + def save_to_preset(self, preset): + """Save tokenizer to a preset directory. + + Args: + preset: The path to the local model preset directory. + """ + save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE) + def call(self, inputs, *args, training=None, **kwargs): return self.tokenize(inputs, *args, **kwargs) + + @classproperty + def presets(cls): + """List built-in presets for a `Task` subclass.""" + presets = list_presets(cls) + for subclass in list_subclasses(cls): + presets.update(subclass.presets) + return presets + + @classmethod + def from_preset( + cls, + preset, + **kwargs, + ): + """Instantiate a `keras_nlp.models.Tokenizer` from a model preset. + + A preset is a directory of configs, weights and other file assets used + to save and load a pre-trained model. The `preset` can be passed as a + one of: + + 1. a built in preset identifier like `'bert_base_en'` + 2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'` + 3. a Hugging Face handle like `'hf://user/bert_base_en'` + 4. a path to a local preset directory like `'./bert_base_en'` + + For any `Tokenizer` subclass, you can run `cls.presets.keys()` to list + all built-in presets available on the class. + + This constructor can be called in one of two ways. Either from the base + class like `keras_nlp.models.Tokenizer.from_preset()`, or from + a model class like `keras_nlp.models.GemmaTokenizer.from_preset()`. + If calling from the base class, the subclass of the returning object + will be inferred from the config in the preset directory. + + Args: + preset: string. A built in preset identifier, a Kaggle Models + handle, a Hugging Face handle, or a path to a local directory. + load_weights: bool. If `True`, the weights will be loaded into the + model architecture. If `False`, the weights will be randomly + initialized. + + Examples: + ```python + # Load a preset tokenizer. + tokenizer = keras_nlp.tokenizerTokenizer.from_preset("bert_base_en") + + # Tokenize some input. + tokenizer("The quick brown fox tripped.") + + # Detokenize some input. + tokenizer.detokenize([5, 6, 7, 8, 9]) + ``` + """ + config_file = "tokenizer.json" + preset_cls = check_config_class(preset, config_file=config_file) + if not issubclass(preset_cls, cls): + raise ValueError( + f"Preset has type `{preset_cls.__name__}` which is not a " + f"a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{preset_cls.__name__}` instead." + ) + return load_from_preset( + preset, + config_file=config_file, + config_overrides=kwargs, + ) diff --git a/keras_nlp/tokenizers/tokenizer_test.py b/keras_nlp/tokenizers/tokenizer_test.py index 509503d921..f9611e5eb8 100644 --- a/keras_nlp/tokenizers/tokenizer_test.py +++ b/keras_nlp/tokenizers/tokenizer_test.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import tensorflow as tf +from keras_nlp.models.bert.bert_tokenizer import BertTokenizer +from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer from keras_nlp.tests.test_case import TestCase from keras_nlp.tokenizers.tokenizer import Tokenizer @@ -29,6 +32,29 @@ def detokenize(self, inputs): class TokenizerTest(TestCase): + def test_preset_accessors(self): + bert_presets = set(BertTokenizer.presets.keys()) + gpt2_presets = set(GPT2Tokenizer.presets.keys()) + all_presets = set(Tokenizer.presets.keys()) + self.assertContainsSubset(bert_presets, all_presets) + self.assertContainsSubset(gpt2_presets, all_presets) + + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + Tokenizer.from_preset("bert_tiny_en_uncased"), + BertTokenizer, + ) + self.assertIsInstance( + Tokenizer.from_preset("gpt2_base_en"), + GPT2Tokenizer, + ) + + @pytest.mark.large + def test_from_preset_errors(self): + with self.assertRaises(ValueError): + GPT2Tokenizer.from_preset("bert_tiny_en_uncased") + def test_tokenize(self): input_data = ["the quick brown fox"] tokenizer = SimpleTokenizer() diff --git a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py index 5fe8f0144d..578e03bca7 100644 --- a/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py +++ b/keras_nlp/tokenizers/unicode_codepoint_tokenizer.py @@ -206,14 +206,14 @@ class UnicodeCodepointTokenizer(tokenizer.Tokenizer): def __init__( self, - sequence_length: int = None, - lowercase: bool = True, - normalization_form: str = None, - errors: str = "replace", - replacement_char: int = 65533, - input_encoding: str = "UTF-8", - output_encoding: str = "UTF-8", - vocabulary_size: int = None, + sequence_length=None, + lowercase=True, + normalization_form=None, + errors="replace", + replacement_char=65533, + input_encoding="UTF-8", + output_encoding="UTF-8", + vocabulary_size=None, dtype="int32", **kwargs, ) -> None: @@ -275,7 +275,7 @@ def get_config(self): ) return config - def vocabulary_size(self) -> int: + def vocabulary_size(self): """Get the size of the tokenizer vocabulary. None implies no vocabulary size was provided""" return self._vocabulary_size diff --git a/keras_nlp/tokenizers/word_piece_tokenizer.py b/keras_nlp/tokenizers/word_piece_tokenizer.py index 75f956899f..4b9b90a943 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer.py @@ -13,17 +13,14 @@ # limitations under the License. import os +import re from typing import Iterable -from typing import List +import keras import tensorflow as tf from keras_nlp.api_export import keras_nlp_export from keras_nlp.tokenizers import tokenizer -from keras_nlp.utils.preset_utils import check_preset_class -from keras_nlp.utils.preset_utils import load_from_preset -from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring from keras_nlp.utils.tensor_utils import assert_tf_text_installed from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import is_int_dtype @@ -101,12 +98,19 @@ ) +def get_special_tokens_pattern(special_tokens): + if special_tokens is None or len(special_tokens) == 0: + return None + return r"|".join([re.escape(token) for token in special_tokens]) + + def pretokenize( text, lowercase=False, strip_accents=True, split=True, split_on_cjk=True, + special_tokens_pattern=None, ): """Helper function that takes in a dataset element and pretokenizes it. @@ -124,7 +128,14 @@ def pretokenize( split_on_cjk: bool. If `True`, input will be split on CJK characters, i.e., Chinese, Japanese, Korean and Vietnamese characters (https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)). - Note that this is applicable only when `split` is `True`. Defaults to `True`. + Note that this is applicable only when `split` is `True`. Defaults + to `True`. + special_tokens_pattern: str. A regex pattern that contain the + special tokens that will never be split during the word-level + splitting applied before the word-peice encoding. This can be used + to ensure special tokens map to unique indices in the vocabulary, + even if these special tokens contain splittable characters such as + punctuation. Returns: A tensor containing the pre-processed and pre-tokenized `text`. @@ -154,6 +165,23 @@ def pretokenize( else: split_pattern = WHITESPACE_AND_PUNCTUATION_REGEX keep_split_pattern = PUNCTUATION_REGEX + if special_tokens_pattern is not None: + # the idea here is to pass the special tokens regex to the split + # function as delimiter regex pattern, so the input will be splitted + # by them, but also the function will treat each on of them as one + # entity that shouldn't be splitted even if they have other + # delimiter regex pattern inside them. then pass the special tokens + # regex also as keep delimiter regex pattern, so they will + # not be removed. + split_pattern = r"|".join( + [ + special_tokens_pattern, + split_pattern, + ] + ) + keep_split_pattern = r"|".join( + [special_tokens_pattern, keep_split_pattern] + ) text = tf_text.regex_split( text, delim_regex_pattern=split_pattern, @@ -225,6 +253,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer): oov_token: str. The string value to substitute for an unknown token. It must be included in the vocab. Defaults to `"[UNK]"`. + special_tokens: list. A list of special tokens. when + `special_tokens_in_strings` is set to `True`, the tokenizer will map + every special token in the input strings to its id, even if these + special tokens contain characters that should be splitted before + tokenization such as punctuation. `special_tokens` must be included + in `vocabulary`. + special_tokens_in_strings: bool. A bool to indicate if the tokenizer + should expect special tokens in input strings that should be + tokenized and mapped correctly to their ids. Defaults to False. References: - [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/) @@ -296,13 +333,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer): def __init__( self, vocabulary=None, - sequence_length: int = None, - lowercase: bool = False, - strip_accents: bool = False, - split: bool = True, - split_on_cjk: bool = True, - suffix_indicator: str = "##", - oov_token: str = "[UNK]", + sequence_length=None, + lowercase=False, + strip_accents=False, + split=True, + split_on_cjk=True, + suffix_indicator="##", + oov_token="[UNK]", + special_tokens=None, + special_tokens_in_strings=False, dtype="int32", **kwargs, ) -> None: @@ -325,6 +364,19 @@ def __init__( self.split_on_cjk = split_on_cjk self.suffix_indicator = suffix_indicator self.oov_token = oov_token + self.special_tokens = special_tokens + self._special_tokens_pattern = None + if self.split and special_tokens_in_strings: + # the idea here is to pass the special tokens regex to the + # split function as delimiter regex pattern, so the input will + # be splitted by them, but also the function will treat each on + # of them as one entity that shouldn't be splitted even if they + # have other delimiter regex pattern inside them. then pass the + # special tokens regex also as keep delimiter regex + # pattern, so they will not be removed. + self._special_tokens_pattern = get_special_tokens_pattern( + self.special_tokens + ) self.set_vocabulary(vocabulary) def save_assets(self, dir_path): @@ -365,6 +417,16 @@ def set_vocabulary(self, vocabulary): "the `oov_token` argument when creating the tokenizer." ) + # Check for special tokens in the vocabulary + if self.special_tokens is not None: + for token in self.special_tokens: + if token not in self.vocabulary: + raise ValueError( + f"Cannot find token `'{token}'` in the provided " + f"`vocabulary`. Please provide `'{token}'` in your " + "`vocabulary` or use a pretrained `vocabulary` name." + ) + self._fast_word_piece = tf_text.FastWordpieceTokenizer( vocab=self.vocabulary, token_out_type=self.compute_dtype, @@ -374,17 +436,17 @@ def set_vocabulary(self, vocabulary): support_detokenization=True, ) - def get_vocabulary(self) -> List[str]: + def get_vocabulary(self): """Get the tokenizer vocabulary as a list of strings tokens.""" self._check_vocabulary() return self.vocabulary - def vocabulary_size(self) -> int: - """Get the size of the tokenizer vocabulary.""" + def vocabulary_size(self): + """Get the integer size of the tokenizer vocabulary.""" self._check_vocabulary() return len(self.vocabulary) - def id_to_token(self, id: int) -> str: + def id_to_token(self, id): """Convert an integer id to a string token.""" self._check_vocabulary() if id >= self.vocabulary_size() or id < 0: @@ -394,7 +456,7 @@ def id_to_token(self, id: int) -> str: ) return self.vocabulary[id] - def token_to_id(self, token: str) -> int: + def token_to_id(self, token): """Convert a string token to an integer id.""" # This will be slow, but keep memory usage down compared to building a # . Assuming the main use case is looking up a few special tokens @@ -413,6 +475,7 @@ def get_config(self): "split": self.split, "suffix_indicator": self.suffix_indicator, "oov_token": self.oov_token, + "special_tokens": self.special_tokens, } ) return config @@ -436,6 +499,7 @@ def tokenize(self, inputs): self.strip_accents, self.split, self.split_on_cjk, + self._special_tokens_pattern, ) # Apply WordPiece and coerce shape for outputs. @@ -465,66 +529,7 @@ def detokenize(self, inputs): outputs = tf.squeeze(outputs, 0) return outputs - @classproperty - def presets(cls): - return {} - - @classmethod - def from_preset( - cls, - preset, - **kwargs, - ): - """Instantiate {{model_name}} tokenizer from preset vocabulary. - - Args: - preset: string. Must be one of "{{preset_names}}". - - Examples: - ```python - # Load a preset tokenizer. - tokenizer = {{model_name}}.from_preset("{{example_preset_name}}") - - # Tokenize some input. - tokenizer("The quick brown fox tripped.") - - # Detokenize some input. - tokenizer.detokenize([5, 6, 7, 8, 9]) - ``` - """ - # We support short IDs for official presets, e.g. `"bert_base_en"`. - # Map these to a Kaggle Models handle. - if preset in cls.presets: - preset = cls.presets[preset]["kaggle_handle"] - - config_file = "tokenizer.json" - check_preset_class(preset, cls, config_file=config_file) - return load_from_preset( - preset, - config_file=config_file, - config_overrides=kwargs, + def compute_output_spec(self, input_spec): + return keras.KerasTensor( + input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype ) - - def __init_subclass__(cls, **kwargs): - # Use __init_subclass__ to setup a correct docstring for from_preset. - super().__init_subclass__(**kwargs) - - # If the subclass does not define from_preset, assign a wrapper so that - # each class can have a distinct docstring. - if "from_preset" not in cls.__dict__: - - def from_preset(calling_cls, *args, **kwargs): - return super(cls, calling_cls).from_preset(*args, **kwargs) - - cls.from_preset = classmethod(from_preset) - - # Format and assign the docstring unless the subclass has overridden it. - if cls.from_preset.__doc__ is None: - cls.from_preset.__func__.__doc__ = ( - WordPieceTokenizer.from_preset.__doc__ - ) - format_docstring( - model_name=cls.__name__, - example_preset_name=next(iter(cls.presets), ""), - preset_names='", "'.join(cls.presets), - )(cls.from_preset.__func__) diff --git a/keras_nlp/tokenizers/word_piece_tokenizer_test.py b/keras_nlp/tokenizers/word_piece_tokenizer_test.py index ead098c36c..5d80a32eee 100644 --- a/keras_nlp/tokenizers/word_piece_tokenizer_test.py +++ b/keras_nlp/tokenizers/word_piece_tokenizer_test.py @@ -78,21 +78,38 @@ def test_error_id_out_of_vocabulary(self): with self.assertRaises(ValueError): tokenizer.id_to_token(-1) - def test_special_tokens(self): - input_data = ["quick brown whale"] - vocab_data = ["@UNK@", "qu", "@@ick", "br", "@@own", "fox"] + def test_special_tokens_string_dtype(self): + input_data = ["quick brown whale @MASK@"] + vocab_data = ["@UNK@", "qu", "@@ick", "br", "@@own", "fox", "@MASK@"] + special_tokens = ["@UNK@", "@MASK@"] tokenizer = WordPieceTokenizer( vocabulary=vocab_data, oov_token="@UNK@", suffix_indicator="@@", dtype="string", + special_tokens=special_tokens, + special_tokens_in_strings=True, ) call_output = tokenizer(input_data) self.assertAllEqual( call_output, - [["qu", "@@ick", "br", "@@own", "@UNK@"]], + [["qu", "@@ick", "br", "@@own", "@UNK@", "@MASK@"]], ) + def test_special_tokens_int_dtype(self): + input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."] + special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"] + vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."] + vocab_data = [*special_tokens, *vocab_data] + + tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, + special_tokens=special_tokens, + special_tokens_in_strings=True, + ) + output = tokenizer(input_data) + self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]) + def test_cjk_tokens(self): input_data = ["ah半ęŽØzz"] vocab_data = ["[UNK]", "ęŽØ", "ꕐ", "乐", "半", "偷", "匕", "ah", "zz"] @@ -217,3 +234,40 @@ def test_no_oov_token_in_vocabulary(self): with self.assertRaises(ValueError): WordPieceTokenizer(vocabulary=vocab_data, oov_token=None) + + def test_no_splitting_with_special_tokens(self): + # When `split` is `False`, no special tokens tokenization will be done. + input_data = [ + "[MASK] t o k e n", + "m i s s i n g", + "[MASK]", + "t o k e n", + ] + vocab_data = ["[UNK]", "[MASK]", "t o k e n"] + tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, split=False, special_tokens=["[MASK]"] + ) + output = tokenizer(input_data) + self.assertAllEqual(output, [0, 0, 1, 2]) + + def test_config_with_special_tokens(self): + input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."] + special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"] + vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."] + vocab_data = [*special_tokens, *vocab_data] + original_tokenizer = WordPieceTokenizer( + vocabulary=vocab_data, + lowercase=False, + oov_token="[UNK]", + suffix_indicator="##", + dtype="string", + special_tokens=special_tokens, + ) + cloned_tokenizer = WordPieceTokenizer.from_config( + original_tokenizer.get_config() + ) + cloned_tokenizer.set_vocabulary(original_tokenizer.get_vocabulary()) + self.assertAllEqual( + original_tokenizer(input_data), + cloned_tokenizer(input_data), + ) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index 6bb2748fd9..d64daeb2b2 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import datetime +import inspect import json import os +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import config as backend_config from keras_nlp.backend import keras try: @@ -23,9 +29,50 @@ except ImportError: kagglehub = None +try: + import huggingface_hub + from huggingface_hub.utils import HFValidationError +except ImportError: + huggingface_hub = None + KAGGLE_PREFIX = "kaggle://" GS_PREFIX = "gs://" +HF_PREFIX = "hf://" + TOKENIZER_ASSET_DIR = "assets/tokenizer" +CONFIG_FILE = "config.json" +TOKENIZER_CONFIG_FILE = "tokenizer.json" + +# Global state for preset registry. +BUILTIN_PRESETS = {} +BUILTIN_PRESETS_FOR_CLASS = collections.defaultdict(dict) + + +def register_presets(presets, classes): + """Register built-in presets for a set of classes. + + Note that this is intended only for models and presets shipped in the + library itself. + """ + for preset in presets: + BUILTIN_PRESETS[preset] = presets[preset] + for cls in classes: + BUILTIN_PRESETS_FOR_CLASS[cls][preset] = presets[preset] + + +def list_presets(cls): + """Find all registered built-in presets for a class.""" + return dict(BUILTIN_PRESETS_FOR_CLASS[cls]) + + +def list_subclasses(cls): + """Find all registered subclasses of a class.""" + custom_objects = keras.saving.get_custom_objects().values() + subclasses = [] + for x in custom_objects: + if inspect.isclass(x) and x != cls and issubclass(x, cls): + subclasses.append(x) + return subclasses def get_file(preset, path): @@ -34,6 +81,8 @@ def get_file(preset, path): raise ValueError( f"A preset identifier must be a string. Received: preset={preset}" ) + if preset in BUILTIN_PRESETS: + preset = BUILTIN_PRESETS[preset]["kaggle_handle"] if preset.startswith(KAGGLE_PREFIX): if kagglehub is None: raise ImportError( @@ -63,15 +112,33 @@ def get_file(preset, path): url, cache_subdir=os.path.join("models", subdir), ) + elif preset.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`from_preset()` requires the `huggingface_hub` package to load from '{preset}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = preset.removeprefix(HF_PREFIX) + try: + return huggingface_hub.hf_hub_download( + repo_id=hf_handle, filename=path + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face preset. Hugging Face model handles " + "should have the form 'hf://{org}/{model}'. For example, " + f"'hf://username/bert_base_en'. Received: preset={preset}." + ) from e elif os.path.exists(preset): # Assume a local filepath. return os.path.join(preset, path) else: raise ValueError( "Unknown preset identifier. A preset must be a one of:\n" - "1) a built in preset identifier like `'bert_base_en'`\n" + "1) a built-in preset identifier like `'bert_base_en'`\n" "2) a Kaggle Models handle like `'kaggle://keras/bert/keras/bert_base_en'`\n" - "3) a path to a local preset directory like `'./bert_base_en`\n" + "3) a Hugging Face handle like `'hf://username/bert_base_en'`\n" + "4) a path to a local preset directory like `'./bert_base_en`\n" "Use `print(cls.presets.keys())` to view all built-in presets for " "API symbol `cls`.\n" f"Received: preset='{preset}'" @@ -154,6 +221,140 @@ def save_to_preset( metadata_file.write(json.dumps(metadata, indent=4)) +def _validate_tokenizer(preset, allow_incomplete=False): + config_path = get_file(preset, TOKENIZER_CONFIG_FILE) + if not os.path.exists(config_path): + if allow_incomplete: + logging.warning( + f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`." + ) + return + else: + raise FileNotFoundError( + f"`{TOKENIZER_CONFIG_FILE}` is missing from the preset directory `{preset}`. " + "To upload the model without a tokenizer, " + "set `allow_incomplete=True`." + ) + try: + with open(config_path) as config_file: + config = json.load(config_file) + except Exception as e: + raise ValueError( + f"Tokenizer config file `{config_path}` is an invalid json file. " + f"Error message: {e}" + ) + layer = keras.saving.deserialize_keras_object(config) + + if not config["assets"]: + raise ValueError( + f"Tokenizer config file {config_path} is missing `asset`." + ) + + for asset in config["assets"]: + asset_path = os.path.join(preset, asset) + if not os.path.exists(asset_path): + raise FileNotFoundError( + f"Asset `{asset}` doesn't exist in the preset direcotry `{preset}`." + ) + config_dir = os.path.dirname(config_path) + asset_dir = os.path.join(config_dir, TOKENIZER_ASSET_DIR) + + tokenizer = get_tokenizer(layer) + if not tokenizer: + raise ValueError(f"Model or layer `{layer}` is missing tokenizer.") + tokenizer.load_assets(asset_dir) + + +def _validate_backbone(preset): + config_path = os.path.join(preset, CONFIG_FILE) + if not os.path.exists(config_path): + raise FileNotFoundError( + f"`{CONFIG_FILE}` is missing from the preset directory `{preset}`." + ) + try: + with open(config_path) as config_file: + config = json.load(config_file) + except Exception as e: + raise ValueError( + f"Config file `{config_path}` is an invalid json file. " + f"Error message: {e}" + ) + + if config["weights"]: + weights_path = os.path.join(preset, config["weights"]) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"The weights file is missing from the preset directory `{preset}`." + ) + else: + raise ValueError( + f"No weights listed in `{CONFIG_FILE}`. Make sure to use " + "`save_to_preset()` which adds additional data to a serialized " + "Keras object." + ) + + +@keras_nlp_export("keras_nlp.upload_preset") +def upload_preset( + uri, + preset, + allow_incomplete=False, +): + """Upload a preset directory to a model hub. + + Args: + uri: The URI identifying model to upload to. + URIs with format + `kaggle://///` + will be uploaded to Kaggle Hub while URIs with format + `hf://[/]` will be uploaded to the Hugging + Face Hub. + preset: The path to the local model preset directory. + allow_incomplete: If True, allows the upload of presets without + a tokenizer configuration. Otherwise, a tokenizer + is required. + """ + + # Check if preset directory exists. + if not os.path.exists(preset): + raise FileNotFoundError(f"The preset directory {preset} doesn't exist.") + + _validate_backbone(preset) + _validate_tokenizer(preset, allow_incomplete) + + if uri.startswith(KAGGLE_PREFIX): + kaggle_handle = uri.removeprefix(KAGGLE_PREFIX) + kagglehub.model_upload(kaggle_handle, preset) + elif uri.startswith(HF_PREFIX): + if huggingface_hub is None: + raise ImportError( + f"`upload_preset()` requires the `huggingface_hub` package to upload to '{uri}'. " + "Please install with `pip install huggingface_hub`." + ) + hf_handle = uri.removeprefix(HF_PREFIX) + try: + repo_url = huggingface_hub.create_repo( + repo_id=hf_handle, exist_ok=True + ) + except HFValidationError as e: + raise ValueError( + "Unexpected Hugging Face URI. Hugging Face model handles " + "should have the form 'hf://[{org}/]{model}'. For example, " + "'hf://username/bert_base_en' or 'hf://bert_case_en' to implicitly" + f"upload to your user account. Received: URI={uri}." + ) from e + huggingface_hub.upload_folder( + repo_id=repo_url.repo_id, folder_path=preset + ) + else: + raise ValueError( + "Unknown URI. An URI must be a one of:\n" + "1) a Kaggle Model handle like `'kaggle://///'`\n" + "2) a Hugging Face handle like `'hf://[/]'`\n" + f"Received: uri='{uri}'." + ) + + def load_from_preset( preset, load_weights=True, @@ -180,31 +381,25 @@ def load_from_preset( # Optionally load weights. load_weights = load_weights and config["weights"] if load_weights: + # For jax, delete all previous allocated memory to avoid temporarily + # duplicating variable allocations. torch and tensorflow have stateful + # variable types and do not need this fix. + if backend_config.backend() == "jax": + for weight in layer.weights: + if getattr(weight, "_value", None) is not None: + weight._value.delete() weights_path = get_file(preset, config["weights"]) layer.load_weights(weights_path) return layer -def check_preset_class( +def check_config_class( preset, - classes, config_file="config.json", ): """Validate a preset is being loaded on the correct class.""" config_path = get_file(preset, config_file) with open(config_path) as config_file: config = json.load(config_file) - cls = keras.saving.get_registered_object(config["registered_name"]) - if not isinstance(classes, (tuple, list)): - classes = (classes,) - # Allow subclasses for testing a base class, e.g. - # `check_preset_class(preset, Backbone)` - if not any(issubclass(cls, x) for x in classes): - raise ValueError( - f"Unexpected class in preset `'{preset}'`. " - "When calling `from_preset()` on a class object, the preset class " - f"much match allowed classes. Allowed classes are `{classes}`. " - f"Received: `{cls}`." - ) - return cls + return keras.saving.get_registered_object(config["registered_name"]) diff --git a/keras_nlp/utils/preset_utils_test.py b/keras_nlp/utils/preset_utils_test.py index 44dc39f477..32547b4ad0 100644 --- a/keras_nlp/utils/preset_utils_test.py +++ b/keras_nlp/utils/preset_utils_test.py @@ -18,13 +18,16 @@ import pytest from absl.testing import parameterized -from keras_nlp.models.albert.albert_classifier import AlbertClassifier -from keras_nlp.models.backbone import Backbone -from keras_nlp.models.bert.bert_classifier import BertClassifier -from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier -from keras_nlp.models.task import Task +from keras_nlp import upload_preset +from keras_nlp.models import AlbertClassifier +from keras_nlp.models import BertBackbone +from keras_nlp.models import BertClassifier +from keras_nlp.models import BertTokenizer +from keras_nlp.models import RobertaClassifier from keras_nlp.tests.test_case import TestCase -from keras_nlp.utils.preset_utils import check_preset_class +from keras_nlp.utils.preset_utils import CONFIG_FILE +from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE +from keras_nlp.utils.preset_utils import check_config_class from keras_nlp.utils.preset_utils import load_from_preset from keras_nlp.utils.preset_utils import save_to_preset @@ -76,11 +79,7 @@ def test_preset_saving(self, cls, preset_name, tokenizer_type): self.assertEqual(config["weights"], "model.weights.h5") # Try loading the model from preset directory - self.assertEqual(cls, check_preset_class(save_dir, cls)) - self.assertEqual(cls, check_preset_class(save_dir, Task)) - with self.assertRaises(ValueError): - # Preset is a subclass of Task, not Backbone. - check_preset_class(save_dir, Backbone) + self.assertEqual(cls, check_config_class(save_dir)) # Try loading the model from preset directory restored_model = load_from_preset(save_dir) @@ -105,3 +104,61 @@ def test_preset_errors(self): with self.assertRaisesRegex(ValueError, "Unknown preset identifier"): AlbertClassifier.from_preset("snaggle://bort/bort/bort") + + def test_upload_empty_preset(self): + temp_dir = self.get_temp_dir() + empty_preset = os.path.join(temp_dir, "empty") + os.mkdir(empty_preset) + uri = "kaggle://test/test/test" + + with self.assertRaises(FileNotFoundError): + upload_preset(uri, empty_preset) + + @parameterized.parameters( + (TOKENIZER_CONFIG_FILE), (CONFIG_FILE), ("model.weights.h5") + ) + @pytest.mark.keras_3_only + @pytest.mark.large + def test_upload_with_missing_file(self, missing_file): + # Load a model from Kaggle to use as a test model. + preset = "bert_tiny_en_uncased" + backbone = BertBackbone.from_preset(preset) + tokenizer = BertTokenizer.from_preset(preset) + + # Save the model on a local directory. + temp_dir = self.get_temp_dir() + local_preset_dir = os.path.join(temp_dir, "bert_preset") + backbone.save_to_preset(local_preset_dir) + tokenizer.save_to_preset(local_preset_dir) + + # Delete the file that is supposed to be missing. + missing_path = os.path.join(local_preset_dir, missing_file) + os.remove(missing_path) + + # Verify error handling. + with self.assertRaisesRegex(FileNotFoundError, "is missing"): + upload_preset("kaggle://test/test/test", local_preset_dir) + + @parameterized.parameters((TOKENIZER_CONFIG_FILE), (CONFIG_FILE)) + @pytest.mark.keras_3_only + @pytest.mark.large + def test_upload_with_invalid_json(self, json_file): + # Load a model from Kaggle to use as a test model. + preset = "bert_tiny_en_uncased" + backbone = BertBackbone.from_preset(preset) + tokenizer = BertTokenizer.from_preset(preset) + + # Save the model on a local directory. + temp_dir = self.get_temp_dir() + local_preset_dir = os.path.join(temp_dir, "bert_preset") + backbone.save_to_preset(local_preset_dir) + tokenizer.save_to_preset(local_preset_dir) + + # Re-write json file content to an invalid format. + json_path = os.path.join(local_preset_dir, json_file) + with open(json_path, "w") as file: + file.write("Invalid!") + + # Verify error handling. + with self.assertRaisesRegex(ValueError, "is an invalid json"): + upload_preset("kaggle://test/test/test", local_preset_dir) diff --git a/keras_nlp/utils/python_utils.py b/keras_nlp/utils/python_utils.py index 82c4bcf1a8..d29c71dbbe 100644 --- a/keras_nlp/utils/python_utils.py +++ b/keras_nlp/utils/python_utils.py @@ -19,28 +19,3 @@ class classproperty(property): def __get__(self, _, owner_cls): return self.fget(owner_cls) - - -def format_docstring(**replacements): - """Format a python docstring using a dictionary of replacements. - - This decorator can be placed on a function, class or method to format it's - docstring with python variables. - - The decorator will replace any double bracketed variable with a kwargs - value passed to the decorator itself. For example - `@format_docstring(name="foo")` will replace any occurance of `{{name}}` in - the docstring with the string literal `foo`. - """ - - def decorate(obj): - doc = obj.__doc__ - # We use `str.format()` to replace variables in the docstring, but use - # double brackets, e.g. {{var}}, to mark format strings. So we need to - # to swap all double and single brackets in the source docstring. - doc = "{".join(part.replace("{", "{{") for part in doc.split("{{")) - doc = "}".join(part.replace("}", "}}") for part in doc.split("}}")) - obj.__doc__ = doc.format(**replacements) - return obj - - return decorate diff --git a/keras_nlp/utils/python_utils_test.py b/keras_nlp/utils/python_utils_test.py index 60590dd47b..7d08ca091f 100644 --- a/keras_nlp/utils/python_utils_test.py +++ b/keras_nlp/utils/python_utils_test.py @@ -14,7 +14,6 @@ from keras_nlp.tests.test_case import TestCase from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.python_utils import format_docstring class ClassPropertyTest(TestCase): @@ -25,60 +24,3 @@ def bar(cls): return "class property" self.assertAllEqual(Foo.bar, "class property") - - -class FormatDocstringTest(TestCase): - def test_function(self): - @format_docstring(adjective="salubrious") - def foo(): - """It was a {{adjective}} November day.""" - return "function" - - self.assertAllEqual(foo(), "function") - self.assertAllEqual(foo.__doc__, "It was a salubrious November day.") - - def test_class(self): - @format_docstring(adjective="smelly", name="Mortimer") - class Foo: - """I saw my {{adjective}} friend {{name}}.""" - - def __init__(self): - self.bar = "property" - - self.assertAllEqual(Foo().bar, "property") - self.assertAllEqual(Foo.__doc__, "I saw my smelly friend Mortimer.") - - def test_class_method(self): - @format_docstring(adjective="smelly", name="Mortimer") - class Foo: - """I saw my {{adjective}} friend {{name}}.""" - - def __init__(self): - self.bar = "property" - - @classmethod - @format_docstring(noun="cactus", bodypart="nostril") - def baz(cls): - """He was holding a {{noun}} in his {{bodypart}}.""" - return "class method" - - self.assertAllEqual(Foo.baz(), "class method") - self.assertAllEqual( - Foo.baz.__doc__, - "He was holding a cactus in his nostril.", - ) - self.assertAllEqual( - Foo.baz.__func__.__doc__, - "He was holding a cactus in his nostril.", - ) - - def test_brackets(self): - @format_docstring(nickname="dumdum") - def bar(): - """Use `{}` to create a dictionary, {{nickname}}.""" - return "function" - - self.assertAllEqual(bar(), "function") - self.assertAllEqual( - bar.__doc__, "Use `{}` to create a dictionary, dumdum." - ) diff --git a/keras_nlp/utils/tensor_utils.py b/keras_nlp/utils/tensor_utils.py index a88d80a4da..8b5759b1cf 100644 --- a/keras_nlp/utils/tensor_utils.py +++ b/keras_nlp/utils/tensor_utils.py @@ -170,3 +170,27 @@ def is_int_dtype(dtype): def is_string_dtype(dtype): return "string" in standardize_dtype(dtype) + + +def any_equal(inputs, values, padding_mask): + """Return a mask that is True anywhere `inputs` has a value in `values`. + + Final mask has `padding_mask` applied. + + Args: + inputs: Input tensor. + values: List or iterable of tensors shaped like `inputs` or broadcastable + by bit operators. + padding_mask: Tensor with shape compatible with inputs that will condition + output. + + Returns: + A tensor with `inputs` shape where each position is True if it contains + a value from any `values`. Padding mask will be applied before + returning.""" + output = ops.equal(inputs, values[0]) + for value in values[1:]: + value_equality = ops.equal(inputs, value) + output = ops.logical_or(output, value_equality) + + return ops.logical_and(output, padding_mask) diff --git a/keras_nlp/utils/tensor_utils_test.py b/keras_nlp/utils/tensor_utils_test.py index 317e42ade8..ec27832ed9 100644 --- a/keras_nlp/utils/tensor_utils_test.py +++ b/keras_nlp/utils/tensor_utils_test.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import tensorflow as tf from keras_nlp.backend import ops from keras_nlp.tests.test_case import TestCase +from keras_nlp.utils.tensor_utils import any_equal from keras_nlp.utils.tensor_utils import convert_to_ragged_batch from keras_nlp.utils.tensor_utils import tensor_to_list @@ -97,3 +99,41 @@ def test_convert_ragged(self): self.assertAllEqual(outputs, [[1, 2], [1]]) self.assertFalse(unbatched) self.assertFalse(rectangular) + + +class MaskedAnyEqualTest(tf.test.TestCase): + def test_basic_equality(self): + inputs = ops.array([1, 2, 3, 5]) + values = [3, 5] + padding_mask = ops.array([True, True, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_multiple_values(self): + inputs = ops.array([2, 4, 7, 9]) + values = [5, 4, 9] + padding_mask = ops.array([True, True, True, True]) + expected_output = np.array([False, True, False, True]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_padding_mask(self): + inputs = ops.array([1, 5, 3, 2]) + values = [5, 3] + padding_mask = ops.array([True, False, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) + + def test_input_shaped_values(self): + inputs = ops.array([1, 5, 3, 2]) + values = [[5, 5, 5, 5], [3, 3, 3, 3]] + padding_mask = ops.array([True, False, True, False]) + expected_output = np.array([False, False, True, False]) + result = any_equal(inputs, values, padding_mask) + result = ops.convert_to_numpy(result) + self.assertAllEqual(result, expected_output) diff --git a/keras_nlp/version_utils.py b/keras_nlp/version_utils.py index 4d6a8186d4..5735e45aaa 100644 --- a/keras_nlp/version_utils.py +++ b/keras_nlp/version_utils.py @@ -15,7 +15,7 @@ from keras_nlp.api_export import keras_nlp_export # Unique source of truth for the version number. -__version__ = "0.8.0" +__version__ = "0.9.0" @keras_nlp_export("keras_nlp.version") diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 903a603352..2ded131217 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index be95915996..5426beb5a3 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. -tf-nightly[and-cuda]==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 7ea2981478..2ae593e057 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,11 +1,11 @@ # Tensorflow cpu-only version. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 -torch==2.1.2 -torchvision==0.16.2 +torch==2.2.2+cu121 +torchvision==0.17.2+cu121 # Jax cpu-only version. jax[cpu] diff --git a/requirements.txt b/requirements.txt index b226229d15..8578a4199b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tf-nightly-cpu==2.16.0.dev20240201 # Pin a working nightly until rc0. -tensorflow-text-nightly==2.16.0.dev20240201 # Pin a working nightly until rc0. +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.1 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/tools/checkpoint_conversion/bert_base_cased.ipynb b/tools/checkpoint_conversion/bert_base_cased.ipynb deleted file mode 100644 index a7fc9db3ef..0000000000 --- a/tools/checkpoint_conversion/bert_base_cased.ipynb +++ /dev/null @@ -1,1073 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "564b86a2-a7ed-4e22-f1fa-4246368d30a7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.8 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 49.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 45.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 49.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 59.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 57.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 48.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 67.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 11.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 51.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 73.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 60.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 71.3 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/jbischof/keras-nlp.git@bert_ckpt tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "source": [ - "TOKEN_TYPE = \"cased\"\n", - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_NAME = MODEL_TYPE + \"_\" + TOKEN_TYPE\n", - "VOCAB_SIZE = 28996" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "JdXFWsMVEf-x", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "4d87dbf9-dda5-4afe-fc07-533c84ad768a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\n", - "401886519/401886519 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{TOKEN_TYPE}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "LHYiSsvYtfEU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "8c63f095-4a49-4de0-b61a-8b2fe4705ff8" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = extract_dir + \"vocab.txt\"\n", - "checkpoint_path = extract_dir + \"bert_model.ckpt\"\n", - "config_path = extract_dir + \"bert_config.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "5f3b159b-381e-48e8-a77b-138b251a458c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [28996, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "220a2d92-27b0-41d7-a081-d39aea11d89a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 22268928 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 108,310,272\n", - "Trainable params: 108,310,272\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\"encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "3627782d-92d2-4299-8a5e-55a74402d01e" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Swp2Y0OvYIId", - "outputId": "f413f078-4d4a-42ee-a048-bcf77e655f4b" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 27 - } - ], - "source": [ - "keras_nlp_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "b46aab26-c8c2-467f-e8f8-6c3fa992d5b2" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 28 - } - ], - "source": [ - "mg_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "4447e234-bf17-46f2-f09f-86493a61e1cb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "42e126a0-92af-4779-f652-47b3396d38a9" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "f506140a-217b-4f45-b331-793e47fc5bae" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "210441" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "07debeae-6f54-4201-8d15-8cfa992fb121" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "f30ac6fac11115322e3d0c61e87e98b2 bert_base_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "b96360bc-47d6-4611-e727-f861d6038726" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_cased/model.h5\n", - "433474808/433474808 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "f1aa7c12-6681-4ae2-c689-f4b7ddfd1701" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 23 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, 0:10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OKfvONXmzX_L", - "outputId": "ece23e28-78b2-415a-faba-be5a3825812f" - }, - "execution_count": 26, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 26 - } - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "checkpoint convert model garden -> keras-nlp Bert cased", - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_multi_cased.ipynb b/tools/checkpoint_conversion/bert_base_multi_cased.ipynb deleted file mode 100644 index c87bbdbda0..0000000000 --- a/tools/checkpoint_conversion/bert_base_multi_cased.ipynb +++ /dev/null @@ -1,1084 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "52ac4b95-b8c9-4c2d-b4ca-eb57505d2cae" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.3 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 42.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 49.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 69.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 48.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 57.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 72.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 10.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 63.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 64.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 63.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 100.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 73.7 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-base-chinese tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_SUFFIX = \"multi_cased\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 119547" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JdXFWsMVEf-x", - "outputId": "043339cd-3d95-4083-c915-f9632effc319" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\n", - "660198400/660198400 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{MODEL_SUFFIX}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LHYiSsvYtfEU", - "outputId": "c185b640-d358-4d25-c249-7ac5ec4234df" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "d1f58eaa-a601-4bf1-bd1f-0b5ad1a28a28" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [119547, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "937d8db5-4567-4fdc-e4c8-3a14c1f50b36" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 91812096 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 177,853,440\n", - "Trainable params: 177,853,440\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"The ą¤ą¤Ÿą¤Ŗą¤Ÿ brown ą¤²ą„‹ą¤®ą¤”ą¤¼ą„€.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "9cdd88a8-efc8-43bb-c469-9496578f46d6" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"input_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9gF51d-CbwUh", - "outputId": "39dc40e7-7e9d-4a70-a55e-2ecc8d3e91ce" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "48a0565d-30bb-4db0-c658-a4a4af7dc599" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "12d8cb41-85f9-44a5-e990-9e497ac45fff" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "32cb37c4-344e-4b34-dab9-907f4f6b7216" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "3289713c-5ba2-4c9a-d3fa-fefa52a6223a" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "764415" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "d394313b-ccf0-4db5-b270-aee6f27d1f06" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "b0631cec0a1f2513c6cfd75ba29c33aa bert_base_multi_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "d66aca01-c1aa-404a-d2a1-090644f842bd" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_multi_cased/model.h5\n", - "711647480/711647480 [==============================] - 8s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_SUFFIX)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "330f724f-29a4-4e90-d1b8-669501575d80" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "57dd61f6-5415-47a6-af9b-01d082f6b745" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "UqSREee-sLmP" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_uncased.ipynb b/tools/checkpoint_conversion/bert_base_uncased.ipynb deleted file mode 100644 index 282998b76e..0000000000 --- a/tools/checkpoint_conversion/bert_base_uncased.ipynb +++ /dev/null @@ -1,1073 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "6fa43e61-f3ca-449b-a6ea-971bae8f5020" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.0 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 46.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 57.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 66.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 56.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 26.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 9.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 70.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 59.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 52.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 75.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 68.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 76.8 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/jbischof/keras-nlp.git@bert_ckpt tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "source": [ - "TOKEN_TYPE = \"uncased\"\n", - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_NAME = MODEL_TYPE + \"_\" + TOKEN_TYPE\n", - "VOCAB_SIZE = 30522" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "JdXFWsMVEf-x", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "fe21cd4d-7f08-4d49-ef07-1660e3e76fa4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\n", - "405351189/405351189 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{TOKEN_TYPE}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "LHYiSsvYtfEU", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "77580043-09a7-4950-99bb-473d71614365" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = extract_dir + \"vocab.txt\"\n", - "checkpoint_path = extract_dir + \"bert_model.ckpt\"\n", - "config_path = extract_dir + \"bert_config.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "67ca9d60-c88e-4ad1-c93a-a06eeaf2e631" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [30522, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "next_sentence..pooler_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "next_sentence..pooler_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "e130445e-1676-4a21-b332-b373814aef14" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 23440896 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 109,482,240\n", - "Trainable params: 109,482,240\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"next_sentence..pooler_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\"next_sentence..pooler_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "72dfdb08-dbbb-4f8a-8a0a-2d881659edeb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Swp2Y0OvYIId", - "outputId": "8a981dbb-eb5b-432c-ee6d-52419e962cfd" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "2cb14435-68bc-48e3-b462-80842686d0e9" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, 0:10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "c7fc2d2d-919f-443e-85dd-d7139d01b173" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "0b8e4fbc-c356-4f66-d01c-d43fc1a326fb" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "76d51be8-d503-4807-c78e-0b94a64ecb6c" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "9312bab4-4889-42b4-a7d6-b7e59c45dfcc" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "94d5ec8bea7816f0c29904c959284343 bert_base_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "c92995ea-0e01-4289-f45b-3a52900ac7e7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_uncased/model.h5\n", - "438162680/438162680 [==============================] - 5s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=MODEL_NAME)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "fff1c4e5-a41c-42d4-df2d-303d4bc8a0ec" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 23 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, 0:10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "6b0bdfbe-27df-483c-a761-2ef34d8e5417" - }, - "execution_count": 24, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 24 - } - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "checkpoint convert model garden -> keras-nlp Bert uncased", - "provenance": [], - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_base_zh.ipynb b/tools/checkpoint_conversion/bert_base_zh.ipynb deleted file mode 100644 index 7dbf1dac8a..0000000000 --- a/tools/checkpoint_conversion/bert_base_zh.ipynb +++ /dev/null @@ -1,1084 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "0e8b9004-fe3f-4fb7-b905-cebb361887ab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.2 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 47.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 43.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 51.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 72.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 54.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 71.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 60.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 68.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 67.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 75.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 71.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 12.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-base-chinese tensorflow tf-models-official --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "import tensorflow_models as tfm\n", - "from tensorflow import keras" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_base\"\n", - "MODEL_SUFFIX = \"chinese\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 21128" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "L3j5PFBt5JeR" - }, - "source": [ - "## Load the model garden checkpoints and weights" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JdXFWsMVEf-x", - "outputId": "e47ef71a-11c2-4c53-f398-ca3de4da77f9" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\n", - "379546341/379546341 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Model garden BERT paths.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/{MODEL_SUFFIX}_L-12_H-768_A-12.tar.gz\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"tar\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "LHYiSsvYtfEU", - "outputId": "47518515-60ce-49f3-e6ac-39273b1d4616" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "tmp/temp_dir/raw/\n", - "tmp/temp_dir/raw/vocab.txt\n", - "tmp/temp_dir/raw/bert_model.ckpt.index\n", - "tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001\n", - "tmp/temp_dir/raw/bert_config.json\n" - ] - } - ], - "source": [ - "!tar -xvf \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "Q9QRSp47tnVo" - }, - "outputs": [], - "source": [ - "# Model garden BERT paths.\n", - "extract_dir = \"/content/tmp/temp_dir/raw/\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vCUJk59B2rai", - "outputId": "59d95000-83a8-48f8-870a-51c2dbb5db6c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "_CHECKPOINTABLE_OBJECT_GRAPH []\n", - "encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE [21128, 768]\n", - "encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE [512, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-10/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-10/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-10/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-11/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-11/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-11/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-12/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-12/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-12/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-13/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-13/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-13/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-14/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-14/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-14/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-15/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-15/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-15/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-16/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 768]\n", - "encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE [2, 768]\n", - "encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-4/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-4/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-4/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-5/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-5/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-5/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-6/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-6/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-6/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-7/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-7/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-7/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-8/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-8/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-8/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [12, 64, 768]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 12, 64]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [3072]\n", - "encoder/layer_with_weights-9/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [768, 3072]\n", - "encoder/layer_with_weights-9/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE [3072, 768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE [768]\n", - "encoder/layer_with_weights-9/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE [768]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XrwVEkDP5RjE" - }, - "source": [ - "## Load BertBase model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Y5gJgD2j5BTG", - "outputId": "6b7cd86d-33af-4864-a3c7-63acf53fecc2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 768) 16226304 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 768) 393216 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 768) 1536 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 768) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 768) 1536 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 768) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 768) 7087872 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 768) 7087872 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 768) 7087872 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 768) 7087872 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 768) 7087872 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 768) 7087872 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 768) 7087872 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 768) 7087872 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 768) 7087872 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 768) 7087872 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 768) 7087872 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 768) 7087872 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 768) 0 ['transformer_layer_11[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 768) 590592 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 102,267,648\n", - "Trainable params: 102,267,648\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZitZxcFyvlb" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "wmuCofTwBgfo" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-1/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\n", - " \"encoder/layer_with_weights-2/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"encoder/layer_with_weights-3/gamma/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"encoder/layer_with_weights-3/beta/.ATTRIBUTES/VARIABLE_VALUE\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_key_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_query_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_value_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_attention_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_intermediate_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_dense/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/gamma/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{i + 4}/_output_layer_norm/beta/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/kernel/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(\n", - " weights[\n", - " f\"encoder/layer_with_weights-{model.num_layers + 4}/bias/.ATTRIBUTES/VARIABLE_VALUE\"\n", - " ]\n", - ")\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tNxzCIaF_-IG" - }, - "source": [ - "## Compare Output" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "plQq3yxI_ry_" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path,\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"ę•ę·ēš„ę£•č‰²ē‹ē‹ø\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "chg-o813CJNJ", - "outputId": "5c8144f8-4a07-4a6d-8c06-6bca7e8625cf" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 11 - } - ], - "source": [ - "encoder_config = tfm.nlp.encoders.EncoderConfig(\n", - " type=\"bert\",\n", - " bert=json.load(tf.io.gfile.GFile(config_path)),\n", - ")\n", - "mg_model = tfm.nlp.encoders.build_encoder(encoder_config)\n", - "checkpoint = tf.train.Checkpoint(encoder=mg_model)\n", - "checkpoint.read(checkpoint_path).assert_consumed()" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "oFG-tRIKEzer" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "\n", - "mg_output = mg_model(\n", - " {\n", - " \"input_word_ids\": token_ids,\n", - " \"input_type_ids\": segment_ids,\n", - " \"input_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9gF51d-CbwUh", - "outputId": "dc0fdef7-1f6c-4e7d-ac84-71066108846d" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zmhZYi-1YMGH", - "outputId": "893835bc-6438-45e5-c326-0ba631e1a4d6" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "mg_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "TvNj8DFBYNPT", - "outputId": "e8094ce2-2cf5-4c6a-8dc3-474f20b9b58e" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 15 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "tf.reduce_mean(keras_nlp_output - mg_output)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "aXu7Y_zue70C" - }, - "outputs": [], - "source": [ - "# Save BertBase checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "bwlLYTFeg1og" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertBase(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yTqPl39qhPMV", - "outputId": "3282584c-38a9-486c-a3bf-1589dc17083f" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output2)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cXiV4NvilLg6", - "outputId": "8da86cef-d78f-4979-db6b-5fe8c9428c19" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "75770" - ] - }, - "metadata": {}, - "execution_count": 19 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lVutnU-hB6IQ", - "outputId": "438c793e-bbc8-4c46-808c-a6bb69ef68e6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "79afa421e386076e62ab42dad555ab0c bert_base_chinese.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "5db37007-d7eb-4431-b182-8304425d8f71" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_base_zh/model.h5\n", - "409304312/409304312 [==============================] - 3s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertBase(weights=\"zh\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "ba0b05d4-f1f2-41d0-ab93-3fa5fe7ea2e5" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "0f3528b5-262f-4736-e377-bd450ee586de" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 22 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "ae9jwK4xqZQf" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_large_cased_en.ipynb b/tools/checkpoint_conversion/bert_large_cased_en.ipynb deleted file mode 100644 index 0c164a5b5b..0000000000 --- a/tools/checkpoint_conversion/bert_large_cased_en.ipynb +++ /dev/null @@ -1,1280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "fcd8f38b-e213-456f-b49f-397b21076120" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.8 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 46.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 58.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 59.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 64.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 74.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 62.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 60.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 11.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 62.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 68.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 71.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 69.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-large-vars tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_large\"\n", - "MODEL_SUFFIX = \"cased\"\n", - "MODEL_SPEC_STR = \"L-24_H-1024_A-16\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 28996\n", - "NUM_LAYERS = 24\n", - "NUM_ATTN_HEADS = 16\n", - "EMBEDDING_SIZE = 1024" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2018_10_18/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "534dc0a1-97bb-412a-bc20-e02af3e88afc" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2018_10_18/cased_L-24_H-1024_A-16.zip\n", - "1242178883/1242178883 [==============================] - 17s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "774aec39-bd51-46c5-ccb4-3abc99acee5b" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_large_cased\n", - " creating: cased_L-24_H-1024_A-16/\n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.meta \n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: cased_L-24_H-1024_A-16/vocab.txt \n", - " inflating: cased_L-24_H-1024_A-16/bert_model.ckpt.index \n", - " inflating: cased_L-24_H-1024_A-16/bert_config.json \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "86b04c44-0640-46fc-8e61-1706e3778dfc" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [1024]\n", - "bert/embeddings/LayerNorm/gamma [1024]\n", - "bert/embeddings/position_embeddings [512, 1024]\n", - "bert/embeddings/token_type_embeddings [2, 1024]\n", - "bert/embeddings/word_embeddings [28996, 1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/key/bias [1024]\n", - "bert/encoder/layer_0/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/query/bias [1024]\n", - "bert/encoder/layer_0/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/value/bias [1024]\n", - "bert/encoder/layer_0/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_0/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/output/dense/bias [1024]\n", - "bert/encoder/layer_0/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/key/bias [1024]\n", - "bert/encoder/layer_1/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/query/bias [1024]\n", - "bert/encoder/layer_1/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/value/bias [1024]\n", - "bert/encoder/layer_1/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_1/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/output/dense/bias [1024]\n", - "bert/encoder/layer_1/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_10/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/key/bias [1024]\n", - "bert/encoder/layer_10/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/query/bias [1024]\n", - "bert/encoder/layer_10/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/value/bias [1024]\n", - "bert/encoder/layer_10/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_10/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_10/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_10/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/output/dense/bias [1024]\n", - "bert/encoder/layer_10/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_11/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/key/bias [1024]\n", - "bert/encoder/layer_11/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/query/bias [1024]\n", - "bert/encoder/layer_11/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/value/bias [1024]\n", - "bert/encoder/layer_11/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_11/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_11/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_11/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/output/dense/bias [1024]\n", - "bert/encoder/layer_11/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_12/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/key/bias [1024]\n", - "bert/encoder/layer_12/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/query/bias [1024]\n", - "bert/encoder/layer_12/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/value/bias [1024]\n", - "bert/encoder/layer_12/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_12/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_12/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_12/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/output/dense/bias [1024]\n", - "bert/encoder/layer_12/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_13/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/key/bias [1024]\n", - "bert/encoder/layer_13/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/query/bias [1024]\n", - "bert/encoder/layer_13/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/value/bias [1024]\n", - "bert/encoder/layer_13/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_13/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_13/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_13/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/output/dense/bias [1024]\n", - "bert/encoder/layer_13/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_14/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/key/bias [1024]\n", - "bert/encoder/layer_14/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/query/bias [1024]\n", - "bert/encoder/layer_14/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/value/bias [1024]\n", - "bert/encoder/layer_14/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_14/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_14/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_14/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/output/dense/bias [1024]\n", - "bert/encoder/layer_14/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_15/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/key/bias [1024]\n", - "bert/encoder/layer_15/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/query/bias [1024]\n", - "bert/encoder/layer_15/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/value/bias [1024]\n", - "bert/encoder/layer_15/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_15/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_15/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_15/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/output/dense/bias [1024]\n", - "bert/encoder/layer_15/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_16/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/key/bias [1024]\n", - "bert/encoder/layer_16/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/query/bias [1024]\n", - "bert/encoder/layer_16/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/value/bias [1024]\n", - "bert/encoder/layer_16/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_16/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_16/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_16/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/output/dense/bias [1024]\n", - "bert/encoder/layer_16/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_17/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/key/bias [1024]\n", - "bert/encoder/layer_17/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/query/bias [1024]\n", - "bert/encoder/layer_17/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/value/bias [1024]\n", - "bert/encoder/layer_17/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_17/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_17/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_17/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/output/dense/bias [1024]\n", - "bert/encoder/layer_17/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_18/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/key/bias [1024]\n", - "bert/encoder/layer_18/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/query/bias [1024]\n", - "bert/encoder/layer_18/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/value/bias [1024]\n", - "bert/encoder/layer_18/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_18/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_18/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_18/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/output/dense/bias [1024]\n", - "bert/encoder/layer_18/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_19/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/key/bias [1024]\n", - "bert/encoder/layer_19/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/query/bias [1024]\n", - "bert/encoder/layer_19/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/value/bias [1024]\n", - "bert/encoder/layer_19/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_19/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_19/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_19/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/output/dense/bias [1024]\n", - "bert/encoder/layer_19/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/key/bias [1024]\n", - "bert/encoder/layer_2/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/query/bias [1024]\n", - "bert/encoder/layer_2/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/value/bias [1024]\n", - "bert/encoder/layer_2/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_2/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/output/dense/bias [1024]\n", - "bert/encoder/layer_2/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_20/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/key/bias [1024]\n", - "bert/encoder/layer_20/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/query/bias [1024]\n", - "bert/encoder/layer_20/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/value/bias [1024]\n", - "bert/encoder/layer_20/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_20/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_20/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_20/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/output/dense/bias [1024]\n", - "bert/encoder/layer_20/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_21/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/key/bias [1024]\n", - "bert/encoder/layer_21/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/query/bias [1024]\n", - "bert/encoder/layer_21/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/value/bias [1024]\n", - "bert/encoder/layer_21/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_21/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_21/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_21/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/output/dense/bias [1024]\n", - "bert/encoder/layer_21/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_22/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/key/bias [1024]\n", - "bert/encoder/layer_22/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/query/bias [1024]\n", - "bert/encoder/layer_22/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/value/bias [1024]\n", - "bert/encoder/layer_22/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_22/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_22/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_22/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/output/dense/bias [1024]\n", - "bert/encoder/layer_22/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_23/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/key/bias [1024]\n", - "bert/encoder/layer_23/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/query/bias [1024]\n", - "bert/encoder/layer_23/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/value/bias [1024]\n", - "bert/encoder/layer_23/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_23/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_23/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_23/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/output/dense/bias [1024]\n", - "bert/encoder/layer_23/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/key/bias [1024]\n", - "bert/encoder/layer_3/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/query/bias [1024]\n", - "bert/encoder/layer_3/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/value/bias [1024]\n", - "bert/encoder/layer_3/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_3/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/output/dense/bias [1024]\n", - "bert/encoder/layer_3/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/key/bias [1024]\n", - "bert/encoder/layer_4/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/query/bias [1024]\n", - "bert/encoder/layer_4/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/value/bias [1024]\n", - "bert/encoder/layer_4/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_4/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/output/dense/bias [1024]\n", - "bert/encoder/layer_4/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/key/bias [1024]\n", - "bert/encoder/layer_5/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/query/bias [1024]\n", - "bert/encoder/layer_5/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/value/bias [1024]\n", - "bert/encoder/layer_5/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_5/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/output/dense/bias [1024]\n", - "bert/encoder/layer_5/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/key/bias [1024]\n", - "bert/encoder/layer_6/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/query/bias [1024]\n", - "bert/encoder/layer_6/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/value/bias [1024]\n", - "bert/encoder/layer_6/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_6/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/output/dense/bias [1024]\n", - "bert/encoder/layer_6/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/key/bias [1024]\n", - "bert/encoder/layer_7/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/query/bias [1024]\n", - "bert/encoder/layer_7/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/value/bias [1024]\n", - "bert/encoder/layer_7/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_7/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/output/dense/bias [1024]\n", - "bert/encoder/layer_7/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_8/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/key/bias [1024]\n", - "bert/encoder/layer_8/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/query/bias [1024]\n", - "bert/encoder/layer_8/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/value/bias [1024]\n", - "bert/encoder/layer_8/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_8/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_8/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_8/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/output/dense/bias [1024]\n", - "bert/encoder/layer_8/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_9/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/key/bias [1024]\n", - "bert/encoder/layer_9/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/query/bias [1024]\n", - "bert/encoder/layer_9/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/value/bias [1024]\n", - "bert/encoder/layer_9/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_9/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_9/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_9/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/output/dense/bias [1024]\n", - "bert/encoder/layer_9/output/dense/kernel [4096, 1024]\n", - "bert/pooler/dense/bias [1024]\n", - "bert/pooler/dense/kernel [1024, 1024]\n", - "cls/predictions/output_bias [28996]\n", - "cls/predictions/transform/LayerNorm/beta [1024]\n", - "cls/predictions/transform/LayerNorm/gamma [1024]\n", - "cls/predictions/transform/dense/bias [1024]\n", - "cls/predictions/transform/dense/kernel [1024, 1024]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 1024]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertLarge model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "07380744-32ea-4200-effe-ef697bcc8dca", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 1024) 29691904 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 1024) 524288 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 1024) 2048 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 1024) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 1024) 2048 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 1024) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 1024) 12596224 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 1024) 12596224 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 1024) 12596224 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 1024) 12596224 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 1024) 12596224 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 1024) 12596224 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 1024) 12596224 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 1024) 12596224 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 1024) 12596224 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 1024) 12596224 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 1024) 12596224 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 1024) 12596224 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_12 (Transfor (None, None, 1024) 12596224 ['transformer_layer_11[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_13 (Transfor (None, None, 1024) 12596224 ['transformer_layer_12[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_14 (Transfor (None, None, 1024) 12596224 ['transformer_layer_13[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_15 (Transfor (None, None, 1024) 12596224 ['transformer_layer_14[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_16 (Transfor (None, None, 1024) 12596224 ['transformer_layer_15[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_17 (Transfor (None, None, 1024) 12596224 ['transformer_layer_16[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_18 (Transfor (None, None, 1024) 12596224 ['transformer_layer_17[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_19 (Transfor (None, None, 1024) 12596224 ['transformer_layer_18[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_20 (Transfor (None, None, 1024) 12596224 ['transformer_layer_19[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_21 (Transfor (None, None, 1024) 12596224 ['transformer_layer_20[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_22 (Transfor (None, None, 1024) 12596224 ['transformer_layer_21[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_23 (Transfor (None, None, 1024) 12596224 ['transformer_layer_22[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 1024) 0 ['transformer_layer_23[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 1024) 1049600 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 333,579,264\n", - "Trainable params: 333,579,264\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Large from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded toĀ TensorFlow Hub. SeeĀ run_classifier_with_tfhub.pyĀ for an example of how to use the TF Hub module, or run an example in the browser onĀ Colab.\"\n", - "\n", - "### TF Hub statement:\n", - "\"The weights of this model are those released by the original BERT authors.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/bert_en_cased_{MODEL_SPEC_STR}/4\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"The quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "orig_pooled_output, orig_sequence_output = embedding_model(\n", - " tf.constant([\"The quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], orig_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "598bad82-1f9a-4869-e9e4-9ac38157cd89", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a35bbd8a-7f87-4c6a-fd63-7ba227e4cdf0", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - orig_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - orig_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertLarge checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "144ded0b-512a-4f3a-9552-17ec709dd11c", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "eb9adff4-7aad-44a3-f6aa-49f83199e00f", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "210441" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "d16ef506-b1d0-450e-cae8-240c53896ede", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "8b8ab82290bbf4f8db87d4f100648890 bert_large_cased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertLarge(weights=\"cased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": 21, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_large_uncased_en.ipynb b/tools/checkpoint_conversion/bert_large_uncased_en.ipynb deleted file mode 100644 index a80069e199..0000000000 --- a/tools/checkpoint_conversion/bert_large_uncased_en.ipynb +++ /dev/null @@ -1,1280 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "c5e29c2b-f3ba-46eb-c604-74545a36204f" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.6 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 64.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 63.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 48.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 72.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 61.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 63.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 60.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 69.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 1.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 56.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 70.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 71.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 10.0 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@bert-large-vars tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_large\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-24_H-1024_A-16\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 24\n", - "NUM_ATTN_HEADS = 16\n", - "EMBEDDING_SIZE = 1024" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "41332d42-8e39-408f-fcc0-803cef5ccfb7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip\n", - "1247797031/1247797031 [==============================] - 23s 0us/step\n" - ] - } - ], - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2018_10_18/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "3496afcb-a342-449e-b5b4-975238cec430" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_large_uncased\n", - " creating: uncased_L-24_H-1024_A-16/\n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.meta \n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-24_H-1024_A-16/vocab.txt \n", - " inflating: uncased_L-24_H-1024_A-16/bert_model.ckpt.index \n", - " inflating: uncased_L-24_H-1024_A-16/bert_config.json \n" - ] - } - ], - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "OGij7IQU4rJL" - }, - "outputs": [], - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "ba1605fe-4503-49dc-ebc5-d89148527b5d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [1024]\n", - "bert/embeddings/LayerNorm/gamma [1024]\n", - "bert/embeddings/position_embeddings [512, 1024]\n", - "bert/embeddings/token_type_embeddings [2, 1024]\n", - "bert/embeddings/word_embeddings [30522, 1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/key/bias [1024]\n", - "bert/encoder/layer_0/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/query/bias [1024]\n", - "bert/encoder/layer_0/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_0/attention/self/value/bias [1024]\n", - "bert/encoder/layer_0/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_0/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_0/output/dense/bias [1024]\n", - "bert/encoder/layer_0/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/key/bias [1024]\n", - "bert/encoder/layer_1/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/query/bias [1024]\n", - "bert/encoder/layer_1/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_1/attention/self/value/bias [1024]\n", - "bert/encoder/layer_1/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_1/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_1/output/dense/bias [1024]\n", - "bert/encoder/layer_1/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_10/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/key/bias [1024]\n", - "bert/encoder/layer_10/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/query/bias [1024]\n", - "bert/encoder/layer_10/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_10/attention/self/value/bias [1024]\n", - "bert/encoder/layer_10/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_10/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_10/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_10/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_10/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_10/output/dense/bias [1024]\n", - "bert/encoder/layer_10/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_11/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/key/bias [1024]\n", - "bert/encoder/layer_11/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/query/bias [1024]\n", - "bert/encoder/layer_11/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_11/attention/self/value/bias [1024]\n", - "bert/encoder/layer_11/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_11/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_11/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_11/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_11/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_11/output/dense/bias [1024]\n", - "bert/encoder/layer_11/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_12/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/key/bias [1024]\n", - "bert/encoder/layer_12/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/query/bias [1024]\n", - "bert/encoder/layer_12/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_12/attention/self/value/bias [1024]\n", - "bert/encoder/layer_12/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_12/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_12/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_12/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_12/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_12/output/dense/bias [1024]\n", - "bert/encoder/layer_12/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_13/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/key/bias [1024]\n", - "bert/encoder/layer_13/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/query/bias [1024]\n", - "bert/encoder/layer_13/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_13/attention/self/value/bias [1024]\n", - "bert/encoder/layer_13/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_13/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_13/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_13/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_13/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_13/output/dense/bias [1024]\n", - "bert/encoder/layer_13/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_14/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/key/bias [1024]\n", - "bert/encoder/layer_14/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/query/bias [1024]\n", - "bert/encoder/layer_14/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_14/attention/self/value/bias [1024]\n", - "bert/encoder/layer_14/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_14/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_14/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_14/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_14/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_14/output/dense/bias [1024]\n", - "bert/encoder/layer_14/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_15/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/key/bias [1024]\n", - "bert/encoder/layer_15/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/query/bias [1024]\n", - "bert/encoder/layer_15/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_15/attention/self/value/bias [1024]\n", - "bert/encoder/layer_15/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_15/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_15/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_15/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_15/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_15/output/dense/bias [1024]\n", - "bert/encoder/layer_15/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_16/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/key/bias [1024]\n", - "bert/encoder/layer_16/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/query/bias [1024]\n", - "bert/encoder/layer_16/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_16/attention/self/value/bias [1024]\n", - "bert/encoder/layer_16/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_16/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_16/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_16/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_16/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_16/output/dense/bias [1024]\n", - "bert/encoder/layer_16/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_17/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/key/bias [1024]\n", - "bert/encoder/layer_17/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/query/bias [1024]\n", - "bert/encoder/layer_17/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_17/attention/self/value/bias [1024]\n", - "bert/encoder/layer_17/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_17/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_17/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_17/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_17/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_17/output/dense/bias [1024]\n", - "bert/encoder/layer_17/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_18/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/key/bias [1024]\n", - "bert/encoder/layer_18/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/query/bias [1024]\n", - "bert/encoder/layer_18/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_18/attention/self/value/bias [1024]\n", - "bert/encoder/layer_18/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_18/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_18/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_18/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_18/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_18/output/dense/bias [1024]\n", - "bert/encoder/layer_18/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_19/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/key/bias [1024]\n", - "bert/encoder/layer_19/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/query/bias [1024]\n", - "bert/encoder/layer_19/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_19/attention/self/value/bias [1024]\n", - "bert/encoder/layer_19/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_19/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_19/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_19/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_19/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_19/output/dense/bias [1024]\n", - "bert/encoder/layer_19/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/key/bias [1024]\n", - "bert/encoder/layer_2/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/query/bias [1024]\n", - "bert/encoder/layer_2/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_2/attention/self/value/bias [1024]\n", - "bert/encoder/layer_2/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_2/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_2/output/dense/bias [1024]\n", - "bert/encoder/layer_2/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_20/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/key/bias [1024]\n", - "bert/encoder/layer_20/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/query/bias [1024]\n", - "bert/encoder/layer_20/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_20/attention/self/value/bias [1024]\n", - "bert/encoder/layer_20/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_20/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_20/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_20/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_20/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_20/output/dense/bias [1024]\n", - "bert/encoder/layer_20/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_21/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/key/bias [1024]\n", - "bert/encoder/layer_21/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/query/bias [1024]\n", - "bert/encoder/layer_21/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_21/attention/self/value/bias [1024]\n", - "bert/encoder/layer_21/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_21/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_21/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_21/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_21/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_21/output/dense/bias [1024]\n", - "bert/encoder/layer_21/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_22/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/key/bias [1024]\n", - "bert/encoder/layer_22/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/query/bias [1024]\n", - "bert/encoder/layer_22/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_22/attention/self/value/bias [1024]\n", - "bert/encoder/layer_22/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_22/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_22/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_22/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_22/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_22/output/dense/bias [1024]\n", - "bert/encoder/layer_22/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_23/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/key/bias [1024]\n", - "bert/encoder/layer_23/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/query/bias [1024]\n", - "bert/encoder/layer_23/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_23/attention/self/value/bias [1024]\n", - "bert/encoder/layer_23/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_23/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_23/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_23/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_23/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_23/output/dense/bias [1024]\n", - "bert/encoder/layer_23/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/key/bias [1024]\n", - "bert/encoder/layer_3/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/query/bias [1024]\n", - "bert/encoder/layer_3/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_3/attention/self/value/bias [1024]\n", - "bert/encoder/layer_3/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_3/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_3/output/dense/bias [1024]\n", - "bert/encoder/layer_3/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/key/bias [1024]\n", - "bert/encoder/layer_4/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/query/bias [1024]\n", - "bert/encoder/layer_4/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_4/attention/self/value/bias [1024]\n", - "bert/encoder/layer_4/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_4/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_4/output/dense/bias [1024]\n", - "bert/encoder/layer_4/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/key/bias [1024]\n", - "bert/encoder/layer_5/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/query/bias [1024]\n", - "bert/encoder/layer_5/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_5/attention/self/value/bias [1024]\n", - "bert/encoder/layer_5/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_5/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_5/output/dense/bias [1024]\n", - "bert/encoder/layer_5/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/key/bias [1024]\n", - "bert/encoder/layer_6/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/query/bias [1024]\n", - "bert/encoder/layer_6/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_6/attention/self/value/bias [1024]\n", - "bert/encoder/layer_6/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_6/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_6/output/dense/bias [1024]\n", - "bert/encoder/layer_6/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/key/bias [1024]\n", - "bert/encoder/layer_7/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/query/bias [1024]\n", - "bert/encoder/layer_7/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_7/attention/self/value/bias [1024]\n", - "bert/encoder/layer_7/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_7/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_7/output/dense/bias [1024]\n", - "bert/encoder/layer_7/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_8/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/key/bias [1024]\n", - "bert/encoder/layer_8/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/query/bias [1024]\n", - "bert/encoder/layer_8/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_8/attention/self/value/bias [1024]\n", - "bert/encoder/layer_8/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_8/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_8/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_8/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_8/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_8/output/dense/bias [1024]\n", - "bert/encoder/layer_8/output/dense/kernel [4096, 1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/attention/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/attention/output/dense/bias [1024]\n", - "bert/encoder/layer_9/attention/output/dense/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/key/bias [1024]\n", - "bert/encoder/layer_9/attention/self/key/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/query/bias [1024]\n", - "bert/encoder/layer_9/attention/self/query/kernel [1024, 1024]\n", - "bert/encoder/layer_9/attention/self/value/bias [1024]\n", - "bert/encoder/layer_9/attention/self/value/kernel [1024, 1024]\n", - "bert/encoder/layer_9/intermediate/dense/bias [4096]\n", - "bert/encoder/layer_9/intermediate/dense/kernel [1024, 4096]\n", - "bert/encoder/layer_9/output/LayerNorm/beta [1024]\n", - "bert/encoder/layer_9/output/LayerNorm/gamma [1024]\n", - "bert/encoder/layer_9/output/dense/bias [1024]\n", - "bert/encoder/layer_9/output/dense/kernel [4096, 1024]\n", - "bert/pooler/dense/bias [1024]\n", - "bert/pooler/dense/kernel [1024, 1024]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [1024]\n", - "cls/predictions/transform/LayerNorm/gamma [1024]\n", - "cls/predictions/transform/dense/bias [1024]\n", - "cls/predictions/transform/dense/kernel [1024, 1024]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 1024]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertLarge model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "g1kp1M9b6hdU", - "outputId": "8eb94045-3b29-400e-92e7-329cbc1250d7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 1024) 31254528 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 1024) 524288 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 1024) 2048 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 1024) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 1024) 2048 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 1024) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 1024) 12596224 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 1024) 12596224 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 1024) 12596224 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 1024) 12596224 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 1024) 12596224 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 1024) 12596224 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 1024) 12596224 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 1024) 12596224 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_8 (Transform (None, None, 1024) 12596224 ['transformer_layer_7[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_9 (Transform (None, None, 1024) 12596224 ['transformer_layer_8[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_10 (Transfor (None, None, 1024) 12596224 ['transformer_layer_9[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_11 (Transfor (None, None, 1024) 12596224 ['transformer_layer_10[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_12 (Transfor (None, None, 1024) 12596224 ['transformer_layer_11[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_13 (Transfor (None, None, 1024) 12596224 ['transformer_layer_12[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_14 (Transfor (None, None, 1024) 12596224 ['transformer_layer_13[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_15 (Transfor (None, None, 1024) 12596224 ['transformer_layer_14[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_16 (Transfor (None, None, 1024) 12596224 ['transformer_layer_15[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_17 (Transfor (None, None, 1024) 12596224 ['transformer_layer_16[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_18 (Transfor (None, None, 1024) 12596224 ['transformer_layer_17[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_19 (Transfor (None, None, 1024) 12596224 ['transformer_layer_18[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_20 (Transfor (None, None, 1024) 12596224 ['transformer_layer_19[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_21 (Transfor (None, None, 1024) 12596224 ['transformer_layer_20[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_22 (Transfor (None, None, 1024) 12596224 ['transformer_layer_21[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_23 (Transfor (None, None, 1024) 12596224 ['transformer_layer_22[0][0]', \n", - " merEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 1024) 0 ['transformer_layer_23[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 1024) 1049600 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 335,141,888\n", - "Trainable params: 335,141,888\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ByCEoIyn-_Ld" - }, - "source": [ - "## Load Bert Large from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded toĀ TensorFlow Hub. SeeĀ run_classifier_with_tfhub.pyĀ for an example of how to use the TF Hub module, or run an example in the browser onĀ Colab.\"\n", - "\n", - "### TF Hub statement:\n", - "\"The weights of this model are those released by the original BERT authors.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "outputs": [], - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/bert_en_uncased_{MODEL_SPEC_STR}/4\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "w4OhasjS9Ozn" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "orig_pooled_output, orig_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HzUii8Tp9qth", - "outputId": "5c7585f3-831e-4da2-abdf-07a968f86856" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ], - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], orig_pooled_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "II0akvof9qth", - "outputId": "0f4d085a-7835-41da-80de-4282fe953f6b" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - orig_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - orig_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertLarge checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertLarge(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OD0B0UxN-QdB", - "outputId": "89ffaadc-32ee-4a16-bfc0-6f61f163bc73" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q0K9JAY5-QdD", - "outputId": "1a4a04e2-ffda-4f68-d396-1490faa1c5e2" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-jVECpzp-QdD", - "outputId": "9b05e1b1-adcb-4073-950b-1ab0ab579a92" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cc5cacc9565ef400ee4376105f40ddae bert_large_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "b3061cdc-a831-44ae-a6a5-a999b1bf2544" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_uncased/model.h5\n", - "1341009960/1341009960 [==============================] - 46s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertLarge(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ycvzjqdZYjNo", - "outputId": "42366183-9d87-4409-a853-de745c38babb" - }, - "execution_count": 20, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "7ff6fcf4-64fb-4f58-c8b3-1bbc05376e85" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "KcwejBTMXsIc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb b/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb deleted file mode 100644 index 01e5169a63..0000000000 --- a/tools/checkpoint_conversion/bert_medium_uncased_en.ipynb +++ /dev/null @@ -1,971 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "86338ead-40b4-4149-f30f-8f2b28d3426d" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.1 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 52.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 49.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 55.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 68.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 54.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 69.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 54.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 8.9 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 66.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 56.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 70.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 58.8 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_medium\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-8_H-512_A-8\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 8\n", - "NUM_ATTN_HEADS = 8\n", - "EMBEDDING_SIZE = 512" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "9acfd0df-e230-4698-9178-08c7e1854fa4" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-8_H-512_A-8.zip\n", - "154608092/154608092 [==============================] - 1s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "4d6a2bda-7482-42a1-aff5-8a28a6796085" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_medium_uncased\n", - " inflating: uncased_L-8_H-512_A-8/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-8_H-512_A-8/bert_config.json \n", - " inflating: uncased_L-8_H-512_A-8/vocab.txt \n", - " inflating: uncased_L-8_H-512_A-8/bert_model.ckpt.index \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "c4cc8f57-f252-49fe-a9b6-de96c2162813" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [512]\n", - "bert/embeddings/LayerNorm/gamma [512]\n", - "bert/embeddings/position_embeddings [512, 512]\n", - "bert/embeddings/token_type_embeddings [2, 512]\n", - "bert/embeddings/word_embeddings [30522, 512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/attention/output/dense/bias [512]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/key/bias [512]\n", - "bert/encoder/layer_0/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/query/bias [512]\n", - "bert/encoder/layer_0/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/value/bias [512]\n", - "bert/encoder/layer_0/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_0/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/output/dense/bias [512]\n", - "bert/encoder/layer_0/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/attention/output/dense/bias [512]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/key/bias [512]\n", - "bert/encoder/layer_1/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/query/bias [512]\n", - "bert/encoder/layer_1/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/value/bias [512]\n", - "bert/encoder/layer_1/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_1/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/output/dense/bias [512]\n", - "bert/encoder/layer_1/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/attention/output/dense/bias [512]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/key/bias [512]\n", - "bert/encoder/layer_2/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/query/bias [512]\n", - "bert/encoder/layer_2/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/value/bias [512]\n", - "bert/encoder/layer_2/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_2/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/output/dense/bias [512]\n", - "bert/encoder/layer_2/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/attention/output/dense/bias [512]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/key/bias [512]\n", - "bert/encoder/layer_3/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/query/bias [512]\n", - "bert/encoder/layer_3/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/value/bias [512]\n", - "bert/encoder/layer_3/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_3/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/output/dense/bias [512]\n", - "bert/encoder/layer_3/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_4/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_4/attention/output/dense/bias [512]\n", - "bert/encoder/layer_4/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/key/bias [512]\n", - "bert/encoder/layer_4/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/query/bias [512]\n", - "bert/encoder/layer_4/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_4/attention/self/value/bias [512]\n", - "bert/encoder/layer_4/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_4/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_4/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_4/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_4/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_4/output/dense/bias [512]\n", - "bert/encoder/layer_4/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_5/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_5/attention/output/dense/bias [512]\n", - "bert/encoder/layer_5/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/key/bias [512]\n", - "bert/encoder/layer_5/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/query/bias [512]\n", - "bert/encoder/layer_5/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_5/attention/self/value/bias [512]\n", - "bert/encoder/layer_5/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_5/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_5/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_5/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_5/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_5/output/dense/bias [512]\n", - "bert/encoder/layer_5/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_6/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_6/attention/output/dense/bias [512]\n", - "bert/encoder/layer_6/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/key/bias [512]\n", - "bert/encoder/layer_6/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/query/bias [512]\n", - "bert/encoder/layer_6/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_6/attention/self/value/bias [512]\n", - "bert/encoder/layer_6/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_6/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_6/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_6/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_6/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_6/output/dense/bias [512]\n", - "bert/encoder/layer_6/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_7/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_7/attention/output/dense/bias [512]\n", - "bert/encoder/layer_7/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/key/bias [512]\n", - "bert/encoder/layer_7/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/query/bias [512]\n", - "bert/encoder/layer_7/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_7/attention/self/value/bias [512]\n", - "bert/encoder/layer_7/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_7/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_7/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_7/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_7/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_7/output/dense/bias [512]\n", - "bert/encoder/layer_7/output/dense/kernel [2048, 512]\n", - "bert/pooler/dense/bias [512]\n", - "bert/pooler/dense/kernel [512, 512]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [512]\n", - "cls/predictions/transform/LayerNorm/gamma [512]\n", - "cls/predictions/transform/dense/bias [512]\n", - "cls/predictions/transform/dense/kernel [512, 512]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 512]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertMedium model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "c4c615a2-2997-4682-a8fa-97da40289e23", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 512) 15627264 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 512) 262144 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 512) 1024 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 512) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 512) 1024 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 512) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 512) 3152384 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 512) 3152384 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 512) 3152384 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 512) 3152384 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_4 (Transform (None, None, 512) 3152384 ['transformer_layer_3[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_5 (Transform (None, None, 512) 3152384 ['transformer_layer_4[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_6 (Transform (None, None, 512) 3152384 ['transformer_layer_5[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_7 (Transform (None, None, 512) 3152384 ['transformer_layer_6[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 512) 0 ['transformer_layer_7[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 512) 262656 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 41,373,184\n", - "Trainable params: 41,373,184\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertMedium(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Medium from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded toĀ TensorFlow Hub. SeeĀ run_classifier_with_tfhub.pyĀ for an example of how to use the TF Hub module, or run an example in the browser onĀ Colab.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "06d8485d-84f0-47ea-eab9-7d229da581ce", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(, )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1fbc2ea7-577f-491d-97d1-2becf294eb94", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertMedium checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertMedium(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0c3ae2e3-f23e-46a3-a7bc-2b9e29e72a21", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "68d4f6b4-dd02-423c-ca45-c18ede23a3dc", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0e9970fb-cb39-4255-c1bb-65d891a8b845", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bb990e1184ec6b6185450c73833cd661 bert_medium_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertMedium(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/bert_small_uncased_en.ipynb b/tools/checkpoint_conversion/bert_small_uncased_en.ipynb deleted file mode 100644 index ded288b2d3..0000000000 --- a/tools/checkpoint_conversion/bert_small_uncased_en.ipynb +++ /dev/null @@ -1,895 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Szd6xKUd2tIE", - "outputId": "33a180e3-462f-4f6a-b641-0104f18f96de" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.8 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 48.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 48.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 52.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 64.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 56.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 68.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 38.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 74.8 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 10.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.2 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 50.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 56.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 69.5 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "DmVlNiSexzR7" - }, - "outputs": [], - "source": [ - "MODEL_TYPE = \"bert_small\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-4_H-512_A-8\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 4\n", - "NUM_ATTN_HEADS = 8\n", - "EMBEDDING_SIZE = 512" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "37545bbe-70bc-4824-f96e-3eda6b99e709" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-4_H-512_A-8.zip\n", - "107814641/107814641 [==============================] - 1s 0us/step\n" - ] - } - ], - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "8ca6d690-8f73-45e7-a0e9-16b5901a4297" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Archive: bert_small_uncased\n", - " inflating: uncased_L-4_H-512_A-8/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-4_H-512_A-8/bert_config.json \n", - " inflating: uncased_L-4_H-512_A-8/vocab.txt \n", - " inflating: uncased_L-4_H-512_A-8/bert_model.ckpt.index \n" - ] - } - ], - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "OGij7IQU4rJL" - }, - "outputs": [], - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "405b2156-e025-4848-8aee-2589c4354311" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bert/embeddings/LayerNorm/beta [512]\n", - "bert/embeddings/LayerNorm/gamma [512]\n", - "bert/embeddings/position_embeddings [512, 512]\n", - "bert/embeddings/token_type_embeddings [2, 512]\n", - "bert/embeddings/word_embeddings [30522, 512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/attention/output/dense/bias [512]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/key/bias [512]\n", - "bert/encoder/layer_0/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/query/bias [512]\n", - "bert/encoder/layer_0/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_0/attention/self/value/bias [512]\n", - "bert/encoder/layer_0/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_0/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_0/output/dense/bias [512]\n", - "bert/encoder/layer_0/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/attention/output/dense/bias [512]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/key/bias [512]\n", - "bert/encoder/layer_1/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/query/bias [512]\n", - "bert/encoder/layer_1/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_1/attention/self/value/bias [512]\n", - "bert/encoder/layer_1/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_1/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_1/output/dense/bias [512]\n", - "bert/encoder/layer_1/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/attention/output/dense/bias [512]\n", - "bert/encoder/layer_2/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/key/bias [512]\n", - "bert/encoder/layer_2/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/query/bias [512]\n", - "bert/encoder/layer_2/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_2/attention/self/value/bias [512]\n", - "bert/encoder/layer_2/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_2/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_2/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_2/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_2/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_2/output/dense/bias [512]\n", - "bert/encoder/layer_2/output/dense/kernel [2048, 512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/attention/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/attention/output/dense/bias [512]\n", - "bert/encoder/layer_3/attention/output/dense/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/key/bias [512]\n", - "bert/encoder/layer_3/attention/self/key/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/query/bias [512]\n", - "bert/encoder/layer_3/attention/self/query/kernel [512, 512]\n", - "bert/encoder/layer_3/attention/self/value/bias [512]\n", - "bert/encoder/layer_3/attention/self/value/kernel [512, 512]\n", - "bert/encoder/layer_3/intermediate/dense/bias [2048]\n", - "bert/encoder/layer_3/intermediate/dense/kernel [512, 2048]\n", - "bert/encoder/layer_3/output/LayerNorm/beta [512]\n", - "bert/encoder/layer_3/output/LayerNorm/gamma [512]\n", - "bert/encoder/layer_3/output/dense/bias [512]\n", - "bert/encoder/layer_3/output/dense/kernel [2048, 512]\n", - "bert/pooler/dense/bias [512]\n", - "bert/pooler/dense/kernel [512, 512]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [512]\n", - "cls/predictions/transform/LayerNorm/gamma [512]\n", - "cls/predictions/transform/dense/bias [512]\n", - "cls/predictions/transform/dense/kernel [512, 512]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 512]\n" - ] - } - ], - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertSmall model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "g1kp1M9b6hdU", - "outputId": "24c6056a-ef5a-426c-b172-f5aeec699fad" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 512) 15627264 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 512) 262144 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 512) 1024 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 512) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 512) 1024 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 512) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 512) 3152384 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 512) 3152384 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_2 (Transform (None, None, 512) 3152384 ['transformer_layer_1[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_3 (Transform (None, None, 512) 3152384 ['transformer_layer_2[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 512) 0 ['transformer_layer_3[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 512) 262656 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 28,763,648\n", - "Trainable params: 28,763,648\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertSmall(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ByCEoIyn-_Ld" - }, - "source": [ - "## Load Bert Small from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded toĀ TensorFlow Hub. SeeĀ run_classifier_with_tfhub.pyĀ for an example of how to use the TF Hub module, or run an example in the browser onĀ Colab.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "outputs": [], - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "iAubWsWj9qtg" - }, - "outputs": [], - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HzUii8Tp9qth", - "outputId": "97bd5f3c-8440-4f86-f87d-2940370463e7" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(, )" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "II0akvof9qth", - "outputId": "e7b49897-abb3-45da-9ad8-45a62b22a37d" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertSmall checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertSmall(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OD0B0UxN-QdB", - "outputId": "6d4feea6-bc13-4fc2-a9b0-33e929ff9d0c" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "q0K9JAY5-QdD", - "outputId": "697099b1-de1e-4393-a272-12144d29d155" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "228209" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-jVECpzp-QdD", - "outputId": "7e75dcb5-b985-4b44-801c-e8d63d34f035" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "08632c9479b034f342ba2c2b7afba5f7 bert_small_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "wTd-5vUyVG0Q", - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertSmall(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "keras_nlp_output_cloud[0, :10]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "gpuClass": "standard", - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb b/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb deleted file mode 100644 index d1fb1119cd..0000000000 --- a/tools/checkpoint_conversion/bert_tiny_uncased_en.ipynb +++ /dev/null @@ -1,858 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vGp_yrJi5Ehf" - }, - "source": [ - "## Install deps" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "Szd6xKUd2tIE", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "5107d4a7-7205-448d-8989-9d81d01b7195" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 511.7 MB 6.7 kB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 2.1 MB 47.0 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 4.6 MB 49.1 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 5.8 MB 58.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 438 kB 69.4 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.6 MB 58.6 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 99 kB 11.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 636 kB 69.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.3 MB 56.7 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 352 kB 74.5 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 43 kB 2.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 116 kB 68.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 1.1 MB 61.3 MB/s \n", - "\u001b[K |ā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆā–ˆ| 238 kB 67.1 MB/s \n", - "\u001b[?25h Building wheel for keras-nlp (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for py-cpuinfo (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n" - ] - } - ], - "source": [ - "!pip install git+https://github.com/abheesht17/keras-nlp.git@more-bert-variants tensorflow tf-models-official tensorflow_hub --upgrade --quiet" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "JsbnAdSz5DzZ" - }, - "outputs": [], - "source": [ - "import json\n", - "import os\n", - "\n", - "import keras_nlp\n", - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "\n", - "import tensorflow_hub as hub" - ] - }, - { - "cell_type": "code", - "source": [ - "MODEL_TYPE = \"bert_tiny\"\n", - "MODEL_SUFFIX = \"uncased\"\n", - "MODEL_SPEC_STR = \"L-2_H-128_A-2\"\n", - "MODEL_NAME = f\"{MODEL_TYPE}_{MODEL_SUFFIX}\"\n", - "VOCAB_SIZE = 30522\n", - "NUM_LAYERS = 2\n", - "NUM_ATTN_HEADS = 2\n", - "EMBEDDING_SIZE = 128" - ], - "metadata": { - "id": "DmVlNiSexzR7" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "# BERT ckpt https://github.com/google-research/bert/blob/master/README.md.\n", - "zip_path = f\"\"\"https://storage.googleapis.com/bert_models/2020_02_20/{MODEL_SUFFIX}_{MODEL_SPEC_STR}.zip\"\"\"\n", - "zip_file = keras.utils.get_file(\n", - " f\"\"\"/content/{MODEL_NAME}\"\"\",\n", - " zip_path,\n", - " extract=True,\n", - " archive_format=\"zip\",\n", - ")" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "FXid57wR3tE5", - "outputId": "8e952e27-282d-440e-ba72-b526fe586b75" - }, - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip\n", - "16529104/16529104 [==============================] - 0s 0us/step\n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "!unzip \"\"\"{MODEL_NAME}\"\"\" -d \"\"\"{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\"\"" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "j-VBpV0n4VA3", - "outputId": "15e3842f-e312-4e79-829c-7fc136e29f60" - }, - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Archive: bert_tiny_uncased\n", - " inflating: uncased_L-2_H-128_A-2/bert_model.ckpt.data-00000-of-00001 \n", - " inflating: uncased_L-2_H-128_A-2/bert_config.json \n", - " inflating: uncased_L-2_H-128_A-2/vocab.txt \n", - " inflating: uncased_L-2_H-128_A-2/bert_model.ckpt.index \n" - ] - } - ] - }, - { - "cell_type": "code", - "source": [ - "# BERT paths.\n", - "extract_dir = f\"/content/{MODEL_SUFFIX}_{MODEL_SPEC_STR}\"\n", - "vocab_path = os.path.join(extract_dir, \"vocab.txt\")\n", - "checkpoint_path = os.path.join(extract_dir, \"bert_model.ckpt\")\n", - "config_path = os.path.join(extract_dir, \"bert_config.json\")" - ], - "metadata": { - "id": "OGij7IQU4rJL" - }, - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "vars = tf.train.list_variables(checkpoint_path)\n", - "weights = {}\n", - "for name, shape in vars:\n", - " print(name, shape)\n", - " weight = tf.train.load_variable(checkpoint_path, name)\n", - " weights[name] = weight" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RC6DqSfo4iPR", - "outputId": "01e19274-afa7-45ae-f4f8-2c46538ed97c" - }, - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "bert/embeddings/LayerNorm/beta [128]\n", - "bert/embeddings/LayerNorm/gamma [128]\n", - "bert/embeddings/position_embeddings [512, 128]\n", - "bert/embeddings/token_type_embeddings [2, 128]\n", - "bert/embeddings/word_embeddings [30522, 128]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_0/attention/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_0/attention/output/dense/bias [128]\n", - "bert/encoder/layer_0/attention/output/dense/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/key/bias [128]\n", - "bert/encoder/layer_0/attention/self/key/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/query/bias [128]\n", - "bert/encoder/layer_0/attention/self/query/kernel [128, 128]\n", - "bert/encoder/layer_0/attention/self/value/bias [128]\n", - "bert/encoder/layer_0/attention/self/value/kernel [128, 128]\n", - "bert/encoder/layer_0/intermediate/dense/bias [512]\n", - "bert/encoder/layer_0/intermediate/dense/kernel [128, 512]\n", - "bert/encoder/layer_0/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_0/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_0/output/dense/bias [128]\n", - "bert/encoder/layer_0/output/dense/kernel [512, 128]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_1/attention/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_1/attention/output/dense/bias [128]\n", - "bert/encoder/layer_1/attention/output/dense/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/key/bias [128]\n", - "bert/encoder/layer_1/attention/self/key/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/query/bias [128]\n", - "bert/encoder/layer_1/attention/self/query/kernel [128, 128]\n", - "bert/encoder/layer_1/attention/self/value/bias [128]\n", - "bert/encoder/layer_1/attention/self/value/kernel [128, 128]\n", - "bert/encoder/layer_1/intermediate/dense/bias [512]\n", - "bert/encoder/layer_1/intermediate/dense/kernel [128, 512]\n", - "bert/encoder/layer_1/output/LayerNorm/beta [128]\n", - "bert/encoder/layer_1/output/LayerNorm/gamma [128]\n", - "bert/encoder/layer_1/output/dense/bias [128]\n", - "bert/encoder/layer_1/output/dense/kernel [512, 128]\n", - "bert/pooler/dense/bias [128]\n", - "bert/pooler/dense/kernel [128, 128]\n", - "cls/predictions/output_bias [30522]\n", - "cls/predictions/transform/LayerNorm/beta [128]\n", - "cls/predictions/transform/LayerNorm/gamma [128]\n", - "cls/predictions/transform/dense/bias [128]\n", - "cls/predictions/transform/dense/kernel [128, 128]\n", - "cls/seq_relationship/output_bias [2]\n", - "cls/seq_relationship/output_weights [2, 128]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FTIwxvcB6hc-" - }, - "source": [ - "## Load BertTiny model with KerasNLP." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "89fd51e3-8fa8-4045-de21-ec90a4d515dd", - "id": "g1kp1M9b6hdU" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Model: \"bert_custom\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " token_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " token_embedding (Embedding) (None, None, 128) 3906816 ['token_ids[0][0]'] \n", - " \n", - " segment_ids (InputLayer) [(None, None)] 0 [] \n", - " \n", - " position_embedding (PositionEm (None, None, 128) 65536 ['token_embedding[0][0]'] \n", - " bedding) \n", - " \n", - " segment_embedding (Embedding) (None, None, 128) 256 ['segment_ids[0][0]'] \n", - " \n", - " add (Add) (None, None, 128) 0 ['token_embedding[0][0]', \n", - " 'position_embedding[0][0]', \n", - " 'segment_embedding[0][0]'] \n", - " \n", - " embeddings_layer_norm (LayerNo (None, None, 128) 256 ['add[0][0]'] \n", - " rmalization) \n", - " \n", - " embeddings_dropout (Dropout) (None, None, 128) 0 ['embeddings_layer_norm[0][0]'] \n", - " \n", - " padding_mask (InputLayer) [(None, None)] 0 [] \n", - " \n", - " transformer_layer_0 (Transform (None, None, 128) 198272 ['embeddings_dropout[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " transformer_layer_1 (Transform (None, None, 128) 198272 ['transformer_layer_0[0][0]', \n", - " erEncoder) 'padding_mask[0][0]'] \n", - " \n", - " tf.__operators__.getitem (Slic (None, 128) 0 ['transformer_layer_1[0][0]'] \n", - " ingOpLambda) \n", - " \n", - " pooled_dense (Dense) (None, 128) 16512 ['tf.__operators__.getitem[0][0]'\n", - " ] \n", - " \n", - "==================================================================================================\n", - "Total params: 4,385,920\n", - "Trainable params: 4,385,920\n", - "Non-trainable params: 0\n", - "__________________________________________________________________________________________________\n" - ] - } - ], - "source": [ - "model = keras_nlp.models.BertTiny(vocabulary_size=VOCAB_SIZE)\n", - "model.summary()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PxG_evKB6hdU" - }, - "source": [ - "## Convert Weights" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "VGEx-zLM6hdV" - }, - "outputs": [], - "source": [ - "model.get_layer(\"token_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/word_embeddings\"]\n", - ")\n", - "model.get_layer(\"position_embedding\").position_embeddings.assign(\n", - " weights[\"bert/embeddings/position_embeddings\"]\n", - ")\n", - "model.get_layer(\"segment_embedding\").embeddings.assign(\n", - " weights[\"bert/embeddings/token_type_embeddings\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").gamma.assign(\n", - " weights[\"bert/embeddings/LayerNorm/gamma\"]\n", - ")\n", - "model.get_layer(\"embeddings_layer_norm\").beta.assign(\n", - " weights[\"bert/embeddings/LayerNorm/beta\"]\n", - ")\n", - "\n", - "for i in range(model.num_layers):\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._key_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/key/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._query_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/query/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/kernel\"].reshape(\n", - " (EMBEDDING_SIZE, NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._value_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/self/value/bias\"].reshape(\n", - " (NUM_ATTN_HEADS, -1)\n", - " )\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.kernel.assign(\n", - " weights[\n", - " f\"bert/encoder/layer_{i}/attention/output/dense/kernel\"\n", - " ].reshape((NUM_ATTN_HEADS, -1, EMBEDDING_SIZE))\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer._output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._self_attention_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/attention/output/LayerNorm/beta\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_intermediate_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/intermediate/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.kernel.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/kernel\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_output_dense.bias.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/dense/bias\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.gamma.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/gamma\"]\n", - " )\n", - " model.get_layer(\n", - " f\"transformer_layer_{i}\"\n", - " )._feedforward_layer_norm.beta.assign(\n", - " weights[f\"bert/encoder/layer_{i}/output/LayerNorm/beta\"]\n", - " )\n", - "\n", - "model.get_layer(\"pooled_dense\").kernel.assign(\n", - " weights[\"bert/pooler/dense/kernel\"]\n", - ")\n", - "model.get_layer(\"pooled_dense\").bias.assign(weights[\"bert/pooler/dense/bias\"])\n", - "pass" - ] - }, - { - "cell_type": "markdown", - "source": [ - "## Load Bert Tiny from TF-Hub.\n", - "\n", - "These weights have been ratified by the authors of BERT: https://github.com/google-research/bert/blob/master/README.md.\n", - "\n", - "### BERT README statement:\n", - "\n", - "\"***** New February 7th, 2019: TfHub Module *****\n", - "BERT has been uploaded toĀ TensorFlow Hub. SeeĀ run_classifier_with_tfhub.pyĀ for an example of how to use the TF Hub module, or run an example in the browser onĀ Colab.\"" - ], - "metadata": { - "id": "ByCEoIyn-_Ld" - } - }, - { - "cell_type": "code", - "source": [ - "text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)\n", - "\n", - "preprocessor = hub.load(\n", - " \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\"\n", - ")\n", - "tokenizer = hub.KerasLayer(preprocessor.tokenize, name=\"tokenizer\")\n", - "tokenized_text = tokenizer(text_input)\n", - "\n", - "packer = hub.KerasLayer(\n", - " preprocessor.bert_pack_inputs, arguments=dict(seq_length=512), name=\"packer\"\n", - ")\n", - "encoder_inputs = packer([tokenized_text])\n", - "\n", - "encoder = hub.KerasLayer(\n", - " f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_{MODEL_SPEC_STR}/2\",\n", - " trainable=True,\n", - ")\n", - "outputs = encoder(encoder_inputs)\n", - "pooled_output = outputs[\"pooled_output\"] # [batch_size, 1024].\n", - "sequence_output = outputs[\"sequence_output\"] # [batch_size, seq_length, 1024].\n", - "\n", - "embedding_model = tf.keras.Model(text_input, (pooled_output, sequence_output))" - ], - "metadata": { - "id": "hQ0lMSluxMx1" - }, - "execution_count": 10, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "def preprocess(x):\n", - " tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab_path, lowercase=False\n", - " )\n", - " packer = keras_nlp.layers.MultiSegmentPacker(\n", - " sequence_length=model.max_sequence_length,\n", - " start_value=tokenizer.token_to_id(\"[CLS]\"),\n", - " end_value=tokenizer.token_to_id(\"[SEP]\"),\n", - " )\n", - " return packer(tokenizer(x))\n", - "\n", - "\n", - "token_ids, segment_ids = preprocess([\"the quick brown fox.\"])" - ], - "metadata": { - "id": "iAubWsWj9qtg" - }, - "execution_count": 11, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "-JvyB96k9qtg" - }, - "outputs": [], - "source": [ - "keras_nlp_output = model(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "hub_pooled_output, hub_sequence_output = embedding_model(\n", - " tf.constant([\"the quick brown fox.\"])\n", - ")" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output[\"pooled_output\"][0, :10], hub_pooled_output[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "aea2b522-a267-4f9e-ffcc-e5160d4ad04d", - "id": "HzUii8Tp9qth" - }, - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 13 - } - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a669a92c-cb29-4673-e1cb-5f8ad7ab2c23", - "id": "II0akvof9qth" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 14 - } - ], - "source": [ - "# Very close! Though not 100% exact.\n", - "(\n", - " tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - hub_pooled_output),\n", - " tf.reduce_mean(keras_nlp_output[\"sequence_output\"] - hub_sequence_output),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "78sejS0B-Qce" - }, - "outputs": [], - "source": [ - "# Save BertTiny checkpoint\n", - "model.save_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "bVlbhdZX-QdA" - }, - "outputs": [], - "source": [ - "model2 = keras_nlp.models.BertTiny(vocabulary_size=VOCAB_SIZE)\n", - "model2.load_weights(f\"\"\"{MODEL_NAME}.h5\"\"\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "a83056b5-9673-4e88-b81b-a92209c2305f", - "id": "OD0B0UxN-QdB" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(,\n", - " )" - ] - }, - "metadata": {}, - "execution_count": 17 - } - ], - "source": [ - "# Same output from loaded checkpoint\n", - "keras_nlp_output2 = model2(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")\n", - "\n", - "(\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"pooled_output\"] - keras_nlp_output2[\"pooled_output\"]\n", - " ),\n", - " tf.reduce_mean(\n", - " keras_nlp_output[\"sequence_output\"]\n", - " - keras_nlp_output2[\"sequence_output\"]\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0dc1fcf7-abf6-472b-acb3-5cafc8cf85cf", - "id": "q0K9JAY5-QdD" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "228209" - ] - }, - "metadata": {}, - "execution_count": 18 - } - ], - "source": [ - "# Save vocab file as well\n", - "vocab_info = tf.io.gfile.GFile(vocab_path).read()\n", - "f = open(\"vocab.txt\", \"w\")\n", - "f.write(vocab_info)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "07cae6fa-240b-473b-9b10-3734b9da0593", - "id": "-jVECpzp-QdD" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "c2b29fcbf8f814a0812e4ab89ef5c068 bert_tiny_uncased.h5\n" - ] - } - ], - "source": [ - "# Get MD5 of model\n", - "!md5sum \"\"\"{MODEL_NAME}.h5\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "z_0iMTCdFl8t" - }, - "outputs": [], - "source": [ - "# Upload model to drive\n", - "# from google.colab import drive\n", - "# drive.mount('/content/drive')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wTd-5vUyVG0Q", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "083acd06-b55f-4b07-82e4-3f71f866500e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from https://storage.googleapis.com/keras-nlp/models/bert_large_en_cased/model.h5\n", - "1334759464/1334759464 [==============================] - 41s 0us/step\n" - ] - } - ], - "source": [ - "# Check uploaded model once added to repo\n", - "model_cloud = keras_nlp.models.BertTiny(weights=\"uncased_en\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zs5x_f6GVdNY", - "outputId": "9ea2098f-4c71-4d8c-9991-6672b1de9f34" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 20 - } - ], - "source": [ - "# Same output from cloud model\n", - "keras_nlp_output_cloud = model_cloud(\n", - " {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - ")[\"pooled_output\"]\n", - "tf.reduce_mean(keras_nlp_output[\"pooled_output\"] - keras_nlp_output_cloud)" - ] - }, - { - "cell_type": "code", - "source": [ - "keras_nlp_output_cloud[0, :10]" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "RAwrhAcSzHWa", - "outputId": "92e1ecc4-b783-4f60-f65f-c2895ba1218f" - }, - "execution_count": null, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "code", - "source": [], - "metadata": { - "id": "S2JGnbTYaeGc" - }, - "execution_count": null, - "outputs": [] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU", - "gpuClass": "standard" - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/tools/checkpoint_conversion/convert_bloom_checkpoints.py b/tools/checkpoint_conversion/convert_bloom_checkpoints.py index 38acd099cf..a9e833d1b0 100644 --- a/tools/checkpoint_conversion/convert_bloom_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bloom_checkpoints.py @@ -15,19 +15,16 @@ import json import os -os.environ["KERAS_BACKEND"] = "torch" -os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import huggingface_hub +import numpy as np +import transformers +from absl import app +from absl import flags -import huggingface_hub # noqa: E402 -import numpy as np # noqa: E402 -import torch # noqa: E402 -import transformers # noqa: E402 -from absl import app # noqa: E402 -from absl import flags # noqa: E402 - -import keras_nlp # noqa: E402 -from keras_nlp.models import BloomBackbone # noqa: E402 -from keras_nlp.models import BloomTokenizer # noqa: E402 +import keras_nlp +from keras_nlp.models import BloomBackbone +from keras_nlp.models import BloomPreprocessor +from keras_nlp.models import BloomTokenizer FLAGS = flags.FLAGS @@ -38,22 +35,46 @@ "bloom_3b_multi": "bigscience/bloom-3b", "bloom_7b_multi": "bigscience/bloom-7b1", "bloom_176b_multi": "bigscience/bloom", + # Multitask finetuned on xP3 (Crosslingual Public Pool of Prompts) https://huggingface.co/datasets/bigscience/xP3 + # xP3 is a mixture of 13 training tasks in 46 languages with English prompts + "bloomz_560m_multi": "bigscience/bloomz-560m", + "bloomz_1.1b_multi": "bigscience/bloomz-1b1", + "bloomz_1.7b_multi": "bigscience/bloomz-1b7", + "bloomz_3b_multi": "bigscience/bloomz-3b", + "bloomz_7b_multi": "bigscience/bloomz-7b1", + "bloomz_176b_multi": "bigscience/bloomz", + # Multitask finetuned on xP3mt + # (Crosslingual Public Pool of Prompts machine-translated) https://huggingface.co/datasets/bigscience/xP3 + # xP3mt is Mixture of 13 training tasks in 46 languages with prompts in 20 + # languages (machine-translated from English) + "bloomz_7b_mt": "bigscience/bloomz-7b1-mt", + "bloomz_176b_mt": "bigscience/bloomz-mt", + # Multitask finetuned on P3 (Public Pool of Prompts) https://huggingface.co/datasets/Muennighoff/P3 + # xP3 is a mixture of 8 training tasks with English-only prompts + "bloomz_7b_p3": "bigscience/bloomz-7b1-p3", + "bloomz_176b_p3": "bigscience/bloomz-p3", } EXTRACT_DIR = "./model" flags.DEFINE_string( - "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' + "preset", None, f'Must be one of {", ".join(PRESET_MAP.keys())}' ) flags.mark_flag_as_required("preset") +flags.DEFINE_boolean( + "validate_only", + False, + "To validate the output of a preset that has been already uploaded. " + "No weights conversion will happen.", +) def download_hf_model(hf_model_name): hf_model_dir = huggingface_hub.snapshot_download( repo_id=hf_model_name, allow_patterns=["*.json", "*.bin"], - ignore_patterns=["onnx/*"], + ignore_patterns=["*/*"], local_dir=EXTRACT_DIR, ) @@ -99,51 +120,61 @@ def convert_weights(keras_model, hf_model): # assign huggingface weights to the keras model. # Embedding layer. keras_model.get_layer("token_embedding").embeddings.assign( - hf_wts["word_embeddings.weight"] + hf_wts["word_embeddings.weight"].detach().numpy() ) # LayerNorm. keras_model.get_layer("token_embedding_layernorm").gamma.assign( - hf_wts["word_embeddings_layernorm.weight"] + hf_wts["word_embeddings_layernorm.weight"].detach().numpy() ) keras_model.get_layer("token_embedding_layernorm").beta.assign( - hf_wts["word_embeddings_layernorm.bias"] + hf_wts["word_embeddings_layernorm.bias"].detach().numpy() ) - keras_model.get_layer("final_layernorm").gamma.assign(hf_wts["ln_f.weight"]) - keras_model.get_layer("final_layernorm").beta.assign(hf_wts["ln_f.bias"]) + keras_model.get_layer("final_layernorm").gamma.assign( + hf_wts["ln_f.weight"].detach().numpy() + ) + keras_model.get_layer("final_layernorm").beta.assign( + hf_wts["ln_f.bias"].detach().numpy() + ) # Decoder layers. for i in range(num_layers): decoder_layer = keras_model.get_layer(f"transformer_layer_{i}") # LayrNorm. decoder_layer._pre_attention_layernorm.gamma.assign( - hf_wts[f"h.{i}.input_layernorm.weight"] + hf_wts[f"h.{i}.input_layernorm.weight"].detach().numpy() ) decoder_layer._pre_attention_layernorm.beta.assign( - hf_wts[f"h.{i}.input_layernorm.bias"] + hf_wts[f"h.{i}.input_layernorm.bias"].detach().numpy() ) decoder_layer._post_attention_layernorm.gamma.assign( - hf_wts[f"h.{i}.post_attention_layernorm.weight"] + hf_wts[f"h.{i}.post_attention_layernorm.weight"].detach().numpy() ) decoder_layer._post_attention_layernorm.beta.assign( - hf_wts[f"h.{i}.post_attention_layernorm.bias"] + hf_wts[f"h.{i}.post_attention_layernorm.bias"].detach().numpy() ) # Attention layer. attention_layer = decoder_layer._self_attention_layer - fused_qkv_kernal = hf_wts[ - f"h.{i}.self_attention.query_key_value.weight" - ].T - fused_qkv_kernal = fused_qkv_kernal.view( + fused_qkv_kernal = ( + hf_wts[f"h.{i}.self_attention.query_key_value.weight"] + .T.detach() + .numpy() + ) + fused_qkv_kernal = fused_qkv_kernal.reshape( hidden_dim, num_heads, 3, head_dim ) query_kernal = fused_qkv_kernal[..., 0, :] key_kernal = fused_qkv_kernal[..., 1, :] value_kernl = fused_qkv_kernal[..., 2, :] - fused_qkv_bais = hf_wts[f"h.{i}.self_attention.query_key_value.bias"] - fused_qkv_bais = fused_qkv_bais.view(num_heads, 3, head_dim) + fused_qkv_bais = ( + hf_wts[f"h.{i}.self_attention.query_key_value.bias"] + .detach() + .numpy() + ) + fused_qkv_bais = fused_qkv_bais.reshape(num_heads, 3, head_dim) query_bais = fused_qkv_bais[:, 0, :] key_bais = fused_qkv_bais[:, 1, :] value_bais = fused_qkv_bais[:, 2, :] @@ -156,24 +187,24 @@ def convert_weights(keras_model, hf_model): attention_layer._value_dense.bias.assign(value_bais) attention_layer._output_dense.kernel.assign( - hf_wts[f"h.{i}.self_attention.dense.weight"].T + hf_wts[f"h.{i}.self_attention.dense.weight"].T.detach().numpy() ) attention_layer._output_dense.bias.assign( - hf_wts[f"h.{i}.self_attention.dense.bias"] + hf_wts[f"h.{i}.self_attention.dense.bias"].detach().numpy() ) # mlp. decoder_layer._mlp_intermediate_dense.kernel.assign( - hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T + hf_wts[f"h.{i}.mlp.dense_h_to_4h.weight"].T.detach().numpy() ) decoder_layer._mlp_intermediate_dense.bias.assign( - hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"] + hf_wts[f"h.{i}.mlp.dense_h_to_4h.bias"].detach().numpy() ) decoder_layer._mlp_output_dense.kernel.assign( - hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T + hf_wts[f"h.{i}.mlp.dense_4h_to_h.weight"].T.detach().numpy() ) decoder_layer._mlp_output_dense.bias.assign( - hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"] + hf_wts[f"h.{i}.mlp.dense_4h_to_h.bias"].detach().numpy() ) @@ -185,20 +216,21 @@ def validate_output( ): input_str = ["the quick brown fox ran, galloped and jumped."] - # KerasNLP - token_ids = torch.tensor(keras_tokenizer(input_str)) - padding_mask = token_ids != 3 - keras_model_input = { - "token_ids": token_ids, - "padding_mask": padding_mask, - } - keras_model_outputs = keras_model.predict(keras_model_input) - + # HuggingFace hf_model_input = hf_tokenizer(input_str, return_tensors="pt") - hf_model_outputs = hf_model(**hf_model_input).last_hidden_state hf_model_outputs = hf_model_outputs.detach().numpy() + # KerasNLP + preprocessor = BloomPreprocessor( + tokenizer=keras_tokenizer, + sequence_length=hf_model_outputs.shape[1], + add_end_token=False, + add_start_token=False, + ) + keras_model_input = preprocessor(input_str) + keras_model_outputs = keras_model.predict(keras_model_input) + # Comparing the outputs. print("šŸ”¶ KerasNLP output:", keras_model_outputs[0, 0, :10]) print("šŸ”¶ HF output:", hf_model_outputs[0, 0, :10]) @@ -207,41 +239,87 @@ def validate_output( def main(_): preset = FLAGS.preset - assert ( preset in PRESET_MAP.keys() - ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + ), f'Invalid preset {preset}. Must be one of {", ".join(PRESET_MAP.keys())}' + + validate_only = FLAGS.validate_only + + if not validate_only: + print(f"āœ… Coverting {preset}") - print(f"āœ… Coverting {preset}") + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("āœ… Huggingface model downloaded from hub") - hf_model_name = PRESET_MAP[preset] - hf_model_dir = download_hf_model(hf_model_name) - print("āœ… Huggingface model downloaded from hub") + hf_model = transformers.BloomModel.from_pretrained( + hf_model_dir, + ) + hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained( + hf_model_dir + ) + print("āœ… Huggingface model loaded") - hf_model = transformers.BloomModel.from_pretrained(hf_model_dir) - hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained(hf_model_dir) - print("āœ… Huggingface model loaded") + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("āœ… Keras model loaded") - keras_model = convert_model(hf_model) - keras_tokenizer = convert_tokenizer(hf_model_dir) - print("āœ… Keras model loaded") + convert_weights(keras_model, hf_model) + print("āœ… Weights converted") - convert_weights(keras_model, hf_model) - print("āœ… Weights converted") + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("āœ… Numerics validated") - validate_output( - hf_model, - keras_model, - hf_tokenizer, - keras_tokenizer, - ) - print("āœ… Numerics validated") + # Delete huggingface model + del hf_model + del hf_tokenizer - keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) - keras_nlp.src.utils.preset_utils.save_to_preset( - keras_tokenizer, preset, config_filename="tokenizer.json" - ) - print("āœ… Preset saved") + # Save float32 keras preset + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + + # Delete float32 Keras model + del keras_model + + # Load The model in float16 percision + preset_path = os.path.join(os.getcwd(), preset) + keras_model = BloomBackbone.from_preset(preset_path, dtype="float16") + + # Save float16 keras model + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + + print("āœ… Preset saved") + else: + print(f"āœ… Validating {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("āœ… Huggingface model downloaded from hub") + + hf_model = transformers.BloomModel.from_pretrained( + hf_model_dir, + ) + hf_tokenizer = transformers.BloomTokenizerFast.from_pretrained( + hf_model_dir + ) + + keras_model = BloomBackbone.from_preset(preset) + keras_tokenizer = BloomTokenizer.from_preset(preset) + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("āœ… Numerics validated") if __name__ == "__main__": diff --git a/tools/checkpoint_conversion/convert_electra_checkpoints.py b/tools/checkpoint_conversion/convert_electra_checkpoints.py new file mode 100644 index 0000000000..8bbafb1305 --- /dev/null +++ b/tools/checkpoint_conversion/convert_electra_checkpoints.py @@ -0,0 +1,278 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Electra weights conversion script. +""" + +import json +import os + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import tensorflow as tf # noqa: E402 +import transformers # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 + +import keras_nlp # noqa: E402 +from keras_nlp.utils.preset_utils import save_to_preset # noqa: E402 + +PRESET_MAP = { + "electra_base_generator_en": "google/electra-base-generator", + "electra_small_generator_en": "google/electra-small-generator", + "electra_base_discriminator_en": "google/electra-base-discriminator", + "electra_small_discriminator_en": "google/electra-small-discriminator", + "electra_large_discriminator_en": "google/electra-large-discriminator", + "electra_large_generator_en": "google/electra-large-generator", +} + +EXTRACT_DIR = "./model" + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + "electra_base_discriminator_en", + f'Must be one of {",".join(PRESET_MAP)}', +) +flags.mark_flag_as_required("preset") + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onx/*"], + local_dir=EXTRACT_DIR, + ) + return hf_model_dir + + +def convert_model(hf_model): + hf_config = hf_model.config.to_dict() + cfg = {} + cfg["vocab_size"] = hf_config["vocab_size"] + cfg["embedding_dim"] = hf_config["embedding_size"] + cfg["num_layers"] = hf_config["num_hidden_layers"] + cfg["num_heads"] = hf_config["num_attention_heads"] + cfg["hidden_dim"] = hf_config["hidden_size"] + cfg["intermediate_dim"] = hf_config["intermediate_size"] + cfg["dropout"] = hf_config["hidden_dropout_prob"] + cfg["max_sequence_length"] = hf_config["max_position_embeddings"] + return keras_nlp.models.ElectraBackbone(**cfg) + + +def convert_tokenizer(hf_model_dir): + tokenizer_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_path) as f: + hf_tokenizer = json.load(f) + vocab = hf_tokenizer["model"]["vocab"] + + return keras_nlp.models.ElectraTokenizer(vocabulary=vocab) + + +def convert_weights(keras_model, hf_model): + hf_model_dict = hf_model.state_dict() + + keras_model.get_layer("token_embedding").embeddings.assign( + hf_model_dict["embeddings.word_embeddings.weight"].numpy() + ) + keras_model.get_layer("position_embedding").position_embeddings.assign( + hf_model_dict["embeddings.position_embeddings.weight"].numpy() + ) + keras_model.get_layer("segment_embedding").embeddings.assign( + hf_model_dict["embeddings.token_type_embeddings.weight"].numpy() + ) + keras_model.get_layer("embeddings_layer_norm").gamma.assign( + hf_model_dict["embeddings.LayerNorm.weight"] + ) + keras_model.get_layer("embeddings_layer_norm").beta.assign( + hf_model_dict["embeddings.LayerNorm.bias"] + ) + + if any( + layer.name == "embeddings_projection" for layer in keras_model.layers + ): + keras_model.get_layer("embeddings_projection").kernel.assign( + hf_model_dict["embeddings_project.weight"].transpose(1, 0).numpy() + ) + keras_model.get_layer("embeddings_projection").bias.assign( + hf_model_dict["embeddings_project.bias"] + ) + + for i in range(keras_model.num_layers): + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.query.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._query_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.query.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.key.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._key_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.key.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.value.weight"] + .transpose(1, 0) + .reshape((keras_model.hidden_dim, keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._value_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.attention.self.value.bias"] + .reshape((keras_model.num_heads, -1)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.attention.output.dense.weight"] + .transpose(1, 0) + .reshape((keras_model.num_heads, -1, keras_model.hidden_dim)) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer._output_dense.bias.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.dense.bias" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.gamma.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.LayerNorm.weight" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._self_attention_layer_norm.beta.assign( + hf_model_dict[ + f"encoder.layer.{i}.attention.output.LayerNorm.bias" + ].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.intermediate.dense.weight"] + .transpose(1, 0) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_intermediate_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.intermediate.dense.bias"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.kernel.assign( + hf_model_dict[f"encoder.layer.{i}.output.dense.weight"] + .transpose(1, 0) + .numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_output_dense.bias.assign( + hf_model_dict[f"encoder.layer.{i}.output.dense.bias"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.gamma.assign( + hf_model_dict[f"encoder.layer.{i}.output.LayerNorm.weight"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{i}" + )._feedforward_layer_norm.beta.assign( + hf_model_dict[f"encoder.layer.{i}.output.LayerNorm.bias"].numpy() + ) + + +def validate_output(keras_model, hf_model, keras_tokenizer, hf_tokenizer): + input_str = ["The quick brown fox jumps over the lazy dog."] + + keras_nlp_preprocessor = keras_nlp.models.ElectraPreprocessor( + keras_tokenizer + ) + keras_nlp_inputs = keras_nlp_preprocessor(tf.constant(input_str)) + keras_nlp_output = keras_model.predict(keras_nlp_inputs).get( + "sequence_output" + ) + + hf_inputs = hf_tokenizer( + input_str, padding="max_length", return_tensors="pt" + ) + hf_output = hf_model(**hf_inputs).last_hidden_state.detach().numpy() + print("šŸ”¶ KerasNLP output:", keras_nlp_output[0, 0, :10]) + print("šŸ”¶ HF output:", hf_output[0, 0, :10]) + print("Difference: ", np.mean(keras_nlp_output - hf_output)) + + +def main(_): + preset = FLAGS.preset + assert preset in PRESET_MAP.keys(), f"Invalid preset: {preset}" + print(f"āœ… Converting {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("āœ… Downloaded model from Hugging face hub") + + hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_model_name) + hf_model = transformers.AutoModel.from_pretrained(hf_model_name) + print(f"āœ… Loaded {preset} from Hugging Face") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("āœ… Keras model loaded") + + convert_weights(keras_model, hf_model) + print("āœ… Weights converted") + + validate_output(keras_model, hf_model, keras_tokenizer, hf_tokenizer) + print("āœ… Validation complete") + + save_to_preset(keras_model, preset) + save_to_preset(keras_tokenizer, preset, config_filename="tokenizer.json") + + print("āœ… Preset saved") + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/checkpoint_conversion/convert_falcon_checkpoints.py b/tools/checkpoint_conversion/convert_falcon_checkpoints.py new file mode 100644 index 0000000000..fdbdffd670 --- /dev/null +++ b/tools/checkpoint_conversion/convert_falcon_checkpoints.py @@ -0,0 +1,307 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Falcon weight conversion script. + +To run, install the CPU only development environment and huggingface libraries: +``` +pip install -r requirements.txt +pip install transformers huggingface-cli +``` + +Login to Huggingface: +``` +huggingface-cli login +``` + +Finally run this script to convert, validate and upload weights. +``` +python tools/checkpoint_conversion/convert_falcon_checkpoints.py \ + --preset falcon_refinedweb_1b_en +``` +""" + +import json +import os + +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import absl # noqa: E402 +import huggingface_hub # noqa: E402 +import numpy as np # noqa: E402 +import torch # noqa: E402 +import transformers # noqa: E402 + +import keras_nlp # noqa: E402 + +PRESET_MAP = { + "falcon_refinedweb_1b_en": "tiiuae/falcon-rw-1b", +} + +EXTRACT_DIR = "./model" + +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string( + "preset", + "falcon_refinedweb_1b_en", + f'Must be one of {",".join(PRESET_MAP.keys())}.', +) + + +def download_hf_model(hf_model_name): + hf_model_dir = huggingface_hub.snapshot_download( + repo_id=hf_model_name, + allow_patterns=["*.json", "*.bin"], + ignore_patterns=["onnx/*"], + local_dir=EXTRACT_DIR, + ) + + return hf_model_dir + + +def convert_model(hf_model): + hf_config = hf_model.config.to_dict() + kwargs = {} + kwargs["vocabulary_size"] = hf_config["vocab_size"] + kwargs["num_layers"] = hf_config["num_hidden_layers"] + kwargs["num_attention_heads"] = hf_config["num_attention_heads"] + kwargs["hidden_dim"] = hf_config["hidden_size"] + kwargs["intermediate_dim"] = 4 * kwargs["hidden_dim"] + kwargs["feedforward_dropout_rate"] = hf_config["hidden_dropout"] + kwargs["attention_dropout_rate"] = hf_config["attention_dropout"] + + return keras_nlp.models.FalconBackbone(**kwargs) + + +def convert_tokenizer(hf_model_dir): + tokenizer_file_path = os.path.join(hf_model_dir, "tokenizer.json") + with open(tokenizer_file_path) as tokenizer_file: + hf_tokenizer = json.load(tokenizer_file) + + vocab = hf_tokenizer["model"]["vocab"] + merges = hf_tokenizer["model"]["merges"] + return keras_nlp.models.FalconTokenizer(vocabulary=vocab, merges=merges) + + +def convert_weights(keras_model, hf_model): + hf_model.eval() + hf_wts = hf_model.state_dict() + + # token_embedding. + keras_model.get_layer("token_embedding").embeddings.assign( + hf_wts["word_embeddings.weight"] + ) + + for ilayer in range(keras_model.num_layers): + # Split key query value. + fused_qkv = ( + hf_wts[f"h.{ilayer}.self_attention.query_key_value.weight"] + .numpy() + .T + ) + seq_length, _ = fused_qkv.shape + head_dim = keras_model.hidden_dim // keras_model.num_attention_heads + fused_qkv = fused_qkv.reshape( + seq_length, keras_model.num_attention_heads, 3, head_dim + ) + query, key, value = ( + fused_qkv[..., 0, :], + fused_qkv[..., 1, :], + fused_qkv[..., 2, :], + ) + + fused_bias = hf_wts[ + f"h.{ilayer}.self_attention.query_key_value.bias" + ].numpy() + fused_bias = fused_bias.reshape( + keras_model.num_attention_heads, 3, head_dim + ) + query_bias, key_bias, value_bias = ( + fused_bias[..., 0, :], + fused_bias[..., 1, :], + fused_bias[..., 2, :], + ) + + # TODO: check if bias is true before assigning bias. + # Attention/query. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.query_dense.kernel.assign(query) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.query_dense.bias.assign(query_bias) + + # Attention/key. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.key_dense.kernel.assign(key) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.key_dense.bias.assign(key_bias) + + # Attention/value. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.value_dense.kernel.assign(value) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.value_dense.bias.assign(value_bias) + + # Attention/dense. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.output_dense.kernel.assign( + hf_wts[f"h.{ilayer}.self_attention.dense.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).attention_layer.output_dense.bias.assign( + hf_wts[f"h.{ilayer}.self_attention.dense.bias"].numpy() + ) + + # MLP/dense_h_to_4h. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).dense_h_to_4h.kernel.assign( + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).dense_h_to_4h.bias.assign( + hf_wts[f"h.{ilayer}.mlp.dense_h_to_4h.bias"].numpy() + ) + + # MLP/dense_4h_to_h. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).dense_4h_to_h.kernel.assign( + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.weight"].T.numpy() + ) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).dense_4h_to_h.bias.assign( + hf_wts[f"h.{ilayer}.mlp.dense_4h_to_h.bias"].numpy() + ) + + # input_layernorm. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).input_layernorm.gamma.assign( + hf_wts[f"h.{ilayer}.input_layernorm.weight"] + ) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).input_layernorm.beta.assign( + hf_wts[f"h.{ilayer}.input_layernorm.bias"] + ) + + # post_attention_layernorm. + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).post_attention_layernorm.gamma.assign( + hf_wts[f"h.{ilayer}.post_attention_layernorm.weight"].numpy() + ) + keras_model.get_layer( + f"transformer_layer_{ilayer}" + ).post_attention_layernorm.beta.assign( + hf_wts[f"h.{ilayer}.post_attention_layernorm.bias"].numpy() + ) + + # final_layernorm. + keras_model.get_layer("final_layernorm").gamma.assign( + hf_wts["ln_f.weight"].numpy() + ) + keras_model.get_layer("final_layernorm").beta.assign( + hf_wts["ln_f.bias"].numpy() + ) + + +def validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, +): + input_str = ["the quick brown fox ran, galloped and jumped."] + + # KerasNLP model. + token_ids = torch.tensor(keras_tokenizer(input_str)) + padding_mask = token_ids != 3 + keras_model_input = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + keras_model_outputs = keras_model.predict(keras_model_input) + + # HuggingFace model. + hf_model_input = hf_tokenizer(input_str, return_tensors="pt") + + activation = {} + + def get_activation(name): + def hook(hf_model, input, output): + activation[name] = output[0].detach() + + return hook + + hf_model.register_forward_hook(get_activation("ln_f")) + hf_model(**hf_model_input) + hf_model_outputs = activation["ln_f"].detach().numpy() + + # Comparing the outputs. + print("šŸ”¶ KerasNLP tokens ids:", keras_model_input["token_ids"]) + print("šŸ”¶ HF tokens ids:", hf_model_input["input_ids"]) + print("šŸ”¶ KerasNLP output:", keras_model_outputs[0, 1, :10]) + print("šŸ”¶ HF output:", hf_model_outputs[0, 1, :10]) + print("šŸ”¶ Difference:", np.mean(keras_model_outputs - hf_model_outputs)) + + +def main(_): + preset = FLAGS.preset + print(f"āœ… Coverting {preset}") + + hf_model_name = PRESET_MAP[preset] + hf_model_dir = download_hf_model(hf_model_name) + print("āœ… Huggingface model downloaded from hub") + + hf_model = transformers.FalconModel.from_pretrained(hf_model_dir) + # Falcon uses GPT2 tokenizer. + hf_tokenizer = transformers.GPT2TokenizerFast.from_pretrained(hf_model_dir) + print("āœ… Huggingface model loaded") + + keras_model = convert_model(hf_model) + keras_tokenizer = convert_tokenizer(hf_model_dir) + print("āœ… Keras model loaded") + + convert_weights(keras_model, hf_model) + print("āœ… Weights converted") + + validate_output( + hf_model, + keras_model, + hf_tokenizer, + keras_tokenizer, + ) + print("āœ… Numerics validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print("āœ… Preset saved") + + +if __name__ == "__main__": + absl.app.run(main) diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py new file mode 100644 index 0000000000..ed81e023d4 --- /dev/null +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -0,0 +1,224 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert Gemma flax checkpoints to the Keras format. + +Setup: +pip install -r requirements.txt +pip install git+https://github.com/google-deepmind/gemma.git +python pip_build.py --install + +Usage: +cd tools/checkpoint_conversion +python convert_gemma_checkpoints.py --preset gemma_2b_en +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" +# No GPU for conversion, makes memory management easier. +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import kagglehub # noqa: E402 +import keras # noqa: E402 +import numpy as np # noqa: E402 +import sentencepiece # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 +from gemma import params as params_lib # noqa: E402 +from gemma import sampler as sampler_lib # noqa: E402 +from gemma import transformer as transformer_lib # noqa: E402 + +import keras_nlp # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "gemma_2b_en": "google/gemma/flax/2b", + "gemma_7b_en": "google/gemma/flax/7b", + "gemma_instruct_2b_en": "google/gemma/flax/2b-it", + "gemma_instruct_7b_en": "google/gemma/flax/7b-it", +} + + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) + + +def download_flax_model(handle): + return kagglehub.model_download(handle) + + +def convert_model(flax_config, vocab_size): + return keras_nlp.models.GemmaBackbone( + vocabulary_size=vocab_size, + num_layers=flax_config.num_layers, + num_query_heads=flax_config.num_heads, + num_key_value_heads=flax_config.num_kv_heads, + hidden_dim=flax_config.embed_dim, + intermediate_dim=flax_config.hidden_dim * 2, + head_dim=flax_config.head_dim, + ) + + +def convert_tokenizer(proto_path): + return keras_nlp.models.GemmaTokenizer(proto=proto_path) + + +def convert_weights(keras_model, flax_config, flax_params): + # Chomp the embedding weights. Upstream pads for TPU efficiency, but this + # leads to weird gotchas (you need to disregard part of your output logits). + embeddings = flax_params["transformer"]["embedder"]["input_embedding"] + embeddings = np.asarray(embeddings[: keras_model.vocabulary_size, :]) + keras_model.get_layer("token_embedding").set_weights([embeddings]) + keras_model.get_layer("final_normalization").set_weights( + [np.asarray(flax_params["transformer"]["final_norm"]["scale"])] + ) + for i in range(flax_config.num_layers): + flax_layer_name = f"layer_{i}" + keras_block = keras_model.get_layer(f"decoder_block_{i}") + + flax_block = flax_params["transformer"][flax_layer_name] + keras_block.pre_attention_norm.set_weights( + [flax_block["pre_attention_norm"]["scale"]] + ) + keras_block.pre_ffw_norm.set_weights( + [flax_block["pre_ffw_norm"]["scale"]] + ) + + keras_block.gating_ffw.set_weights( + [flax_block["mlp"]["gating_einsum"][0]] + ) + keras_block.gating_ffw_2.set_weights( + [flax_block["mlp"]["gating_einsum"][1]] + ) + keras_block.ffw_linear.set_weights([flax_block["mlp"]["linear"]]) + + attn_block = flax_block["attn"] + if flax_config.num_heads != flax_config.num_kv_heads: + # MQA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["q_einsum"]["w"][:, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][1, :, :, :]) + ) + else: + # MHA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][1, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][2, :, :, :]) + ) + keras_block.attention.output_dense.kernel.assign( + flax_block["attn"]["attn_vec_einsum"]["w"] + ) + + +def validate_output( + keras_model, + keras_tokenizer, + flax_params, + flax_tokenizer, +): + input_str = "What is Keras?" + length = 32 + + # KerasNLP + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor(keras_tokenizer) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=keras_model, + preprocessor=preprocessor, + ) + keras_output = gemma_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + + # Flax + transformer_config = transformer_lib.TransformerConfig.from_params( + flax_params, + cache_size=length, + ) + transformer = transformer_lib.Transformer(transformer_config) + sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=flax_tokenizer, + params=flax_params["transformer"], + ) + flax_output = sampler( + input_strings=[input_str], + total_generation_steps=length - 5, # Length of "What is Keras?" + ) + flax_output = input_str + flax_output.text[0] + + # Comparing the outputs. + print("šŸ”¶ KerasNLP output:", keras_output) + print("šŸ”¶ Flax output:", flax_output) + + +def main(_): + preset = FLAGS.preset + + assert ( + preset in PRESET_MAP.keys() + ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + + print(f"šŸƒ Coverting {preset}") + + # Currently all flax weights are bfloat16 (and have much faster download + # times for it). We follow suit with Keras weights. + keras.config.set_floatx("bfloat16") + + handle = PRESET_MAP[preset] + flax_dir = download_flax_model(handle) + proto_path = flax_dir + "/tokenizer.model" + print("āœ… Flax model downloaded from kaggle") + + variant = handle.split("/")[-1] + flax_tokenier = sentencepiece.SentencePieceProcessor() + flax_tokenier.Load(proto_path) + flax_params = params_lib.load_and_format_params(flax_dir + "/" + variant) + flax_config = transformer_lib.TransformerConfig.from_params(flax_params) + print("āœ… Flax model loaded") + + keras_tokenizer = convert_tokenizer(proto_path) + vocab_size = keras_tokenizer.vocabulary_size() + keras_model = convert_model(flax_config, vocab_size) + print("āœ… Keras model loaded") + + convert_weights(keras_model, flax_config, flax_params) + print("āœ… Weights converted") + + validate_output(keras_model, keras_tokenizer, flax_params, flax_tokenier) + print("āœ… Output validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print(f"šŸ Preset saved to ./{preset}") + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/checkpoint_conversion/convert_llama_checkpoints.py b/tools/checkpoint_conversion/convert_llama_checkpoints.py index 5eb3973f36..4e127b2c7d 100644 --- a/tools/checkpoint_conversion/convert_llama_checkpoints.py +++ b/tools/checkpoint_conversion/convert_llama_checkpoints.py @@ -11,131 +11,280 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os +import shutil +import tempfile +import traceback -import torch -from transformers import AutoModel +import numpy as np +from absl import app +from absl import flags +from keras import ops +from transformers import AutoTokenizer +from transformers import LlamaForCausalLM -from keras_nlp.models.llama.llama_backbone import LlamaBackbone +from keras_nlp.models import LlamaBackbone +from keras_nlp.models import LlamaCausalLMPreprocessor +from keras_nlp.models import LlamaTokenizer +from keras_nlp.utils.preset_utils import save_to_preset -os.environ["KERAS_BACKEND"] = "torch" +PRESET_MAP = { + "llama2_7b_en": "meta-llama/Llama-2-7b-hf", + "llama2_instruct_7b_en": "meta-llama/Llama-2-7b-chat-hf", +} -# from huggingface_hub import login -# llama weights as of now are on request access -# login(token=' Huggingface model and tokenizer loaded") - # MLP - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_intermediate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.up_proj.weight"].numpy().T - ) + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = LlamaBackbone(**backbone_kwargs) - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_gate_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.gate_proj.weight"].numpy().T - ) + # === Get the tokenizer from the Huggingface model === + tokenizer_path = hf_tokenizer.vocab_file + keras_nlp_tokenizer = LlamaTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_output_dense.kernel.assign( - hf_wts[f"layers.{ilayer}.mlp.down_proj.weight"].numpy().T - ) + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") - # LAYERNORM - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._self_attention_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.input_layernorm.weight"] - ) + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") - keras_model.get_layer( - f"transformer_layer_{ilayer}" - )._feedforward_layernorm.weight.assign( - hf_wts[f"layers.{ilayer}.post_attention_layernorm.weight"] - ) + # === Save the model weights in float32 format === + keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + del keras_nlp_model, hf_model + gc.collect() -keras_model.get_layer("layer_norm").gamma.assign(hf_wts["norm.weight"]) + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_nlp_model = LlamaBackbone(**backbone_kwargs) + keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") -token_ids = [1, 2181, 8522, 338] -padding_mask = [1, 1, 1, 1] + # === Save the tokenizer === + save_to_preset( + keras_nlp_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) -keras_inputs = { - "token_ids": torch.tensor([token_ids]), - "padding_mask": torch.tensor([padding_mask]), -} -with torch.no_grad(): - keras_outputs = keras_model(keras_inputs) -print("Keras output = ", keras_outputs.numpy()) +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) diff --git a/tools/checkpoint_conversion/convert_mistral_checkpoints.py b/tools/checkpoint_conversion/convert_mistral_checkpoints.py index 8e10089efd..ae3b1f5b83 100644 --- a/tools/checkpoint_conversion/convert_mistral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mistral_checkpoints.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime import gc -import json import os -import pathlib +import shutil +import tempfile import traceback -import keras import numpy as np import requests from absl import app @@ -27,14 +25,15 @@ from transformers import AutoTokenizer from transformers import MistralForCausalLM -import keras_nlp from keras_nlp.models import MistralBackbone from keras_nlp.models import MistralCausalLMPreprocessor from keras_nlp.models import MistralTokenizer +from keras_nlp.utils.preset_utils import save_to_preset PRESET_MAP = { "mistral_7b_en": "mistralai/Mistral-7B-v0.1", "mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1", + "mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2", } FLAGS = flags.FLAGS @@ -227,124 +226,74 @@ def main(_): preset = FLAGS.preset hf_preset = PRESET_MAP[preset] - # === Create the save directories === - model_dir = pathlib.Path(__file__).parent / f"{preset}" - tokenizer_dir = model_dir / "assets" / "tokenizer" - if not model_dir.exists(): - os.makedirs(model_dir) - if not tokenizer_dir.exists(): - os.makedirs(tokenizer_dir) + # === Create the temporary save directories === + temp_dir = tempfile.mkdtemp() - # === Load the Huggingface model === - hf_model = MistralForCausalLM.from_pretrained(hf_preset) - hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) - hf_model.eval() - print("\n-> Huggingface model and tokenizer loaded") - - # === Load the KerasNLP model === - keras_nlp_config = dict( - vocabulary_size=hf_model.config.vocab_size, - hidden_dim=hf_model.config.hidden_size, - num_layers=hf_model.config.num_hidden_layers, - num_query_heads=hf_model.config.num_attention_heads, - num_key_value_heads=hf_model.config.num_key_value_heads, - intermediate_dim=hf_model.config.intermediate_size, - sliding_window=hf_model.config.sliding_window, - layer_norm_epsilon=hf_model.config.rms_norm_eps, - rope_max_wavelength=hf_model.config.rope_theta, - dtype="float32", - ) - keras_nlp_model = MistralBackbone(**keras_nlp_config) - - # === Download the tokenizer from Huggingface model card === - spm_path = ( - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" - ) - response = requests.get(spm_path) - if not response.ok: - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") - tokenizer_path = tokenizer_dir / "vocabulary.spm" - with open(tokenizer_path, "wb") as tokenizer_file: - tokenizer_file.write(response.content) - keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute())) - print("\n-> Keras 3 model and tokenizer loaded.") - - # === Port the weights === - convert_checkpoints(keras_nlp_model, hf_model) - print("\n-> Weight transfer done.") - - # === Check that the models and tokenizers outputs match === - test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) - test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) - print("\n-> Tests passed!") - - # === Save the model weights in float32 format === - keras_nlp_model.save_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - print("\n-> Saved the model weights in float16") - - del keras_nlp_model, hf_model - gc.collect() - - keras_nlp_config["dtype"] = "float16" - - # === Save the weights again in float16 === - keras_nlp_model = MistralBackbone(**keras_nlp_config) - keras_nlp_model.load_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - keras_nlp_model.save_weights( - str((model_dir / "model.weights.h5").absolute()) - ) - print("-> Saved the model weights in float16") - - # === Save the model config === - keras_nlp_config["dtype"] = "bfloat16" - model_config = { - "module": "keras_nlp.src.models.mistral.mistral_backbone", - "class_name": "MistralBackbone", - "config": {**keras_nlp_config}, - "registered_name": "keras_nlp>MistralBackbone", - "assets": [], - "weights": "model.weights.h5", - } - model_config_json = json.dumps(model_config) - with open(model_dir / "config.json", "w") as model_config_file: - model_config_file.write(model_config_json) - print("\n-> Saved model config") - - # === Save the tokenizer config === - tokenizer_config = { - "module": "keras_nlp.src.models.mistral.Mistral_tokenizer", - "class_name": "MistralTokenizer", - "config": { - "name": "mistral_tokenizer", - "trainable": True, - "dtype": "int32", - "proto": None, - "sequence_length": None, - }, - "registered_name": "keras_nlp>MistralTokenizer", - "assets": ["assets/tokenizer/vocabulary.spm"], - "weights": None, - } - tokenizer_config_json = json.dumps(tokenizer_config) - with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file: - tokenizer_config_file.write(tokenizer_config_json) - print("\n-> Saved tokenizer config") + try: + # === Load the Huggingface model === + hf_model = MistralForCausalLM.from_pretrained(hf_preset) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + # === Load the KerasNLP model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + sliding_window=hf_model.config.sliding_window, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_nlp_model = MistralBackbone(**backbone_kwargs) - # === Save metadata === - metadata_config = { - "keras_version": keras.__version__, - "keras_nlp_version": keras_nlp.__version__, - "parameter_count": keras_nlp_model.count_params(), - "date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"), - } - metadata_config_json = json.dumps(metadata_config) - with open(model_dir / "metadata.json", "w") as metadata_config_file: - metadata_config_file.write(metadata_config_json) - print("\n-> Saved metadata") + # === Download the tokenizer from Huggingface model card === + spm_path = ( + f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" + ) + response = requests.get(spm_path) + if not response.ok: + raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") + tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") + with open(tokenizer_path, "wb") as tokenizer_file: + tokenizer_file.write(response.content) + keras_nlp_tokenizer = MistralTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") + + # === Port the weights === + convert_checkpoints(keras_nlp_model, hf_model) + print("\n-> Weight transfer done.") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_nlp_tokenizer, hf_tokenizer) + test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + # === Save the model weights in float32 format === + keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + + del keras_nlp_model, hf_model + gc.collect() + + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_nlp_model = MistralBackbone(**backbone_kwargs) + keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_nlp_model, preset) + print("\n-> Saved the model preset in float16") + + # === Save the tokenizer === + save_to_preset( + keras_nlp_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) if __name__ == "__main__": diff --git a/tools/gemma/export_gemma_to_hf.py b/tools/gemma/export_gemma_to_hf.py new file mode 100644 index 0000000000..6f1fdf24d2 --- /dev/null +++ b/tools/gemma/export_gemma_to_hf.py @@ -0,0 +1,349 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os + +import torch +import transformers +from absl import app +from absl import flags + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a keras model to HuggingFace format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`), the model size (`2b` or `7b`), and the tokenizer +vocabulary file (`.spm`, `.model`, or equivalent) to +`--weights_file`, `--size`, and `--vocab_path`, respectively. + +Optionally, you can specify the output directory +for the converted model at `--output_dir`. (defaults to `gg_hf`) +``` +python tools/gemma/export_gemma_to_hf.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --vocab_path gemma_lm_tokenizer/vocabulary.spm \ + --output_dir fine_tuned_gg_hf +``` + +For converting a Keras model to HuggingFace format from a preset, +simply pass the Keras preset name to `--preset` and its model size +(`2b` or `7b`) to `--size`. +``` +python tools/gemma/export_gemma_to_hf.py \ + --preset gemma_2b_en \ + --size 2b \ + --output_dir keras_hf_model/ +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": "gg-hf/gemma-2b", + "gemma_instruct_2b_en": "gg-hf/gemma-2b", + "gemma_7b_en": "gg-hf/gemma-7b", + "gemma_instruct_7b_en": "gg-hf/gemma-7b", +} + +SIZE_MAP = { + "2b": ("gg-hf/gemma-2b", "gemma_2b_en"), + "7b": ("gg-hf/gemma-7b", "gemma_7b_en"), +} + +gemma_2b_config = transformers.GemmaConfig( + num_hidden_layers=18, + num_attention_heads=8, + num_key_value_heads=1, + hidden_size=2048, + intermediate_size=16384, +) + +gemma_7b_config = transformers.GemmaConfig() + +CONFIG_MAPPING = {"2b": gemma_2b_config, "7b": gemma_7b_config} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "hf_token", + None, + "Your HuggingFace token. Needed for access to the HuggingFace Gemma" + "implementation since the repository is private, for now.", +) +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_dir", + "gg_hf", + "An output directory for the converted HuggingFace model and tokenizer.", +) +flags.DEFINE_string( + "vocab_path", + None, + "A path containing the vocabulary (must be a `.spm` file or equivalent). " + "If not passed, the vocabulary of the preset will be used.", +) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def convert_checkpoints(preset, weights_file, size, output_dir, vocab_path): + if preset is not None: + hf_id = PRESET_MAP[preset] + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + hf_id, keras_preset = SIZE_MAP[size.lower()] + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + keras_preset + ) + keras_nlp_model.load_weights(weights_file) + + print(f"\n-> Loading HuggingFace Gemma `{size.upper()}` model...") + hf_model = transformers.GemmaForCausalLM(CONFIG_MAPPING[size.lower()]) + + print("\nāœ… Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to HuggingFace Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + hf_vocab_size = hf_model.model.embed_tokens.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if hf_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - hf_vocab_size + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + hf_model.model.embed_tokens, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + + # Pre-attention norm + update_state_dict( + hf_model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + query_target_shape = hf_model.model.layers[ + i + ].self_attn.q_proj.weight.shape + query_tensor = decoder_block.attention.query_dense.weights[0].value + query_tensor = query_tensor.transpose(1, 2).reshape(query_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.q_proj, "weight", query_tensor + ) + + key_target_shape = hf_model.model.layers[ + i + ].self_attn.k_proj.weight.shape + key_tensor = decoder_block.attention.key_dense.weights[0].value + key_tensor = key_tensor.transpose(1, 2).reshape(key_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.k_proj, "weight", key_tensor + ) + + value_target_shape = hf_model.model.layers[ + i + ].self_attn.v_proj.weight.shape + value_tensor = decoder_block.attention.value_dense.weights[0].value + value_tensor = value_tensor.transpose(1, 2).reshape(value_target_shape) + update_state_dict( + hf_model.model.layers[i].self_attn.v_proj, "weight", value_tensor + ) + + out_target_shape = hf_model.model.layers[ + i + ].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + hf_model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + hf_model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + hf_model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + hf_model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + hf_model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\nāœ… Weights converted successfully.") + print(f"\n-> Saving HuggingFace model to `{output_dir}`...") + + # Save model to HF Transformers format + os.makedirs(output_dir, exist_ok=True) + hf_model.save_pretrained(output_dir) + + print(f"\nāœ… Saving complete. Model saved at `{output_dir}`.") + + # Tokenizer + + if not vocab_path: + tokenizer_preset = preset or SIZE_MAP[size.lower()] + print( + "\n-> Loading KerasNLP Gemma tokenizer with " + f"preset `{tokenizer_preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + tokenizer_preset + ) + # Save tokenizer state + keras_nlp_tokenizer.save_assets(output_dir) + vocab_path = os.path.join(output_dir, "vocabulary.spm") + print("\nāœ… Tokenizer loading complete.") + + hf_tokenizer = transformers.GemmaTokenizer(vocab_path) + + print(f"\n-> Saving HuggingFace Gemma tokenizer to `{output_dir}`...") + # Save tokenizer to HF Transformers format + hf_tokenizer.save_pretrained(output_dir) + + print(f"\nāœ… Saving complete. Tokenizer saved at `{output_dir}`.") + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) + + +def main(_): + flag_error_handler() + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_dir, + FLAGS.vocab_path, + ) + + +if __name__ == "__main__": + flags.mark_flag_as_required("size") + app.run(main) diff --git a/tools/gemma/export_gemma_to_torch_xla.py b/tools/gemma/export_gemma_to_torch_xla.py new file mode 100644 index 0000000000..08d4b3ac98 --- /dev/null +++ b/tools/gemma/export_gemma_to_torch_xla.py @@ -0,0 +1,344 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prior to running this conversion script, please install the PyTorch +implementation of Gemma and `torch_xla`: + +`pip install git+https://github.com/google/gemma_pytorch.git` +`pip install torch_xla` + +Please also ensure that your installed versions of `torch_xla` and `torch` are +compatible. +""" + +import contextlib +import os + +import gemma +import torch +import torch_xla.core.xla_model as xm +from absl import app +from absl import flags +from gemma import model_xla as gemma_model + +import keras_nlp + +os.environ["KERAS_BACKEND"] = "torch" + +""" +Sample usage: + +For converting a Keras model to PyTorch format using a custom or fine-tuned +checkpoint from Keras, make sure to pass the path for the Keras weights file +(ending in `.weights.h5`) and the model size (`2b` or `7b`) to `--weights_file` +and `--size`, respectively. + +Optionally, you can specify the output path for the converted model at +`--output_file`. (This defaults to `gemma.ckpt`) +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --weights_file fine_tuned_imdb.weights.h5 \ + --size 2b \ + --output_file fine_tuned_imdb.ckpt +``` + +For converting a Keras model to PyTorch format from a preset, +simply pass the Keras preset name to `--preset`. +``` +python tools/gemma/export_gemma_to_torch_xla.py \ + --preset gemma_2b_en \ + --output_file path/to/keras_torch_model.ckpt +``` + +Following this usage, you can run the verification script to confirm +functionality of the converted checkpoint: + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --size 2b \ + --checkpoint_file fine_tuned_imdb.ckpt \ + --vocab_file gemma_tokenizer/vocabulary.spm \ + --prompt "Inception is about" +``` +""" + + +PRESET_MAP = { + "gemma_2b_en": gemma.config.get_config_for_2b(), + "gemma_instruct_2b_en": gemma.config.get_config_for_2b(), + "gemma_7b_en": gemma.config.get_config_for_7b(), + "gemma_instruct_7b_en": gemma.config.get_config_for_7b(), +} + +SIZE_MAP = { + "2b": (gemma.config.get_config_for_2b(), "gemma_2b_en"), + "7b": (gemma.config.get_config_for_7b(), "gemma_7b_en"), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}' + " Alternatively, a Keras weights file (`.weights.h5`) can be passed" + " to --weights_file flag.", +) +flags.DEFINE_string( + "weights_file", + None, + "A Keras weights file (`.weights.h5`)." + " Alternatively, a model preset can be passed to --preset flag.", +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `weights_file` is passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "output_file", + "gemma.ckpt", + "An output file for the converted PyTorch checkpoint. Default: `gemma.ckpt`", +) +flags.DEFINE_string( + "vocab_dir", + "gemma_tokenizer", + "A directory in which the vocabulary for the tokenizer will be stored.", +) +flags.DEFINE_string( + "dtype", + "float32", + "Set the precision of the converted checkpoint. Must be a valid PyTorch dtype.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def _reconcile_attention_dims(qkv, target_shape): + return torch.cat(qkv).reshape(tuple(target_shape)) + + +def convert_checkpoints(preset, weights_file, size, output_file, vocab_dir): + device = xm.xla_device() + + if preset is not None: + print( + f"\n-> Loading PyTorch Gemma model config for preset `{preset}`..." + ) + model = gemma_model.GemmaForCausalLM( + PRESET_MAP[preset], world_size=1, rank=0, device=device + ) + print(f"\n-> Loading KerasNLP Gemma model with preset `{preset}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset(preset) + else: + print(f"\n-> Loading PyTorch Gemma model config for `{size}` model...") + config, size_preset = SIZE_MAP[size.lower()] + model = gemma_model.GemmaForCausalLM( + config, world_size=1, rank=0, device=device + ) + print(f"\n-> Loading Keras weights from file `{weights_file}`...") + keras_nlp_model = keras_nlp.models.GemmaCausalLM.from_preset( + size_preset + ) + keras_nlp_model.load_weights(weights_file) + + print("\nāœ… Model loading complete.") + print("\n-> Converting weights from KerasNLP Gemma to PyTorch Gemma...") + + # Token embedding (with vocab size difference handling) + keras_embedding = keras_nlp_model.backbone.token_embedding.weights[0] + torch_vocab_size = model.embedder.weight.shape[0] + keras_nlp_vocab_size = keras_embedding.value.shape[0] + if torch_vocab_size < keras_nlp_vocab_size: + diff = keras_nlp_vocab_size - torch_vocab_size + update_state_dict( + model.embedder, + "weight", + keras_embedding.value[:-diff, :], + ) + else: + update_state_dict( + model.embedder, + "weight", + keras_embedding.value, + ) + + # Decoder blocks + for i in range(keras_nlp_model.backbone.num_layers): + decoder_block = keras_nlp_model.backbone.get_layer(f"decoder_block_{i}") + # Pre-attention norm + update_state_dict( + model.model.layers[i].input_layernorm, + "weight", + decoder_block.pre_attention_norm.weights[0].value, + ) + + # Attention + qkv = ( + decoder_block.attention.query_dense.weights[0].value.transpose( + 1, 2 + ), + decoder_block.attention.key_dense.weights[0].value.transpose(1, 2), + decoder_block.attention.value_dense.weights[0].value.transpose( + 1, 2 + ), + ) + qkv_target_shape = model.model.layers[i].self_attn.qkv_proj.weight.shape + combined_tensor = _reconcile_attention_dims(qkv, qkv_target_shape) + + update_state_dict( + model.model.layers[i].self_attn.qkv_proj, "weight", combined_tensor + ) + + out_target_shape = model.model.layers[i].self_attn.o_proj.weight.shape + keras_out_tensor = decoder_block.attention.output_dense.weights[0].value + out_tensor = keras_out_tensor.reshape( + (out_target_shape[1], out_target_shape[0]) # Transpose target size + ).transpose(0, 1) + + update_state_dict( + model.model.layers[i].self_attn.o_proj, "weight", out_tensor + ) + + # Post-attention norm + update_state_dict( + model.model.layers[i].post_attention_layernorm, + "weight", + decoder_block.pre_ffw_norm.weights[0].value, + ) + + # MLP (Feed-forward) + update_state_dict( + model.model.layers[i].mlp.gate_proj, + "weight", + decoder_block.gating_ffw.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.up_proj, + "weight", + decoder_block.gating_ffw_2.weights[0].value.transpose(0, 1), + ) + update_state_dict( + model.model.layers[i].mlp.down_proj, + "weight", + decoder_block.ffw_linear.weights[0].value.transpose(0, 1), + ) + + # Final norm + update_state_dict( + model.model.norm, + "weight", + keras_nlp_model.backbone.layers[-1].weights[0].value, + ) + + print("\nāœ… Weights converted successfully.") + print(f"\n-> Saving PyTorch model checkpoint to `{output_file}`...") + + # Save model checkpoint + torch.save({"model_state_dict": model.state_dict()}, output_file) + + print( + f"\nāœ… Saving complete. Model checkpoint available at `{output_file}`." + ) + + if preset is not None: + # Tokenizer + print( + f"\n-> Loading KerasNLP Gemma tokenizer with preset `{preset}`..." + ) + keras_nlp_tokenizer = keras_nlp.models.GemmaTokenizer.from_preset( + preset + ) + print("\nāœ… Model loading complete.") + print(f"\n-> Saving tokenizer state to directory `{vocab_dir}`...") + + # Save tokenizer state + os.makedirs(vocab_dir, exist_ok=True) + keras_nlp_tokenizer.save_assets(vocab_dir) + + print( + "\nāœ… Saving complete. Tokenizer state " + f"available at `{vocab_dir}/vocabulary.spm`." + ) + + +def update_state_dict(layer, weight_name: str, tensor: torch.Tensor) -> None: + """Updates the state dict for a weight given a tensor.""" + assert ( + tensor.shape == layer.state_dict()[weight_name].shape + ), f"{tensor.shape} vs {layer.state_dict()[weight_name].shape}" + layer.state_dict()[weight_name].copy_(tensor) + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.weights_file: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a Keras weights file (`.weights.h5`) and model size" + " (`2b` or `7b`) to `--weights_file` and `--size`, respectively." + ) + if FLAGS.weights_file: + if FLAGS.preset: + raise ValueError( + "Both `--preset` and `--weights_file` flags cannot be supplied " + "at the same time. Either supply a valid Keras preset to " + "`--preset`or supply a Keras `.weights.h5` file and " + "model size (`2b` or `7b`) to `--weights_file` and `--size`, " + "respectively." + ) + if not str(FLAGS.weights_file).endswith(".weights.h5"): + raise ValueError( + "Please pass a valid Keras weights file ending in `.weights.h5`." + ) + if not FLAGS.size: + raise ValueError( + "The `size` flag must be passed if a weights file is passed. " + "Please pass the appropriate size (`2b` or `7b`) for your " + "model to the `--size` flag." + ) + if FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + if FLAGS.dtype: + dtype = getattr(torch, FLAGS.dtype) + if not isinstance(dtype, torch.dtype): + raise ValueError( + "Invalid `dtype`. Please pass a valid PyTorch data type (e.g. " + "`float32', 'float16`, etc.) to the `--dtype` flag." + ) + + +def main(_): + flag_error_handler() + with _set_default_tensor_type(getattr(torch, FLAGS.dtype)): + convert_checkpoints( + FLAGS.preset, + FLAGS.weights_file, + FLAGS.size, + FLAGS.output_file, + FLAGS.vocab_dir, + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/gemma/run_gemma_xla.py b/tools/gemma/run_gemma_xla.py new file mode 100644 index 0000000000..f212154c99 --- /dev/null +++ b/tools/gemma/run_gemma_xla.py @@ -0,0 +1,340 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is a modified version of `run_xla.py` script in the PyTorch Gemma repo +to ensure proper functionality after porting checkpoints from Keras. Please +run `export_gemma_to_torch_xla.py` prior to running this verification script. + +As with the conversion script, ensure that `torch_xla` and the PyTorch +implementation of Gemma are properly installed: + +`pip install git+https://github.com/google/gemma_pytorch.git` +`pip install torch_xla` + +Note that this verification script can take several minutes to run. +""" + +import contextlib +import os +import random +import sys +from typing import List + +import gemma.xla_model_parallel as xla_model_parallel +import numpy as np +import torch +import torch.multiprocessing +import torch_xla.core.xla_model as xm +import torch_xla.distributed.xla_multiprocessing as xmp +from absl import app +from absl import flags +from gemma.config import GemmaConfig +from gemma.config import get_config_for_2b +from gemma.config import get_config_for_7b +from gemma.model_xla import GemmaForCausalLM +from gemma.tokenizer import Tokenizer + +""" +Sample usage: + +Run the verification script supplying your model size, converted checkpoint file, +vocabulary file, and test prompt. + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --size 2b \ + --checkpoint_file fine_tuned_imdb.ckpt \ + --vocab_file gemma_tokenizer/vocabulary.spm \ + --prompt "Three Billboards" +``` + +After a delay (a couple minutes if running on CPU), this should produce: +``` +====================================== +PROMPT: Three Billboards +RESULT: Outside Ebbing, Missouri is a film in the tradition of Hollywood westerns +====================================== +``` + +If running from a preset, instead provide your converted checkpoint file and +the associated preset name: + +``` +python keras-nlp-gemma/tools/gemma/run_gemma_xla.py \ + --preset gemma_2b_en \ + --checkpoint_file gemma_2b.ckpt \ + --prompt "California is the largest" +``` + +After a delay (a couple minutes if running on CPU), this should produce: +``` +====================================== +PROMPT: California is the largest +RESULT: producer of strawberries in the world, and is a +====================================== +``` +""" + +PAD_TOKEN_ID = -1 + +FILE_PATH = "gemma.ckpt" +TOKENIZER_DIR = "gemma_tokenizer" + +PRESET_MAP = { + "gemma_2b_en": get_config_for_2b(), + "gemma_instruct_2b_en": get_config_for_2b(), + "gemma_7b_en": get_config_for_7b(), + "gemma_instruct_7b_en": get_config_for_7b(), +} + +SIZE_MAP = { + "2b": get_config_for_2b(), + "7b": get_config_for_7b(), +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f'Must be one of {",".join(PRESET_MAP.keys())}' +) +flags.DEFINE_string( + "size", + None, + "Size of model. Must be passed if `preset` is not passed. " + "This should be either `2b` or `7b`.", +) +flags.DEFINE_string( + "checkpoint_file", + "gemma.ckpt", + "A PyTorch checkpoint file containing the converted weights.", +) +flags.DEFINE_string( + "vocab_file", + "gemma_tokenizer/vocabulary.spm", + "The file containing the vocabulary for the tokenizer.", +) +flags.DEFINE_string( + "prompt", + "The capital of France is", + "A test prompt for verifying functionality of the PyTorch Gemma model.", +) + + +@contextlib.contextmanager +def _set_default_tensor_type(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(torch.float) + + +def generate( + i: int, + model_config: GemmaConfig, + checkpoint_file: str, + vocab_file: str, + prompts: List[str], + output_lens: List[int], + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], +): + # Set seed from config + seed = model_config.seed + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + device = xm.xla_device() + xm.set_rng_state(seed, device) + + rank = xla_model_parallel.get_model_parallel_rank() + world_size = xla_model_parallel.get_model_parallel_world_size() + if rank > 0: + sys.stdout = open(os.devnull, "w") + + # Load model with ported weights and place on device + with _set_default_tensor_type(model_config.get_dtype()): + model = GemmaForCausalLM(model_config, world_size, rank, device) + model.load_weights(checkpoint_file) + model = model.to(device).eval() + + # Create tokenizer with saved Keras tokenizer state + tokenizer = Tokenizer(vocab_file) + + prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts] + min_prompt_len = min(len(p) for p in prompt_tokens) + + batch_size = len(prompts) + assert batch_size == len(temperatures) + assert batch_size == len(top_ps) + assert batch_size == len(top_ks) + max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)]) + assert max_seq_len <= model_config.max_position_embeddings + if model_config.num_key_value_heads < world_size: + assert world_size % model_config.num_key_value_heads == 0 + n_local_heads = 1 + else: + assert model_config.num_key_value_heads % world_size == 0 + n_local_heads = model_config.num_key_value_heads // world_size + + # build KV caches + kv_caches = [] + for _ in range(model_config.num_hidden_layers): + k_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + v_cache = torch.zeros( + size=( + batch_size, + max_seq_len, + n_local_heads, + model_config.head_dim, + ), + dtype=model_config.get_dtype(), + device=device, + ) + kv_caches.append((k_cache, v_cache)) + + # prepare inputs + token_ids_tensor = torch.full( + (batch_size, max_seq_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + input_token_ids_tensor = torch.full( + (batch_size, min_prompt_len), PAD_TOKEN_ID, dtype=torch.int64 + ) + for i, p in enumerate(prompt_tokens): + token_ids_tensor[i, : len(p)] = torch.tensor(p) + input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( + p[:min_prompt_len] + ) + token_ids_tensor = token_ids_tensor.to(device) + prompt_mask_tensor = token_ids_tensor != PAD_TOKEN_ID + input_token_ids_tensor = input_token_ids_tensor.to(device) + input_positions_tensor = torch.arange( + 0, min_prompt_len, dtype=torch.int64 + ).to(device) + mask_tensor = torch.full( + (1, 1, max_seq_len, max_seq_len), -2.3819763e38 + ).to(torch.float) + mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) + temperatures_tensor = torch.FloatTensor(temperatures).to(device) + top_ps_tensor = torch.FloatTensor(top_ps).to(device) + top_ks_tensor = torch.LongTensor(top_ks).to(device) + output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) + xm.mark_step() + + # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. + for i in range(max_seq_len - min_prompt_len): + next_token_ids = model( + input_token_ids=input_token_ids_tensor, + input_positions=input_positions_tensor, + kv_write_indices=None, + kv_caches=kv_caches, + mask=curr_mask_tensor, + output_positions=output_positions_tensor, + temperatures=temperatures_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + ) + curr_prompt_mask = prompt_mask_tensor.index_select( + 1, output_index + ).squeeze(dim=1) + curr_token_ids = token_ids_tensor.index_select(1, output_index).squeeze( + dim=1 + ) + output_token_ids = torch.where( + curr_prompt_mask, curr_token_ids, next_token_ids + ).unsqueeze(dim=1) + token_ids_tensor.index_copy_(1, output_index, output_token_ids) + + input_token_ids_tensor = output_token_ids + input_positions_tensor = output_index + curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) + output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) + output_index = output_index + 1 + xm.mark_step() + + # Detokenization. + token_ids = token_ids_tensor.tolist() + results = [] + for i, tokens in enumerate(token_ids): + trimmed_output = tokens[ + len(prompt_tokens[i]) : len(prompt_tokens[i]) + output_lens[i] + ] + if tokenizer.eos_id in trimmed_output: + eos_index = trimmed_output.index(tokenizer.eos_id) + trimmed_output = trimmed_output[:eos_index] + results.append(tokenizer.decode(trimmed_output)) + + for prompt, result in zip(prompts, results): + print("======================================") + print(f"PROMPT: {prompt}") + print(f"RESULT: {result}") + print("======================================") + + +def flag_error_handler(): + if not FLAGS.preset and not FLAGS.size: + raise ValueError( + "Please pass either a valid Keras preset to `--preset`" + " or supply a model size (`2b` or `7b`) to `--size`." + ) + if FLAGS.size and FLAGS.size.lower() not in ["2b", "7b"]: + raise ValueError( + "Invalid `size`. Please pass the appropriate size (`2b` or `7b`) " + "for your model to the `--size` flag." + ) + + +def main(_): + flag_error_handler() + if FLAGS.preset: + model_config = PRESET_MAP[FLAGS.preset] + else: + model_config = SIZE_MAP[FLAGS.size.lower()] + prompts = [ + FLAGS.prompt, + ] + n = len(prompts) + output_lengths = [10] * n + temperatures = [0.95] * n + top_ps = [1.0] * n + top_ks = [100] * n + xmp.spawn( + generate, + args=( + model_config, + FLAGS.checkpoint_file, + FLAGS.vocab_file, + prompts, + output_lengths, + temperatures, + top_ps, + top_ks, + ), + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/tools/sentencepiece_testing/create_gemma_test_proto.py b/tools/sentencepiece_testing/create_gemma_test_proto.py new file mode 100644 index 0000000000..c3ce418a4b --- /dev/null +++ b/tools/sentencepiece_testing/create_gemma_test_proto.py @@ -0,0 +1,36 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tools.sentencepiece_testing.utils import train_sentencepiece + + +def main(): + train_sentencepiece( + ["the quick brown fox", "the earth is round"], + "gemma_test_vocab.spm", + vocab_size=11, + model_type="WORD", + pad_id=0, + bos_id=1, + eos_id=2, + unk_id=3, + pad_piece="", + bos_piece="", + eos_piece="", + unk_piece="", + ) + + +if __name__ == "__main__": + main()