diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops.h index 458afb737..000d9d8cb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops.h @@ -19,8 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" +#if TF_VERSION_INTEGER >= 2160 +#include "unsupported/Eigen/CXX11/Tensor" +#else #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - +#endif namespace tensorflow { class OpKernelContext; diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h index b8e58b49e..b7c1799e4 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h @@ -41,9 +41,13 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" +#if TF_VERSION_INTEGER >= 2160 +#include "unsupported/Eigen/CXX11/Tensor" +#include "Eigen/Core" +#else #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - +#endif #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -80,6 +84,7 @@ class SparseSegmentSumGpuOp : public AsyncOpKernel { explicit SparseSegmentSumGpuOp(OpKernelConstruction* context) : AsyncOpKernel(context){}; + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { const Tensor& input_data = context->input(0); const Tensor& indices = context->input(1); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py index d98404222..ccd063ebb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py @@ -113,123 +113,127 @@ def _de_var_fs_restore_fn(trackables, merged_prefix): return load_ops.as_list() -class _DynamicEmbeddingSingleDeviceSaver(functional_saver._SingleDeviceSaver): - - def save(self, file_prefix, options=None): - """Save the saveable objects to a checkpoint with `file_prefix`. - - Args: - file_prefix: A string or scalar string Tensor containing the prefix to - save under. - options: Optional `CheckpointOptions` object. - Returns: - An `Operation`, or None when executing eagerly. - """ - options = options or checkpoint_options.CheckpointOptions() - tensor_names = [] - tensors = [] - tensor_slices = [] - save_ops = tf_utils.ListWrapper([]) - variables_folder_dir = string_ops.regex_replace(file_prefix, - pattern='/([^/]*)/([^/]*)$', - rewrite='') - for saveable in self._saveable_objects: - if type(saveable).__name__ in de_fs_sub_saveable_class_names: - if saveable._saver_config.save_path: - de_variable_folder_dir = saveable._saver_config.save_path - else: - de_variable_folder_dir = string_ops.string_join( - [variables_folder_dir, 'TFRADynamicEmbedding'], separator='/') - - # Rewrite saved file name by user specified node information when use multi process distributed training such as horovod. - # Because table shards in different process couldn't touch each other, all origin shards name would be '_mht_1of1'. - save_file_name = re.sub( - r'_mht_([^/]*)of([^/]*)', - '_mht_' + str(saveable.local_shard_idx + 1) + 'of' + - str(saveable.local_shard_num) + '_rank' + str(saveable.proc_rank) + - '_size' + str(saveable.proc_size), saveable.op._name) - _DynamicEmbeddingShardSaveable_save_op = saveable.op.save_to_file_system( - de_variable_folder_dir, - file_name=save_file_name, - buffer_size=saveable._saver_config.buffer_size) - save_ops.as_list().append(_DynamicEmbeddingShardSaveable_save_op) - for spec in saveable.specs: - tensor = spec.tensor - # A tensor value of `None` indicates that this SaveableObject gets - # recorded in the object graph, but that no value is saved in the - # checkpoint. - if tensor is not None: - tensor_names.append(spec.name) - tensors.append(tensor) - tensor_slices.append(spec.slice_spec) - save_device = options.experimental_io_device or "cpu:0" - with ops.device(save_device): - tf_save_op = io_ops.save_v2(file_prefix, tensor_names, tensor_slices, - tensors) - save_ops.as_list().append(tf_save_op) - return control_flow_ops.group(save_ops.as_list()) - - def restore(self, file_prefix, options=None): - """Restore the saveable objects from a checkpoint with `file_prefix`. - - Args: - file_prefix: A string or scalar string Tensor containing the prefix for - files to read from. - options: Optional `CheckpointOptions` object. - - Returns: - A dictionary mapping from SaveableObject names to restore operations. - """ - options = options or checkpoint_options.CheckpointOptions() - restore_specs = [] - tensor_structure = [] - restore_ops = {} - variables_folder_dir = string_ops.regex_replace(file_prefix, - pattern='/([^/]*)$', - rewrite='') - - for saveable in self._saveable_objects: - saveable_class_name = type(saveable).__name__ - if saveable_class_name == '_DynamicEmbeddingVariabelFileSystemSaveable': - with ops.name_scope(saveable._restore_name, - "dynamic_embedding_restore"): +try: # tf version <= 2.15 + + class _DynamicEmbeddingSingleDeviceSaver(functional_saver._SingleDeviceSaver): + + def save(self, file_prefix, options=None): + """Save the saveable objects to a checkpoint with `file_prefix`. + + Args: + file_prefix: A string or scalar string Tensor containing the prefix to + save under. + options: Optional `CheckpointOptions` object. + Returns: + An `Operation`, or None when executing eagerly. + """ + options = options or checkpoint_options.CheckpointOptions() + tensor_names = [] + tensors = [] + tensor_slices = [] + save_ops = tf_utils.ListWrapper([]) + variables_folder_dir = string_ops.regex_replace( + file_prefix, pattern='/([^/]*)/([^/]*)$', rewrite='') + for saveable in self._saveable_objects: + if type(saveable).__name__ in de_fs_sub_saveable_class_names: if saveable._saver_config.save_path: de_variable_folder_dir = saveable._saver_config.save_path else: de_variable_folder_dir = string_ops.string_join( [variables_folder_dir, 'TFRADynamicEmbedding'], separator='/') - restore_ops[saveable.name] = load_de_variable_from_file_system( - saveable.op, de_variable_folder_dir, saveable.proc_size, - saveable.proc_rank, saveable._saver_config.buffer_size) - - _unified_restore_saveable_objects = [] - for saveable in self._saveable_objects: - _unified_restore_saveable_objects.append(saveable) - saveable_tensor_structure = [] - tensor_structure.append(saveable_tensor_structure) - for spec in saveable.specs: - saveable_tensor_structure.append(spec.name) - restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) - tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) - restore_device = options.experimental_io_device or "cpu:0" - with ops.device(restore_device): - restored_tensors = io_ops.restore_v2(file_prefix, tensor_names, - tensor_slices, tensor_dtypes) - structured_restored_tensors = nest.pack_sequence_as(tensor_structure, - restored_tensors) - for saveable, restored_tensors in zip(_unified_restore_saveable_objects, - structured_restored_tensors): - saveable_class_name = type(saveable).__name__ - if (saveable_class_name not in de_fs_saveable_class_names) and ( - saveable_class_name not in de_fs_sub_saveable_class_names): - restore_ops[saveable.name] = saveable.restore(restored_tensors, - restored_shapes=None) - elif (saveable_class_name in de_fs_saveable_class_names): - restore_ops[saveable.name] = control_flow_ops.group([ - saveable.restore(restored_tensors, restored_shapes=None), - restore_ops[saveable.name] - ]) - return restore_ops + + # Rewrite saved file name by user specified node information when use multi process distributed training such as horovod. + # Because table shards in different process couldn't touch each other, all origin shards name would be '_mht_1of1'. + save_file_name = re.sub( + r'_mht_([^/]*)of([^/]*)', + '_mht_' + str(saveable.local_shard_idx + 1) + 'of' + + str(saveable.local_shard_num) + '_rank' + + str(saveable.proc_rank) + '_size' + str(saveable.proc_size), + saveable.op._name) + _DynamicEmbeddingShardSaveable_save_op = saveable.op.save_to_file_system( + de_variable_folder_dir, + file_name=save_file_name, + buffer_size=saveable._saver_config.buffer_size) + save_ops.as_list().append(_DynamicEmbeddingShardSaveable_save_op) + for spec in saveable.specs: + tensor = spec.tensor + # A tensor value of `None` indicates that this SaveableObject gets + # recorded in the object graph, but that no value is saved in the + # checkpoint. + if tensor is not None: + tensor_names.append(spec.name) + tensors.append(tensor) + tensor_slices.append(spec.slice_spec) + save_device = options.experimental_io_device or "cpu:0" + with ops.device(save_device): + tf_save_op = io_ops.save_v2(file_prefix, tensor_names, tensor_slices, + tensors) + save_ops.as_list().append(tf_save_op) + return control_flow_ops.group(save_ops.as_list()) + + def restore(self, file_prefix, options=None): + """Restore the saveable objects from a checkpoint with `file_prefix`. + + Args: + file_prefix: A string or scalar string Tensor containing the prefix for + files to read from. + options: Optional `CheckpointOptions` object. + + Returns: + A dictionary mapping from SaveableObject names to restore operations. + """ + options = options or checkpoint_options.CheckpointOptions() + restore_specs = [] + tensor_structure = [] + restore_ops = {} + variables_folder_dir = string_ops.regex_replace(file_prefix, + pattern='/([^/]*)$', + rewrite='') + + for saveable in self._saveable_objects: + saveable_class_name = type(saveable).__name__ + if saveable_class_name == '_DynamicEmbeddingVariabelFileSystemSaveable': + with ops.name_scope(saveable._restore_name, + "dynamic_embedding_restore"): + if saveable._saver_config.save_path: + de_variable_folder_dir = saveable._saver_config.save_path + else: + de_variable_folder_dir = string_ops.string_join( + [variables_folder_dir, 'TFRADynamicEmbedding'], separator='/') + restore_ops[saveable.name] = load_de_variable_from_file_system( + saveable.op, de_variable_folder_dir, saveable.proc_size, + saveable.proc_rank, saveable._saver_config.buffer_size) + + _unified_restore_saveable_objects = [] + for saveable in self._saveable_objects: + _unified_restore_saveable_objects.append(saveable) + saveable_tensor_structure = [] + tensor_structure.append(saveable_tensor_structure) + for spec in saveable.specs: + saveable_tensor_structure.append(spec.name) + restore_specs.append((spec.name, spec.slice_spec, spec.dtype)) + tensor_names, tensor_slices, tensor_dtypes = zip(*restore_specs) + restore_device = options.experimental_io_device or "cpu:0" + with ops.device(restore_device): + restored_tensors = io_ops.restore_v2(file_prefix, tensor_names, + tensor_slices, tensor_dtypes) + structured_restored_tensors = nest.pack_sequence_as( + tensor_structure, restored_tensors) + for saveable, restored_tensors in zip(_unified_restore_saveable_objects, + structured_restored_tensors): + saveable_class_name = type(saveable).__name__ + if (saveable_class_name not in de_fs_saveable_class_names) and ( + saveable_class_name not in de_fs_sub_saveable_class_names): + restore_ops[saveable.name] = saveable.restore(restored_tensors, + restored_shapes=None) + elif (saveable_class_name in de_fs_saveable_class_names): + restore_ops[saveable.name] = control_flow_ops.group([ + saveable.restore(restored_tensors, restored_shapes=None), + restore_ops[saveable.name] + ]) + return restore_ops +except: + print(" _SingleDeviceSaver removed after tf version 2.15") class _DynamicEmbeddingSaver(saver.Saver):