Skip to content

Commit

Permalink
[fix] When result of importing horovod is None, TFRA DEHvdCheckpoint …
Browse files Browse the repository at this point in the history
…would not call DE variable saving and sweeping redundant DE files.
  • Loading branch information
MoFHeka authored and rhdong committed Nov 26, 2023
1 parent 6f7bbb8 commit b078929
Showing 1 changed file with 14 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,20 @@ def _get_de_dir_from_file_path(file_path):
de_dir = self._get_de_variable_folder_dir(file_path, global_step)
return file_prefix_pattern, global_step, de_dir

def _rank0_delete_files_and_return_de_dir(file_path):
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
file_path)
if global_step is not None:
ckpt_index_list = file_io.get_matching_files(file_prefix_pattern +
'-*.index')
self._delete_redundant_de_dir(
ckpt_index_list
) # Compatible with automatic sweep function of checkpointmanager
return de_dir

if self._hvd is None:
file_path = tf_write_func()
de_dir = _rank0_delete_files_and_return_de_dir(file_path)
self._de_handle_root_and_var_with_func(de_dir=de_dir,
func=self._de_var_fs_save_funtion)
else:
Expand All @@ -189,14 +201,7 @@ def _get_de_dir_from_file_path(file_path):
self._hvd.broadcast_object(file_path,
root_rank=0,
name='de_hvd_broadcast_file_path')
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
file_path)
if global_step is not None:
ckpt_index_list = file_io.get_matching_files(file_prefix_pattern +
'-*.index')
self._delete_redundant_de_dir(
ckpt_index_list
) # Compatible with automatic sweep function of checkpointmanager
de_dir = _rank0_delete_files_and_return_de_dir(file_path)
self._hvd.join() # Sync for avoiding files conflict
self._de_handle_root_and_var_with_func(
de_dir=de_dir, func=self._de_var_fs_save_funtion)
Expand All @@ -205,8 +210,7 @@ def _get_de_dir_from_file_path(file_path):
else:
file_path = self._hvd.broadcast_object(
None, root_rank=0, name='de_hvd_broadcast_file_path')
file_prefix_pattern, global_step, de_dir = _get_de_dir_from_file_path(
file_path)
_, _, de_dir = _get_de_dir_from_file_path(file_path)
self._hvd.join() # Sync for avoiding files conflict
self._de_handle_root_and_var_with_func(
de_dir=de_dir, func=self._de_var_fs_save_funtion)
Expand Down

0 comments on commit b078929

Please sign in to comment.