diff --git a/docs/Usage/Fine-tuning.md b/docs/Usage/Fine-tuning.md new file mode 100644 index 0000000..1898c9f --- /dev/null +++ b/docs/Usage/Fine-tuning.md @@ -0,0 +1,65 @@ +# Fine-tuning +MemBrain-seg is built to function optimally out-of-the-box, eliminating the need for most users to train the model themselves. + +However, if your tomograms differ significantly from the images used in our training dataset, fine-tuning the model on your own data may enhance performance. In this case, it can make sense to [annotate](./Annotations.md) several patches extracted from your tomogram, and fine-tune the pretrained MemBrain-seg model using your corrected data. + +Here are some steps you can follow in order to fine-tune MemBrain-seg: + +# Step 1: Prepare your fine-tuning dataset +MemBrain-seg assumes a specific data structure for creating the fine-tuning dataloaders, which can be a smaller or corrected version of your tomograms: + +```bash +data_dir/ +├── imagesTr/ # Directory containing training images +│ ├── img1.nii.gz # Image file (currently requires nii.gz format) +│ ├── img2.nii.gz # Image file +│ └── ... +├── imagesVal/ # Directory containing validation images +│ ├── img3.nii.gz # Image file +│ ├── img4.nii.gz # Image file +│ └── ... +├── labelsTr/ # Directory containing training labels +│ ├── img1.nii.gz # Label file (currently requires nii.gz format) +│ ├── img2.nii.gz # Label file +│ └── ... +└── labelsVal/ # Directory containing validation labels + ├── img3.nii.gz # Label file + ├── img4.nii.gz # Label file + └── ... +``` + +The data_dir argument is then passed to the fine-tuning procedure (see [Step 2](#step-2-perform-fine-tuning)). + +To fine-tune the pretrained model on your own tomograms, you need to add some corrected patches from your own tomograms to improve the network's performance on these. + +You can find some instructions here: [How to create training annotations from your own tomogram?](./Annotations.md) + +# Step 2: Perform fine-tuning +Fine-tuning starts from a pretrained model checkpoint. After activating your virtual Python environment, you can type: +``` +membrain finetune +``` +to receive help with the input arguments. You will see that the two parameters you need to provide are the --pretrained-checkpoint-path and the --data-dir argument: + +``` +membrain finetune --pretrained-checkpoint-path --finetune-data-dir +``` +This command fine-tunes the pretrained MemBrain-seg model using your fine-tuning dataset. Be sure to point to the correct checkpoint path containing the pretrained weights, as well as the fine-tuning data directory. + +This is exactly the folder you prepared in [Step 1](#step-1-prepare-your-fine-tuning-dataset). + +Running this command should start the fine-tuning process and store the fine-tuned model in the ./finetuned_checkpoints folder. + +**Note:** Fine-tuning can take up to 24 hours. We therefore recommend that you perform training on a device with a CUDA-enabled GPU. + + +# Advanced settings +In case you feel fancy and would like to adjust some of the default settings of MemBrain-seg, you can also use the following command to get access to more customizable options: +``` +membrain finetune_advanced +```` +This will display all available options that can be activated or deactivated. For example, when fine-tuning, you might want to lower the learning rate compared to training from scratch to prevent the model from "forgetting" the knowledge it learned during pretraining. For more in-depth adjustments, you will need to dig into MemBrain-seg's code or contact us. + + +# Contact +If there are any problems coming up when running the code or anything else is unclear, do not hesitate to contact us (Lorenz.Lamm@helmholtz-munich.de). We are more than happy to help. diff --git a/src/membrain_seg/segmentation/cli/__init__.py b/src/membrain_seg/segmentation/cli/__init__.py index a2c078c..aad1241 100644 --- a/src/membrain_seg/segmentation/cli/__init__.py +++ b/src/membrain_seg/segmentation/cli/__init__.py @@ -2,6 +2,7 @@ # These imports are necessary to register CLI commands. Do not remove! from .cli import cli # noqa: F401 +from .fine_tune_cli import finetune # noqa: F401 from .segment_cli import segment # noqa: F401 from .ske_cli import skeletonize # noqa: F401 from .train_cli import data_dir_help, train # noqa: F401 diff --git a/src/membrain_seg/segmentation/cli/fine_tune_cli.py b/src/membrain_seg/segmentation/cli/fine_tune_cli.py new file mode 100644 index 0000000..a9b601a --- /dev/null +++ b/src/membrain_seg/segmentation/cli/fine_tune_cli.py @@ -0,0 +1,236 @@ +from typing import List, Optional + +from typer import Option +from typing_extensions import Annotated + +from ..finetune import fine_tune as _fine_tune +from .cli import OPTION_PROMPT_KWARGS as PKWARGS +from .cli import cli + + +@cli.command(name="finetune", no_args_is_help=True) +def finetune( + pretrained_checkpoint_path: str = Option( # noqa: B008 + ..., + help="Path to the checkpoint of the pre-trained model.", + **PKWARGS, + ), + finetune_data_dir: str = Option( # noqa: B008 + ..., + help='Path to the directory containing the new data for fine-tuning. \ + Following the same required structure as the train function. \ + To learn more about the required\ + data structure, type "membrain data_structure_help"', + **PKWARGS, + ), +): + """ + Initiates fine-tuning of a pre-trained model on new datasets + and validation on original datasets. + + This function fine-tunes a pre-trained model on new datasets provided by the user. + The directory specified by `finetune_data_dir` should be structured according to the + requirements for the training function. + For more details, use "membrain data_structure_help". + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint file of the pre-trained model. + finetune_data_dir : str + Directory containing the new dataset for fine-tuning, + structured as per the MemBrain's requirement. + Use "membrain data_structure_help" for detailed information + on the required data structure. + + Note + + ---- + + This command configures and executes a fine-tuning session + using the provided model checkpoint. + The actual fine-tuning logic resides in the function '_fine_tune'. + """ + finetune_learning_rate = 1e-5 + log_dir = "logs_finetune/" + batch_size = 2 + num_workers = 8 + max_epochs = 100 + early_stop_threshold = 0.05 + aug_prob_to_one = True + use_deep_supervision = True + project_name = "membrain-seg_finetune" + sub_name = "1" + + _fine_tune( + pretrained_checkpoint_path=pretrained_checkpoint_path, + finetune_data_dir=finetune_data_dir, + finetune_learning_rate=finetune_learning_rate, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + early_stop_threshold=early_stop_threshold, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + ) + + +@cli.command(name="finetune_advanced", no_args_is_help=True) +def finetune_advanced( + pretrained_checkpoint_path: str = Option( # noqa: B008 + ..., + help="Path to the checkpoint of the pre-trained model.", + **PKWARGS, + ), + finetune_data_dir: str = Option( # noqa: B008 + ..., + help='Path to the directory containing the new data for fine-tuning. \ + Following the same required structure as the train function. \ + To learn more about the required\ + data structure, type "membrain data_structure_help"', + **PKWARGS, + ), + finetune_learning_rate: float = Option( # noqa: B008 + 1e-5, + help="Learning rate for fine-tuning the model. This parameter controls the \ + step size at each iteration while moving toward a minimum loss. \ + A smaller learning rate can lead to a more precise convergence but may \ + require more epochs. Adjust based on your dataset size and complexity.", + ), + log_dir: str = Option( # noqa: B008 + "logs_fine_tune/", + help="Log directory path. Finetuning logs will be stored here.", + ), + batch_size: int = Option( # noqa: B008 + 2, + help="Batch size for training.", + ), + num_workers: int = Option( # noqa: B008 + 8, + help="Number of worker threads for data loading.", + ), + max_epochs: int = Option( # noqa: B008 + 100, + help="Maximum number of epochs for fine-tuning.", + ), + early_stop_threshold: float = Option( # noqa: B008 + 0.05, + help="Threshold for early stopping based on validation loss deviation.", + ), + aug_prob_to_one: bool = Option( # noqa: B008 + True, + help='Whether to augment with a probability of one. This helps with the \ + model\'s generalization,\ + but also severely increases training time.\ + Pass "True" or "False".', + ), + use_surface_dice: bool = Option( # noqa: B008 + False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' + ), + surface_dice_weight: float = Option( # noqa: B008 + 1.0, help="Scaling factor for the Surface-Dice loss. " + ), + surface_dice_tokens: Annotated[ + Optional[List[str]], + Option( + help='List of tokens to \ + use for the Surface-Dice loss. \ + Pass tokens separately:\ + For example, train_advanced --surface_dice_tokens "ds1" \ + --surface_dice_tokens "ds2"' + ), + ] = None, + use_deep_supervision: bool = Option( # noqa: B008 + True, help='Whether to use deep supervision. Pass "True" or "False".' + ), + project_name: str = Option( # noqa: B008 + "membrain-seg_v0_finetune", + help="Project name. This helps to find your model again.", + ), + sub_name: str = Option( # noqa: B008 + "1", + help="Subproject name. For multiple runs in the same project,\ + please specify sub_names.", + ), +): + """ + Initiates fine-tuning of a pre-trained model on new datasets + and validation on original datasets with more advanced options. + + This function finetunes a pre-trained U-Net model on new data provided by the user. + The `finetune_data_dir` should contain the following directories: + - `imagesTr` and `labelsTr` for the user's own new training data. + - `imagesVal` and `labelsVal` for the old data, which will be used + for validation to ensure that the fine-tuned model's performance + is not significantly worse on the original training data than the + pre-trained model. + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint file of the pre-trained model. + finetune_data_dir : str + Directory containing the new dataset for fine-tuning, + structured as per the MemBrain's requirement. + Use "membrain data_structure_help" for detailed information + on the required data structure. + finetune_learning_rate : float + Learning rate for fine-tuning the model. This parameter controls the step size + at each iteration while moving toward a minimum loss. A smaller learning rate + can lead to a more precise convergence but may require more epochs. + Adjust based on your dataset size and complexity. + log_dir : str + Path to the directory where logs will be stored, by default 'logs_fine_tune/'. + batch_size : int + Number of samples per batch, by default 2. + num_workers : int + Number of worker threads for data loading, by default 8. + max_epochs : int + Maximum number of fine-tuning epochs, by default 100. + early_stop_threshold : float + Threshold for early stopping based on validation loss deviation, + by default 0.05. + aug_prob_to_one : bool + Determines whether to apply very strong data augmentation, by default True. + If set to False, data augmentation still happens, but not as frequently. + More data augmentation can lead to better performance, but also increases the + training time substantially. + use_surface_dice : bool + Determines whether to use Surface-Dice loss, by default False. + surface_dice_weight : float + Scaling factor for the Surface-Dice loss, by default 1.0. + surface_dice_tokens : list + List of tokens to use for the Surface-Dice loss. + use_deep_supervision : bool + Determines whether to use deep supervision, by default True. + project_name : str + Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'. + sub_name : str + Sub-name for the project, by default '1'. + + Note + ---- + This command configures and executes a fine-tuning session + using the provided model checkpoint. + The actual fine-tuning logic resides in the function '_fine_tune'. + """ + _fine_tune( + pretrained_checkpoint_path=pretrained_checkpoint_path, + finetune_data_dir=finetune_data_dir, + finetune_learning_rate=finetune_learning_rate, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + early_stop_threshold=early_stop_threshold, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + use_surf_dice=use_surface_dice, + surf_dice_weight=surface_dice_weight, + surf_dice_tokens=surface_dice_tokens, + project_name=project_name, + sub_name=sub_name, + ) diff --git a/src/membrain_seg/segmentation/cli/ske_cli.py b/src/membrain_seg/segmentation/cli/ske_cli.py index a4fdcfe..212773a 100644 --- a/src/membrain_seg/segmentation/cli/ske_cli.py +++ b/src/membrain_seg/segmentation/cli/ske_cli.py @@ -2,7 +2,11 @@ from typer import Option -from membrain_seg.segmentation.dataloading.data_utils import store_tomogram +from membrain_seg.segmentation.dataloading.data_utils import ( + load_tomogram, + store_tomogram, +) + from ..skeletonize import skeletonization as _skeletonization from .cli import cli @@ -50,7 +54,13 @@ def skeletonize( --batch-size """ # Assuming _skeletonization function is already defined and can handle batch_size - ske = _skeletonization(label_path=label_path, batch_size=batch_size) + + segmentation = load_tomogram(label_path) + ske = _skeletonization(segmentation=segmentation.data, batch_size=batch_size) + + # Update the segmentation data with the skeletonized output while preserving the original header and voxel_size + segmentation.data = ske + if not os.path.exists(out_folder): os.makedirs(out_folder) @@ -59,6 +69,6 @@ def skeletonize( out_folder, os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc", ) - - store_tomogram(filename=out_file, tomogram=ske) + + store_tomogram(filename=out_file, tomogram=segmentation) print("Skeleton saved to ", out_file) diff --git a/src/membrain_seg/segmentation/finetune.py b/src/membrain_seg/segmentation/finetune.py new file mode 100644 index 0000000..abff87d --- /dev/null +++ b/src/membrain_seg/segmentation/finetune.py @@ -0,0 +1,174 @@ +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint + +from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import ( + MemBrainSegDataModule, +) +from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.optim_utils import ( + ToleranceCallback, +) +from membrain_seg.segmentation.training.training_param_summary import ( + print_training_parameters, +) + + +def fine_tune( + pretrained_checkpoint_path: str, + finetune_data_dir: str, + finetune_learning_rate: float = 1e-5, + log_dir: str = "logs_finetune/", + batch_size: int = 2, + num_workers: int = 8, + max_epochs: int = 100, + early_stop_threshold: float = 0.05, + aug_prob_to_one: bool = False, + use_deep_supervision: bool = False, + project_name: str = "membrain-seg_finetune", + sub_name: str = "1", + use_surf_dice: bool = False, + surf_dice_weight: float = 1.0, + surf_dice_tokens: list = None, +) -> None: + """ + Fine-tune a pre-trained U-Net model on new datasets. + + This function finetunes a pre-trained U-Net model on new data provided by the user. + The `finetune_data_dir` should contain the following directories: + - `imagesTr` and `labelsTr` for the user's own new training data. + - `imagesVal` and `labelsVal` for the old data, which will be used + for validation to ensure that the fine-tuned model's performance + is not significantly worse on the original training data than the + pre-trained model. + + Callbacks used during the fine-tuning process + --------- + - ModelCheckpoint: Saves the model checkpoints based on training loss + and at regular intervals. + - ToleranceCallback: Stops training if the validation loss deviates significantly + from the baseline value set after the first epoch. + - LearningRateMonitor: Monitors and logs the learning rate during training. + - PrintLearningRate: Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + pretrained_checkpoint_path : str + Path to the checkpoint of the pre-trained model. + finetune_data_dir : str + Path to the directory containing the new data for fine-tuning + and old data for validation. + finetune_learning_rate : float, optional + Learning rate for fine-tuning, by default 1e-5. + log_dir : str, optional + Path to the directory where logs should be stored. + batch_size : int, optional + Number of samples per batch of input data. + num_workers : int, optional + Number of subprocesses to use for data loading. + max_epochs : int, optional + Maximum number of epochs to finetune, by default 100. + early_stop_threshold : float, optional + Threshold for early stopping based on validation loss deviation, + by default 0.05. + aug_prob_to_one : bool, optional + If True, all augmentation probabilities are set to 1. + use_deep_supervision : bool, optional + If True, enables deep supervision in the U-Net model. + project_name : str, optional + Name of the project for logging purposes. + sub_name : str, optional + Sub-name of the project for logging purposes. + use_surf_dice : bool, optional + If True, enables Surface-Dice loss. + surf_dice_weight : float, optional + Weight for the Surface-Dice loss. + surf_dice_tokens : list, optional + List of tokens to use for the Surface-Dice loss. + + Returns + ------- + None + """ + # Print training parameters for verification + print_training_parameters( + data_dir=finetune_data_dir, + log_dir=log_dir, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + aug_prob_to_one=aug_prob_to_one, + use_deep_supervision=use_deep_supervision, + project_name=project_name, + sub_name=sub_name, + use_surf_dice=use_surf_dice, + surf_dice_weight=surf_dice_weight, + surf_dice_tokens=surf_dice_tokens, + ) + print("————————————————————————————————————————————————————————") + print( + f"Pretrained Checkpoint:\n" + f" '{pretrained_checkpoint_path}' \n" + f" Path to the pretrained model checkpoint." + ) + print("\n") + + # Initialize the data module with fine-tuning datasets + # New data for finetuning and old data for validation + finetune_data_module = MemBrainSegDataModule( + data_dir=finetune_data_dir, + batch_size=batch_size, + num_workers=num_workers, + aug_prob_to_one=aug_prob_to_one, + ) + + # Load the pre-trained model with updated learning rate + pretrained_model = SemanticSegmentationUnet.load_from_checkpoint( + pretrained_checkpoint_path, learning_rate=finetune_learning_rate + ) + + checkpointing_name = project_name + "_" + sub_name + + # Set up logging + csv_logger = pl_loggers.CSVLogger(log_dir) + + # Set up model checkpointing based on training loss + checkpoint_callback_train_loss = ModelCheckpoint( + dirpath="finetuned_checkpoints/", + filename=checkpointing_name + "-{epoch:02d}-{train_loss:.2f}", + monitor="train_loss", + mode="min", + save_top_k=3, + ) + + # Set up regular checkpointing every 5 epochs + checkpoint_callback_regular = ModelCheckpoint( + save_top_k=-1, # Save all checkpoints + every_n_epochs=5, + dirpath="finetuned_checkpoints/", + filename=checkpointing_name + "-{epoch}-{train_loss:.2f}", + verbose=True, # Print a message when a checkpoint is saved + ) + + # Set up ToleranceCallback by monitoring validation loss + early_stop_metric = "val_loss" + tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold) + + # Monitor learning rate changes + lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) + + # Initialize the trainer with specified precision, logger, and callbacks + trainer = pl.Trainer( + precision="16-mixed", + logger=[csv_logger], + callbacks=[ + checkpoint_callback_train_loss, + checkpoint_callback_regular, + lr_monitor, + tolerance_callback, + ], + max_epochs=max_epochs, + ) + + # Start the fine-tuning process + trainer.fit(pretrained_model, finetune_data_module) diff --git a/src/membrain_seg/segmentation/skeletonize.py b/src/membrain_seg/segmentation/skeletonize.py index 82b9a1e..f4048d5 100644 --- a/src/membrain_seg/segmentation/skeletonize.py +++ b/src/membrain_seg/segmentation/skeletonize.py @@ -12,7 +12,7 @@ import scipy.ndimage as ndimage import torch -from membrain_seg.segmentation.dataloading.data_utils import load_tomogram + from membrain_seg.segmentation.skeletonization.diff3d import ( compute_gradients, compute_hessian, @@ -24,18 +24,21 @@ from membrain_seg.segmentation.training.surface_dice import apply_gaussian_filter -def skeletonization(label_path: str, batch_size: int) -> np.ndarray: + +def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: """ Perform skeletonization on a tomogram segmentation. - This function reads a segmentation file (label_path). It performs skeletonization on - the segmentation where the non-zero labels represent the structures of interest. - The resultan skeleton is saved with '_skel' appended after the filename. + This function skeletonizes the input segmentation where the non-zero labels + represent the structures of interest. The resultan skeleton is saved with + '_skel' appended after the filename. Parameters ---------- - label_path : str - Path to the input tomogram segmentation file. + segmentation : ndarray + Tomogram segmentation as a numpy array, where non-zero values represent + the structures of interest. + batch_size : int The number of elements to process in one batch during eigen decomposition. Useful for managing memory usage. @@ -58,9 +61,6 @@ def skeletonization(label_path: str, batch_size: int) -> np.ndarray: --batch-size 1000000 This command runs the skeletonization process from the command line. """ - # Read original segmentation - segmentation = load_tomogram(label_path) - segmentation = segmentation.data # Convert non-zero segmentation values to 1.0 labels = (segmentation > 0) * 1.0 diff --git a/src/membrain_seg/segmentation/train.py b/src/membrain_seg/segmentation/train.py index 9c576f5..2e7253f 100644 --- a/src/membrain_seg/segmentation/train.py +++ b/src/membrain_seg/segmentation/train.py @@ -1,7 +1,6 @@ import warnings import pytorch_lightning as pl -from pytorch_lightning import Callback from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -9,6 +8,7 @@ MemBrainSegDataModule, ) from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet +from membrain_seg.segmentation.training.optim_utils import PrintLearningRate from membrain_seg.segmentation.training.training_param_summary import ( print_training_parameters, ) @@ -124,12 +124,8 @@ def train( lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True) - class PrintLearningRate(Callback): - def on_epoch_start(self, trainer, pl_module): - current_lr = trainer.optimizers[0].param_groups[0]["lr"] - print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") - print_lr_cb = PrintLearningRate() + # Set up the trainer trainer = pl.Trainer( precision="16-mixed", diff --git a/src/membrain_seg/segmentation/training/optim_utils.py b/src/membrain_seg/segmentation/training/optim_utils.py index a16e136..60cbbf7 100644 --- a/src/membrain_seg/segmentation/training/optim_utils.py +++ b/src/membrain_seg/segmentation/training/optim_utils.py @@ -1,7 +1,11 @@ +from typing import Optional + +import pytorch_lightning as pl import torch from monai.losses import DiceLoss, MaskedLoss from monai.networks.nets import DynUNet from monai.utils import LossReduction +from pytorch_lightning import Callback from torch.nn.functional import ( binary_cross_entropy_with_logits, sigmoid, @@ -115,6 +119,8 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: combined_loss = combined_loss.mean() elif self.reduction == "sum": combined_loss = combined_loss.sum() + elif self.reduction == "none": + return combined_loss else: raise ValueError( f"Invalid reduction type {self.reduction}. " @@ -266,3 +272,112 @@ def forward( # Normalize loss loss = loss / sum(self.weights) return loss + + +class PrintLearningRate(Callback): + """ + Callback to print the current learning rate at the start of each epoch. + + Methods + ------- + on_epoch_start(trainer, pl_module) + Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + + """ + + def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + """ + Prints the current learning rate at the start of each epoch. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + """ + current_lr = trainer.optimizers[0].param_groups[0]["lr"] + print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}") + + +class ToleranceCallback(Callback): + """ + Callback to stop training if the monitored metric deviates + beyond a certain threshold from the baseline value obtained + after the first epoch. + + Parameters + ---------- + metric_name : str + The name of the metric to monitor. + threshold : float + The threshold value for deviation from the baseline. + + Methods + ------- + on_validation_epoch_end(trainer, pl_module) + Checks if the monitored metric deviates beyond the threshold + and stops training if it does. + """ + + def __init__(self, metric_name: str, threshold: float): + """ + Initializes the ToleranceCallback with the metric + to monitor and the deviation threshold. + + Parameters + ---------- + metric_name : str + The name of the metric to monitor. + threshold : float + The threshold value for deviation from the baseline. + """ + super().__init__() + self.metric_name = metric_name + self.threshold = threshold + self.baseline_value: Optional[float] = ( + None # Baseline value will be set after the first epoch + ) + + def on_validation_epoch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule + ): + """ + Checks if the monitored metric deviates beyond the threshold + and stops training if it does. + + Parameters + ---------- + trainer : pl.Trainer + The trainer object that manages the training process. + pl_module : pl.LightningModule + The model being trained. + """ + # Access the metric value from the validation metrics + metric_value = trainer.callback_metrics.get(self.metric_name) + + # If the metric value is a tensor, convert it to a float + if isinstance(metric_value, torch.Tensor): + metric_value = metric_value.item() + + # Set the baseline value after the first validation epoch + if metric_value is not None: + if self.baseline_value is None: + self.baseline_value = metric_value + print(f"Baseline {self.metric_name} set to {self.baseline_value}") + return [] + + # Check if the metric value deviates beyond the threshold + if metric_value - self.baseline_value > self.threshold: + print( + f"Stopping training as {self.metric_name} " + f"deviates too far from the baseline value." + ) + trainer.should_stop = True