Skip to content

Commit

Permalink
On the fly data loading during training (#85)
Browse files Browse the repository at this point in the history
* draw random crop from each patch

* add option for on-the-fly dataloading for large datasets

* remove redundant dataloader init

* add on the ly dataloading to training summary
  • Loading branch information
LorenzLamm authored Sep 21, 2024
1 parent 9e79ca1 commit 215b3d5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 21 deletions.
9 changes: 9 additions & 0 deletions src/membrain_seg/segmentation/cli/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def train(
log_dir = "./logs"
batch_size = 2
num_workers = 1
on_the_fly_dataloading = False
max_epochs = 1000
aug_prob_to_one = True
use_deep_supervision = True
Expand All @@ -47,6 +48,7 @@ def train(
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
on_the_fly_dataloading=on_the_fly_dataloading,
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
Expand Down Expand Up @@ -76,6 +78,10 @@ def train_advanced(
8,
help="Number of worker threads for loading data",
),
on_the_fly_dataloading: bool = Option( # noqa: B008
False,
help="Whether to load data on the fly. This is useful for large datasets.",
),
max_epochs: int = Option( # noqa: B008
1000,
help="Maximum number of epochs for training",
Expand Down Expand Up @@ -131,6 +137,8 @@ def train_advanced(
Number of samples per batch, by default 2.
num_workers : int
Number of worker threads for data loading, by default 1.
on_the_fly_dataloading : bool
Determines whether to load data on the fly, by default False.
max_epochs : int
Maximum number of training epochs, by default 1000.
aug_prob_to_one : bool
Expand Down Expand Up @@ -162,6 +170,7 @@ def train_advanced(
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
on_the_fly_dataloading=on_the_fly_dataloading,
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
Expand Down
62 changes: 44 additions & 18 deletions src/membrain_seg/segmentation/dataloading/memseg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
train: bool = False,
aug_prob_to_one: bool = False,
patch_size: int = 160,
on_the_fly_loading: bool = False,
) -> None:
"""
Constructs all the necessary attributes for the CryoETMemSegDataset object.
Expand All @@ -76,12 +77,16 @@ def __init__(
to one or not.
patch_size : int, default 160
The size of the patches to be extracted from the images.
on_the_fly_loading : bool, default False
A flag indicating whether the data should be loaded on the fly or not.
"""
self.train = train
self.img_folder, self.label_folder = img_folder, label_folder
self.patch_size = patch_size
self.on_the_fly_loading = on_the_fly_loading
self.initialize_imgs_paths()
self.load_data()
if not self.on_the_fly_loading:
self.load_data()
self.transforms = (
get_training_transforms(prob_to_one=aug_prob_to_one)
if self.train
Expand All @@ -104,13 +109,20 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
Dict[str, np.ndarray]
A dictionary containing an image and its corresponding label.
"""
idx_dict = {
"image": np.expand_dims(self.imgs[idx], 0),
"label": np.expand_dims(self.labels[idx], 0),
}
if self.on_the_fly_loading:
idx_dict = self.load_data_sample(idx)
idx_dict["image"] = np.expand_dims(idx_dict["image"], 0)
idx_dict["label"] = np.expand_dims(idx_dict["label"], 0)
ds_label = idx_dict["dataset"]
else:
idx_dict = {
"image": np.expand_dims(self.imgs[idx], 0),
"label": np.expand_dims(self.labels[idx], 0),
}
ds_label = self.dataset_labels[idx]
idx_dict = self.get_random_crop(idx_dict)
idx_dict = self.transforms(idx_dict)
idx_dict["dataset"] = self.dataset_labels[idx]
idx_dict["dataset"] = ds_label # transforms remove the dataset token
return idx_dict

def __len__(self) -> int:
Expand Down Expand Up @@ -228,6 +240,27 @@ def get_random_crop(self, idx_dict: Dict[str, np.ndarray]) -> Dict[str, np.ndarr
), f"Image shape is {img.shape} instead of {self.patch_size}"
return {"image": img, "label": label}

def load_data_sample(self, idx: int) -> Dict[str, np.ndarray]:
"""
Loads a single image-label pair from the dataset.
Parameters
----------
idx : int
The index of the sample to be loaded.
Returns
-------
Dict[str, np.ndarray]
A dictionary containing an image and its corresponding label.
"""
label = read_nifti(self.data_paths[idx][1])
img = read_nifti(self.data_paths[idx][0])
label = np.transpose(label, (1, 2, 0))
img = np.transpose(img, (1, 2, 0))
ds_token = get_dataset_token(self.data_paths[idx][0])
return {"image": img, "label": label, "dataset": ds_token}

def load_data(self) -> None:
"""
Loads image-label pairs into memory from the specified directories.
Expand All @@ -240,18 +273,11 @@ def load_data(self) -> None:
self.imgs = []
self.labels = []
self.dataset_labels = []
for entry in tqdm(self.data_paths):
label = read_nifti(
entry[1]
) # TODO: Change this to be applicable to .mrc images
img = read_nifti(entry[0])
label = np.transpose(
label, (1, 2, 0)
) # TODO: Needed? Probably no? z-axis should not matter
img = np.transpose(img, (1, 2, 0))
self.imgs.append(img)
self.labels.append(label)
self.dataset_labels.append(get_dataset_token(entry[0]))
for entry_num in tqdm(range(len(self.data_paths))):
sample_dict = self.load_data_sample(entry_num)
self.imgs.append(sample_dict["image"])
self.labels.append(sample_dict["label"])
self.dataset_labels.append(sample_dict["dataset"])

def initialize_imgs_paths(self) -> None:
"""
Expand Down
21 changes: 18 additions & 3 deletions src/membrain_seg/segmentation/dataloading/memseg_pl_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ class MemBrainSegDataModule(pl.LightningDataModule):
The test dataset.
"""

def __init__(self, data_dir, batch_size, num_workers, aug_prob_to_one=False):
def __init__(
self,
data_dir,
batch_size,
num_workers,
on_the_fly_dataloading=False,
aug_prob_to_one=False,
):
"""Initialization of data paths and data loaders.
The data_dir should have the following structure:
Expand Down Expand Up @@ -72,6 +79,7 @@ def __init__(self, data_dir, batch_size, num_workers, aug_prob_to_one=False):
self.batch_size = batch_size
self.num_workers = num_workers
self.aug_prob_to_one = aug_prob_to_one
self.on_the_fly_dataloading = on_the_fly_dataloading

def setup(self, stage: Optional[str] = None):
"""
Expand All @@ -91,14 +99,21 @@ def setup(self, stage: Optional[str] = None):
label_folder=self.train_lab_dir,
train=True,
aug_prob_to_one=self.aug_prob_to_one,
on_the_fly_loading=self.on_the_fly_dataloading,
)
self.val_dataset = CryoETMemSegDataset(
img_folder=self.val_img_dir, label_folder=self.val_lab_dir, train=False
img_folder=self.val_img_dir,
label_folder=self.val_lab_dir,
train=False,
on_the_fly_loading=self.on_the_fly_dataloading,
)

if stage in (None, "test"):
self.test_dataset = CryoETMemSegDataset(
self.data_dir, test=True, transform=self.transform
self.data_dir,
test=True,
transform=self.transform,
on_the_fly_loading=self.on_the_fly_dataloading,
) # TODO: How to do prediction?

def train_dataloader(self) -> DataLoader:
Expand Down
5 changes: 5 additions & 0 deletions src/membrain_seg/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def train(
log_dir: str = "logs/",
batch_size: int = 2,
num_workers: int = 8,
on_the_fly_dataloading: bool = False,
max_epochs: int = 1000,
aug_prob_to_one: bool = False,
use_deep_supervision: bool = False,
Expand All @@ -48,6 +49,8 @@ def train(
Number of samples per batch of input data.
num_workers : int, optional
Number of subprocesses to use for data loading.
on_the_fly_dataloading : bool, optional
If True, data is loaded on the fly.
max_epochs : int, optional
Maximum number of epochs to train for.
aug_prob_to_one : bool, optional
Expand All @@ -74,6 +77,7 @@ def train(
log_dir=log_dir,
batch_size=batch_size,
num_workers=num_workers,
on_the_fly_dataloading=on_the_fly_dataloading,
max_epochs=max_epochs,
aug_prob_to_one=aug_prob_to_one,
use_deep_supervision=use_deep_supervision,
Expand All @@ -88,6 +92,7 @@ def train(
data_dir=data_dir,
batch_size=batch_size,
num_workers=num_workers,
on_the_fly_dataloading=on_the_fly_dataloading,
aug_prob_to_one=aug_prob_to_one,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ def print_training_parameters(
log_dir: str = "logs/",
batch_size: int = 2,
num_workers: int = 8,
on_the_fly_dataloading: bool = False,
max_epochs: int = 1000,
aug_prob_to_one: bool = False,
use_deep_supervision: bool = False,
Expand All @@ -25,6 +26,8 @@ def print_training_parameters(
Number of samples per batch of input data.
num_workers : int, optional
Number of subprocesses to use for data loading.
on_the_fly_dataloading : bool, optional
If True, data is loaded on the fly.
max_epochs : int, optional
Maximum number of epochs to train for.
aug_prob_to_one : bool, optional
Expand Down Expand Up @@ -68,6 +71,12 @@ def print_training_parameters(
"loading.".format(num_workers)
)
print("————————————————————————————————————————————————————————")
on_the_fly_status = "Enabled" if on_the_fly_dataloading else "Disabled"
print(
"On-the-Fly Data Loading:\n {} \n If enabled, data is loaded on "
"the fly.".format(on_the_fly_status)
)
print("————————————————————————————————————————————————————————")
print(f"Max Epochs:\n {max_epochs} \n Maximum number of training epochs.")
print("————————————————————————————————————————————————————————")
aug_status = "Enabled" if aug_prob_to_one else "Disabled"
Expand Down

0 comments on commit 215b3d5

Please sign in to comment.