Skip to content

Commit

Permalink
add yaml and modify/test for training
Browse files Browse the repository at this point in the history
  • Loading branch information
yichiac committed Jul 10, 2023
1 parent b3f769d commit 304fd1d
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 11 deletions.
31 changes: 31 additions & 0 deletions conf/agrifieldnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module:
_target_: torchgeo.trainers.SemanticSegmentationTask
model: "unet"
backbone: "resnet18"
weights: null
in_channels: 12
num_classes: 13
loss: "ce"
ignore_index: 0
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weight_decay: 0

datamodule:
_target_: torchgeo.datamodules.AgriFieldNetDataModule
root: "data"
batch_size: 64
patch_size: 256
num_workers: 10

trainer:
_target_: lightning.pytorch.Trainer
min_epochs: 1
max_epochs: 2

program:
seed: 2
output_dir: "output/agrifieldnet"
log_dir: "logs/agrifieldnet"
overwrite: True
experiment_name: agrifieldnet
78 changes: 78 additions & 0 deletions experiments/ssl4eo/run_agrifieldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Runs the train script with a grid of hyperparameters."""
import itertools
import os
import subprocess
from multiprocessing import Process, Queue

# list of GPU IDs that we want to use, one job will be started for every ID in the list
GPUS = [0]
DRY_RUN = False # if False then print out the commands to be run, if True then run
conf_file_name = "agrifieldnet.yaml"

# Hyperparameter options
model_options = ["unet"]
backbone_options = ["resnet18"]
lr_options = [0.001, 0.0003, 0.0001, 0.00003]
loss_options = ["ce"]
weight_options = [False]
seed_options = [2]
weight_decay_options = [0]


def do_work(work: "Queue[str]", gpu_idx: int) -> bool:
"""Process for each ID in GPUS."""
while not work.empty():
experiment = work.get()
experiment = experiment.replace("GPU", str(gpu_idx))
print(experiment)
if not DRY_RUN:
subprocess.call(experiment.split(" "))
return True


if __name__ == "__main__":
work: "Queue[str]" = Queue()

for model, backbone, lr, loss, weights, weight_decay, seed in itertools.product(
model_options,
backbone_options,
lr_options,
loss_options,
weight_options,
weight_decay_options,
seed_options,
):
if model == "fcn" and not weights:
continue

experiment_name = f"{conf_file_name.split('.')[0]}_{model}_{backbone}_{lr}_{loss}_{weights}_{weight_decay}_{seed}"

config_file = os.path.join("conf", conf_file_name)

command = (
"python train.py"
+ f" config_file={config_file}"
+ f" module.model={model}"
+ f" module.backbone={backbone}"
+ f" module.learning_rate={lr}"
+ f" module.loss={loss}"
+ f" module.weights={weights}"
+ f" program.experiment_name={experiment_name}"
+ f" program.seed={seed}"
+ " trainer.devices=[GPU]"
)
command = command.strip()

work.put(command)

processes = []
for gpu_idx in GPUS:
p = Process(target=do_work, args=(work, gpu_idx))
processes.append(p)
p.start()
for p in processes:
p.join()
7 changes: 7 additions & 0 deletions tests/datasets/test_agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from torch.utils.data import ConcatDataset

from torchgeo.datasets import AgriFieldNet
from torchgeo.datamodules import AgriFieldNetDataModule
from torchgeo.trainers import SemanticSegmentationTask


class Collection:
Expand Down Expand Up @@ -101,6 +103,11 @@ def test_invalid_bands(self) -> None:
with pytest.raises(ValueError, match="is an invalid band name."):
AgriFieldNet(bands=("foo", "bar"))

# def test_trainer(self) -> None:
# model = SemanticSegmentationTask("unet")
# trainer =
# trainer.fit(model, datamodule=AgriFieldNetDataModule())

def test_plot(self, dataset: AgriFieldNet) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
Expand Down
9 changes: 6 additions & 3 deletions torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..datasets import AgriFieldNet
from .geo import NonGeoDataModule
from .utils import dataset_split
from ..samplers.utils import _to_tuple


class AgriFieldNetDataModule(NonGeoDataModule):
Expand All @@ -21,22 +22,24 @@ class AgriFieldNetDataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: int = 256,
num_workers: int = 0,
val_split_pct: float = 0.1,
test_split_pct: float = 0.1,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new AgriFieldNetDataModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
val_split_pct: Percentage of the dataset to use as a validation set.
test_split_pct: Percentage of the dataset to use as a test set.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.AgriFieldNetDataModule`.
"""
super().__init__(AgriFieldNet, **kwargs)
super().__init__(AgriFieldNet, batch_size, num_workers, **kwargs)
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
self.test_split_pct = test_split_pct

Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
index: index to return
Returns:
data and label at that index
data, label, and field ids at that index
"""
# if self.split == "train":
# tile_name = self.train_tiles[index]
# else:
# tile_name = self.test_tiles[index]

tile_name = self.train_tiles[index]
print(tile_name)
image = self._load_image_tile(tile_name)
labels, field_ids = self._load_label_tile(tile_name)

Expand Down Expand Up @@ -315,7 +315,8 @@ def __len__(self) -> int:
Returns:
length of the dataset
"""
return len(self.source_image_fns)
# return len(self.source_image_fns)
return len(self.train_label_fns)

def _validate_bands(self, bands: tuple[str, ...]) -> None:
"""Validate list of bands.
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ def configure_optimizers(self) -> dict[str, Any]:
Returns:
learning rate dictionary
"""
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.hyperparams["learning_rate"]
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.hyperparams["learning_rate"],
weight_decay=self.hyperparams["weight_decay"],
)
return {
"optimizer": optimizer,
Expand Down
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ def set_up_omegaconf() -> DictConfig:

def main(conf: DictConfig) -> None:
"""Main training loop."""
experiment_name = (
f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}"
)
if conf.program.experiment_name is not None:
experiment_name = conf.program.experiment_name
else:
experiment_name = (
f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}"
)
if os.path.isfile(conf.program.output_dir):
raise NotADirectoryError("`program.output_dir` must be a directory")
os.makedirs(conf.program.output_dir, exist_ok=True)
Expand Down

0 comments on commit 304fd1d

Please sign in to comment.