Skip to content

Commit

Permalink
Improve error handling for non-keras model loading attempts (#1577)
Browse files Browse the repository at this point in the history
* Improve error handling when user is attempting to load a non-keras model.

* Install huggingface_hub package for gpu tests.

* Switch to bert model for testing.

* Use assertRaisesRegex in unit tests.
  • Loading branch information
SamanehSaadat authored Apr 17, 2024
1 parent ee5263b commit 16d3ebb
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ then
fi

pip install --no-deps -e "." --progress-bar off
pip install huggingface_hub

# Run Extra Large Tests for Continuous builds
if [ "${RUN_XLARGE:-0}" == "1" ]
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_metadata
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import validate_metadata
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -197,6 +198,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
)
```
"""
validate_metadata(preset)
preset_cls = check_config_class(preset)
if not issubclass(preset_cls, cls):
raise ValueError(
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from keras_nlp.models.bert.bert_backbone import BertBackbone
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import METADATA_FILE


class TestTask(TestCase):
Expand All @@ -42,3 +43,8 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Backbone.from_preset("bert_tiny_en_uncased", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
# No loading on a non-keras model.
Backbone.from_preset("hf://google-bert/bert-base-uncased")
2 changes: 2 additions & 0 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import validate_metadata
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -126,6 +127,7 @@ def from_preset(
)
```
"""
validate_metadata(preset)
if cls == Preprocessor:
raise ValueError(
"Do not call `Preprocessor.from_preset()` directly. Instead call a "
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import METADATA_FILE


class TestTask(TestCase):
Expand Down Expand Up @@ -49,5 +50,10 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertPreprocessor.from_preset("gpt2_base_en")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
# No loading on a non-keras model.
Preprocessor.from_preset("hf://google-bert/bert-base-uncased")

# TODO: Add more tests when we added a model that has `preprocessor.json`.
3 changes: 3 additions & 0 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import validate_metadata
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -212,6 +213,8 @@ def from_preset(
)
```
"""
validate_metadata(preset)

if cls == Task:
raise ValueError(
"Do not call `Task.from_preset()` directly. Instead call a "
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import METADATA_FILE


class SimpleTokenizer(Tokenizer):
Expand Down Expand Up @@ -69,6 +70,11 @@ def test_from_preset_errors(self):
with self.assertRaises(ValueError):
# No loading on an incorrect class.
BertClassifier.from_preset("gpt2_base_en", load_weights=False)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
# No loading on a non-keras model.
CausalLM.from_preset("hf://google-bert/bert-base-uncased")

def test_summary_with_preprocessor(self):
preprocessor = SimplePreprocessor()
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_tokenizer_assets
from keras_nlp.utils.preset_utils import validate_metadata
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -214,6 +215,7 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
tokenizer.detokenize([5, 6, 7, 8, 9])
```
"""
validate_metadata(preset)
preset_cls = check_config_class(
preset, config_file=TOKENIZER_CONFIG_FILE
)
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/tokenizers/tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.tests.test_case import TestCase
from keras_nlp.tokenizers.tokenizer import Tokenizer
from keras_nlp.utils.preset_utils import METADATA_FILE


class SimpleTokenizer(Tokenizer):
Expand Down Expand Up @@ -54,6 +55,11 @@ def test_from_preset(self):
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
GPT2Tokenizer.from_preset("bert_tiny_en_uncased")
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
# No loading on a non-keras model.
Tokenizer.from_preset("hf://google-bert/bert-base-uncased")

def test_tokenize(self):
input_data = ["the quick brown fox"]
Expand Down
18 changes: 17 additions & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
TOKENIZER_CONFIG_FILE = "tokenizer.json"
TASK_CONFIG_FILE = "task.json"
PREPROCESSOR_CONFIG_FILE = "preprocessor.json"
METADATA_FILE = "metadata.json"

# Weight file names.
MODEL_WEIGHTS_FILE = "model.weights.h5"
Expand Down Expand Up @@ -264,7 +265,7 @@ def save_metadata(layer, preset):
"parameter_count": layer.count_params(),
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
}
metadata_path = os.path.join(preset, "metadata.json")
metadata_path = os.path.join(preset, METADATA_FILE)
with open(metadata_path, "w") as metadata_file:
metadata_file.write(json.dumps(metadata, indent=4))

Expand Down Expand Up @@ -400,6 +401,21 @@ def load_config(preset, config_file=CONFIG_FILE):
return config


def validate_metadata(preset):
if not check_file_exists(preset, METADATA_FILE):
raise FileNotFoundError(
f"The preset directory `{preset}` doesn't have a file named `{METADATA_FILE}`. "
"This file is required to load a Keras model preset. Please verify "
"that the model you are trying to load is a Keras model."
)
metadata = load_config(preset, METADATA_FILE)
if "keras_version" not in metadata:
raise ValueError(
f"`{METADATA_FILE}` in the preset directory `{preset}` doesn't have `keras_version`. "
"Please verify that the model you are trying to load is a Keras model."
)


def load_serialized_object(
preset,
config_file=CONFIG_FILE,
Expand Down
24 changes: 24 additions & 0 deletions keras_nlp/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import pytest
Expand All @@ -23,7 +24,9 @@
from keras_nlp.models import BertTokenizer
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import METADATA_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import validate_metadata


class PresetUtilsTest(TestCase):
Expand Down Expand Up @@ -91,3 +94,24 @@ def test_upload_with_invalid_json(self, json_file):
# Verify error handling.
with self.assertRaisesRegex(ValueError, "is an invalid json"):
upload_preset("kaggle://test/test/test", local_preset_dir)

def test_missing_metadata(self):
temp_dir = self.get_temp_dir()
preset_dir = os.path.join(temp_dir, "test_missing_metadata")
os.mkdir(preset_dir)
with self.assertRaisesRegex(
FileNotFoundError, f"doesn't have a file named `{METADATA_FILE}`"
):
validate_metadata(preset_dir)

def test_incorrect_metadata(self):
temp_dir = self.get_temp_dir()
preset_dir = os.path.join(temp_dir, "test_incorrect_metadata")
os.mkdir(preset_dir)
json_path = os.path.join(preset_dir, METADATA_FILE)
data = {"key": "value"}
with open(json_path, "w") as f:
json.dump(data, f)

with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
validate_metadata(preset_dir)

0 comments on commit 16d3ebb

Please sign in to comment.