From 304fd1daf4a4e79fa9b41081e216cb31c5065a3e Mon Sep 17 00:00:00 2001 From: Yi-Chia Chang Date: Mon, 10 Jul 2023 15:14:59 -0400 Subject: [PATCH] add yaml and modify/test for training --- conf/agrifieldnet.yaml | 31 ++++++++++ experiments/ssl4eo/run_agrifieldnet.py | 78 ++++++++++++++++++++++++++ tests/datasets/test_agrifieldnet.py | 7 +++ torchgeo/datamodules/agrifieldnet.py | 9 ++- torchgeo/datasets/agrifieldnet.py | 7 ++- torchgeo/trainers/segmentation.py | 6 +- train.py | 9 ++- 7 files changed, 136 insertions(+), 11 deletions(-) create mode 100644 conf/agrifieldnet.yaml create mode 100644 experiments/ssl4eo/run_agrifieldnet.py diff --git a/conf/agrifieldnet.yaml b/conf/agrifieldnet.yaml new file mode 100644 index 00000000000..983c9b42fc3 --- /dev/null +++ b/conf/agrifieldnet.yaml @@ -0,0 +1,31 @@ +module: + _target_: torchgeo.trainers.SemanticSegmentationTask + model: "unet" + backbone: "resnet18" + weights: null + in_channels: 12 + num_classes: 13 + loss: "ce" + ignore_index: 0 + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weight_decay: 0 + +datamodule: + _target_: torchgeo.datamodules.AgriFieldNetDataModule + root: "data" + batch_size: 64 + patch_size: 256 + num_workers: 10 + +trainer: + _target_: lightning.pytorch.Trainer + min_epochs: 1 + max_epochs: 2 + +program: + seed: 2 + output_dir: "output/agrifieldnet" + log_dir: "logs/agrifieldnet" + overwrite: True + experiment_name: agrifieldnet diff --git a/experiments/ssl4eo/run_agrifieldnet.py b/experiments/ssl4eo/run_agrifieldnet.py new file mode 100644 index 00000000000..5b8c2d16e78 --- /dev/null +++ b/experiments/ssl4eo/run_agrifieldnet.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Runs the train script with a grid of hyperparameters.""" +import itertools +import os +import subprocess +from multiprocessing import Process, Queue + +# list of GPU IDs that we want to use, one job will be started for every ID in the list +GPUS = [0] +DRY_RUN = False # if False then print out the commands to be run, if True then run +conf_file_name = "agrifieldnet.yaml" + +# Hyperparameter options +model_options = ["unet"] +backbone_options = ["resnet18"] +lr_options = [0.001, 0.0003, 0.0001, 0.00003] +loss_options = ["ce"] +weight_options = [False] +seed_options = [2] +weight_decay_options = [0] + + +def do_work(work: "Queue[str]", gpu_idx: int) -> bool: + """Process for each ID in GPUS.""" + while not work.empty(): + experiment = work.get() + experiment = experiment.replace("GPU", str(gpu_idx)) + print(experiment) + if not DRY_RUN: + subprocess.call(experiment.split(" ")) + return True + + +if __name__ == "__main__": + work: "Queue[str]" = Queue() + + for model, backbone, lr, loss, weights, weight_decay, seed in itertools.product( + model_options, + backbone_options, + lr_options, + loss_options, + weight_options, + weight_decay_options, + seed_options, + ): + if model == "fcn" and not weights: + continue + + experiment_name = f"{conf_file_name.split('.')[0]}_{model}_{backbone}_{lr}_{loss}_{weights}_{weight_decay}_{seed}" + + config_file = os.path.join("conf", conf_file_name) + + command = ( + "python train.py" + + f" config_file={config_file}" + + f" module.model={model}" + + f" module.backbone={backbone}" + + f" module.learning_rate={lr}" + + f" module.loss={loss}" + + f" module.weights={weights}" + + f" program.experiment_name={experiment_name}" + + f" program.seed={seed}" + + " trainer.devices=[GPU]" + ) + command = command.strip() + + work.put(command) + + processes = [] + for gpu_idx in GPUS: + p = Process(target=do_work, args=(work, gpu_idx)) + processes.append(p) + p.start() + for p in processes: + p.join() diff --git a/tests/datasets/test_agrifieldnet.py b/tests/datasets/test_agrifieldnet.py index c78877de24a..3b388ef8f4e 100644 --- a/tests/datasets/test_agrifieldnet.py +++ b/tests/datasets/test_agrifieldnet.py @@ -16,6 +16,8 @@ from torch.utils.data import ConcatDataset from torchgeo.datasets import AgriFieldNet +from torchgeo.datamodules import AgriFieldNetDataModule +from torchgeo.trainers import SemanticSegmentationTask class Collection: @@ -101,6 +103,11 @@ def test_invalid_bands(self) -> None: with pytest.raises(ValueError, match="is an invalid band name."): AgriFieldNet(bands=("foo", "bar")) + # def test_trainer(self) -> None: + # model = SemanticSegmentationTask("unet") + # trainer = + # trainer.fit(model, datamodule=AgriFieldNetDataModule()) + def test_plot(self, dataset: AgriFieldNet) -> None: dataset.plot(dataset[0], suptitle="Test") plt.close() diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index 412380c36c4..b564b771830 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -8,6 +8,7 @@ from ..datasets import AgriFieldNet from .geo import NonGeoDataModule from .utils import dataset_split +from ..samplers.utils import _to_tuple class AgriFieldNetDataModule(NonGeoDataModule): @@ -21,22 +22,24 @@ class AgriFieldNetDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, + patch_size: int = 256, + num_workers: int = 0, val_split_pct: float = 0.1, test_split_pct: float = 0.1, - num_workers: int = 0, **kwargs: Any, ) -> None: """Initialize a new AgriFieldNetDataModule instance. Args: batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. val_split_pct: Percentage of the dataset to use as a validation set. test_split_pct: Percentage of the dataset to use as a test set. - num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.AgriFieldNetDataModule`. """ - super().__init__(AgriFieldNet, **kwargs) + super().__init__(AgriFieldNet, batch_size, num_workers, **kwargs) + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 4b485fd3f06..76e8b06067a 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -185,14 +185,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: index: index to return Returns: - data and label at that index + data, label, and field ids at that index """ # if self.split == "train": # tile_name = self.train_tiles[index] # else: # tile_name = self.test_tiles[index] - tile_name = self.train_tiles[index] + print(tile_name) image = self._load_image_tile(tile_name) labels, field_ids = self._load_label_tile(tile_name) @@ -315,7 +315,8 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.source_image_fns) + # return len(self.source_image_fns) + return len(self.train_label_fns) def _validate_bands(self, bands: tuple[str, ...]) -> None: """Validate list of bands. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index e0497de1b9c..674720b508b 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -326,8 +326,10 @@ def configure_optimizers(self) -> dict[str, Any]: Returns: learning rate dictionary """ - optimizer = torch.optim.Adam( - self.model.parameters(), lr=self.hyperparams["learning_rate"] + optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.hyperparams["learning_rate"], + weight_decay=self.hyperparams["weight_decay"], ) return { "optimizer": optimizer, diff --git a/train.py b/train.py index 2284402d581..07de7f39769 100755 --- a/train.py +++ b/train.py @@ -59,9 +59,12 @@ def set_up_omegaconf() -> DictConfig: def main(conf: DictConfig) -> None: """Main training loop.""" - experiment_name = ( - f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}" - ) + if conf.program.experiment_name is not None: + experiment_name = conf.program.experiment_name + else: + experiment_name = ( + f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}" + ) if os.path.isfile(conf.program.output_dir): raise NotADirectoryError("`program.output_dir` must be a directory") os.makedirs(conf.program.output_dir, exist_ok=True)