Skip to content

Commit

Permalink
Get everything working on 2.13, 2.14, and 2.15
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jan 3, 2024
1 parent 9febcc3 commit 1c0cbb2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
2 changes: 0 additions & 2 deletions keras_cv/models/backbones/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ def __dir__(self):
# Temporary fixes for weight saving. This mimics the following PR for
# older version of Keras: https://github.com/keras-team/keras/pull/18982
def filter_fn(attr):
if attr == "_layer_checkpoint_dependencies":
return False
try:
return id(getattr(self, attr)) not in self._functional_layer_ids
except:
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __dir__(self):
# Temporary fixes for weight saving. This mimics the following PR for
# older version of Keras: https://github.com/keras-team/keras/pull/18982
def filter_fn(attr):
if attr == "_layer_checkpoint_dependencies":
if attr in ["backbone", "_backbone"]:
return False
try:
return id(getattr(self, attr)) not in self._functional_layer_ids
Expand Down
26 changes: 25 additions & 1 deletion keras_cv/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import json
import os

Expand Down Expand Up @@ -164,7 +165,15 @@ def load_from_preset(
# Default to loading weights if available.
if load_weights is not False and config["weights"] is not None:
weights_path = get_file(preset, config["weights"])
layer.load_weights(weights_path)
import h5py

f = h5py.File(weights_path)
f.visititems(lambda name, _: print(name))

if hasattr(layer, "_layer_checkpoint_dependencies"):
legacy_load_weights(layer, weights_path)
else:
layer.load_weights(weights_path)

return layer

Expand Down Expand Up @@ -200,3 +209,18 @@ def check_preset_class(
f"Received: `{cls}`."
)
return cls


def legacy_load_weights(layer, weights_path):
# Hacky fix for TensorFlow 2.13 and 2.14 when loading a `.weights.h5` file.
# We find the `Functional` class, and temporarily remove the
# `_layer_checkpoint_dependencies` property, which on older version of
# TensorFlow complete broke the variable paths for functional models.
functional_cls = None
for cls in inspect.getmro(layer.__class__):
if cls.__name__ == "Functional":
functional_cls = cls
property = functional_cls._layer_checkpoint_dependencies
functional_cls._layer_checkpoint_dependencies = {}
layer.load_weights(weights_path)
functional_cls._layer_checkpoint_dependencies = property

0 comments on commit 1c0cbb2

Please sign in to comment.