From 205dedf68a89cf93d2148e8b7fd47cd4a48ff0cc Mon Sep 17 00:00:00 2001 From: MoFHeka Date: Wed, 16 Oct 2024 21:16:38 +0800 Subject: [PATCH] [Fix] register_checkpoint_saver API was not stable under TF version 2.11. Therefore, for compatibility considerations, before version 2.11, SingleDeviceSaver was used to save. --- .../dynamic_embedding/python/ops/tf_save_restore_patch.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py index d0c167af..5bcfe12a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py @@ -18,6 +18,7 @@ import inspect import functools import os.path +from packaging import version import re from tensorflow_recommenders_addons import dynamic_embedding as de @@ -54,6 +55,7 @@ from tensorflow.python.training.saving import functional_saver from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow import version as tf_version tf_original_save_func = tf_saved_model_save.save if keras_saved_model_save is not None: @@ -566,7 +568,9 @@ def restore(self, sess, save_path): def patch_on_tf_save_restore(): - try: + if version.parse(tf_version.VERSION) < version.parse("2.11"): + functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver + else: from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver class_obj = de.Variable predicate = lambda x: isinstance(x, class_obj) @@ -584,8 +588,6 @@ def patch_on_tf_save_restore(): k_name = param.name kwargs[k_name] = prekwargs[k_name] register_checkpoint_saver(**kwargs) - except: - functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver saver.Saver = _DynamicEmbeddingSaver # # Replace origin saving function is too dangerous. # tf_saved_model_save.save = functools.partial(de.keras.models._de_keras_save_func,