Skip to content

Commit

Permalink
compatiable with both set and not set TF_USE_LEGACY_KERAS
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Sep 10, 2024
1 parent 2651e03 commit c201e8f
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 29 deletions.
2 changes: 1 addition & 1 deletion configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_tf_version_integer():
2.4.1 get 2041
2.6.3 get 2063
2.8.3 get 2083
get 2151
2.15.1 get 2151
The 4-digits-string will be passed to C macro to discriminate different
Tensorflow versions.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization
try: # tf version <= 2.15
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization
except:
from tf_keras.layers import LayerNormalization as TFLayerNormalization


class LayerNormalization(TFLayerNormalization):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
TrainableWrapperDistributedPolicy
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.tf_save_restore_patch import de_fs_saveable_class_names

try: # tf version >= 2.16
from tf_keras.layers.Layer import Layer
from tf_keras.initializers import initializers
from tf_keras.constraints import constraints
except:
from tensorflow.keras.layers.Layer import Layer
from tensorflow.keras.initializers import initializers
from tensorflow.keras.constraints import constraints


def _choose_reduce_method(combiner, sparse=False, segmented=False):
select = 'sparse' if sparse else 'math'
Expand Down Expand Up @@ -93,7 +102,7 @@ def reduce_pooling(x, combiner='sum'):
return x


class Embedding(tf.keras.layers.Layer):
class Embedding(Layer):
"""
A keras style Embedding layer. The `Embedding` layer acts same like
[tf.keras.layers.Embedding](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding),
Expand Down Expand Up @@ -180,7 +189,7 @@ def __init__(self,
self.embedding_size = embedding_size
self.combiner = combiner
if initializer is None:
initializer = tf.keras.initializers.RandomNormal()
initializer = initializers.RandomNormal()
partitioner = kwargs.get('partitioner', devar.default_partition_fn)
trainable = kwargs.get('trainable', True)
self.max_norm = kwargs.get('max_norm', None)
Expand Down Expand Up @@ -281,10 +290,10 @@ def call(self, ids):
def get_config(self):
_initializer = self.params.initializer
if _initializer is None:
_initializer = tf.keras.initializers.Zeros()
_initializer = initializers.Zeros()
_max_norm = None
if isinstance(self.max_norm, tf.keras.constraints.Constraint):
_max_norm = tf.keras.constraints.serialize(self.max_norm)
if isinstance(self.max_norm, constraints.Constraint):
_max_norm = constraints.serialize(self.max_norm)

if self.params.restrict_policy:
_restrict_policy = self.params.restrict_policy.__class__
Expand All @@ -301,7 +310,7 @@ def get_config(self):
'combiner':
self.combiner,
'initializer':
tf.keras.initializers.serialize(_initializer),
initializers.serialize(_initializer),
'devices':
self.params.devices if self.keep_distribution else None,
'name':
Expand Down Expand Up @@ -500,10 +509,10 @@ def _pooling_by_slots(self, lookup_result, ids):
def get_config(self):
_initializer = self.params.initializer
if _initializer is None:
_initializer = tf.keras.initializers.Zeros()
_initializer = initializers.Zeros()
_max_norm = None
if isinstance(self.max_norm, tf.keras.constraints.Constraint):
_max_norm = tf.keras.constraints.serialize(self.max_norm)
if isinstance(self.max_norm, constraints.Constraint):
_max_norm = constraints.serialize(self.max_norm)

config = {
'embedding_size': self.embedding_size,
Expand All @@ -512,7 +521,7 @@ def get_config(self):
'combiner': self.combiner,
'key_dtype': self.params.key_dtype,
'value_dtype': self.params.value_dtype,
'initializer': tf.keras.initializers.serialize(_initializer),
'initializer': initializers.serialize(_initializer),
'devices': self.params.devices,
'name': self.name,
'trainable': self.trainable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 as optimizer_v2_legacy
from tensorflow.python.keras.optimizer_v2 import utils as optimizer_v2_legacy_utils
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2

try: # tf version >= 2.16
from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy

from tensorflow.python.eager import tape
from tensorflow.python.distribute import values_util as distribute_values_util
from tensorflow.python.distribute import distribute_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,22 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops

from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2

try: # tf version >= 2.16
from tf_keras.initializers import initializers
from tf_keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.initializers import initializers
try: # Keras version >= 2.12.0
from tensorflow.keras.optimizers.legacy import Optimizer as keras_OptimizerV2_legacy
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2
except:
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2_legacy
keras_OptimizerV2 = keras_OptimizerV2_legacy

from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
Expand Down Expand Up @@ -694,8 +709,7 @@ def restrict_policy(self):
def _convert_anything_to_init(self, raw_init, dim):
init = raw_init
valid_list = [
init_ops.Initializer, init_ops_v2.Initializer,
tf.keras.initializers.Initializer
init_ops.Initializer, init_ops_v2.Initializer, initializers.Initializer
]
if kinit2 is not None:
valid_list.append(kinit2.Initializer)
Expand Down Expand Up @@ -1142,8 +1156,7 @@ def get_slot_variables(self, optimizer):
Returns:
List of slot `Variable`s in optimizer.
"""
if not isinstance(optimizer,
(Optimizer, OptimizerV2, tf.keras.optimizers.Optimizer)):
if not isinstance(optimizer, (Optimizer, OptimizerV2, keras_OptimizerV2)):
raise TypeError('Expect an optimizer, but get {}'.format(type(optimizer)))
slots = []
if hasattr(optimizer, 'get_slot_names'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# lint-as: python3
"""patch on tensorflow"""
import tf_keras.initializers

from tensorflow_recommenders_addons import dynamic_embedding as de

Expand All @@ -24,11 +25,13 @@
pass # for compatible with TF < 2.3.x

try:
import tensorflow as tf
kinit_tf = tf.keras.initializers
kinit_tf = tf_keras.initializers
except ImportError:
kinit_tf = None
pass # for compatible with TF >= 2.6.x
try:
import tensorflow as tf
kinit_tf = tf.keras.initializers
except ImportError:
pass # for compatible with TF >= 2.6.x

try:
import keras as tmp_keras
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_recommenders_addons/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import numpy as np
import tensorflow as tf

try: # tf version >= 2.16
from tf_keras.optimizers import Optimizer as keras_OptimizerV2
except:
# Keras version >= 2.12.0
from tensorflow.keras.optimizers import Optimizer as keras_OptimizerV2

Number = Union[
float,
int,
Expand All @@ -40,7 +46,7 @@
Regularizer = Union[None, dict, str, Callable]
Constraint = Union[None, dict, str, Callable]
Activation = Union[None, str, Callable]
Optimizer = Union[tf.keras.optimizers.Optimizer, str]
Optimizer = Union[keras_OptimizerV2, str]

TensorLike = Union[
List[Union[Number, list]],
Expand Down
3 changes: 0 additions & 3 deletions tools/install_deps/tf-keras.txt

This file was deleted.

0 comments on commit c201e8f

Please sign in to comment.