Skip to content

Commit

Permalink
upgrade to tf 2.16.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Aug 13, 2024
1 parent 77efd5a commit 5aad31a
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#if TF_VERSION_INTEGER >= 2160
#include "unsupported/Eigen/CXX11/Tensor"
#else
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#endif
namespace tensorflow {

class OpKernelContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
#if TF_VERSION_INTEGER >= 2160
#include "unsupported/Eigen/CXX11/Tensor"
#include "Eigen/Core"
#else
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

#endif
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -80,6 +84,7 @@ class SparseSegmentSumGpuOp : public AsyncOpKernel {
explicit SparseSegmentSumGpuOp(OpKernelConstruction* context)
: AsyncOpKernel(context){};


void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
const Tensor& input_data = context->input(0);
const Tensor& indices = context->input(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,123 +113,127 @@ def _de_var_fs_restore_fn(trackables, merged_prefix):
return load_ops.as_list()


class _DynamicEmbeddingSingleDeviceSaver(functional_saver._SingleDeviceSaver):

def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
tensor_names = []
tensors = []
tensor_slices = []
save_ops = tf_utils.ListWrapper([])
variables_folder_dir = string_ops.regex_replace(file_prefix,
pattern='/([^/]*)/([^/]*)$',
rewrite='')
for saveable in self._saveable_objects:
if type(saveable).__name__ in de_fs_sub_saveable_class_names:
if saveable._saver_config.save_path:
de_variable_folder_dir = saveable._saver_config.save_path
else:
de_variable_folder_dir = string_ops.string_join(
[variables_folder_dir, 'TFRADynamicEmbedding'], separator='/')

# Rewrite saved file name by user specified node information when use multi process distributed training such as horovod.
# Because table shards in different process couldn't touch each other, all origin shards name would be '_mht_1of1'.
save_file_name = re.sub(
r'_mht_([^/]*)of([^/]*)',
'_mht_' + str(saveable.local_shard_idx + 1) + 'of' +
str(saveable.local_shard_num) + '_rank' + str(saveable.proc_rank) +
'_size' + str(saveable.proc_size), saveable.op._name)
_DynamicEmbeddingShardSaveable_save_op = saveable.op.save_to_file_system(
de_variable_folder_dir,
file_name=save_file_name,
buffer_size=saveable._saver_config.buffer_size)
save_ops.as_list().append(_DynamicEmbeddingShardSaveable_save_op)
for spec in saveable.specs:
tensor = spec.tensor
# A tensor value of `None` indicates that this SaveableObject gets
# recorded in the object graph, but that no value is saved in the
# checkpoint.
if tensor is not None:
tensor_names.append(spec.name)
tensors.append(tensor)
tensor_slices.append(spec.slice_spec)
save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
tf_save_op = io_ops.save_v2(file_prefix, tensor_names, tensor_slices,
tensors)
save_ops.as_list().append(tf_save_op)
return control_flow_ops.group(save_ops.as_list())

def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
A dictionary mapping from SaveableObject names to restore operations.
"""
options = options or checkpoint_options.CheckpointOptions()
restore_specs = []
tensor_structure = []
restore_ops = {}
variables_folder_dir = string_ops.regex_replace(file_prefix,
pattern='/([^/]*)$',
rewrite='')

for saveable in self._saveable_objects:
saveable_class_name = type(saveable).__name__
if saveable_class_name == '_DynamicEmbeddingVariabelFileSystemSaveable':
with ops.name_scope(saveable._restore_name,
"dynamic_embedding_restore"):
try: # tf version <= 2.15

class _DynamicEmbeddingSingleDeviceSaver(functional_saver._SingleDeviceSaver):

def save(self, file_prefix, options=None):
"""Save the saveable objects to a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix to
save under.
options: Optional `CheckpointOptions` object.
Returns:
An `Operation`, or None when executing eagerly.
"""
options = options or checkpoint_options.CheckpointOptions()
tensor_names = []
tensors = []
tensor_slices = []
save_ops = tf_utils.ListWrapper([])
variables_folder_dir = string_ops.regex_replace(
file_prefix, pattern='/([^/]*)/([^/]*)$', rewrite='')
for saveable in self._saveable_objects:
if type(saveable).__name__ in de_fs_sub_saveable_class_names:
if saveable._saver_config.save_path:
de_variable_folder_dir = saveable._saver_config.save_path
else:
de_variable_folder_dir = string_ops.string_join(
[variables_folder_dir, 'TFRADynamicEmbedding'], separator='/')
restore_ops[saveable.name] = load_de_variable_from_file_system(
saveable.op, de_variable_folder_dir, saveable.proc_size,
saveable.proc_rank, saveable._saver_config.buffer_size)

_unified_restore_saveable_objects = []
for saveable in self._saveable_objects:
_unified_restore_saveable_objects.append(saveable)
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
restore_device = options.experimental_io_device or "cpu:0"
with ops.device(restore_device):
restored_tensors = io_ops.restore_v2(file_prefix, tensor_names,
tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(tensor_structure,
restored_tensors)
for saveable, restored_tensors in zip(_unified_restore_saveable_objects,
structured_restored_tensors):
saveable_class_name = type(saveable).__name__
if (saveable_class_name not in de_fs_saveable_class_names) and (
saveable_class_name not in de_fs_sub_saveable_class_names):
restore_ops[saveable.name] = saveable.restore(restored_tensors,
restored_shapes=None)
elif (saveable_class_name in de_fs_saveable_class_names):
restore_ops[saveable.name] = control_flow_ops.group([
saveable.restore(restored_tensors, restored_shapes=None),
restore_ops[saveable.name]
])
return restore_ops

# Rewrite saved file name by user specified node information when use multi process distributed training such as horovod.
# Because table shards in different process couldn't touch each other, all origin shards name would be '_mht_1of1'.
save_file_name = re.sub(
r'_mht_([^/]*)of([^/]*)',
'_mht_' + str(saveable.local_shard_idx + 1) + 'of' +
str(saveable.local_shard_num) + '_rank' +
str(saveable.proc_rank) + '_size' + str(saveable.proc_size),
saveable.op._name)
_DynamicEmbeddingShardSaveable_save_op = saveable.op.save_to_file_system(
de_variable_folder_dir,
file_name=save_file_name,
buffer_size=saveable._saver_config.buffer_size)
save_ops.as_list().append(_DynamicEmbeddingShardSaveable_save_op)
for spec in saveable.specs:
tensor = spec.tensor
# A tensor value of `None` indicates that this SaveableObject gets
# recorded in the object graph, but that no value is saved in the
# checkpoint.
if tensor is not None:
tensor_names.append(spec.name)
tensors.append(tensor)
tensor_slices.append(spec.slice_spec)
save_device = options.experimental_io_device or "cpu:0"
with ops.device(save_device):
tf_save_op = io_ops.save_v2(file_prefix, tensor_names, tensor_slices,
tensors)
save_ops.as_list().append(tf_save_op)
return control_flow_ops.group(save_ops.as_list())

def restore(self, file_prefix, options=None):
"""Restore the saveable objects from a checkpoint with `file_prefix`.
Args:
file_prefix: A string or scalar string Tensor containing the prefix for
files to read from.
options: Optional `CheckpointOptions` object.
Returns:
A dictionary mapping from SaveableObject names to restore operations.
"""
options = options or checkpoint_options.CheckpointOptions()
restore_specs = []
tensor_structure = []
restore_ops = {}
variables_folder_dir = string_ops.regex_replace(file_prefix,
pattern='/([^/]*)$',
rewrite='')

for saveable in self._saveable_objects:
saveable_class_name = type(saveable).__name__
if saveable_class_name == '_DynamicEmbeddingVariabelFileSystemSaveable':
with ops.name_scope(saveable._restore_name,
"dynamic_embedding_restore"):
if saveable._saver_config.save_path:
de_variable_folder_dir = saveable._saver_config.save_path
else:
de_variable_folder_dir = string_ops.string_join(
[variables_folder_dir, 'TFRADynamicEmbedding'], separator='/')
restore_ops[saveable.name] = load_de_variable_from_file_system(
saveable.op, de_variable_folder_dir, saveable.proc_size,
saveable.proc_rank, saveable._saver_config.buffer_size)

_unified_restore_saveable_objects = []
for saveable in self._saveable_objects:
_unified_restore_saveable_objects.append(saveable)
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
saveable_tensor_structure.append(spec.name)
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs)
restore_device = options.experimental_io_device or "cpu:0"
with ops.device(restore_device):
restored_tensors = io_ops.restore_v2(file_prefix, tensor_names,
tensor_slices, tensor_dtypes)
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
for saveable, restored_tensors in zip(_unified_restore_saveable_objects,
structured_restored_tensors):
saveable_class_name = type(saveable).__name__
if (saveable_class_name not in de_fs_saveable_class_names) and (
saveable_class_name not in de_fs_sub_saveable_class_names):
restore_ops[saveable.name] = saveable.restore(restored_tensors,
restored_shapes=None)
elif (saveable_class_name in de_fs_saveable_class_names):
restore_ops[saveable.name] = control_flow_ops.group([
saveable.restore(restored_tensors, restored_shapes=None),
restore_ops[saveable.name]
])
return restore_ops
except:
print(" _SingleDeviceSaver removed after tf version 2.15")


class _DynamicEmbeddingSaver(saver.Saver):
Expand Down

0 comments on commit 5aad31a

Please sign in to comment.