Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify finetune function #69

Merged
merged 14 commits into from
Sep 17, 2024
Merged
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.1.0
rev: 24.4.0
hooks:
- id: black

Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""membrane segmentation in 3D for cryo-ET."""

from importlib.metadata import PackageNotFoundError, version

try:
Expand Down
1 change: 1 addition & 0 deletions src/membrain_seg/annotations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""empty init."""

from .cli import cli # noqa: F401
from .extract_patch_cli import extract_patches # noqa: F401
from .merge_corrections_cli import merge_corrections # noqa: F401
3 changes: 3 additions & 0 deletions src/membrain_seg/segmentation/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""CLI init function."""

# These imports are necessary to register CLI commands. Do not remove!
from .cli import cli # noqa: F401
from .fine_tune_cli import finetune # noqa: F401
from .segment_cli import segment # noqa: F401
from .ske_cli import skeletonize # noqa: F401
from .train_cli import data_dir_help, train # noqa: F401
236 changes: 236 additions & 0 deletions src/membrain_seg/segmentation/cli/fine_tune_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from typing import List, Optional

from typer import Option
from typing_extensions import Annotated

from ..finetune import fine_tune as _fine_tune
from .cli import OPTION_PROMPT_KWARGS as PKWARGS
from .cli import cli


@cli.command(name="finetune", no_args_is_help=True)
def finetune(
pretrained_checkpoint_path: str = Option( # noqa: B008
...,
help="Path to the checkpoint of the pre-trained model.",
**PKWARGS,
),
finetune_data_dir: str = Option( # noqa: B008
...,
help='Path to the directory containing the new data for fine-tuning. \
Following the same required structure as the train function. \
To learn more about the required\
data structure, type "membrain data_structure_help"',
**PKWARGS,
),
):
"""
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets.

This function fine-tunes a pre-trained model on new datasets provided by the user.
The directory specified by `finetune_data_dir` should be structured according to the
requirements for the training function.
For more details, use "membrain data_structure_help".

Parameters
----------
pretrained_checkpoint_path : str
Path to the checkpoint file of the pre-trained model.
finetune_data_dir : str
Directory containing the new dataset for fine-tuning,
structured as per the MemBrain's requirement.
Use "membrain data_structure_help" for detailed information
on the required data structure.

Note

----

This command configures and executes a fine-tuning session
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
finetune_learning_rate = 1e-5
log_dir = "logs_finetune/"
batch_size = 2
num_workers = 8
max_epochs = 100
early_stop_threshold = 0.05
aug_prob_to_one = True
use_deep_supervision = True
project_name = "membrain-seg_finetune"
sub_name = "1"

_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
finetune_learning_rate=finetune_learning_rate,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
early_stop_threshold=early_stop_threshold,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
project_name=project_name,
sub_name=sub_name,
)


@cli.command(name="finetune_advanced", no_args_is_help=True)
def finetune_advanced(
pretrained_checkpoint_path: str = Option( # noqa: B008
...,
help="Path to the checkpoint of the pre-trained model.",
**PKWARGS,
),
finetune_data_dir: str = Option( # noqa: B008
...,
help='Path to the directory containing the new data for fine-tuning. \
Following the same required structure as the train function. \
To learn more about the required\
data structure, type "membrain data_structure_help"',
**PKWARGS,
),
finetune_learning_rate: float = Option( # noqa: B008
1e-5,
help="Learning rate for fine-tuning the model. This parameter controls the \
step size at each iteration while moving toward a minimum loss. \
A smaller learning rate can lead to a more precise convergence but may \
require more epochs. Adjust based on your dataset size and complexity.",
),
log_dir: str = Option( # noqa: B008
"logs_fine_tune/",
help="Log directory path. Finetuning logs will be stored here.",
),
batch_size: int = Option( # noqa: B008
2,
help="Batch size for training.",
),
num_workers: int = Option( # noqa: B008
8,
help="Number of worker threads for data loading.",
),
max_epochs: int = Option( # noqa: B008
100,
help="Maximum number of epochs for fine-tuning.",
),
early_stop_threshold: float = Option( # noqa: B008
0.05,
help="Threshold for early stopping based on validation loss deviation.",
),
aug_prob_to_one: bool = Option( # noqa: B008
True,
help='Whether to augment with a probability of one. This helps with the \
model\'s generalization,\
but also severely increases training time.\
Pass "True" or "False".',
),
use_surface_dice: bool = Option( # noqa: B008
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".'
),
surface_dice_weight: float = Option( # noqa: B008
1.0, help="Scaling factor for the Surface-Dice loss. "
),
surface_dice_tokens: Annotated[
Optional[List[str]],
Option(
help='List of tokens to \
use for the Surface-Dice loss. \
Pass tokens separately:\
For example, train_advanced --surface_dice_tokens "ds1" \
--surface_dice_tokens "ds2"'
),
] = None,
use_deep_supervision: bool = Option( # noqa: B008
True, help='Whether to use deep supervision. Pass "True" or "False".'
),
project_name: str = Option( # noqa: B008
"membrain-seg_v0_finetune",
help="Project name. This helps to find your model again.",
),
sub_name: str = Option( # noqa: B008
"1",
help="Subproject name. For multiple runs in the same project,\
please specify sub_names.",
),
):
"""
Initiates fine-tuning of a pre-trained model on new datasets
and validation on original datasets with more advanced options.

This function finetunes a pre-trained U-Net model on new data provided by the user.
The `finetune_data_dir` should contain the following directories:
- `imagesTr` and `labelsTr` for the user's own new training data.
- `imagesVal` and `labelsVal` for the old data, which will be used
for validation to ensure that the fine-tuned model's performance
is not significantly worse on the original training data than the
pre-trained model.

Parameters
----------
pretrained_checkpoint_path : str
Path to the checkpoint file of the pre-trained model.
finetune_data_dir : str
Directory containing the new dataset for fine-tuning,
structured as per the MemBrain's requirement.
Use "membrain data_structure_help" for detailed information
on the required data structure.
finetune_learning_rate : float
Learning rate for fine-tuning the model. This parameter controls the step size
at each iteration while moving toward a minimum loss. A smaller learning rate
can lead to a more precise convergence but may require more epochs.
Adjust based on your dataset size and complexity.
log_dir : str
Path to the directory where logs will be stored, by default 'logs_fine_tune/'.
batch_size : int
Number of samples per batch, by default 2.
num_workers : int
Number of worker threads for data loading, by default 8.
max_epochs : int
Maximum number of fine-tuning epochs, by default 100.
early_stop_threshold : float
Threshold for early stopping based on validation loss deviation,
by default 0.05.
aug_prob_to_one : bool
Determines whether to apply very strong data augmentation, by default True.
If set to False, data augmentation still happens, but not as frequently.
More data augmentation can lead to better performance, but also increases the
training time substantially.
use_surface_dice : bool
Determines whether to use Surface-Dice loss, by default False.
surface_dice_weight : float
Scaling factor for the Surface-Dice loss, by default 1.0.
surface_dice_tokens : list
List of tokens to use for the Surface-Dice loss.
use_deep_supervision : bool
Determines whether to use deep supervision, by default True.
project_name : str
Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'.
sub_name : str
Sub-name for the project, by default '1'.

Note
----
This command configures and executes a fine-tuning session
using the provided model checkpoint.
The actual fine-tuning logic resides in the function '_fine_tune'.
"""
_fine_tune(
pretrained_checkpoint_path=pretrained_checkpoint_path,
finetune_data_dir=finetune_data_dir,
finetune_learning_rate=finetune_learning_rate,
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
max_epochs=max_epochs,
early_stop_threshold=early_stop_threshold,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
use_surf_dice=use_surface_dice,
surf_dice_weight=surface_dice_weight,
surf_dice_tokens=surface_dice_tokens,
project_name=project_name,
sub_name=sub_name,
)
64 changes: 64 additions & 0 deletions src/membrain_seg/segmentation/cli/ske_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

from typer import Option

from membrain_seg.segmentation.dataloading.data_utils import store_tomogram

from ..skeletonize import skeletonization as _skeletonization
from .cli import cli


@cli.command(name="skeletonize", no_args_is_help=True)
def skeletonize(
label_path: str = Option(..., help="Specifies the path for skeletonization."),
out_folder: str = Option(
"./predictions", help="Directory to save the resulting skeletons."
),
batch_size: int = Option(
None,
help="Optional batch size for processing the tomogram. If not specified, "
"the entire volume is processed at once. If operating with limited GPU "
"resources, a batch size of 1,000,000 is recommended.",
),
):
"""
Perform skeletonization on labeled tomograms using nonmax-suppression technique.

This function reads a labeled tomogram, applies skeletonization using a specified
batch size, and stores the results in an MRC file in the specified output directory.
If batch_size is set to None, the entire tomogram is processed at once, which might
require significant memory. It is recommended to specify a batch size if memory
constraints are a concern. The maximum possible batch size is the product of the
tomogram's dimensions (Nx * Ny * Nz).


Parameters
----------
label_path : str
File path to the tomogram to be skeletonized.
out_folder : str
Output folder path for the skeletonized tomogram.
batch_size : int, optional
The size of the batch to process the tomogram. Defaults to None, which processes
the entire volume at once. For large volumes, consider setting it to a specific
value like 1,000,000 for efficient processing without exceeding memory limits.


Examples
--------
membrain skeletonize --label-path <path> --out-folder <output-directory>
--batch-size <batch-size>
"""
# Assuming _skeletonization function is already defined and can handle batch_size
ske = _skeletonization(label_path=label_path, batch_size=batch_size)

if not os.path.exists(out_folder):
os.makedirs(out_folder)

out_file = os.path.join(
out_folder,
os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc",
)

store_tomogram(filename=out_file, tomogram=ske)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As raised in #82, this will lose the information of the header.
Can we load the tomogram before calling skeletonization and pass it instead of label path?

load_tomogram gives you a Tomogram instance (

). If you simply alter tomo.data, you can pass the Tomogram into store_tomogram instead of ske

print("Skeleton saved to ", out_file)
16 changes: 10 additions & 6 deletions src/membrain_seg/segmentation/dataloading/memseg_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ def get_training_transforms(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
max_strength=lambda x, y: np.random.uniform(-5, -1)
if np.random.uniform() < 0.5
else np.random.uniform(1, 5),
max_strength=lambda x, y: (
np.random.uniform(-5, -1)
if np.random.uniform() < 0.5
else np.random.uniform(1, 5)
),
mean_centered=False,
),
prob=(1.0 if prob_to_one else 0.3),
Expand All @@ -268,9 +270,11 @@ def get_training_transforms(
np.random.uniform(np.log(x[y] // 6), np.log(x[y]))
),
loc=(-0.5, 1.5),
gamma=lambda: np.random.uniform(0.01, 0.8)
if np.random.uniform() < 0.5
else np.random.uniform(1.5, 4),
gamma=lambda: (
np.random.uniform(0.01, 0.8)
if np.random.uniform() < 0.5
else np.random.uniform(1.5, 4)
),
),
prob=(1.0 if prob_to_one else 0.3),
),
Expand Down
6 changes: 3 additions & 3 deletions src/membrain_seg/segmentation/dataloading/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ def __call__(self, data):
y = self.R.randint(0, y_max - height)
x = self.R.randint(0, x_max - width)
if self.replace_with == "mean":
image[
..., z : z + depth, y : y + height, x : x + width
] = torch.mean(torch.Tensor(image))
image[..., z : z + depth, y : y + height, x : x + width] = (
torch.mean(torch.Tensor(image))
)
elif self.replace_with == 0.0:
image[..., z : z + depth, y : y + height, x : x + width] = 0.0
d[key] = image
Expand Down
Loading