Skip to content

Commit

Permalink
Add Fine-tuning.md file
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanyi11 committed Sep 17, 2024
1 parent 85261ec commit d8db3a1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 20 deletions.
65 changes: 65 additions & 0 deletions docs/Usage/Fine-tuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Fine-tuning
MemBrain-seg is built to function optimally out-of-the-box, eliminating the need for most users to train the model themselves.

However, if your tomograms differ significantly from the images used in our training dataset, fine-tuning the model on your own data may enhance performance. In this case, it can make sense to [annotate](./Annotations.md) several patches extracted from your tomogram, and fine-tune the pretrained MemBrain-seg model using your corrected data.

Here are some steps you can follow in order to fine-tune MemBrain-seg:

# Step 1: Prepare your fine-tuning dataset
MemBrain-seg assumes a specific data structure for creating the fine-tuning dataloaders, which can be a smaller or corrected version of your tomograms:

```bash
data_dir/
├── imagesTr/ # Directory containing training images
│ ├── img1.nii.gz # Image file (currently requires nii.gz format)
│ ├── img2.nii.gz # Image file
│ └── ...
├── imagesVal/ # Directory containing validation images
│ ├── img3.nii.gz # Image file
│ ├── img4.nii.gz # Image file
│ └── ...
├── labelsTr/ # Directory containing training labels
│ ├── img1.nii.gz # Label file (currently requires nii.gz format)
│ ├── img2.nii.gz # Label file
│ └── ...
└── labelsVal/ # Directory containing validation labels
├── img3.nii.gz # Label file
├── img4.nii.gz # Label file
└── ...
```

The data_dir argument is then passed to the fine-tuning procedure (see [Step 2](#step-2-perform-fine-tuning)).

To fine-tune the pretrained model on your own tomograms, you need to add some corrected patches from your own tomograms to improve the network's performance on these.

You can find some instructions here: [How to create training annotations from your own tomogram?](./Annotations.md)

# Step 2: Perform fine-tuning
Fine-tuning starts from a pretrained model checkpoint. After activating your virtual Python environment, you can type:
```
membrain finetune
```
to receive help with the input arguments. You will see that the two parameters you need to provide are the --pretrained-checkpoint-path and the --data-dir argument:

```
membrain finetune --pretrained-checkpoint-path <path-to-the-pretrained-checkpoint> --finetune-data-dir <path-to-your-finetuning-data>
```
This command fine-tunes the pretrained MemBrain-seg model using your fine-tuning dataset. Be sure to point to the correct checkpoint path containing the pretrained weights, as well as the fine-tuning data directory.

This is exactly the folder you prepared in [Step 1](#step-1-prepare-your-fine-tuning-dataset).

Running this command should start the fine-tuning process and store the fine-tuned model in the ./finetuned_checkpoints folder.

**Note:** Fine-tuning can take up to 24 hours. We therefore recommend that you perform training on a device with a CUDA-enabled GPU.


# Advanced settings
In case you feel fancy and would like to adjust some of the default settings of MemBrain-seg, you can also use the following command to get access to more customizable options:
```
membrain finetune_advanced
````
This will display all available options that can be activated or deactivated. For example, when fine-tuning, you might want to lower the learning rate compared to training from scratch to prevent the model from "forgetting" the knowledge it learned during pretraining. For more in-depth adjustments, you will need to dig into MemBrain-seg's code or contact us.


# Contact
If there are any problems coming up when running the code or anything else is unclear, do not hesitate to contact us ([email protected]). We are more than happy to help.
13 changes: 10 additions & 3 deletions src/membrain_seg/segmentation/cli/ske_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from typer import Option

from membrain_seg.segmentation.dataloading.data_utils import store_tomogram
from membrain_seg.segmentation.dataloading.data_utils import (
load_tomogram,
store_tomogram,
)

from ..skeletonize import skeletonization as _skeletonization
from .cli import cli
Expand Down Expand Up @@ -50,7 +53,11 @@ def skeletonize(
--batch-size <batch-size>
"""
# Assuming _skeletonization function is already defined and can handle batch_size
ske = _skeletonization(label_path=label_path, batch_size=batch_size)
segmentation = load_tomogram(label_path)
ske = _skeletonization(segmentation=segmentation.data, batch_size=batch_size)

# Update the segmentation data with the skeletonized output while preserving the original header and voxel_size
segmentation.data = ske

if not os.path.exists(out_folder):
os.makedirs(out_folder)
Expand All @@ -60,5 +67,5 @@ def skeletonize(
os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc",
)

store_tomogram(filename=out_file, tomogram=ske)
store_tomogram(filename=out_file, tomogram=segmentation)
print("Skeleton saved to ", out_file)
5 changes: 0 additions & 5 deletions src/membrain_seg/segmentation/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
)
from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.training.optim_utils import (
PrintLearningRate,
ToleranceCallback,
)
from membrain_seg.segmentation.training.training_param_summary import (
Expand Down Expand Up @@ -158,9 +157,6 @@ def fine_tune(
# Monitor learning rate changes
lr_monitor = LearningRateMonitor(logging_interval="epoch", log_momentum=True)

# Print the current learning rate at the start of each epoch
print_lr_cb = PrintLearningRate()

# Initialize the trainer with specified precision, logger, and callbacks
trainer = pl.Trainer(
precision="16-mixed",
Expand All @@ -169,7 +165,6 @@ def fine_tune(
checkpoint_callback_train_loss,
checkpoint_callback_regular,
lr_monitor,
print_lr_cb,
tolerance_callback,
],
max_epochs=max_epochs,
Expand Down
18 changes: 7 additions & 11 deletions src/membrain_seg/segmentation/skeletonize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import scipy.ndimage as ndimage
import torch

from membrain_seg.segmentation.dataloading.data_utils import load_tomogram
from membrain_seg.segmentation.skeletonization.diff3d import (
compute_gradients,
compute_hessian,
Expand All @@ -24,18 +23,19 @@
from membrain_seg.segmentation.training.surface_dice import apply_gaussian_filter


def skeletonization(label_path: str, batch_size: int) -> np.ndarray:
def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray:
"""
Perform skeletonization on a tomogram segmentation.
This function reads a segmentation file (label_path). It performs skeletonization on
the segmentation where the non-zero labels represent the structures of interest.
The resultan skeleton is saved with '_skel' appended after the filename.
This function skeletonizes the input segmentation where the non-zero labels
represent the structures of interest. The resultan skeleton is saved with
'_skel' appended after the filename.
Parameters
----------
label_path : str
Path to the input tomogram segmentation file.
segmentation : ndarray
Tomogram segmentation as a numpy array, where non-zero values represent
the structures of interest.
batch_size : int
The number of elements to process in one batch during eigen decomposition.
Useful for managing memory usage.
Expand All @@ -58,10 +58,6 @@ def skeletonization(label_path: str, batch_size: int) -> np.ndarray:
--batch-size 1000000
This command runs the skeletonization process from the command line.
"""
# Read original segmentation
segmentation = load_tomogram(label_path)
segmentation = segmentation.data

# Convert non-zero segmentation values to 1.0
labels = (segmentation > 0) * 1.0

Expand Down
2 changes: 1 addition & 1 deletion src/membrain_seg/segmentation/training/optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def on_validation_epoch_end(
return []

# Check if the metric value deviates beyond the threshold
if abs(metric_value - self.baseline_value) > self.threshold:
if metric_value - self.baseline_value > self.threshold:
print(
f"Stopping training as {self.metric_name} "
f"deviates too far from the baseline value."
Expand Down

1 comment on commit d8db3a1

@Hanyi11
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main Changes:

  1. Fixed the issue of losing header information by loading the tomogram before calling the skeletonization function. The tomo.data is now passed into skeletonization, and the resulting skeleton replaces the original tomo.data for saving.
  2. Removed the PrintLearningRate callback from the fine-tuning function.
  3. Added the Fine-tuning.md file.

Please sign in to comment.