Skip to content

Commit

Permalink
compatiable with both set and not set TF_USE_LEGACY_KERAS
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Sep 10, 2024
1 parent 2651e03 commit bad0871
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 38 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_tf_version_integer():
2.4.1 get 2041
2.6.3 get 2063
2.8.3 get 2083
get 2151
2.15.1 get 2151
The 4-digits-string will be passed to C macro to discriminate different
Tensorflow versions.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization
try: # tf version <= 2.15
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization
except:
from tf_keras.layers import LayerNormalization as TFLayerNormalization


class LayerNormalization(TFLayerNormalization):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import tensorflow as tf
from tensorflow_recommenders_addons import dynamic_embedding as de

try:
from tf_keras import Input, models, layers
except:
from tensorflow.keras import Input, models, layers

class DynamicLayerNormalizationTest(tf.test.TestCase):

def test_dynamic_shape_support(self):
input_data = tf.keras.Input(shape=(None, 10), dtype=tf.float32)
input_data = Input(shape=(None, 10), dtype=tf.float32)
layer = de.keras.layers.LayerNormalization()
output = layer(input_data)

model = tf.keras.models.Model(inputs=input_data, outputs=output)
model = models.Model(inputs=input_data, outputs=output)

np.random.seed(0)
test_data = np.random.randn(2, 5, 10).astype(np.float32)
Expand All @@ -38,10 +42,10 @@ def test_training_with_layer_normalization(self):
np.random.randn(num_samples) * 0.5).astype(np.float32).reshape(
-1, 1)

input_data = tf.keras.Input(shape=(input_dim,), dtype=tf.float32)
input_data = Input(shape=(input_dim,), dtype=tf.float32)
normalized = de.keras.layers.LayerNormalization()(input_data)
output = tf.keras.layers.Dense(output_dim)(normalized)
model = tf.keras.models.Model(inputs=input_data, outputs=output)
output = layers.Dense(output_dim)(normalized)
model = models.Model(inputs=input_data, outputs=output)

model.compile(optimizer='adam', loss='mean_squared_error')
initial_weights = [layer.get_weights() for layer in model.layers]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
TrainableWrapperDistributedPolicy
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names

try: # tf version >= 2.16
from tf_keras.layers import Layer
from tf_keras.initializers import RandomNormal, Zeros, serialize
from tf_keras import constraints
except:
from tensorflow.keras.layers import Layer
from tensorflow.keras.initializers import RandomNormal, Zeros, serialize
from tensorflow.keras import constraints


def _choose_reduce_method(combiner, sparse=False, segmented=False):
select = 'sparse' if sparse else 'math'
Expand Down Expand Up @@ -93,7 +102,7 @@ def reduce_pooling(x, combiner='sum'):
return x


class Embedding(tf.keras.layers.Layer):
class Embedding(Layer):
"""
A keras style Embedding layer. The `Embedding` layer acts same like
[tf.keras.layers.Embedding](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding),
Expand Down Expand Up @@ -180,7 +189,7 @@ def __init__(self,
self.embedding_size = embedding_size
self.combiner = combiner
if initializer is None:
initializer = tf.keras.initializers.RandomNormal()
initializer = RandomNormal()
partitioner = kwargs.get('partitioner', devar.default_partition_fn)
trainable = kwargs.get('trainable', True)
self.max_norm = kwargs.get('max_norm', None)
Expand Down Expand Up @@ -281,10 +290,10 @@ def call(self, ids):
def get_config(self):
_initializer = self.params.initializer
if _initializer is None:
_initializer = tf.keras.initializers.Zeros()
_initializer = Zeros()
_max_norm = None
if isinstance(self.max_norm, tf.keras.constraints.Constraint):
_max_norm = tf.keras.constraints.serialize(self.max_norm)
if isinstance(self.max_norm, constraints.Constraint):
_max_norm = constraints.serialize(self.max_norm)

if self.params.restrict_policy:
_restrict_policy = self.params.restrict_policy.__class__
Expand All @@ -301,7 +310,7 @@ def get_config(self):
'combiner':
self.combiner,
'initializer':
tf.keras.initializers.serialize(_initializer),
serialize(_initializer),
'devices':
self.params.devices if self.keep_distribution else None,
'name':
Expand Down Expand Up @@ -500,10 +509,10 @@ def _pooling_by_slots(self, lookup_result, ids):
def get_config(self):
_initializer = self.params.initializer
if _initializer is None:
_initializer = tf.keras.initializers.Zeros()
_initializer = Zeros()
_max_norm = None
if isinstance(self.max_norm, tf.keras.constraints.Constraint):
_max_norm = tf.keras.constraints.serialize(self.max_norm)
if isinstance(self.max_norm, constraints.Constraint):
_max_norm = constraints.serialize(self.max_norm)

config = {
'embedding_size': self.embedding_size,
Expand All @@ -512,7 +521,7 @@ def get_config(self):
'combiner': self.combiner,
'key_dtype': self.params.key_dtype,
'value_dtype': self.params.value_dtype,
'initializer': tf.keras.initializers.serialize(_initializer),
'initializer': serialize(_initializer),
'devices': self.params.devices,
'name': self.name,
'trainable': self.trainable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 as optimizer_v2_legacy
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_v2_legacy_utils
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2

try: # tf version >= 2.16
from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy

from tensorflow.python.eager import tape
from tensorflow.python.distribute import values_util as distribute_values_util
from tensorflow.python.distribute import distribute_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,22 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops

from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2

try: # tf version >= 2.16
from tf_keras.initializers import Initializer
from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.initializers import Initializer
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy

from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
Expand Down Expand Up @@ -693,10 +708,7 @@ def restrict_policy(self):

def _convert_anything_to_init(self, raw_init, dim):
init = raw_init
valid_list = [
init_ops.Initializer, init_ops_v2.Initializer,
tf.keras.initializers.Initializer
]
valid_list = [init_ops.Initializer, init_ops_v2.Initializer, Initializer]
if kinit2 is not None:
valid_list.append(kinit2.Initializer)
valid_list = tuple(valid_list)
Expand Down Expand Up @@ -1142,8 +1154,7 @@ def get_slot_variables(self, optimizer):
Returns:
List of slot `Variable`s in optimizer.
"""
if not isinstance(optimizer,
(Optimizer, OptimizerV2, tf.keras.optimizers.Optimizer)):
if not isinstance(optimizer, (Optimizer, OptimizerV2, keras_OptimizerV2)):
raise TypeError('Expect an optimizer, but get {}'.format(type(optimizer)))
slots = []
if hasattr(optimizer, 'get_slot_names'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# lint-as: python3
"""patch on tensorflow"""

from tensorflow_recommenders_addons import dynamic_embedding as de

try:
Expand All @@ -24,11 +23,14 @@
pass # for compatible with TF < 2.3.x

try:
import tensorflow as tf
kinit_tf = tf.keras.initializers
except ImportError:
kinit_tf = None
pass # for compatible with TF >= 2.6.x
import tf_keras
kinit_tf = tf_keras.initializers
except:
try:
import tensorflow as tf
kinit_tf = tf.keras.initializers
except ImportError:
pass # for compatible with TF >= 2.6.x

try:
import keras as tmp_keras
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_recommenders_addons/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import numpy as np
import tensorflow as tf

try: # tf version >= 2.16
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
# Keras version >= 2.12.0
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2

Number = Union[
float,
int,
Expand All @@ -40,7 +46,7 @@
Regularizer = Union[None, dict, str, Callable]
Constraint = Union[None, dict, str, Callable]
Activation = Union[None, str, Callable]
Optimizer = Union[tf.keras.optimizers.Optimizer, str]
Optimizer = Union[keras_OptimizerV2, str]

TensorLike = Union[
List[Union[Number, list]],
Expand Down
3 changes: 0 additions & 3 deletions tools/install_deps/tf-keras.txt

This file was deleted.

0 comments on commit bad0871

Please sign in to comment.