diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py b/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py index d32188d3f..7a6b73094 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/train/checkpoint.py @@ -178,8 +178,20 @@ def _get_de_dir_from_file_path(file_path): de_dir = self._get_de_variable_folder_dir(file_path, global_step) return file_prefix_pattern, global_step, de_dir + def _rank0_delete_files_and_return_de_dir(file_path): + file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path( + file_path) + if global_step is not None: + ckpt_index_list = file_io.get_matching_files(file_prefix_pattern + + '-*.index') + self._delete_redundant_de_dir( + ckpt_index_list + ) # Compatible with automatic sweep function of checkpointmanager + return de_dir + if self._hvd is None: file_path = tf_write_func() + de_dir = _rank0_delete_files_and_return_de_dir(file_path) self._de_handle_root_and_var_with_func(de_dir=de_dir, func=self._de_var_fs_save_funtion) else: @@ -189,14 +201,7 @@ def _get_de_dir_from_file_path(file_path): self._hvd.broadcast_object(file_path, root_rank=0, name='de_hvd_broadcast_file_path') - file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path( - file_path) - if global_step is not None: - ckpt_index_list = file_io.get_matching_files(file_prefix_pattern + - '-*.index') - self._delete_redundant_de_dir( - ckpt_index_list - ) # Compatible with automatic sweep function of checkpointmanager + de_dir = _rank0_delete_files_and_return_de_dir(file_path) self._hvd.join() # Sync for avoiding files conflict self._de_handle_root_and_var_with_func( de_dir=de_dir, func=self._de_var_fs_save_funtion) @@ -205,8 +210,7 @@ def _get_de_dir_from_file_path(file_path): else: file_path = self._hvd.broadcast_object( None, root_rank=0, name='de_hvd_broadcast_file_path') - file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path( - file_path) + _, _, de_dir = _get_de_dir_from_file_path(file_path) self._hvd.join() # Sync for avoiding files conflict self._de_handle_root_and_var_with_func( de_dir=de_dir, func=self._de_var_fs_save_funtion)