Skip to content

Commit

Permalink
add normalized_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Jul 5, 2024
1 parent a8642aa commit b8906af
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from absl import flags
from absl import app

from tensorflow_recommenders_addons import dynamic_embedding as de

os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" #VERY IMPORTANT!
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"
# Because of the two environment variables above no non-standard library imports should happen before this.
Expand Down Expand Up @@ -371,7 +373,7 @@ def __init__(self,
embedding_initializer=embedding_initializer,
mpi_size=mpi_size,
mpi_rank=mpi_rank)

self.dynamic_layer_norm = de.keras.layers.LayerNormalization()
self.dnn1 = tf.keras.layers.Dense(
64,
activation='relu',
Expand Down Expand Up @@ -427,22 +429,23 @@ def call(self, features):
for key, value in feature_info_spec.items()
if key in user_fea
}
user_latent = self.user_embedding(user_fea_info)
movie_fea = ['movie_id', 'movie_genres', 'user_occupation_label']
movie_fea = [i for i in features.keys() if i in movie_fea]
movie_fea_info = {
key: value
for key, value in feature_info_spec.items()
if key in movie_fea
}
user_latent = self.user_embedding(user_fea_info)
movie_latent = self.movie_embedding(movie_fea_info)
latent = tf.concat([user_latent, movie_latent], axis=1)

x = self.dnn1(latent)
normalized_emb = self.dynamic_layer_norm(latent)
x = self.dnn1(normalized_emb)
x = self.dnn2(x)
x = self.dnn3(x)

bias = self.bias_net(latent)
bias = self.bias_net(normalized_emb)
x = 0.2 * x + 0.8 * bias
user_rating = tf.keras.layers.Lambda(lambda x: x, name='user_rating')(x)
return {'user_rating': user_rating}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import FieldWiseEmbedding
from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import SquashedEmbedding
from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.embedding import HvdAllToAllEmbedding
from tensorflow_recommenders_addons.dynamic_embedding.python.keras.layers.dynamic_layer_normalization import LayerNormalization
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization as TFLayerNormalization


class LayerNormalization(TFLayerNormalization):

def call(self, inputs):
# TODO(b/229545225): Remove the RaggedTensor check.
is_ragged = isinstance(inputs, tf.RaggedTensor)
if is_ragged:
inputs_lengths = inputs.nested_row_lengths()
inputs = inputs.to_tensor()
inputs = tf.cast(inputs, self.compute_dtype)
# Compute the axes along which to reduce the mean / variance
input_shape = tf.shape(inputs)
# Get the number of dimensions dynamically
ndims = input_shape.shape[0]

# Broadcasting only necessary for norm when the axis is not just
# the last dimension
broadcast_shape = [1] * ndims
for dim in self.axis:
broadcast_shape[dim] = input_shape[dim]

def _broadcast(v):
if v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]:
return tf.reshape(v, broadcast_shape)
return v

if not self._fused:
input_dtype = inputs.dtype
if input_dtype in ("float16", "bfloat16") and self.dtype == "float32":
# If mixed precision is used, cast inputs to float32 so that
# this is at least as numerically stable as the fused version.
inputs = tf.cast(inputs, "float32")

# Calculate the moments on the last axis (layer activations).
mean, variance = tf.nn.moments(inputs, self.axis, keepdims=True)

scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

# Compute layer normalization using the batch_normalization
# function.
outputs = tf.nn.batch_normalization(
inputs,
mean,
variance,
offset=offset,
scale=scale,
variance_epsilon=self.epsilon,
)
outputs = tf.cast(outputs, input_dtype)
else:
# Collapse dims before self.axis, and dims in self.axis

axis = sorted(self.axis)
tensor_shape = tf.shape(inputs)
pre_dim = tf.reduce_prod(tensor_shape[:axis[0]])
in_dim = tf.reduce_prod(tensor_shape[axis[0]:])
squeezed_shape = [1, pre_dim, in_dim, 1]
# This fused operation requires reshaped inputs to be NCHW.
data_format = "NCHW"

inputs = tf.reshape(inputs, squeezed_shape)

# self.gamma and self.beta have the wrong shape for
# fused_batch_norm, so we cannot pass them as the scale and offset
# parameters. Therefore, we create two constant tensors in correct
# shapes for fused_batch_norm and later construct a separate
# calculation on the scale and offset.
scale = tf.ones([pre_dim], dtype=self.dtype)
offset = tf.zeros([pre_dim], dtype=self.dtype)

# Compute layer normalization using the fused_batch_norm function.
outputs, _, _ = tf.compat.v1.nn.fused_batch_norm(
inputs,
scale=scale,
offset=offset,
epsilon=self.epsilon,
data_format=data_format,
)

outputs = tf.reshape(outputs, tensor_shape)

scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

if scale is not None:
outputs = outputs * tf.cast(scale, outputs.dtype)
if offset is not None:
outputs = outputs + tf.cast(offset, outputs.dtype)

# If some components of the shape got lost due to adjustments, fix that.
outputs = tf.reshape(outputs, input_shape)

if is_ragged:
outputs = tf.RaggedTensor.from_tensor(outputs, inputs_lengths)
return outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np

import tensorflow as tf
from tensorflow_recommenders_addons import dynamic_embedding as de


class DynamicLayerNormalizationTest(tf.test.TestCase):

def test_dynamic_shape_support(self):
input_data = tf.keras.Input(shape=(None, 10), dtype=tf.float32)
layer = de.keras.layers.LayerNormalization()
output = layer(input_data)

model = tf.keras.models.Model(inputs=input_data, outputs=output)

np.random.seed(0)
test_data = np.random.randn(2, 5, 10).astype(np.float32)
output_data = model.predict(test_data)
self.assertAllEqual(output_data.shape, (2, 5, 10))

expected_mean = np.mean(test_data, axis=-1, keepdims=True)
expected_std = np.std(test_data, axis=-1, keepdims=True)
expected_normalized = (test_data - expected_mean) / (expected_std +
layer.epsilon)

# Calculate expected output considering gamma and beta are default (i.e., gamma=1, beta=0)
# 1e-3 is the default value for epsilon in LayerNormalization
self.assertAllClose(output_data, expected_normalized, rtol=1e-3, atol=1e-3)

def test_training_with_layer_normalization(self):
input_dim = 10
num_samples = 100
output_dim = 1

np.random.seed(0)
features = np.random.randn(num_samples, input_dim).astype(np.float32)
labels = (np.sum(features, axis=1) +
np.random.randn(num_samples) * 0.5).astype(np.float32).reshape(
-1, 1)

input_data = tf.keras.Input(shape=(input_dim,), dtype=tf.float32)
normalized = de.keras.layers.LayerNormalization()(input_data)
output = tf.keras.layers.Dense(output_dim)(normalized)
model = tf.keras.models.Model(inputs=input_data, outputs=output)

model.compile(optimizer='adam', loss='mean_squared_error')
initial_weights = [layer.get_weights() for layer in model.layers]

model.fit(features, labels, epochs=5, batch_size=10, verbose=0)

updated_weights = [layer.get_weights() for layer in model.layers]

for initial, updated in zip(initial_weights, updated_weights):
for ini_w, upd_w in zip(initial, updated):
self.assertGreater(np.sum(np.abs(ini_w - upd_w)), 0)

predictions = model.predict(features)
self.assertAllEqual(predictions.shape, (num_samples, output_dim))
self.assertGreater(np.std(predictions), 0.1)
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def _traverse_emb_layers_and_save(hvd_rank=0):
proc_size=hvd.size(),
proc_rank=hvd.rank())

_check_saveable_and_redirect_new_de_dir(hvd.rank())
if hvd is None:
call_original_save_func()
_traverse_emb_layers_and_save(0)
else:
_check_saveable_and_redirect_new_de_dir(hvd.rank())
if hvd.rank() == 0:
call_original_save_func()
_traverse_emb_layers_and_save(hvd.rank())
Expand Down

0 comments on commit b8906af

Please sign in to comment.