Skip to content

Commit

Permalink
Add a signal handler in Python
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jul 3, 2023
1 parent 65e3822 commit 779413b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 21 deletions.
29 changes: 18 additions & 11 deletions docs/examples/good_practices/checkpointing/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,14 @@ repository.
#SBATCH --ntasks-per-node=1
#SBATCH --mem=16G
#SBATCH --time=00:15:00
-
+#SBATCH --requeue
+#SBATCH --signal=B:TERM@300 # tells the controller to send SIGTERM to the job 5
+ # min before its time ends to give it a chance for
+ # better cleanup. If you cancel the job manually,
+ # make sure that you specify the signal as TERM like
+ # so `scancel --signal=TERM <jobid>`.
+ # https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/
+
+# trap the signal to the main BATCH script here.
+sig_handler()
+{
+ echo "BATCH interrupted"
+ wait # wait for all children, this is important!
+}
+
+trap 'sig_handler' SIGINT SIGTERM SIGCONT
# Echo time and hostname into log
echo "Date: $(date)"
Expand Down Expand Up @@ -103,8 +94,10 @@ repository.
import os
+import random
+import shutil
+import signal
+from logging import getLogger as get_logger
from pathlib import Path
+from types import FrameType
+from typing import Any, TypedDict
+import numpy
Expand Down Expand Up @@ -230,8 +223,22 @@ repository.
- # Checkout the "checkpointing and preemption" example for more info!
- logger.debug("Starting training from scratch.")
-
+ def signal_handler(signum: int, frame: FrameType | None):
+ """Called before the job gets pre-empted or reaches the time-limit.
+
+ This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
+ """
+ signal_enum = signal.Signals(signum)
+ logger.error(f"Job received a {signal_enum.name} signal!")
+ # Perform quick actions that will help the job resume later.
+ # If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
+ # if wandb.run:
+ # wandb.mark_preempting()
- for epoch in range(training_epochs):
+ signal.signal(signal.SIGTERM, signal_handler) # Before getting pre-empted and requeued.
+ signal.signal(signal.SIGUSR1, signal_handler) # Before reaching the end of the time limit.
+
+ for epoch in range(start_epoch, training_epochs):
logger.debug(f"Starting epoch {epoch}/{training_epochs}")
Expand Down
10 changes: 0 additions & 10 deletions docs/examples/good_practices/checkpointing/job.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
# so `scancel --signal=TERM <jobid>`.
# https://dhruveshp.com/blog/2021/signal-propagation-on-slurm/

# trap the signal to the main BATCH script here.
sig_handler()
{
echo "BATCH interrupted"
wait # wait for all children, this is important!
}

trap 'sig_handler' SIGINT SIGTERM SIGCONT


# Echo time and hostname into log
echo "Date: $(date)"
echo "Hostname: $(hostname)"
Expand Down
17 changes: 17 additions & 0 deletions docs/examples/good_practices/checkpointing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import os
import random
import shutil
import signal
from logging import getLogger as get_logger
from pathlib import Path
from types import FrameType
from typing import Any, TypedDict

import numpy
Expand Down Expand Up @@ -126,6 +128,21 @@ def main():
shuffle=False,
)

def signal_handler(signum: int, frame: FrameType | None):
"""Called before the job gets pre-empted or reaches the time-limit.
This should run quickly. Performing a full checkpoint here mid-epoch is not recommended.
"""
signal_enum = signal.Signals(signum)
logger.error(f"Job received a {signal_enum.name} signal!")
# Perform quick actions that will help the job resume later.
# If you use Weights & Biases: https://docs.wandb.ai/guides/runs/resuming#preemptible-sweeps
# if wandb.run:
# wandb.mark_preempting()

signal.signal(signal.SIGTERM, signal_handler) # Before getting pre-empted and requeued.
signal.signal(signal.SIGUSR1, signal_handler) # Before reaching the end of the time limit.

for epoch in range(start_epoch, training_epochs):
logger.debug(f"Starting epoch {epoch}/{training_epochs}")

Expand Down

0 comments on commit 779413b

Please sign in to comment.