Skip to content

Commit

Permalink
tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Sep 6, 2024
1 parent 9384a0a commit 0ddd23b
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_rank() -> int:
'Embedding size for users and movies')
flags.DEFINE_integer('test_steps', 128, 'test steps.')
flags.DEFINE_integer('test_batch', 1024, 'test batch size.')
flags.DEFINE_integer('profiles', 10, 'number of profiles')
flags.DEFINE_bool('shuffle', True, 'shuffle dataset.')
FLAGS = flags.FLAGS

Expand Down Expand Up @@ -638,7 +639,6 @@ def train():
if os.path.exists(FLAGS.model_dir + '/variables'):
model.load_weights(FLAGS.model_dir)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir)
save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
ckpt_callback = de.keras.callbacks.ModelCheckpoint(
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
Expand All @@ -651,9 +651,20 @@ def train():
else:
callbacks_list = [ckpt_callback]

def get_profiling_batch(total_batches: int, num_profiles: int):
interval = total_batches // num_profiles
if interval == 0:
return None
else:
return ((i + 1) * interval for i in range(num_profiles))

# The log class callback only takes effect in rank0 for convenience
if get_rank() == 0:
callbacks_list.extend([tensorboard_callback])
profile_batch = get_profiling_batch(FLAGS.steps_per_epoch, FLAGS.profiles)
if profile_batch:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir="logs/profile", update_freq=100, profile_batch=(50, 100))
callbacks_list.extend([tensorboard_callback])
# If there are callbacks such as evaluation metrics that call model calculations, take effect on all ranks.
# callbacks_list.extend([my_auc_callback])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# optimal performance
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'

tf.config.optimizer.set_jit(True)

def has_horovod() -> bool:
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'PMI_RANK' in os.environ
Expand Down Expand Up @@ -205,26 +206,53 @@ def concat_tensors(tensors):
concatenated = tf.concat(flat_values, axis=0)
return concatenated, row_lengths


def concat_embedding(tensors, embeddings, row_lengths):
offset = 0
results = []
indices = [0]
for length in row_lengths: # Compute start indices for each segment
sum = tf.reduce_sum(length).numpy()
indices.append(indices[-1] + sum)
emb_shape = embeddings.shape[-1]
for i, tensor in enumerate(tensors):
# Calculate the start and end indices for the current tensor
start_index = indices[i]
end_index = indices[i+1]
begin = tf.constant([start_index, 0])
end = tf.constant([end_index, emb_shape])
emb = tf.strided_slice(
embeddings, begin, end)
if isinstance(tensor, tf.RaggedTensor):
orignal = tf.RaggedTensor.from_row_lengths(emb, row_lengths[i])
emb = tf.reduce_mean(orignal, axis=1)
results.append(emb)
return tf.concat(results, axis=0)

def concat_embedding_slow(tensors, embeddings, row_lengths):
results = []
start_indices = [0]
for length in row_lengths[:-1]: # Compute start indices for each segment
start_indices.append(start_indices[-1] + length)

for i, tensor in enumerate(tensors):
if isinstance(tensor, tf.RaggedTensor):
count = tf.shape(tensor.flat_values)[0]
ragged_embeddings = tf.RaggedTensor.from_row_lengths(
embeddings[offset:offset + count], row_lengths[i])
pooled = tf.reduce_mean(ragged_embeddings, axis=1)
# Calculate the start and end indices for the current tensor
start_index = start_indices[i]
end_index = start_index + row_lengths[i]

# Extract embeddings using tf.strided_slice
sliced_embeddings = tf.strided_slice(
embeddings, [start_index, 0], [end_index, embeddings.shape[-1]])

# Pool the embeddings
pooled = tf.reduce_mean(sliced_embeddings, axis=0)
results.append(pooled)
offset += count
else:
count = tf.shape(tensor)[0]
pooled_embeddings = tf.reshape(embeddings[offset:offset + count],
pooled_embeddings = tf.reshape(embeddings[start_indices[i]:start_indices[i] + count],
[-1, embeddings.shape[-1]])
results.append(pooled_embeddings)
offset += count

return tf.concat(results, axis=1)
return tf.concat(results, axis=0)


def get_kv_creator(mpi_size: int,
Expand Down
1 change: 1 addition & 0 deletions demo/dynamic_embedding/seq_and_dense/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
rm -rf ./export_dir
gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
export gpu_num
export TF_XLA_FLAGS=--tf_xla_auto_jit=2
horovodrun -np $gpu_num python seq_and_dense.py --mode="train" --model_dir="./model_dir" --export_dir="./export_dir" \
--steps_per_epoch=${1:-20000} --shuffle=${2:-True}

0 comments on commit 0ddd23b

Please sign in to comment.