-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MultiWorkerMirroredStrategy
distributed training strategy for dynamic embeddings
#365
Comments
Hi @sivukhin, thank you for the feedback! We will give a resolution after the discussion. Thank you! |
Hi @sivukhin, because of resource lock of TF, the MirroredStrategy for TFRA multi-table is not efficient. We recommend using Horovod for distributed training. Line 528 in 6f7bbb8
https://github.com/tensorflow/recommenders-addons/blob/master/demo/dynamic_embedding/movielens-1m-keras-with-horovod/movielens-1m-keras-with-horovod.py Or you could have helped us improve the code so that each MirroredStrategy worker created its own DEVariale object to hold its own table, and interacted with communication operators just like HvdAllToAllEmbedding as well. |
@MoFHeka, thanks for quick reply! For now I only encountered the problem that latest version of TFRA (0.6.0 on PyPI) doesn't have |
Yes, it is not a released feature, but you can try to install it from the source by following the guidance: https://github.com/tensorflow/recommenders-addons#installing-from-source. It is easy to do. If there is any problem, you can be helped here. |
Yes, thanks! Did
Error stack trace
Training source codeimport dataclasses
import os
from typing import Dict
import horovod.tensorflow as hvd
import tensorflow as tf
import tensorflow_datasets as tfds
# tensorflow_recommenders_addons does some patching on TensorFlow, so it MUST be imported after importing TF
import tensorflow_recommenders as tfrs
import tensorflow_recommenders_addons as tfra
from tensorflow_recommenders_addons import dynamic_embedding as de
hvd.init()
redis_config = tfra.dynamic_embedding.RedisTableConfig(redis_config_abs_dir="redis.config")
redis_creator = tfra.dynamic_embedding.RedisTableCreator(redis_config)
cuckoo_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver(proc_size=1, proc_rank=0))
batch_size = 4096
seed = 2023
@dataclasses.dataclass(frozen=True)
class TrainingDatasets:
train_ds: tf.data.Dataset
validation_ds: tf.data.Dataset
@dataclasses.dataclass(frozen=True)
class RetrievalDatasets:
training_datasets: TrainingDatasets
candidate_dataset: tf.data.Dataset
def create_datasets():
def split_train_validation_datasets(ratings_dataset: tf.data.Dataset) -> TrainingDatasets:
train_size = int(len(ratings_dataset) * 0.9)
validation_size = len(ratings_dataset) - train_size
print(f"Train size: {train_size}")
print(f"Validation size: {validation_size}")
shuffled_dataset = ratings_dataset.shuffle(buffer_size=5 * batch_size, seed=seed)
train_ds = shuffled_dataset.skip(validation_size).shuffle(buffer_size=10 * batch_size).apply(lambda dataset: dataset.padded_batch(batch_size))
validation_ds = shuffled_dataset.take(validation_size).apply(lambda dataset: dataset.padded_batch(batch_size))
return TrainingDatasets(train_ds=train_ds, validation_ds=validation_ds)
ratings_dataset = tfds.load("movielens/100k-ratings", split="train").map(lambda x: {
'user_id': tf.strings.to_number(x["user_id"], tf.int64),
'movie_id': tf.strings.to_number(x["movie_id"], tf.int64)
})
movies_dataset = tfds.load("movielens/100k-movies", split="train").map(lambda x: tf.strings.to_number(x["movie_id"], tf.int64))
for item in ratings_dataset.take(3):
print(item)
for item in movies_dataset.take(3):
print(item)
training_datasets = split_train_validation_datasets(ratings_dataset)
return RetrievalDatasets(training_datasets=training_datasets, candidate_dataset=movies_dataset.padded_batch(batch_size))
class TwoTowerModel(tfrs.Model):
def __init__(self, user_model: tf.keras.Model, item_model: tf.keras.Model, task: tfrs.tasks.Retrieval):
super().__init__()
self.user_model = user_model
self.item_model = item_model
self.task = task
def compute_loss(self, features: Dict[str, tf.Tensor], training=False) -> tf.Tensor:
user_embeddings = self.user_model(features["user_id"])
movie_embeddings = self.item_model(features["movie_id"])
return self.task(user_embeddings, movie_embeddings)
def build_two_tower_model(candidate_dataset: tf.data.Dataset) -> tf.keras.Model:
user_model = tf.keras.Sequential([
de.keras.layers.HvdAllToAllEmbedding(
embedding_size=64,
key_dtype=tf.int64,
value_dtype=tf.float32,
initializer=tf.random_uniform_initializer(),
init_capacity=100_000,
restrict_policy=de.FrequencyRestrictPolicy,
name="user-embedding",
kv_creator=redis_creator,
),
tf.keras.layers.Dense(64, activation="gelu"),
tf.keras.layers.Dense(32),
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
], name='user_model')
item_model = tf.keras.models.Sequential([
de.keras.layers.HvdAllToAllEmbedding(
embedding_size=64,
key_dtype=tf.int64,
value_dtype=tf.float32,
initializer=tf.random_uniform_initializer(),
init_capacity=100_000,
restrict_policy=de.FrequencyRestrictPolicy,
name="movie-embedding",
kv_creator=redis_creator,
),
tf.keras.layers.Dense(64, activation="gelu"),
tf.keras.layers.Dense(32),
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
], name='movie_model')
model = TwoTowerModel(user_model, item_model, task=tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(candidate_dataset.map(item_model))
))
optimize = de.DynamicEmbeddingOptimizer(tf.keras.optimizers.Adam())
model.compile(optimizer=optimize)
return model
def train_multi_worker():
datasets = create_datasets()
model_dir = f'model_dir_{hvd.rank()}'
print(f'model_dir: {model_dir}')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_dir)
broadcast_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(root_rank=0)
checkpoint_callback = de.keras.callbacks.DEHvdModelCheckpoint(
filepath=model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
options=tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
)
callbacks_list = [tensorboard_callback, broadcast_callback, checkpoint_callback]
model = build_two_tower_model(datasets.candidate_dataset)
history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
print(history)
if __name__ == '__main__':
train_multi_worker() I also tried to launch demo from |
Of course HvdAllToAllEmbedding supports training on CPU. I ran your code successfully with Also, if the error that your GPU doesn't work is from horovod's all2all operator, it may have been caused by a third package, which I suspect is tensorflow_recommenders. Because I also failed to run your code on the GPU. One more thing, if you want to train on CPU, parameter server is your best choice. MirrorStrategy is much more efficient on a single multi-GPU machine. Training source codeimport dataclasses
import os
from typing import Dict
import horovod.tensorflow as hvd
import tensorflow as tf
import tensorflow_datasets as tfds
# tensorflow_recommenders_addons does some patching on TensorFlow, so it MUST be imported after importing TF
import tensorflow_recommenders as tfrs
import tensorflow_recommenders_addons as tfra
from tensorflow_recommenders_addons import dynamic_embedding as de
hvd.init()
redis_config = tfra.dynamic_embedding.RedisTableConfig(redis_config_abs_dir="redis.config")
redis_creator = tfra.dynamic_embedding.RedisTableCreator(redis_config)
cuckoo_creator = de.CuckooHashTableCreator(saver=de.FileSystemSaver(proc_size=hvd.size(), proc_rank=hvd.rank()))
batch_size = 4096
seed = 2023
@dataclasses.dataclass(frozen=True)
class TrainingDatasets:
train_ds: tf.data.Dataset
validation_ds: tf.data.Dataset
@dataclasses.dataclass(frozen=True)
class RetrievalDatasets:
training_datasets: TrainingDatasets
candidate_dataset: tf.data.Dataset
def create_datasets():
def split_train_validation_datasets(ratings_dataset: tf.data.Dataset) -> TrainingDatasets:
train_size = int(len(ratings_dataset) * 0.9)
validation_size = len(ratings_dataset) - train_size
print(f"Train size: {train_size}")
print(f"Validation size: {validation_size}")
shuffled_dataset = ratings_dataset.shuffle(buffer_size=5 * batch_size, seed=seed)
train_ds = shuffled_dataset.skip(validation_size).shuffle(buffer_size=10 * batch_size).apply(lambda dataset: dataset.padded_batch(batch_size))
validation_ds = shuffled_dataset.take(validation_size).apply(lambda dataset: dataset.padded_batch(batch_size))
return TrainingDatasets(train_ds=train_ds, validation_ds=validation_ds)
ratings_dataset =tfds.load("movielens/100k-ratings",
split="train",
data_dir=".",
download=False).map(lambda x: {
'user_id': tf.strings.to_number(x["user_id"], tf.int64),
'movie_id': tf.strings.to_number(x["movie_id"], tf.int64)
})
movies_dataset = tfds.load("movielens/100k-ratings",
split="train",
data_dir=".",
download=False).map(lambda x: tf.strings.to_number(x["movie_id"], tf.int64))
for item in ratings_dataset.take(3):
print(item)
for item in movies_dataset.take(3):
print(item)
training_datasets = split_train_validation_datasets(ratings_dataset)
return RetrievalDatasets(training_datasets=training_datasets, candidate_dataset=movies_dataset.padded_batch(batch_size))
class TwoTowerModel(tfrs.Model):
def __init__(self, user_model: tf.keras.Model, item_model: tf.keras.Model, task: tfrs.tasks.Retrieval):
super().__init__()
self.user_model = user_model
self.item_model = item_model
self.task = task
def compute_loss(self, features: Dict[str, tf.Tensor], training=False) -> tf.Tensor:
user_embeddings = self.user_model(features["user_id"])
movie_embeddings = self.item_model(features["movie_id"])
return self.task(user_embeddings, movie_embeddings)
def build_two_tower_model(candidate_dataset: tf.data.Dataset) -> tf.keras.Model:
user_model = tf.keras.Sequential([
de.keras.layers.HvdAllToAllEmbedding(
embedding_size=64,
key_dtype=tf.int64,
value_dtype=tf.float32,
initializer=tf.random_uniform_initializer(),
init_capacity=100_000,
restrict_policy=de.FrequencyRestrictPolicy,
name="user-embedding",
devices=['CPU'],
kv_creator=cuckoo_creator,
),
tf.keras.layers.Dense(64, activation="gelu"),
tf.keras.layers.Dense(32),
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
], name='user_model')
item_model = tf.keras.models.Sequential([
de.keras.layers.HvdAllToAllEmbedding(
embedding_size=64,
key_dtype=tf.int64,
value_dtype=tf.float32,
initializer=tf.random_uniform_initializer(),
init_capacity=100_000,
restrict_policy=de.FrequencyRestrictPolicy,
name="movie-embedding",
devices=['CPU'],
kv_creator=cuckoo_creator,
),
tf.keras.layers.Dense(64, activation="gelu"),
tf.keras.layers.Dense(32),
tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
], name='movie_model')
model = TwoTowerModel(user_model, item_model, task=tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(candidate_dataset.map(item_model))
))
optimize = de.DynamicEmbeddingOptimizer(tf.keras.optimizers.Adam())
model.compile(optimizer=optimize)
return model
def train_multi_worker():
datasets = create_datasets()
model_dir = f'model_dir_{hvd.rank()}'
print(f'model_dir: {model_dir}')
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_dir)
broadcast_callback = de.keras.callbacks.DEHvdBroadcastGlobalVariablesCallback(root_rank=0)
checkpoint_callback = de.keras.callbacks.DEHvdModelCheckpoint(
filepath=model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
options=tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
)
callbacks_list = [tensorboard_callback, broadcast_callback, checkpoint_callback]
model = build_two_tower_model(datasets.candidate_dataset)
history = model.fit(datasets.training_datasets.train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks_list, verbose=1)
print(history)
if __name__ == '__main__':
train_multi_worker() |
Hm..ok. I still got this weird error about inconsistent shapes (with Redis & Cuckoo) - but maybe need to dig more into it... UPD: I looked more closely on your sample code and found difference in
Why parameter server is better for CPU? We have very simple model with very few weights (apart from embedding table). I thought that multi-worker strategy will be more efficient as it will require only rare communication between workers in order to accumulate updated gradients. With parameter server I just not sure what will be stored on them... If all dense weights will be there - it seems like this can create huge communication overhead, no? My initial thought were that I can just train dense weights independently on multiple workers (to provide high throughput) and use Redis as an external storage for embedding table. In my head this setup will imply following communication for single worker:
If there is a way to control frequency of sync between worker and Redis & frequency of inter-worker communication - I thought that this scheme can work for pretty high load scenarios (with low frequency of syncs we will trade convergence rate for throughput - which looks fine for me at the moment)... |
Ring-AllReduce vs Parameter ServerThe lower communication time overhead of multi-worker strategy is based on synchronous training. If many CPU nodes are trained asynchronously with a small batch size, parameter server can complete the training of all samples faster under a specific cluster size. Semi-Synchronous Training = Ring-AllReduce + Parameter ServerAnother method is semi-synchronous training, the parameters of the dense layer are synchronized by horovod, but the parameters of the embedding are trained asynchronously by PS. You can refer to: semi-synchronous training with TF1 API. Although this demo uses the TF1 API, the principles used in TF2 are similar. Semi-Synchronous Training with RedisRedis is used as a serving, although you can definitely use it as a alternative solution for training purposes. If you want to use Redis Embedding in horovod synchronization training, use the normal Embedding layer instead of HvdAllToAllEmbedding. In addition, enabling bp_v2 may improve the model convergence effect(not guaranteed), and the bp_v2 function of redis requires another compilation of Redis module. |
Thanks @MoFHeka, got it! |
@sivukhin |
I tried to explore available approaches for distributed training of large-scale recommendation models with huge embedding tables and tried to use TFRA
DynamicEmbedding
combined withMultiWorkerMirroredStrategy
.MultiWorkerMirroredStrategy
can suite my needs because model will have very small volume of parameters apart from the embeddings - so we can replicate them across all workersIt seems like current implementation struggle with
MultiWorkerMirroredStrategy
. My attempts to make it works failed with following error:I tried to launch following training code on 2 workers with following commands:
Source code
Redis configuration
Relevant information
Which API type would this fall under (layer, metric, optimizer, etc.)
model.fit
Who will benefit with this feature?
The text was updated successfully, but these errors were encountered: