-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify finetune function #69
Merged
Merged
Changes from 12 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
71bf422
add skeletonization code
08cf4df
Second commit
9ebbcdf
Second commit
e1c4fe9
Second commit
ea9712d
Second commit
1af6edd
Third commit
0d8b9de
Third commit
e622259
Fourth commit
ab13ceb
Fourth commit
d005be5
Fix data type warning and absolute value error
9d5fd59
Add finetune function
85261ec
Modify finetune function
d8db3a1
Add Fine-tuning.md file
Hanyi11 1a6e349
Merge branch 'finetuning' into finetune
LorenzLamm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""empty init.""" | ||
|
||
from .cli import cli # noqa: F401 | ||
from .extract_patch_cli import extract_patches # noqa: F401 | ||
from .merge_corrections_cli import merge_corrections # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
"""CLI init function.""" | ||
|
||
# These imports are necessary to register CLI commands. Do not remove! | ||
from .cli import cli # noqa: F401 | ||
from .fine_tune_cli import finetune # noqa: F401 | ||
from .segment_cli import segment # noqa: F401 | ||
from .ske_cli import skeletonize # noqa: F401 | ||
from .train_cli import data_dir_help, train # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
from typing import List, Optional | ||
|
||
from typer import Option | ||
from typing_extensions import Annotated | ||
|
||
from ..finetune import fine_tune as _fine_tune | ||
from .cli import OPTION_PROMPT_KWARGS as PKWARGS | ||
from .cli import cli | ||
|
||
|
||
@cli.command(name="finetune", no_args_is_help=True) | ||
def finetune( | ||
pretrained_checkpoint_path: str = Option( # noqa: B008 | ||
..., | ||
help="Path to the checkpoint of the pre-trained model.", | ||
**PKWARGS, | ||
), | ||
finetune_data_dir: str = Option( # noqa: B008 | ||
..., | ||
help='Path to the directory containing the new data for fine-tuning. \ | ||
Following the same required structure as the train function. \ | ||
To learn more about the required\ | ||
data structure, type "membrain data_structure_help"', | ||
**PKWARGS, | ||
), | ||
): | ||
""" | ||
Initiates fine-tuning of a pre-trained model on new datasets | ||
and validation on original datasets. | ||
|
||
This function fine-tunes a pre-trained model on new datasets provided by the user. | ||
The directory specified by `finetune_data_dir` should be structured according to the | ||
requirements for the training function. | ||
For more details, use "membrain data_structure_help". | ||
|
||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint file of the pre-trained model. | ||
finetune_data_dir : str | ||
Directory containing the new dataset for fine-tuning, | ||
structured as per the MemBrain's requirement. | ||
Use "membrain data_structure_help" for detailed information | ||
on the required data structure. | ||
|
||
Note | ||
|
||
---- | ||
|
||
This command configures and executes a fine-tuning session | ||
using the provided model checkpoint. | ||
The actual fine-tuning logic resides in the function '_fine_tune'. | ||
""" | ||
finetune_learning_rate = 1e-5 | ||
log_dir = "logs_finetune/" | ||
batch_size = 2 | ||
num_workers = 8 | ||
max_epochs = 100 | ||
early_stop_threshold = 0.05 | ||
aug_prob_to_one = True | ||
use_deep_supervision = True | ||
project_name = "membrain-seg_finetune" | ||
sub_name = "1" | ||
|
||
_fine_tune( | ||
pretrained_checkpoint_path=pretrained_checkpoint_path, | ||
finetune_data_dir=finetune_data_dir, | ||
finetune_learning_rate=finetune_learning_rate, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
early_stop_threshold=early_stop_threshold, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) | ||
|
||
|
||
@cli.command(name="finetune_advanced", no_args_is_help=True) | ||
def finetune_advanced( | ||
pretrained_checkpoint_path: str = Option( # noqa: B008 | ||
..., | ||
help="Path to the checkpoint of the pre-trained model.", | ||
**PKWARGS, | ||
), | ||
finetune_data_dir: str = Option( # noqa: B008 | ||
..., | ||
help='Path to the directory containing the new data for fine-tuning. \ | ||
Following the same required structure as the train function. \ | ||
To learn more about the required\ | ||
data structure, type "membrain data_structure_help"', | ||
**PKWARGS, | ||
), | ||
finetune_learning_rate: float = Option( # noqa: B008 | ||
1e-5, | ||
help="Learning rate for fine-tuning the model. This parameter controls the \ | ||
step size at each iteration while moving toward a minimum loss. \ | ||
A smaller learning rate can lead to a more precise convergence but may \ | ||
require more epochs. Adjust based on your dataset size and complexity.", | ||
), | ||
log_dir: str = Option( # noqa: B008 | ||
"logs_fine_tune/", | ||
help="Log directory path. Finetuning logs will be stored here.", | ||
), | ||
batch_size: int = Option( # noqa: B008 | ||
2, | ||
help="Batch size for training.", | ||
), | ||
num_workers: int = Option( # noqa: B008 | ||
8, | ||
help="Number of worker threads for data loading.", | ||
), | ||
max_epochs: int = Option( # noqa: B008 | ||
100, | ||
help="Maximum number of epochs for fine-tuning.", | ||
), | ||
early_stop_threshold: float = Option( # noqa: B008 | ||
0.05, | ||
help="Threshold for early stopping based on validation loss deviation.", | ||
), | ||
aug_prob_to_one: bool = Option( # noqa: B008 | ||
True, | ||
help='Whether to augment with a probability of one. This helps with the \ | ||
model\'s generalization,\ | ||
but also severely increases training time.\ | ||
Pass "True" or "False".', | ||
), | ||
use_surface_dice: bool = Option( # noqa: B008 | ||
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' | ||
), | ||
surface_dice_weight: float = Option( # noqa: B008 | ||
1.0, help="Scaling factor for the Surface-Dice loss. " | ||
), | ||
surface_dice_tokens: Annotated[ | ||
Optional[List[str]], | ||
Option( | ||
help='List of tokens to \ | ||
use for the Surface-Dice loss. \ | ||
Pass tokens separately:\ | ||
For example, train_advanced --surface_dice_tokens "ds1" \ | ||
--surface_dice_tokens "ds2"' | ||
), | ||
] = None, | ||
use_deep_supervision: bool = Option( # noqa: B008 | ||
True, help='Whether to use deep supervision. Pass "True" or "False".' | ||
), | ||
project_name: str = Option( # noqa: B008 | ||
"membrain-seg_v0_finetune", | ||
help="Project name. This helps to find your model again.", | ||
), | ||
sub_name: str = Option( # noqa: B008 | ||
"1", | ||
help="Subproject name. For multiple runs in the same project,\ | ||
please specify sub_names.", | ||
), | ||
): | ||
""" | ||
Initiates fine-tuning of a pre-trained model on new datasets | ||
and validation on original datasets with more advanced options. | ||
|
||
This function finetunes a pre-trained U-Net model on new data provided by the user. | ||
The `finetune_data_dir` should contain the following directories: | ||
- `imagesTr` and `labelsTr` for the user's own new training data. | ||
- `imagesVal` and `labelsVal` for the old data, which will be used | ||
for validation to ensure that the fine-tuned model's performance | ||
is not significantly worse on the original training data than the | ||
pre-trained model. | ||
|
||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint file of the pre-trained model. | ||
finetune_data_dir : str | ||
Directory containing the new dataset for fine-tuning, | ||
structured as per the MemBrain's requirement. | ||
Use "membrain data_structure_help" for detailed information | ||
on the required data structure. | ||
finetune_learning_rate : float | ||
Learning rate for fine-tuning the model. This parameter controls the step size | ||
at each iteration while moving toward a minimum loss. A smaller learning rate | ||
can lead to a more precise convergence but may require more epochs. | ||
Adjust based on your dataset size and complexity. | ||
log_dir : str | ||
Path to the directory where logs will be stored, by default 'logs_fine_tune/'. | ||
batch_size : int | ||
Number of samples per batch, by default 2. | ||
num_workers : int | ||
Number of worker threads for data loading, by default 8. | ||
max_epochs : int | ||
Maximum number of fine-tuning epochs, by default 100. | ||
early_stop_threshold : float | ||
Threshold for early stopping based on validation loss deviation, | ||
by default 0.05. | ||
aug_prob_to_one : bool | ||
Determines whether to apply very strong data augmentation, by default True. | ||
If set to False, data augmentation still happens, but not as frequently. | ||
More data augmentation can lead to better performance, but also increases the | ||
training time substantially. | ||
use_surface_dice : bool | ||
Determines whether to use Surface-Dice loss, by default False. | ||
surface_dice_weight : float | ||
Scaling factor for the Surface-Dice loss, by default 1.0. | ||
surface_dice_tokens : list | ||
List of tokens to use for the Surface-Dice loss. | ||
use_deep_supervision : bool | ||
Determines whether to use deep supervision, by default True. | ||
project_name : str | ||
Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'. | ||
sub_name : str | ||
Sub-name for the project, by default '1'. | ||
|
||
Note | ||
---- | ||
This command configures and executes a fine-tuning session | ||
using the provided model checkpoint. | ||
The actual fine-tuning logic resides in the function '_fine_tune'. | ||
""" | ||
_fine_tune( | ||
pretrained_checkpoint_path=pretrained_checkpoint_path, | ||
finetune_data_dir=finetune_data_dir, | ||
finetune_learning_rate=finetune_learning_rate, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
early_stop_threshold=early_stop_threshold, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
use_surf_dice=use_surface_dice, | ||
surf_dice_weight=surface_dice_weight, | ||
surf_dice_tokens=surface_dice_tokens, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
|
||
from typer import Option | ||
|
||
from membrain_seg.segmentation.dataloading.data_utils import store_tomogram | ||
|
||
from ..skeletonize import skeletonization as _skeletonization | ||
from .cli import cli | ||
|
||
|
||
@cli.command(name="skeletonize", no_args_is_help=True) | ||
def skeletonize( | ||
label_path: str = Option(..., help="Specifies the path for skeletonization."), | ||
out_folder: str = Option( | ||
"./predictions", help="Directory to save the resulting skeletons." | ||
), | ||
batch_size: int = Option( | ||
None, | ||
help="Optional batch size for processing the tomogram. If not specified, " | ||
"the entire volume is processed at once. If operating with limited GPU " | ||
"resources, a batch size of 1,000,000 is recommended.", | ||
), | ||
): | ||
""" | ||
Perform skeletonization on labeled tomograms using nonmax-suppression technique. | ||
|
||
This function reads a labeled tomogram, applies skeletonization using a specified | ||
batch size, and stores the results in an MRC file in the specified output directory. | ||
If batch_size is set to None, the entire tomogram is processed at once, which might | ||
require significant memory. It is recommended to specify a batch size if memory | ||
constraints are a concern. The maximum possible batch size is the product of the | ||
tomogram's dimensions (Nx * Ny * Nz). | ||
|
||
|
||
Parameters | ||
---------- | ||
label_path : str | ||
File path to the tomogram to be skeletonized. | ||
out_folder : str | ||
Output folder path for the skeletonized tomogram. | ||
batch_size : int, optional | ||
The size of the batch to process the tomogram. Defaults to None, which processes | ||
the entire volume at once. For large volumes, consider setting it to a specific | ||
value like 1,000,000 for efficient processing without exceeding memory limits. | ||
|
||
|
||
Examples | ||
-------- | ||
membrain skeletonize --label-path <path> --out-folder <output-directory> | ||
--batch-size <batch-size> | ||
""" | ||
# Assuming _skeletonization function is already defined and can handle batch_size | ||
ske = _skeletonization(label_path=label_path, batch_size=batch_size) | ||
|
||
if not os.path.exists(out_folder): | ||
os.makedirs(out_folder) | ||
|
||
out_file = os.path.join( | ||
out_folder, | ||
os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc", | ||
) | ||
|
||
store_tomogram(filename=out_file, tomogram=ske) | ||
print("Skeleton saved to ", out_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As raised in #82, this will lose the information of the header.
Can we load the tomogram before calling skeletonization and pass it instead of label path?
load_tomogram gives you a Tomogram instance (
membrain-seg/src/membrain_seg/segmentation/dataloading/data_utils.py
Line 247 in 0a87963
store_tomogram
instead ofske