-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_training_semisup_domain_pred_irm_seg.py
174 lines (150 loc) · 8.97 KB
/
example_training_semisup_domain_pred_irm_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# ------------------------------------------------------------------------------
# This file (and it siblings) is explained in more details in
# example_training_scripts.md
# ------------------------------------------------------------------------------
# 1. Imports
import warnings
from itertools import chain
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from mp.agents.segmentation_semisup_domain_pred_IRM_seg_agent import SegmentationSemisupDomainPredictionIRMAgent
from mp.data.data import Data
from mp.data.datasets.ds_mr_hippocampus_decathlon import DecathlonHippocampus
from mp.data.datasets.ds_mr_hippocampus_dryad import DryadHippocampus
from mp.data.datasets.ds_mr_hippocampus_harp import HarP
from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg3DDataset
from mp.eval.losses.losses_domain_prediction import ConfusionLoss
from mp.eval.losses.losses_irm import IRMv1Loss, VRexLoss, MMRexLoss, ERMWrapper
from mp.eval.losses.losses_segmentation import LossDiceBCE, LossClassWeighted, LossBCE
from mp.eval.result import Result
from mp.experiments.limited_labels_experiment import LimitedLabelsExperiment
from mp.models.domain_prediction.domain_predictor_segmentation import DomainPredictor3D
from mp.models.domain_prediction.unet_with_domain_pred import UNetWithDomainPred
from mp.models.segmentation.unet_fepegar import UNet3D
from mp.utils.early_stopping import EarlyStopping
from mp.utils.load_restore import pkl_dump
from mp.visualization.plot_results import plot_results
warnings.filterwarnings("ignore")
# 2. Define data
data = Data()
decath = DecathlonHippocampus(merge_labels=True)
data.add_dataset(decath)
harp = HarP()
data.add_dataset(harp)
dryad = DryadHippocampus(merge_labels=True)
data.add_dataset(dryad)
nr_labels = data.nr_labels
label_names = data.label_names
# 3. Define configuration
configs = [
{'experiment_name': 'harp(24,6)_dryad_no_aug_vrex', 'device': 'cuda:0',
'nr_runs': 5, 'cross_validation': True, 'val_ratio': 0.1, 'test_ratio': 0.3,
'input_shape': (1, 48, 64, 64), 'resize': False, 'augmentation': 'none',
'class_weights': (0., 1.), 'lr': 1.5e-4, 'batch_sizes_seg': [4, 8], 'batch_sizes_dp': [8, 8],
"alpha": 10, "beta": 1, "loss": "vrex",
"eval_interval": 20,
"train_ds_names": (harp.name, dryad.name),
"limited_labels": {harp.name: (24, 6)},
},
]
# 4. Pre-split datasets to avoid having the "Repetition k i of j" messages spammed at each experiment's start
for config in configs:
exp = LimitedLabelsExperiment(config=config, name=config['experiment_name'], notes='', reload_exp=True)
exp.set_data_splits(data, limited_datasets=config["limited_labels"])
for config in configs:
print(config["experiment_name"])
device = config['device']
device_name = torch.cuda.get_device_name(device)
input_shape = config['input_shape']
train_ds_names = config["train_ds_names"]
# 5. Create experiment directories
exp = LimitedLabelsExperiment(config=config, name=config['experiment_name'], notes='', reload_exp=True)
# 6. Create data splits for each repetition
exp.set_data_splits(data, limited_datasets=config["limited_labels"])
# Now repeat for each repetition
for run_ix in range(config.get("start", 0), config['nr_runs']):
exp_run = exp.get_run(run_ix)
# 7. Bring data to Pytorch format
datasets_seg = dict()
datasets_dp = dict()
for ds_name, ds in data.datasets.items():
# 2 cases: either the dataset's name is in train_ds_names
# In which case, we proceed as usual:
if ds_name in train_ds_names:
for split, data_ixs in exp.splits[ds_name][exp_run.run_ix].items():
datasets = datasets_dp if "_dp" in split else datasets_seg
if len(data_ixs) > 0: # Sometimes val indexes may be an empty list
aug = config['augmentation'] if not ('test' in split) else 'none'
datasets[(ds_name, split)] = PytorchSeg3DDataset(ds,
ix_lst=data_ixs, size=input_shape, aug_key=aug,
resize=config['resize'])
# If it's not the case, then the dataset's purpose is only testing and the whole dataset is the test split
else:
datasets_seg[(ds_name, "test")] = PytorchSeg3DDataset(ds,
ix_lst=None, size=input_shape, aug_key="none",
resize=config['resize'])
# 8. Build train dataloader, and visualize
dls_seg = [DataLoader(datasets_seg[name, "train"], batch_size=length, shuffle=True)
for name, length in zip(train_ds_names, config['batch_sizes_seg'])]
dls_dp = [DataLoader(datasets_dp[name, "train_dp"], batch_size=length, shuffle=True)
for name, length in zip(train_ds_names, config['batch_sizes_dp'])]
# 9. Initialize model
unet = UNet3D(input_shape, nr_labels)
unet.to(device)
domain_predictor = DomainPredictor3D(input_shape, len(train_ds_names), out_channels_first_layer=16)
domain_predictor.to(device)
model = UNetWithDomainPred(unet, domain_predictor, input_shape, (2,))
model.to(device)
# 10. Define loss and optimizer
erm_loss = LossClassWeighted(LossDiceBCE(bce_weight=1., smooth=1., device=device),
weights=config["class_weights"], device=device)
irm_losses = {"vrex": VRexLoss(erm_loss, device=device),
"mmrex": MMRexLoss(erm_loss, device=device),
"irmv1": IRMv1Loss(erm_loss, device=device),
"erm": ERMWrapper(erm_loss, device=device)
}
loss_f_classifier = irm_losses[config["loss"]]
# erm_loss = LossBCE(device=device)
# irm_losses = {"vrex": VRexLoss(erm_loss, device=device),
# "mmrex": MMRexLoss(erm_loss, device=device),
# "irmv1": IRMv1Loss(erm_loss, device=device),
# "erm": ERMWrapper(erm_loss, device=device)
# }
# loss_f_domain_predictor = irm_losses[config["loss"]]
loss_f_domain_predictor = LossBCE(device=device)
loss_f_encoder = ConfusionLoss(device=device)
losses = [loss_f_classifier, loss_f_domain_predictor, loss_f_encoder]
optimizer_stage1 = optim.Adam(model.parameters(), lr=config['lr'])
optimizer_model = optim.Adam(chain(model.encoder_parameters(), model.classifier_parameters()), lr=5e-5)
optimizer_domain_predictor = optim.Adam(model.domain_predictor_parameters(), lr=5e-5)
optimizer_encoder = optim.Adam(model.encoder_parameters(), lr=5e-5)
optimizers = [optimizer_stage1, optimizer_model, optimizer_domain_predictor, optimizer_encoder]
# 11. Train model
results = Result(name='training_trajectory')
agent = SegmentationSemisupDomainPredictionIRMAgent(model=model, label_names=label_names, device=device,
metrics=["ScoreDice"], verbose=True)
early_stopping = EarlyStopping(2, "Mean_ScoreDice[hippocampus]", [name + "_val" for name in train_ds_names],
metric_min_delta=20e-4)
epochs = agent.train_with_early_stopping(results, optimizers, losses,
train_dataloaders_seg=dls_seg,
train_dataloaders_dp=dls_dp,
train_dataset_names=train_ds_names,
early_stopping=early_stopping,
run_loss_print_interval=config["eval_interval"],
eval_datasets_seg=datasets_seg,
eval_datasets_dp=datasets_dp,
eval_interval=config["eval_interval"],
save_path=exp_run.paths['states'],
alpha=config.get("alpha", 1.0),
beta=config["beta"])
# Save the stage delimitations in the obj folder
pkl_dump(epochs, "epochs.pkl", exp_run.paths['obj'])
# 12. Save and print results for this experiment run
exp_run.finish(results=results, plot_metrics=['Mean_ScoreDice[hippocampus]'],
plot_metrics_args={"axvlines": epochs})
# Plotting in a separate file, because of seaborn's limited dashes style list
plot_results(results, save_path=exp_run.paths['results'],
save_name="domain_prediction_accuracy.png",
measures=["Mean_ScoreAccuracy_DomPred"],
axvlines=epochs)