Skip to content

Commit

Permalink
Add log_every_n_steps in training
Browse files Browse the repository at this point in the history
  • Loading branch information
guoli-yin committed Jul 9, 2024
1 parent e13d41a commit 4004d80
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class Config(Module.Config):
# An optional recorder for measuring common metrics like step time.
recorder: Optional[InstantiableConfig[measurement.Recorder]] = None

# The frequency of logging during training. By default, it will use 100.
log_every_n_steps: Optional[int] = None

def __init__(
self,
cfg: Config,
Expand Down Expand Up @@ -425,7 +428,7 @@ def run(
the specific `metric_calculator` config of the evaler.
"""
with self._watchdog(), self.mesh(), jax.log_compiles(self.vlog_is_on(1)):
cfg = self.config
cfg: SpmdTrainer.Config = self.config
# Check if need to force run evals at the last training step.
force_run_eval_sets_at_max_step = self._should_force_run_evals(
return_evaler_summaries=return_evaler_summaries, evalers=cfg.evalers
Expand All @@ -435,6 +438,8 @@ def run(
if not self._prepare_training(prng_key):
return None

# Set log_every_n_steps.
log_every_n_steps = cfg.log_every_n_steps or 100
with self.checkpointer:
logging.info("Starting loop...")
start_time = time.perf_counter()
Expand All @@ -461,7 +466,7 @@ def run(
)
self.vlog(3, "Done step %s", self.step)
num_steps += 1
if num_steps % 100 == 0:
if num_steps % log_every_n_steps == 0:
now = time.perf_counter()
average_step_time = (now - start_time) / num_steps
self._step_log("Average step time: %s seconds", average_step_time)
Expand Down Expand Up @@ -800,12 +805,13 @@ def _run_step(
A dict containing 'loss' and 'aux' outputs. If force_run_evals is a set,
force run the evalers in the set and return 'evaler_summaries' output.
"""
cfg: SpmdTrainer.Config = self.config
with jax.profiler.StepTraceAnnotation("train", step_num=self.step):
# Note(Jan 2022):
# pjit currently requires all parameters to be specified as positional args.
self._trainer_state, outputs = self._jit_train_step(self._trainer_state, input_batch)

if self.step % 100 == 0 or 0 <= self.step <= 5:
log_every_n_steps = cfg.log_every_n_steps or 100
if self.step % log_every_n_steps or 0 <= self.step <= 5:
self._step_log(
"loss=%s aux=%s",
outputs["loss"],
Expand Down

0 comments on commit 4004d80

Please sign in to comment.