Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Make saving model more easier when using HvdAllToAllEmbedding by adding save function overwriting patch in tf_save_restore_patch.py. #362

Merged
merged 1 commit into from
Sep 25, 2023

Conversation

MoFHeka
Copy link
Collaborator

@MoFHeka MoFHeka commented Sep 21, 2023

Make saving model more easier when using HvdAllToAllEmbedding by adding save function overwriting patch in tf_save_restore_patch.py.
Also fix some import bug in tf_save_restore_patch.py.
Also fix the example in demo where the python code for keras horovod synchronous training was wrong.

Description

I have overwritten the keras save function and now it is not necessary to save the embedding shard explicitly, as long as model.save or Keras.model.save_model is called on each rank, but tf.saved_model.save is not supported.
tf.saved_model.save can also be supported in theory, but because the obj object of the save is not necessarily the keras object, I am lazy to write it for the moment, and there is a need to talk about it.

Type of change

  • Bug fix
  • New Tutorial
  • Updated or additional documentation
  • Additional Testing
  • New Feature

Checklist:

  • I've properly formatted my code according to the guidelines
    • By running yapf
    • By running clang-format
  • This PR addresses an already submitted issue for TensorFlow Recommenders-Addons
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works

How Has This Been Tested?

Adding a test with HvdAllToAllEmbedding.
Follow the demo demo/dynamic_embedding/movielens-1m-keras-with-horovod.

Comment on lines 561 to 585
try:
import horovod.tensorflow as hvd
try:
hvd.rank()
except:
hvd = None
except:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  try:
    import horovod.tensorflow as hvd
    hvd.rank()
  except:
    hvd = None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

model,
filepath,
overwrite,
include_optimizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to add comments for each important input arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

*args,
**kwargs)

def _traverse_emb_layers_and_save(hvd_rank):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have adequate UT cases to cover this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -106,6 +107,10 @@ def __init__(self, root_rank=0, device='', local_variables=None):
self.register_local_var(var)


@deprecated(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the warning always triggered? It's recommended to show it only when the users actually refer to the AllToAllEmbedding.

Copy link
Collaborator Author

@MoFHeka MoFHeka Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only show when this class was called init or new.
This callback class was only designed for horovod all2all embedding saving. For now, it's useless after new saving patch function.

def deprecated_wrapper(func_or_class):
    """Deprecation wrapper."""
    if isinstance(func_or_class, type):
      # If a class is deprecated, you actually want to wrap the constructor.
      cls = func_or_class
      if cls.__new__ is object.__new__:
        # If a class defaults to its parent's constructor, wrap that instead.
        func = cls.__init__
        constructor_name = '__init__'
        decorators, _ = tf_decorator.unwrap(func)
        for decorator in decorators:
          if decorator.decorator_name == 'deprecated':
            # If the parent is already deprecated, there's nothing to do.
            return cls
      else:
        func = cls.__new__
        constructor_name = '__new__'

    else:
      cls = None
      constructor_name = None
      func = func_or_class

…by adding save function overwriting patch in tf_save_restore_patch.py.

Also fix some import bug in tf_save_restore_patch.py.
Also adding a save and restore test for HvdAllToAllEmbeeding.
Copy link
Member

@rhdong rhdong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rhdong rhdong merged commit d774172 into tensorflow:master Sep 25, 2023
33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants