diff --git a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py index e963e4449..6d243528f 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py +++ b/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py @@ -4,6 +4,8 @@ from absl import flags from absl import app +from tensorflow_recommenders_addons import dynamic_embedding as de + os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT! os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # Because of the two environment variables above no non-standard library imports should happen before this. @@ -371,7 +373,7 @@ def __init__(self, embedding_initializer=embedding_initializer, mpi_size=mpi_size, mpi_rank=mpi_rank) - + self.dynamic_layer_norm = de.keras.layers.LayerNormalization() self.dnn1 = tf.keras.layers.Dense( 64, activation='relu', @@ -427,7 +429,6 @@ def call(self, features): for key, value in feature_info_spec.items() if key in user_fea } - user_latent = self.user_embedding(user_fea_info) movie_fea = ['movie_id', 'movie_genres', 'user_occupation_label'] movie_fea = [i for i in features.keys() if i in movie_fea] movie_fea_info = { @@ -435,14 +436,16 @@ def call(self, features): for key, value in feature_info_spec.items() if key in movie_fea } + user_latent = self.user_embedding(user_fea_info) movie_latent = self.movie_embedding(movie_fea_info) latent = tf.concat([user_latent, movie_latent], axis=1) - x = self.dnn1(latent) + normalized_emb = self.dynamic_layer_norm(latent) + x = self.dnn1(normalized_emb) x = self.dnn2(x) x = self.dnn3(x) - bias = self.bias_net(latent) + bias = self.bias_net(normalized_emb) x = 0.2 * x + 0.8 * bias user_rating = tf.keras.layers.Lambda(lambda x: x, name='user_rating')(x) return {'user_rating': user_rating} diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/__init__.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/__init__.py index 87696c251..c6c3251e4 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/__init__.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/__init__.py @@ -4,3 +4,4 @@ from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import FieldWiseEmbedding from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import SquashedEmbedding from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import HvdAllToAllEmbedding +from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.dynamic_layer_normalization import LayerNormalization diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization.py new file mode 100644 index 000000000..26f64d6d6 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization.py @@ -0,0 +1,97 @@ +import tensorflow as tf +from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization + + +class LayerNormalization(TFLayerNormalization): + + def call(self, inputs): + # TODO(b/229545225): Remove the RaggedTensor check. + is_ragged = isinstance(inputs, tf.RaggedTensor) + if is_ragged: + inputs_lengths = inputs.nested_row_lengths() + inputs = inputs.to_tensor() + inputs = tf.cast(inputs, self.compute_dtype) + # Compute the axes along which to reduce the mean / variance + input_shape = tf.shape(inputs) + # Get the number of dimensions dynamically + ndims = input_shape.shape[0] + + # Broadcasting only necessary for norm when the axis is not just + # the last dimension + broadcast_shape = [1] * ndims + for dim in self.axis: + broadcast_shape[dim] = input_shape[dim] + + def _broadcast(v): + if v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]: + return tf.reshape(v, broadcast_shape) + return v + + if not self._fused: + input_dtype = inputs.dtype + if input_dtype in ("float16", "bfloat16") and self.dtype == "float32": + # If mixed precision is used, cast inputs to float32 so that + # this is at least as numerically stable as the fused version. + inputs = tf.cast(inputs, "float32") + + # Calculate the moments on the last axis (layer activations). + mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True) + + scale, offset = _broadcast(self.gamma), _broadcast(self.beta) + + # Compute layer normalization using the batch_normalization + # function. + outputs = tf.nn.batch_normalization( + inputs, + mean, + variance, + offset=offset, + scale=scale, + variance_epsilon=self.epsilon, + ) + outputs = tf.cast(outputs, input_dtype) + else: + # Collapse dims before self.axis, and dims in self.axis + + axis = sorted(self.axis) + tensor_shape = tf.shape(inputs) + pre_dim = tf.reduce_prod(tensor_shape[:axis[0]]) + in_dim = tf.reduce_prod(tensor_shape[axis[0]:]) + squeezed_shape = [1, pre_dim, in_dim, 1] + # This fused operation requires reshaped inputs to be NCHW. + data_format = "NCHW" + + inputs = tf.reshape(inputs, squeezed_shape) + + # self.gamma and self.beta have the wrong shape for + # fused_batch_norm, so we cannot pass them as the scale and offset + # parameters. Therefore, we create two constant tensors in correct + # shapes for fused_batch_norm and later construct a separate + # calculation on the scale and offset. + scale = tf.ones([pre_dim], dtype=self.dtype) + offset = tf.zeros([pre_dim], dtype=self.dtype) + + # Compute layer normalization using the fused_batch_norm function. + outputs, _, _ = tf.compat.v1.nn.fused_batch_norm( + inputs, + scale=scale, + offset=offset, + epsilon=self.epsilon, + data_format=data_format, + ) + + outputs = tf.reshape(outputs, tensor_shape) + + scale, offset = _broadcast(self.gamma), _broadcast(self.beta) + + if scale is not None: + outputs = outputs * tf.cast(scale, outputs.dtype) + if offset is not None: + outputs = outputs + tf.cast(offset, outputs.dtype) + + # If some components of the shape got lost due to adjustments, fix that. + outputs = tf.reshape(outputs, input_shape) + + if is_ragged: + outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths) + return outputs diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization_test.py new file mode 100644 index 000000000..59dfdfc87 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/dynamic_layer_normalization_test.py @@ -0,0 +1,59 @@ +import numpy as np + +import tensorflow as tf +from tensorflow_recommenders_addons import dynamic_embedding as de + + +class DynamicLayerNormalizationTest(tf.test.TestCase): + + def test_dynamic_shape_support(self): + input_data = tf.keras.Input(shape=(None, 10), dtype=tf.float32) + layer = de.keras.layers.LayerNormalization() + output = layer(input_data) + + model = tf.keras.models.Model(inputs=input_data, outputs=output) + + np.random.seed(0) + test_data = np.random.randn(2, 5, 10).astype(np.float32) + output_data = model.predict(test_data) + self.assertAllEqual(output_data.shape, (2, 5, 10)) + + expected_mean = np.mean(test_data, axis=-1, keepdims=True) + expected_std = np.std(test_data, axis=-1, keepdims=True) + expected_normalized = (test_data - expected_mean) / (expected_std + + layer.epsilon) + + # Calculate expected output considering gamma and beta are default (i.e., gamma=1, beta=0) + # 1e-3 is the default value for epsilon in LayerNormalization + self.assertAllClose(output_data, expected_normalized, rtol=1e-3, atol=1e-3) + + def test_training_with_layer_normalization(self): + input_dim = 10 + num_samples = 100 + output_dim = 1 + + np.random.seed(0) + features = np.random.randn(num_samples, input_dim).astype(np.float32) + labels = (np.sum(features, axis=1) + + np.random.randn(num_samples) * 0.5).astype(np.float32).reshape( + -1, 1) + + input_data = tf.keras.Input(shape=(input_dim,), dtype=tf.float32) + normalized = de.keras.layers.LayerNormalization()(input_data) + output = tf.keras.layers.Dense(output_dim)(normalized) + model = tf.keras.models.Model(inputs=input_data, outputs=output) + + model.compile(optimizer='adam', loss='mean_squared_error') + initial_weights = [layer.get_weights() for layer in model.layers] + + model.fit(features, labels, epochs=5, batch_size=10, verbose=0) + + updated_weights = [layer.get_weights() for layer in model.layers] + + for initial, updated in zip(initial_weights, updated_weights): + for ini_w, upd_w in zip(initial, updated): + self.assertGreater(np.sum(np.abs(ini_w - upd_w)), 0) + + predictions = model.predict(features) + self.assertAllEqual(predictions.shape, (num_samples, output_dim)) + self.assertGreater(np.std(predictions), 0.1) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py index 4233e6566..65fc8be54 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/models.py @@ -127,11 +127,11 @@ def _traverse_emb_layers_and_save(hvd_rank=0): proc_size=hvd.size(), proc_rank=hvd.rank()) - _check_saveable_and_redirect_new_de_dir(hvd.rank()) if hvd is None: call_original_save_func() _traverse_emb_layers_and_save(0) else: + _check_saveable_and_redirect_new_de_dir(hvd.rank()) if hvd.rank() == 0: call_original_save_func() _traverse_emb_layers_and_save(hvd.rank())