diff --git a/src/train.py b/src/train.py index b44c105..23aef48 100644 --- a/src/train.py +++ b/src/train.py @@ -274,6 +274,8 @@ def train( except KeyboardInterrupt: with accelerator.main_process_first(): logger.error("KeyboardInterrupt") + if not (max_steps % eval_steps == 0): + trainer.evaluate() if __name__ == "__main__":