Skip to content

Commit

Permalink
[Fix] register_checkpoint_saver API was not stable under TF version 2…
Browse files Browse the repository at this point in the history
….11.

Therefore, for compatibility considerations, before version 2.11,
SingleDeviceSaver was used to save.
  • Loading branch information
MoFHeka committed Oct 16, 2024
1 parent cff0518 commit 205dedf
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import functools
import os.path
from packaging import version
import re

from tensorflow_recommenders_addons import dynamic_embedding as de
Expand Down Expand Up @@ -54,6 +55,7 @@
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow import version as tf_version

tf_original_save_func = tf_saved_model_save.save
if keras_saved_model_save is not None:
Expand Down Expand Up @@ -566,7 +568,9 @@ def restore(self, sess, save_path):


def patch_on_tf_save_restore():
try:
if version.parse(tf_version.VERSION) < version.parse("2.11"):
functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver
else:
from tensorflow.python.saved_model.registration.registration import register_checkpoint_saver
class_obj = de.Variable
predicate = lambda x: isinstance(x, class_obj)
Expand All @@ -584,8 +588,6 @@ def patch_on_tf_save_restore():
k_name = param.name
kwargs[k_name] = prekwargs[k_name]
register_checkpoint_saver(**kwargs)
except:
functional_saver._SingleDeviceSaver = _DynamicEmbeddingSingleDeviceSaver
saver.Saver = _DynamicEmbeddingSaver
# # Replace origin saving function is too dangerous.
# tf_saved_model_save.save = functools.partial(de.keras.models._de_keras_save_func,
Expand Down

0 comments on commit 205dedf

Please sign in to comment.