Skip to content

Commit

Permalink
tensorboard
Browse files Browse the repository at this point in the history
  • Loading branch information
jq committed Sep 5, 2024
1 parent 9384a0a commit d2f7506
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_rank() -> int:
'Embedding size for users and movies')
flags.DEFINE_integer('test_steps', 128, 'test steps.')
flags.DEFINE_integer('test_batch', 1024, 'test batch size.')
flags.DEFINE_integer('profiles', 10, 'number of profiles')
flags.DEFINE_bool('shuffle', True, 'shuffle dataset.')
FLAGS = flags.FLAGS

Expand Down Expand Up @@ -638,7 +639,6 @@ def train():
if os.path.exists(FLAGS.model_dir + '/variables'):
model.load_weights(FLAGS.model_dir)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=FLAGS.model_dir)
save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
ckpt_callback = de.keras.callbacks.ModelCheckpoint(
filepath=FLAGS.model_dir + '/weights_epoch{epoch:03d}_loss{loss:.4f}',
Expand All @@ -651,9 +651,20 @@ def train():
else:
callbacks_list = [ckpt_callback]

def get_profiling_batch(total_batches: int, num_profiles: int):
interval = total_batches // num_profiles
if interval == 0:
return None
else:
return ((i + 1) * interval for i in range(num_profiles))

# The log class callback only takes effect in rank0 for convenience
if get_rank() == 0:
callbacks_list.extend([tensorboard_callback])
profile_batch = get_profiling_batch(FLAGS.steps_per_epoch, FLAGS.profiles)
if profile_batch:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir="logs/profile", update_freq=100, profile_batch=(50, 100))
callbacks_list.extend([tensorboard_callback])
# If there are callbacks such as evaluation metrics that call model calculations, take effect on all ranks.
# callbacks_list.extend([my_auc_callback])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
# optimal performance
os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit'

tf.config.optimizer.set_jit(True)

def has_horovod() -> bool:
return 'OMPI_COMM_WORLD_RANK' in os.environ or 'PMI_RANK' in os.environ
Expand Down
1 change: 1 addition & 0 deletions demo/dynamic_embedding/seq_and_dense/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
rm -rf ./export_dir
gpu_num=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
export gpu_num
export TF_XLA_FLAGS=--tf_xla_auto_jit=2
horovodrun -np $gpu_num python seq_and_dense.py --mode="train" --model_dir="./model_dir" --export_dir="./export_dir" \
--steps_per_epoch=${1:-20000} --shuffle=${2:-True}

0 comments on commit d2f7506

Please sign in to comment.