From 06b5f3dc108db83eb69c9f75b6ae2946e1f78789 Mon Sep 17 00:00:00 2001 From: Julian Qian Date: Wed, 14 Aug 2024 20:38:09 -0700 Subject: [PATCH] remove estimator for tf 2.16 --- .../core/kernels/segment_reduction_ops_impl.h | 3 +- .../dynamic_embedding_variable_test.py | 95 ++++++++++--------- .../kernel_tests/warm_start_util_test.py | 30 ++---- .../python/ops/dynamic_embedding_optimizer.py | 53 +++++++++++ 4 files changed, 114 insertions(+), 67 deletions(-) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h index b7c1799e4..638bf5dfb 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow_recommenders_addons/dynamic_embedding/core/kernels/segment_reduction_ops_impl.h @@ -42,8 +42,8 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" #if TF_VERSION_INTEGER >= 2160 -#include "unsupported/Eigen/CXX11/Tensor" #include "Eigen/Core" +#include "unsupported/Eigen/CXX11/Tensor" #else #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -84,7 +84,6 @@ class SparseSegmentSumGpuOp : public AsyncOpKernel { explicit SparseSegmentSumGpuOp(OpKernelConstruction* context) : AsyncOpKernel(context){}; - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { const Tensor& input_data = context->input(0); const Tensor& indices = context->input(1); diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py index 950d56adf..855738938 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/dynamic_embedding_variable_test.py @@ -62,8 +62,13 @@ except: from tensorflow.python.training.tracking.util import Checkpoint from tensorflow.python.util import compat -from tensorflow_estimator.python.estimator import estimator -from tensorflow_estimator.python.estimator import estimator_lib + +try: # tf version <= 2.15 + from tensorflow_estimator.python.estimator import estimator + from tensorflow_estimator.python.estimator import estimator_lib +except: + # do nothing + pass try: # The data_structures has been moved to the new package in tf 2.11 from tensorflow.python.trackable import data_structures @@ -970,52 +975,56 @@ def test_table_save_load_local_file_system(self): del table - def test_table_save_load_local_file_system_for_estimator(self): + try: # only test for tensorflow <= 2.15 - def input_fn(): - return {"x": constant_op.constant([1], dtype=dtypes.int64)} + def test_table_save_load_local_file_system_for_estimator(self): - def model_fn(features, labels, mode, params): - file_system_saver = de.FileSystemSaver() - embedding = de.get_variable( - name="embedding", - dim=3, - trainable=False, - key_dtype=dtypes.int64, - value_dtype=dtypes.float32, - initializer=-1.0, - kv_creator=de.CuckooHashTableCreator(saver=file_system_saver), - ) - lookup = de.embedding_lookup(embedding, features["x"]) - upsert = embedding.upsert(features["x"], - constant_op.constant([[1.0, 2.0, 3.0]])) - - with ops.control_dependencies([lookup, upsert]): - train_op = state_ops.assign_add(training.get_global_step(), 1) - - scaffold = training.Scaffold( - saver=saver.Saver(sharded=True, - max_to_keep=1, - keep_checkpoint_every_n_hours=None, - defer_build=True, - save_relative_paths=True)) - est = estimator_lib.EstimatorSpec(mode=mode, - scaffold=scaffold, - loss=constant_op.constant(0.), - train_op=train_op, - predictions=lookup) - return est + def input_fn(): + return {"x": constant_op.constant([1], dtype=dtypes.int64)} - save_dir = os.path.join(self.get_temp_dir(), "save_restore") - save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + def model_fn(features, labels, mode, params): + file_system_saver = de.FileSystemSaver() + embedding = de.get_variable( + name="embedding", + dim=3, + trainable=False, + key_dtype=dtypes.int64, + value_dtype=dtypes.float32, + initializer=-1.0, + kv_creator=de.CuckooHashTableCreator(saver=file_system_saver), + ) + lookup = de.embedding_lookup(embedding, features["x"]) + upsert = embedding.upsert(features["x"], + constant_op.constant([[1.0, 2.0, 3.0]])) + + with ops.control_dependencies([lookup, upsert]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + + scaffold = training.Scaffold( + saver=saver.Saver(sharded=True, + max_to_keep=1, + keep_checkpoint_every_n_hours=None, + defer_build=True, + save_relative_paths=True)) + est = estimator_lib.EstimatorSpec(mode=mode, + scaffold=scaffold, + loss=constant_op.constant(0.), + train_op=train_op, + predictions=lookup) + return est + + save_dir = os.path.join(self.get_temp_dir(), "save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") - # train and save - est = estimator.Estimator(model_fn=model_fn, model_dir=save_path) - est.train(input_fn=input_fn, steps=1) + # train and save + est = estimator.Estimator(model_fn=model_fn, model_dir=save_path) + est.train(input_fn=input_fn, steps=1) - # restore and predict - predict_results = next(est.predict(input_fn=input_fn)) - self.assertAllEqual(predict_results, [1.0, 2.0, 3.0]) + # restore and predict + predict_results = next(est.predict(input_fn=input_fn)) + self.assertAllEqual(predict_results, [1.0, 2.0, 3.0]) + except: + pass def test_save_restore_only_table(self): if context.executing_eagerly(): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py index 9e4094c3c..d4cfc9aec 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/warm_start_util_test.py @@ -34,30 +34,12 @@ kinit2 = None pass # for compatible with TF < 2.3.x -from tensorflow.core.protobuf import cluster_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util -from tensorflow.python.keras import initializers as keras_init_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import gen_array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resources -from tensorflow.python.ops import script_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -from tensorflow.python.training import device_setter from tensorflow.python.training import saver -from tensorflow.python.training import server_lib -from tensorflow.python.util import compat from tensorflow_recommenders_addons import dynamic_embedding as de import tensorflow as tf @@ -213,10 +195,14 @@ def test_warm_start_rename(self): self._test_warm_start_rename(num_shards, True) self._test_warm_start_rename(num_shards, False) - def test_warm_start_estimator(self): - for num_shards in [1, 3]: - self._test_warm_start_estimator(num_shards, True) - self._test_warm_start_estimator(num_shards, False) + try: # tf version <= 2.15 + + def test_warm_start_estimator(self): + for num_shards in [1, 3]: + self._test_warm_start_estimator(num_shards, True) + self._test_warm_start_estimator(num_shards, False) + except: + print(f"estimator is not supported in this version of tensorflow") if __name__ == "__main__": diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index f3100000d..b89eebaf7 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -119,6 +119,10 @@ def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False, **kwargs): if hasattr(self, 'add_variable_from_reference'): original_add_variable_from_reference = self.add_variable_from_reference + def _minimize(loss, var_list, tape=None): + grads_and_vars = self.compute_gradients(loss, var_list, tape) + self.apply_gradients(grads_and_vars) + # pylint: disable=protected-access def _distributed_apply(distribution, grads_and_vars, name, apply_state): """`apply_gradients` using a `DistributionStrategy`.""" @@ -292,6 +296,47 @@ def add_slot_v2_lagacy(var, slot_name, initializer="zeros", shape=None): self._weights.append(weight) return weight + def _distributed_tf_update_step(self, distribution, grads_and_vars, + learning_rate): + + def apply_grad_to_update_var(var, grad, learning_rate): + if not isinstance(var, de.TrainableWrapper): + return self.update_step(grad, var, learning_rate) + else: + if not var.params.trainable: + return control_flow_ops.no_op() + with ops.colocate_with(None, ignore_existing=True): + _slots = [ + _s for _s in self._variables + if isinstance(_s, de.TrainableWrapper) + ] + var._track_optimizer_slots(_slots) + + with ops.control_dependencies([grad]): + if isinstance(var, de.shadow_ops.ShadowVariable): + v0 = var.read_value(do_prefetch=False) + else: + v0 = var.read_value(do_prefetch=var.params.bp_v2) + s0 = [_s.read_value() for _s in _slots] + _before = [v0] + s0 + + with ops.control_dependencies(_before): + self.update_step(grad, var, learning_rate) + + with ops.control_dependencies([var]): + _after = control_flow_ops.group( + [var.update_op(v0=v0)] + + [_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)]) + return _after + + for grad, var in grads_and_vars: + distribution.extended.update( + var, + apply_grad_to_update_var, + args=(grad, learning_rate), + group=False, + ) + def _distributed_apply_gradients_fn(distribution, grads_and_vars, **kwargs): """`apply_gradients` using a `DistributionStrategy`.""" @@ -849,6 +894,14 @@ def compute_gradients_horovod_wrapper_impl(*args, **kwargs): self.apply_gradients = apply_gradients_sync_v2 else: self.apply_gradients = apply_gradients_strategy_v2 + elif hasattr(self, '_distributed_tf_update_step'): + # tf2 + keras3 optimizer + self._distributed_tf_update_step = _distributed_tf_update_step + self.minimize = _minimize + if self._custom_sync: + self.apply_gradients = apply_gradients_sync_v2 + else: + self.apply_gradients = apply_gradients_strategy_v2 else: raise Exception(f"Optimizer type is not supported! got {str(type(self))}") else: