Skip to content

Commit

Permalink
remove estimator for tf 2.16
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Aug 19, 2024
1 parent 5aad31a commit 06b5f3d
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/util.h"
#if TF_VERSION_INTEGER >= 2160
#include "unsupported/Eigen/CXX11/Tensor"
#include "Eigen/Core"
#include "unsupported/Eigen/CXX11/Tensor"
#else
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
Expand Down Expand Up @@ -84,7 +84,6 @@ class SparseSegmentSumGpuOp : public AsyncOpKernel {
explicit SparseSegmentSumGpuOp(OpKernelConstruction* context)
: AsyncOpKernel(context){};


void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
const Tensor& input_data = context->input(0);
const Tensor& indices = context->input(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@
except:
from tensorflow.python.training.tracking.util import Checkpoint
from tensorflow.python.util import compat
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator import estimator_lib

try: # tf version <= 2.15
from tensorflow_estimator.python.estimator import estimator
from tensorflow_estimator.python.estimator import estimator_lib
except:
# do nothing
pass

try: # The data_structures has been moved to the new package in tf 2.11
from tensorflow.python.trackable import data_structures
Expand Down Expand Up @@ -970,52 +975,56 @@ def test_table_save_load_local_file_system(self):

del table

def test_table_save_load_local_file_system_for_estimator(self):
try: # only test for tensorflow <= 2.15

def input_fn():
return {"x": constant_op.constant([1], dtype=dtypes.int64)}
def test_table_save_load_local_file_system_for_estimator(self):

def model_fn(features, labels, mode, params):
file_system_saver = de.FileSystemSaver()
embedding = de.get_variable(
name="embedding",
dim=3,
trainable=False,
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
initializer=-1.0,
kv_creator=de.CuckooHashTableCreator(saver=file_system_saver),
)
lookup = de.embedding_lookup(embedding, features["x"])
upsert = embedding.upsert(features["x"],
constant_op.constant([[1.0, 2.0, 3.0]]))

with ops.control_dependencies([lookup, upsert]):
train_op = state_ops.assign_add(training.get_global_step(), 1)

scaffold = training.Scaffold(
saver=saver.Saver(sharded=True,
max_to_keep=1,
keep_checkpoint_every_n_hours=None,
defer_build=True,
save_relative_paths=True))
est = estimator_lib.EstimatorSpec(mode=mode,
scaffold=scaffold,
loss=constant_op.constant(0.),
train_op=train_op,
predictions=lookup)
return est
def input_fn():
return {"x": constant_op.constant([1], dtype=dtypes.int64)}

save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
def model_fn(features, labels, mode, params):
file_system_saver = de.FileSystemSaver()
embedding = de.get_variable(
name="embedding",
dim=3,
trainable=False,
key_dtype=dtypes.int64,
value_dtype=dtypes.float32,
initializer=-1.0,
kv_creator=de.CuckooHashTableCreator(saver=file_system_saver),
)
lookup = de.embedding_lookup(embedding, features["x"])
upsert = embedding.upsert(features["x"],
constant_op.constant([[1.0, 2.0, 3.0]]))

with ops.control_dependencies([lookup, upsert]):
train_op = state_ops.assign_add(training.get_global_step(), 1)

scaffold = training.Scaffold(
saver=saver.Saver(sharded=True,
max_to_keep=1,
keep_checkpoint_every_n_hours=None,
defer_build=True,
save_relative_paths=True))
est = estimator_lib.EstimatorSpec(mode=mode,
scaffold=scaffold,
loss=constant_op.constant(0.),
train_op=train_op,
predictions=lookup)
return est

save_dir = os.path.join(self.get_temp_dir(), "save_restore")
save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")

# train and save
est = estimator.Estimator(model_fn=model_fn, model_dir=save_path)
est.train(input_fn=input_fn, steps=1)
# train and save
est = estimator.Estimator(model_fn=model_fn, model_dir=save_path)
est.train(input_fn=input_fn, steps=1)

# restore and predict
predict_results = next(est.predict(input_fn=input_fn))
self.assertAllEqual(predict_results, [1.0, 2.0, 3.0])
# restore and predict
predict_results = next(est.predict(input_fn=input_fn))
self.assertAllEqual(predict_results, [1.0, 2.0, 3.0])
except:
pass

def test_save_restore_only_table(self):
if context.executing_eagerly():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,12 @@
kinit2 = None
pass # for compatible with TF < 2.3.x

from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.keras import initializers as keras_init_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import device_setter
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow_recommenders_addons import dynamic_embedding as de

import tensorflow as tf
Expand Down Expand Up @@ -213,10 +195,14 @@ def test_warm_start_rename(self):
self._test_warm_start_rename(num_shards, True)
self._test_warm_start_rename(num_shards, False)

def test_warm_start_estimator(self):
for num_shards in [1, 3]:
self._test_warm_start_estimator(num_shards, True)
self._test_warm_start_estimator(num_shards, False)
try: # tf version <= 2.15

def test_warm_start_estimator(self):
for num_shards in [1, 3]:
self._test_warm_start_estimator(num_shards, True)
self._test_warm_start_estimator(num_shards, False)
except:
print(f"estimator is not supported in this version of tensorflow")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def DynamicEmbeddingOptimizer(self, bp_v2=False, synchronous=False, **kwargs):
if hasattr(self, 'add_variable_from_reference'):
original_add_variable_from_reference = self.add_variable_from_reference

def _minimize(loss, var_list, tape=None):
grads_and_vars = self.compute_gradients(loss, var_list, tape)
self.apply_gradients(grads_and_vars)

# pylint: disable=protected-access
def _distributed_apply(distribution, grads_and_vars, name, apply_state):
"""`apply_gradients` using a `DistributionStrategy`."""
Expand Down Expand Up @@ -292,6 +296,47 @@ def add_slot_v2_lagacy(var, slot_name, initializer="zeros", shape=None):
self._weights.append(weight)
return weight

def _distributed_tf_update_step(self, distribution, grads_and_vars,
learning_rate):

def apply_grad_to_update_var(var, grad, learning_rate):
if not isinstance(var, de.TrainableWrapper):
return self.update_step(grad, var, learning_rate)
else:
if not var.params.trainable:
return control_flow_ops.no_op()
with ops.colocate_with(None, ignore_existing=True):
_slots = [
_s for _s in self._variables
if isinstance(_s, de.TrainableWrapper)
]
var._track_optimizer_slots(_slots)

with ops.control_dependencies([grad]):
if isinstance(var, de.shadow_ops.ShadowVariable):
v0 = var.read_value(do_prefetch=False)
else:
v0 = var.read_value(do_prefetch=var.params.bp_v2)
s0 = [_s.read_value() for _s in _slots]
_before = [v0] + s0

with ops.control_dependencies(_before):
self.update_step(grad, var, learning_rate)

with ops.control_dependencies([var]):
_after = control_flow_ops.group(
[var.update_op(v0=v0)] +
[_s.update_op(v0=s0[si]) for si, _s in enumerate(_slots)])
return _after

for grad, var in grads_and_vars:
distribution.extended.update(
var,
apply_grad_to_update_var,
args=(grad, learning_rate),
group=False,
)

def _distributed_apply_gradients_fn(distribution, grads_and_vars, **kwargs):
"""`apply_gradients` using a `DistributionStrategy`."""

Expand Down Expand Up @@ -849,6 +894,14 @@ def compute_gradients_horovod_wrapper_impl(*args, **kwargs):
self.apply_gradients = apply_gradients_sync_v2
else:
self.apply_gradients = apply_gradients_strategy_v2
elif hasattr(self, '_distributed_tf_update_step'):
# tf2 + keras3 optimizer
self._distributed_tf_update_step = _distributed_tf_update_step
self.minimize = _minimize
if self._custom_sync:
self.apply_gradients = apply_gradients_sync_v2
else:
self.apply_gradients = apply_gradients_strategy_v2
else:
raise Exception(f"Optimizer type is not supported! got {str(type(self))}")
else:
Expand Down

0 comments on commit 06b5f3d

Please sign in to comment.