-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* add skeletonization code * Second commit * Second commit * Second commit * Second commit * Third commit * Third commit * Fourth commit * Fourth commit * Fix data type warning and absolute value error * Add finetune function * Modify finetune function * Add Fine-tuning.md file --------- Co-authored-by: Hanyi11 <[email protected]> Co-authored-by: Hanyi Zhang <[email protected]> Co-authored-by: Hanyi Zhang <[email protected]> Co-authored-by: Hanyi Zhang <[email protected]>
- Loading branch information
1 parent
0a87963
commit 800f6ce
Showing
8 changed files
with
617 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Fine-tuning | ||
MemBrain-seg is built to function optimally out-of-the-box, eliminating the need for most users to train the model themselves. | ||
|
||
However, if your tomograms differ significantly from the images used in our training dataset, fine-tuning the model on your own data may enhance performance. In this case, it can make sense to [annotate](./Annotations.md) several patches extracted from your tomogram, and fine-tune the pretrained MemBrain-seg model using your corrected data. | ||
|
||
Here are some steps you can follow in order to fine-tune MemBrain-seg: | ||
|
||
# Step 1: Prepare your fine-tuning dataset | ||
MemBrain-seg assumes a specific data structure for creating the fine-tuning dataloaders, which can be a smaller or corrected version of your tomograms: | ||
|
||
```bash | ||
data_dir/ | ||
├── imagesTr/ # Directory containing training images | ||
│ ├── img1.nii.gz # Image file (currently requires nii.gz format) | ||
│ ├── img2.nii.gz # Image file | ||
│ └── ... | ||
├── imagesVal/ # Directory containing validation images | ||
│ ├── img3.nii.gz # Image file | ||
│ ├── img4.nii.gz # Image file | ||
│ └── ... | ||
├── labelsTr/ # Directory containing training labels | ||
│ ├── img1.nii.gz # Label file (currently requires nii.gz format) | ||
│ ├── img2.nii.gz # Label file | ||
│ └── ... | ||
└── labelsVal/ # Directory containing validation labels | ||
├── img3.nii.gz # Label file | ||
├── img4.nii.gz # Label file | ||
└── ... | ||
``` | ||
|
||
The data_dir argument is then passed to the fine-tuning procedure (see [Step 2](#step-2-perform-fine-tuning)). | ||
|
||
To fine-tune the pretrained model on your own tomograms, you need to add some corrected patches from your own tomograms to improve the network's performance on these. | ||
|
||
You can find some instructions here: [How to create training annotations from your own tomogram?](./Annotations.md) | ||
|
||
# Step 2: Perform fine-tuning | ||
Fine-tuning starts from a pretrained model checkpoint. After activating your virtual Python environment, you can type: | ||
``` | ||
membrain finetune | ||
``` | ||
to receive help with the input arguments. You will see that the two parameters you need to provide are the --pretrained-checkpoint-path and the --data-dir argument: | ||
|
||
``` | ||
membrain finetune --pretrained-checkpoint-path <path-to-the-pretrained-checkpoint> --finetune-data-dir <path-to-your-finetuning-data> | ||
``` | ||
This command fine-tunes the pretrained MemBrain-seg model using your fine-tuning dataset. Be sure to point to the correct checkpoint path containing the pretrained weights, as well as the fine-tuning data directory. | ||
|
||
This is exactly the folder you prepared in [Step 1](#step-1-prepare-your-fine-tuning-dataset). | ||
|
||
Running this command should start the fine-tuning process and store the fine-tuned model in the ./finetuned_checkpoints folder. | ||
|
||
**Note:** Fine-tuning can take up to 24 hours. We therefore recommend that you perform training on a device with a CUDA-enabled GPU. | ||
|
||
|
||
# Advanced settings | ||
In case you feel fancy and would like to adjust some of the default settings of MemBrain-seg, you can also use the following command to get access to more customizable options: | ||
``` | ||
membrain finetune_advanced | ||
```` | ||
This will display all available options that can be activated or deactivated. For example, when fine-tuning, you might want to lower the learning rate compared to training from scratch to prevent the model from "forgetting" the knowledge it learned during pretraining. For more in-depth adjustments, you will need to dig into MemBrain-seg's code or contact us. | ||
|
||
|
||
# Contact | ||
If there are any problems coming up when running the code or anything else is unclear, do not hesitate to contact us ([email protected]). We are more than happy to help. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
from typing import List, Optional | ||
|
||
from typer import Option | ||
from typing_extensions import Annotated | ||
|
||
from ..finetune import fine_tune as _fine_tune | ||
from .cli import OPTION_PROMPT_KWARGS as PKWARGS | ||
from .cli import cli | ||
|
||
|
||
@cli.command(name="finetune", no_args_is_help=True) | ||
def finetune( | ||
pretrained_checkpoint_path: str = Option( # noqa: B008 | ||
..., | ||
help="Path to the checkpoint of the pre-trained model.", | ||
**PKWARGS, | ||
), | ||
finetune_data_dir: str = Option( # noqa: B008 | ||
..., | ||
help='Path to the directory containing the new data for fine-tuning. \ | ||
Following the same required structure as the train function. \ | ||
To learn more about the required\ | ||
data structure, type "membrain data_structure_help"', | ||
**PKWARGS, | ||
), | ||
): | ||
""" | ||
Initiates fine-tuning of a pre-trained model on new datasets | ||
and validation on original datasets. | ||
This function fine-tunes a pre-trained model on new datasets provided by the user. | ||
The directory specified by `finetune_data_dir` should be structured according to the | ||
requirements for the training function. | ||
For more details, use "membrain data_structure_help". | ||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint file of the pre-trained model. | ||
finetune_data_dir : str | ||
Directory containing the new dataset for fine-tuning, | ||
structured as per the MemBrain's requirement. | ||
Use "membrain data_structure_help" for detailed information | ||
on the required data structure. | ||
Note | ||
---- | ||
This command configures and executes a fine-tuning session | ||
using the provided model checkpoint. | ||
The actual fine-tuning logic resides in the function '_fine_tune'. | ||
""" | ||
finetune_learning_rate = 1e-5 | ||
log_dir = "logs_finetune/" | ||
batch_size = 2 | ||
num_workers = 8 | ||
max_epochs = 100 | ||
early_stop_threshold = 0.05 | ||
aug_prob_to_one = True | ||
use_deep_supervision = True | ||
project_name = "membrain-seg_finetune" | ||
sub_name = "1" | ||
|
||
_fine_tune( | ||
pretrained_checkpoint_path=pretrained_checkpoint_path, | ||
finetune_data_dir=finetune_data_dir, | ||
finetune_learning_rate=finetune_learning_rate, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
early_stop_threshold=early_stop_threshold, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) | ||
|
||
|
||
@cli.command(name="finetune_advanced", no_args_is_help=True) | ||
def finetune_advanced( | ||
pretrained_checkpoint_path: str = Option( # noqa: B008 | ||
..., | ||
help="Path to the checkpoint of the pre-trained model.", | ||
**PKWARGS, | ||
), | ||
finetune_data_dir: str = Option( # noqa: B008 | ||
..., | ||
help='Path to the directory containing the new data for fine-tuning. \ | ||
Following the same required structure as the train function. \ | ||
To learn more about the required\ | ||
data structure, type "membrain data_structure_help"', | ||
**PKWARGS, | ||
), | ||
finetune_learning_rate: float = Option( # noqa: B008 | ||
1e-5, | ||
help="Learning rate for fine-tuning the model. This parameter controls the \ | ||
step size at each iteration while moving toward a minimum loss. \ | ||
A smaller learning rate can lead to a more precise convergence but may \ | ||
require more epochs. Adjust based on your dataset size and complexity.", | ||
), | ||
log_dir: str = Option( # noqa: B008 | ||
"logs_fine_tune/", | ||
help="Log directory path. Finetuning logs will be stored here.", | ||
), | ||
batch_size: int = Option( # noqa: B008 | ||
2, | ||
help="Batch size for training.", | ||
), | ||
num_workers: int = Option( # noqa: B008 | ||
8, | ||
help="Number of worker threads for data loading.", | ||
), | ||
max_epochs: int = Option( # noqa: B008 | ||
100, | ||
help="Maximum number of epochs for fine-tuning.", | ||
), | ||
early_stop_threshold: float = Option( # noqa: B008 | ||
0.05, | ||
help="Threshold for early stopping based on validation loss deviation.", | ||
), | ||
aug_prob_to_one: bool = Option( # noqa: B008 | ||
True, | ||
help='Whether to augment with a probability of one. This helps with the \ | ||
model\'s generalization,\ | ||
but also severely increases training time.\ | ||
Pass "True" or "False".', | ||
), | ||
use_surface_dice: bool = Option( # noqa: B008 | ||
False, help='Whether to use Surface-Dice as a loss. Pass "True" or "False".' | ||
), | ||
surface_dice_weight: float = Option( # noqa: B008 | ||
1.0, help="Scaling factor for the Surface-Dice loss. " | ||
), | ||
surface_dice_tokens: Annotated[ | ||
Optional[List[str]], | ||
Option( | ||
help='List of tokens to \ | ||
use for the Surface-Dice loss. \ | ||
Pass tokens separately:\ | ||
For example, train_advanced --surface_dice_tokens "ds1" \ | ||
--surface_dice_tokens "ds2"' | ||
), | ||
] = None, | ||
use_deep_supervision: bool = Option( # noqa: B008 | ||
True, help='Whether to use deep supervision. Pass "True" or "False".' | ||
), | ||
project_name: str = Option( # noqa: B008 | ||
"membrain-seg_v0_finetune", | ||
help="Project name. This helps to find your model again.", | ||
), | ||
sub_name: str = Option( # noqa: B008 | ||
"1", | ||
help="Subproject name. For multiple runs in the same project,\ | ||
please specify sub_names.", | ||
), | ||
): | ||
""" | ||
Initiates fine-tuning of a pre-trained model on new datasets | ||
and validation on original datasets with more advanced options. | ||
This function finetunes a pre-trained U-Net model on new data provided by the user. | ||
The `finetune_data_dir` should contain the following directories: | ||
- `imagesTr` and `labelsTr` for the user's own new training data. | ||
- `imagesVal` and `labelsVal` for the old data, which will be used | ||
for validation to ensure that the fine-tuned model's performance | ||
is not significantly worse on the original training data than the | ||
pre-trained model. | ||
Parameters | ||
---------- | ||
pretrained_checkpoint_path : str | ||
Path to the checkpoint file of the pre-trained model. | ||
finetune_data_dir : str | ||
Directory containing the new dataset for fine-tuning, | ||
structured as per the MemBrain's requirement. | ||
Use "membrain data_structure_help" for detailed information | ||
on the required data structure. | ||
finetune_learning_rate : float | ||
Learning rate for fine-tuning the model. This parameter controls the step size | ||
at each iteration while moving toward a minimum loss. A smaller learning rate | ||
can lead to a more precise convergence but may require more epochs. | ||
Adjust based on your dataset size and complexity. | ||
log_dir : str | ||
Path to the directory where logs will be stored, by default 'logs_fine_tune/'. | ||
batch_size : int | ||
Number of samples per batch, by default 2. | ||
num_workers : int | ||
Number of worker threads for data loading, by default 8. | ||
max_epochs : int | ||
Maximum number of fine-tuning epochs, by default 100. | ||
early_stop_threshold : float | ||
Threshold for early stopping based on validation loss deviation, | ||
by default 0.05. | ||
aug_prob_to_one : bool | ||
Determines whether to apply very strong data augmentation, by default True. | ||
If set to False, data augmentation still happens, but not as frequently. | ||
More data augmentation can lead to better performance, but also increases the | ||
training time substantially. | ||
use_surface_dice : bool | ||
Determines whether to use Surface-Dice loss, by default False. | ||
surface_dice_weight : float | ||
Scaling factor for the Surface-Dice loss, by default 1.0. | ||
surface_dice_tokens : list | ||
List of tokens to use for the Surface-Dice loss. | ||
use_deep_supervision : bool | ||
Determines whether to use deep supervision, by default True. | ||
project_name : str | ||
Name of the project for logging purposes, by default 'membrain-seg_v0_finetune'. | ||
sub_name : str | ||
Sub-name for the project, by default '1'. | ||
Note | ||
---- | ||
This command configures and executes a fine-tuning session | ||
using the provided model checkpoint. | ||
The actual fine-tuning logic resides in the function '_fine_tune'. | ||
""" | ||
_fine_tune( | ||
pretrained_checkpoint_path=pretrained_checkpoint_path, | ||
finetune_data_dir=finetune_data_dir, | ||
finetune_learning_rate=finetune_learning_rate, | ||
log_dir=log_dir, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
max_epochs=max_epochs, | ||
early_stop_threshold=early_stop_threshold, | ||
aug_prob_to_one=aug_prob_to_one, | ||
use_deep_supervision=use_deep_supervision, | ||
use_surf_dice=use_surface_dice, | ||
surf_dice_weight=surface_dice_weight, | ||
surf_dice_tokens=surface_dice_tokens, | ||
project_name=project_name, | ||
sub_name=sub_name, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.