Skip to content

Commit

Permalink
Update load_weights for Keras 2 compatibility (#2235)
Browse files Browse the repository at this point in the history
* Update SegFormer preset tests to remove input_shape arg

* Add fix for saving compatibility with Keras 2, fix formatting
  • Loading branch information
nkovela1 authored and sampathweb committed Dec 15, 2023
1 parent 4939f78 commit e0b3883
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
8 changes: 2 additions & 6 deletions keras_cv/models/segmentation/segformer/segformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def test_segformer_call(self):
mit_output = mit_model(images)
mit_pred = mit_model.predict(images)

seg_model = SegFormer.from_preset(
"segformer_b0", num_classes=1
)
seg_model = SegFormer.from_preset("segformer_b0", num_classes=1)
seg_output = seg_model(images)
seg_pred = seg_model.predict(images)

Expand Down Expand Up @@ -123,9 +121,7 @@ def test_saved_model(self):
def test_preset_saved_model(self):
target_size = [224, 224, 3]

model = SegFormer.from_preset(
"segformer_b0", num_classes=1
)
model = SegFormer.from_preset("segformer_b0", num_classes=1)

input_batch = np.ones(shape=[2] + target_size)
model_output = model(input_batch)
Expand Down
21 changes: 20 additions & 1 deletion keras_cv/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import json
import os

Expand Down Expand Up @@ -164,7 +165,10 @@ def load_from_preset(
# Default to loading weights if available.
if load_weights is not False and config["weights"] is not None:
weights_path = get_file(preset, config["weights"])
layer.load_weights(weights_path)
if hasattr(layer, "_layer_checkpoint_dependencies"):
legacy_load_weights(layer, weights_path)
else:
layer.load_weights(weights_path)

return layer

Expand Down Expand Up @@ -200,3 +204,18 @@ def check_preset_class(
f"Received: `{cls}`."
)
return cls


def legacy_load_weights(layer, weights_path):
# Hacky fix for TensorFlow 2.13 and 2.14 when loading a `.weights.h5` file.
# We find the `Functional` class, and temporarily remove the
# `_layer_checkpoint_dependencies` property, which on older version of
# TensorFlow complete broke the variable paths for functional models.
functional_cls = None
for cls in inspect.getmro(layer.__class__):
if cls.__name__ == "Functional":
functional_cls = cls
property = functional_cls._layer_checkpoint_dependencies
functional_cls._layer_checkpoint_dependencies = None
layer.load_weights(weights_path)
functional_cls._layer_checkpoint_dependencies = property

0 comments on commit e0b3883

Please sign in to comment.