Skip to content

Commit

Permalink
Modify finetune function
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanyi Zhang committed Jun 3, 2024
1 parent 9d5fd59 commit 85261ec
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 59 deletions.
85 changes: 84 additions & 1 deletion src/membrain_seg/segmentation/cli/fine_tune_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,83 @@ def finetune(
data structure, type "membrain data_structure_help"',
**PKWARGS,
),
):
"""
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets.
This function fine-tunes a pre-trained model on new datasets provided by the user.
The directory specified by `finetune_data_dir` should be structured according to the
requirements for the training function.
For more details, use "membrain data_structure_help".
Parameters
----------
pretrained_checkpoint_path : str
Path to the checkpoint file of the pre-trained model.
finetune_data_dir : str
Directory containing the new dataset for fine-tuning,
structured as per the MemBrain's requirement.
Use "membrain data_structure_help" for detailed information
on the required data structure.
Note
----
This command configures and executes a fine-tuning session
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
finetune_learning_rate = 1e-5
log_dir = "logs_finetune/"
batch_size = 2
num_workers = 8
max_epochs = 100
early_stop_threshold = 0.05
aug_prob_to_one = True
use_deep_supervision = True
project_name = "membrain-seg_finetune"
sub_name = "1"

_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
finetune_learning_rate=finetune_learning_rate,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
early_stop_threshold=early_stop_threshold,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
project_name=project_name,
sub_name=sub_name,
)


@cli.command(name="finetune_advanced", no_args_is_help=True)
def finetune_advanced(
pretrained_checkpoint_path: str = Option( # noqa: B008
...,
help="Path to the checkpoint of the pre-trained model.",
**PKWARGS,
),
finetune_data_dir: str = Option( # noqa: B008
...,
help='Path to the directory containing the new data for fine-tuning. \
Following the same required structure as the train function. \
To learn more about the required\
data structure, type "membrain data_structure_help"',
**PKWARGS,
),
finetune_learning_rate: float = Option( # noqa: B008
1e-5,
help="Learning rate for fine-tuning the model. This parameter controls the \
step size at each iteration while moving toward a minimum loss. \
A smaller learning rate can lead to a more precise convergence but may \
require more epochs. Adjust based on your dataset size and complexity.",
),
log_dir: str = Option( # noqa: B008
"logs_fine_tune/",
help="Log directory path. Finetuning logs will be stored here.",
Expand Down Expand Up @@ -81,7 +158,7 @@ def finetune(
):
"""
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets.
and validation on original datasets with more advanced options.
This function finetunes a pre-trained U-Net model on new data provided by the user.
The `finetune_data_dir` should contain the following directories:
Expand All @@ -100,6 +177,11 @@ def finetune(
structured as per the MemBrain's requirement.
Use "membrain data_structure_help" for detailed information
on the required data structure.
finetune_learning_rate : float
Learning rate for fine-tuning the model. This parameter controls the step size
at each iteration while moving toward a minimum loss. A smaller learning rate
can lead to a more precise convergence but may require more epochs.
Adjust based on your dataset size and complexity.
log_dir : str
Path to the directory where logs will be stored, by default 'logs_fine_tune/'.
batch_size : int
Expand Down Expand Up @@ -138,6 +220,7 @@ def finetune(
_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
finetune_learning_rate=finetune_learning_rate,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
Expand Down
58 changes: 6 additions & 52 deletions src/membrain_seg/segmentation/finetune.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import (
MemBrainSegDataModule,
)
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.training.optim_utils import (
PrintLearningRate,
ToleranceCallback,
)
from membrain_seg.segmentation.training.training_param_summary import (
print_training_parameters,
)
Expand Down Expand Up @@ -151,60 +151,14 @@ def fine_tune(
verbose=True, # Print a message when a checkpoint is saved
)

class ToleranceCallback(Callback):
"""
Callback to stop training if the monitored metric deviates
beyond a certain threshold from the baseline value obtained
after the first epoch.
"""

def __init__(self, metric_name: str, threshold: float):
super().__init__()
self.metric_name = metric_name
self.threshold = threshold
self.baseline_value: Optional[float] = (
None # Baseline value will be set after the first epoch
)

def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
# Access the metric value from the validation metrics
metric_value = trainer.callback_metrics.get(self.metric_name)

# If the metric value is a tensor, convert it to a float
if isinstance(metric_value, torch.Tensor):
metric_value = metric_value.item()

# Set the baseline value after the first validation epoch
if metric_value is not None:
if self.baseline_value is None:
self.baseline_value = metric_value
print(f"Baseline {self.metric_name} set to {self.baseline_value}")
return []

# Check if the metric value deviates beyond the threshold
if abs(metric_value - self.baseline_value) > self.threshold:
print(
f"Stopping training as {self.metric_name} "
f"deviates too far from the baseline value."
)
trainer.should_stop = True

# Set up ToleranceCallback by monitoring validation loss
early_stop_metric = "val_loss"

tolerance_callback = ToleranceCallback(early_stop_metric, early_stop_threshold)

# Monitor learning rate changes
lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True)

class PrintLearningRate(Callback):
"""Callback to print the current learning rate at the start of each epoch."""

def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}")

# Print the current learning rate at the start of each epoch
print_lr_cb = PrintLearningRate()

# Initialize the trainer with specified precision, logger, and callbacks
Expand Down
8 changes: 2 additions & 6 deletions src/membrain_seg/segmentation/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import warnings

import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from membrain_seg.segmentation.dataloading.memseg_pl_datamodule import (
MemBrainSegDataModule,
)
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.training.optim_utils import PrintLearningRate
from membrain_seg.segmentation.training.training_param_summary import (
print_training_parameters,
)
Expand Down Expand Up @@ -124,12 +124,8 @@ def train(

lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True)

class PrintLearningRate(Callback):
def on_epoch_start(self, trainer, pl_module):
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}")

print_lr_cb = PrintLearningRate()

# Set up the trainer
trainer = pl.Trainer(
precision="16-mixed",
Expand Down
113 changes: 113 additions & 0 deletions src/membrain_seg/segmentation/training/optim_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Optional

import pytorch_lightning as pl
import torch
from monai.losses import DiceLoss, MaskedLoss
from monai.networks.nets import DynUNet
from monai.utils import LossReduction
from pytorch_lightning import Callback
from torch.nn.functional import (
binary_cross_entropy_with_logits,
sigmoid,
Expand Down Expand Up @@ -268,3 +272,112 @@ def forward(
# Normalize loss
loss = loss / sum(self.weights)
return loss


class PrintLearningRate(Callback):
"""
Callback to print the current learning rate at the start of each epoch.
Methods
-------
on_epoch_start(trainer, pl_module)
Prints the current learning rate at the start of each epoch.
Parameters
----------
trainer : pl.Trainer
The trainer object that manages the training process.
pl_module : pl.LightningModule
The model being trained.
"""

def on_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
"""
Prints the current learning rate at the start of each epoch.
Parameters
----------
trainer : pl.Trainer
The trainer object that manages the training process.
pl_module : pl.LightningModule
The model being trained.
"""
current_lr = trainer.optimizers[0].param_groups[0]["lr"]
print(f"Epoch {trainer.current_epoch}: Learning Rate = {current_lr}")


class ToleranceCallback(Callback):
"""
Callback to stop training if the monitored metric deviates
beyond a certain threshold from the baseline value obtained
after the first epoch.
Parameters
----------
metric_name : str
The name of the metric to monitor.
threshold : float
The threshold value for deviation from the baseline.
Methods
-------
on_validation_epoch_end(trainer, pl_module)
Checks if the monitored metric deviates beyond the threshold
and stops training if it does.
"""

def __init__(self, metric_name: str, threshold: float):
"""
Initializes the ToleranceCallback with the metric
to monitor and the deviation threshold.
Parameters
----------
metric_name : str
The name of the metric to monitor.
threshold : float
The threshold value for deviation from the baseline.
"""
super().__init__()
self.metric_name = metric_name
self.threshold = threshold
self.baseline_value: Optional[float] = (
None # Baseline value will be set after the first epoch
)

def on_validation_epoch_end(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
"""
Checks if the monitored metric deviates beyond the threshold
and stops training if it does.
Parameters
----------
trainer : pl.Trainer
The trainer object that manages the training process.
pl_module : pl.LightningModule
The model being trained.
"""
# Access the metric value from the validation metrics
metric_value = trainer.callback_metrics.get(self.metric_name)

# If the metric value is a tensor, convert it to a float
if isinstance(metric_value, torch.Tensor):
metric_value = metric_value.item()

# Set the baseline value after the first validation epoch
if metric_value is not None:
if self.baseline_value is None:
self.baseline_value = metric_value
print(f"Baseline {self.metric_name} set to {self.baseline_value}")
return []

# Check if the metric value deviates beyond the threshold
if abs(metric_value - self.baseline_value) > self.threshold:
print(
f"Stopping training as {self.metric_name} "
f"deviates too far from the baseline value."
)
trainer.should_stop = True

0 comments on commit 85261ec

Please sign in to comment.