diff --git a/satflow/configs/config.yaml b/satflow/configs/config.yaml index c58c5a48..b34503af 100644 --- a/satflow/configs/config.yaml +++ b/satflow/configs/config.yaml @@ -6,7 +6,7 @@ defaults: - model: convlstm_model.yaml - datamodule: satflow_datamodule.yaml - callbacks: default.yaml # set this to null if you don't want to use callbacks - - logger: tensorboard # set logger here or use command line (e.g. `python run.py logger=wandb`) + - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) - experiment: null - hparams_search: null @@ -24,7 +24,7 @@ defaults: work_dir: ${hydra:runtime.cwd} # path to folder with data -data_dir: /run/media/jacob/T7/ +data_dir: /run/media/bieker/data/EUMETSAT/ # use `python run.py debug=true` for easy debugging! # this will run 1 train, val and test loop with only 1 batch diff --git a/satflow/configs/datamodule/segmentation_datamodule.yaml b/satflow/configs/datamodule/segmentation_datamodule.yaml index 1a62bef2..f3da5e90 100644 --- a/satflow/configs/datamodule/segmentation_datamodule.yaml +++ b/satflow/configs/datamodule/segmentation_datamodule.yaml @@ -1,31 +1,33 @@ # @package _group_ -_target_: satflow.data.datamodules.SatFlowDataModule +_target_: satflow.data.datamodules.SegFlowDataModule -batch_size: 1 +batch_size: 16 data_dir: ${data_dir} # data_dir is specified in config.yaml shuffle: 0 sources: - train: "satflow-flow-144-tiled-{00001..00105}.tar" - val: "satflow-flow-144-tiled-{00106..00129}.tar" - test: "satflow-flow-144-tiled-{00130..00149}.tar" + train: "satflow-flow-144-tiled-{00001..00105}.tar" # 2020 + val: "satflow-flow-144-tiled-{00106..00129}.tar" # 2021 + test: "satflow-flow-144-tiled-{00130..00149}.tar" # 2021 num_workers: 12 pin_memory: True config: visualize: False - num_timesteps: 6 + num_timesteps: 0 skip_timesteps: 1 - forecast_times: 6 - output_shape: 128 + forecast_times: 2 + output_shape: 256 target_type: "cloudmask" - num_crops: 10 - num_times: 10 + num_crops: 5 + num_times: 20 use_topo: True - use_latlon: True + use_latlon: False use_time: False time_aux: False use_mask: False use_image: False time_as_channels: True + add_pixel_coords: False + add_polar_coords: False bands: [ "HRV", diff --git a/satflow/configs/datamodule/unet_datamodule.yaml b/satflow/configs/datamodule/unet_datamodule.yaml new file mode 100644 index 00000000..e304c88e --- /dev/null +++ b/satflow/configs/datamodule/unet_datamodule.yaml @@ -0,0 +1,32 @@ +# @package _group_ +_target_: satflow.data.datamodules.SatFlowDataModule + +batch_size: 8 +data_dir: ${data_dir} # data_dir is specified in config.yaml +shuffle: 0 +sources: + train: "satflow-flow-144-tiled-{00001..00105}.tar" + val: "satflow-flow-144-tiled-{00106..00129}.tar" + test: "satflow-flow-144-tiled-{00130..00149}.tar" +num_workers: 12 +pin_memory: True +config: + visualize: False + num_timesteps: 10 + skip_timesteps: 1 + forecast_times: 20 + output_shape: 128 + target_type: "cloudmask" + num_crops: 10 + num_times: 15 + use_topo: False + use_latlon: False + use_time: False + time_aux: False + use_mask: True + use_image: False + time_as_channels: True + add_pixel_coords: True + add_polar_coords: False + bands: ["IR016"] + transforms: {} diff --git a/satflow/configs/model/deeplabv3_r50_model.yaml b/satflow/configs/model/deeplabv3_r50_model.yaml index 5afc6048..ec8a9ef4 100644 --- a/satflow/configs/model/deeplabv3_r50_model.yaml +++ b/satflow/configs/model/deeplabv3_r50_model.yaml @@ -1,6 +1,6 @@ # @package _group_ _target_: satflow.models.deeplabv3.DeepLabV3 -forecast_steps: 6 +forecast_steps: 2 input_channels: 12 lr: 0.001 make_vis: False diff --git a/satflow/configs/model/fcn_r50_model.yaml b/satflow/configs/model/fcn_r50_model.yaml index 86e35dd3..a31abc1e 100644 --- a/satflow/configs/model/fcn_r50_model.yaml +++ b/satflow/configs/model/fcn_r50_model.yaml @@ -1,10 +1,9 @@ # @package _group_ _target_: satflow.models.fcn.FCN -forecast_steps: 6 +forecast_steps: 1 input_channels: 12 lr: 0.001 make_vis: False loss: "bce" backbone: "resnet50" pretrained: False -aux_loss: False diff --git a/satflow/configs/trainer/minimal.yaml b/satflow/configs/trainer/minimal.yaml index 19eda1e7..9d66a784 100644 --- a/satflow/configs/trainer/minimal.yaml +++ b/satflow/configs/trainer/minimal.yaml @@ -14,6 +14,10 @@ auto_lr_find: True auto_scale_batch_size: False reload_dataloaders_every_epoch: True +accumulate_grad_batches: 1 +precision: 32 +# stochastic_weight_avg: True + weights_summary: null progress_bar_refresh_rate: 10 # resume_from_checkpoint: null diff --git a/satflow/data/datamodules.py b/satflow/data/datamodules.py index 110743b7..27c3862b 100644 --- a/satflow/data/datamodules.py +++ b/satflow/data/datamodules.py @@ -2,7 +2,7 @@ from torch.utils.data import DataLoader from typing import Optional import webdataset as wds -from satflow.data.datasets import SatFlowDataset, CloudFlowDataset +from satflow.data.datasets import SatFlowDataset, CloudFlowDataset, SegFlowDataset import os @@ -152,3 +152,68 @@ def test_dataloader(self): pin_memory=self.pin_memory, num_workers=self.num_workers, ) + + +class SegFlowDataModule(pl.LightningDataModule): + def __init__( + self, + config: dict, + sources: dict, + batch_size: int = 2, + shuffle: int = 0, + data_dir: str = "./", + num_workers: int = 1, + pin_memory: bool = True, + ): + super().__init__() + self.data_dir = data_dir + self.config = config + self.batch_size = batch_size + self.shuffle = shuffle + self.sources = sources + self.num_workers = num_workers + self.pin_memory = pin_memory + + def prepare_data(self): + # download + pass + + def setup(self, stage: Optional[str] = None): + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage is None: + train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"])) + val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"])) + if self.shuffle > 0: + # Add shuffling, each sample is still quite large, so too many examples ends up running out of ram + train_dset = train_dset.shuffle(self.shuffle) + self.train_dataset = SegFlowDataset([train_dset], config=self.config, train=True) + self.val_dataset = SegFlowDataset([val_dset], config=self.config, train=False) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage is None: + test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"])) + self.test_dataset = SegFlowDataset([test_dset], config=self.config, train=False) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + num_workers=self.num_workers, + ) diff --git a/satflow/data/datasets.py b/satflow/data/datasets.py index e39720b8..b9976aa3 100644 --- a/satflow/data/datasets.py +++ b/satflow/data/datasets.py @@ -122,8 +122,7 @@ def create_pixel_coord_layers(x_dim: int, y_dim: int, with_r: bool = False) -> n xx_channel = xx_channel * 2 - 1 yy_channel = yy_channel * 2 - 1 - ret = np.stack([xx_channel, yy_channel], axis=0) - + ret = np.concatenate((xx_channel, yy_channel), axis=-1) if with_r: rr = np.sqrt(np.square(xx_channel - 0.5) + np.square(yy_channel - 0.5)) ret = np.concatenate([ret, np.expand_dims(rr, axis=0)], axis=0) @@ -582,10 +581,13 @@ def __iter__(self) -> Iterator[T_co]: # but could be interpolated between the previous step and next one by weighting by time difference # Topographic is same of course, just need to resize to 1km x 1km? # grid by taking the mean value of the interior ones - sources = [iter(ds) for ds in self.datasets] while True: + sources = [iter(ds) for ds in self.datasets] for source in sources: - sample = next(source) + try: + sample = next(source) + except StopIteration: + continue timesteps = pickle.loads(sample["time.pyd"]) available_steps = len(timesteps) # number of available timesteps # Check to make sure all timesteps exist diff --git a/satflow/examples/create_webdataset.py b/satflow/examples/create_webdataset.py index dfd27c8a..c375b5f9 100644 --- a/satflow/examples/create_webdataset.py +++ b/satflow/examples/create_webdataset.py @@ -52,7 +52,7 @@ eumetsat_dir = "/run/media/bieker/Round1/EUMETSAT/" -def make_day(data, flow=True, batch=144, tile=True): +def make_day(data, flow=True, batch=4, tile=True): root_dir, shard_num = data # reset optical flow samples flow_sample = {} @@ -78,11 +78,11 @@ def make_day(data, flow=True, batch=144, tile=True): shard_num += 1 interday_frame = 0 if os.path.exists( - f"/run/media/bieker/data/EUMETSAT/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar" + f"/home/bieker/Development/satflow/datasets/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar" ): return sink_flow = wds.TarWriter( - f"/run/media/bieker/data/EUMETSAT/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar", + f"/home/bieker/Development/satflow/datasets/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar", compress=True, ) for root, dirs, files in os.walk(root_dir): @@ -184,6 +184,9 @@ def make_day(data, flow=True, batch=144, tile=True): for i in range(8) ] batch_num += 1 + if batch_num >= 1: + sink_flow.close() + return else: flow_sample["time.pyd"] = datetime_object sink_flow.write(flow_sample) @@ -229,3 +232,4 @@ def make_day(data, flow=True, batch=144, tile=True): # exit() for data in all_dates: make_day(data) + exit() diff --git a/satflow/models/__init__.py b/satflow/models/__init__.py index ade19833..dcab870c 100644 --- a/satflow/models/__init__.py +++ b/satflow/models/__init__.py @@ -2,3 +2,5 @@ from .conv_lstm import EncoderDecoderConvLSTM from .metnet import MetNet from .predrnn import PredRNN +from .deeplabv3 import DeepLabV3 +from .fcn import FCN diff --git a/satflow/models/base.py b/satflow/models/base.py index 33cc2017..f300fbe0 100644 --- a/satflow/models/base.py +++ b/satflow/models/base.py @@ -53,7 +53,10 @@ def create_model(model_name, pretrained=False, checkpoint_path="", **kwargs): global_pool (str): global pool type (default: 'avg') input_channels (int): number of input channels (default: 12) forecast_steps (int): number of steps to forecast (default: 48) + optimizer (str): optimizer (default: 'adam') lr (float): learning rate (default: 0.001) + lr_scheduler (str): learning rate scheduler (default: '') + lr_scheduler_**: lr_scheduler specific arguments **: other kwargs are model specific """ source_name, model_name = split_model_name(model_name) @@ -61,10 +64,11 @@ def create_model(model_name, pretrained=False, checkpoint_path="", **kwargs): # Parameters that aren't supported by all models or are intended to only override model defaults if set # should default to None in command line args/cfg. Remove them if they are present and not set so that # non-supporting models don't break and default args remain in effect. - kwargs = {k: v for k, v in kwargs.items() if v is not None} + lr_kwargs = {k: v for k, v in kwargs.items() if v is not None and "lr_scheduler_" in k} + kwargs = {k: v for k, v in kwargs.items() if v is not None and "lr_scheduler_" not in k} if model_name in REGISTERED_MODELS: - model = get_model(model_name)(pretrained=pretrained, **kwargs) + model = get_model(model_name)(pretrained=pretrained, lr_scheduler=lr_kwargs, **kwargs) else: raise RuntimeError("Unknown model (%s)" % model_name) diff --git a/satflow/models/deeplabv3.py b/satflow/models/deeplabv3.py index 7c0ee1f3..d06a0454 100644 --- a/satflow/models/deeplabv3.py +++ b/satflow/models/deeplabv3.py @@ -9,7 +9,7 @@ @register_model -class DeeplabV3(pl.LightningModule): +class DeepLabV3(pl.LightningModule): def __init__( self, forecast_steps: int = 48, @@ -21,13 +21,14 @@ def __init__( pretrained: bool = False, aux_loss: bool = False, ): - super(DeeplabV3, self).__init__() + super(DeepLabV3, self).__init__() self.lr = lr + self.forecast_steps = forecast_steps assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] if loss == "mse": self.criterion = F.mse_loss elif loss in ["bce", "binary_crossentropy", "crossentropy"]: - self.criterion = F.nll_loss + self.criterion = F.cross_entropy elif loss in ["focal"]: self.criterion = FocalLoss() else: @@ -50,7 +51,7 @@ def __init__( @classmethod def from_config(cls, config): - return DeeplabV3( + return DeepLabV3( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), hidden_dim=config.get("features", 64), @@ -69,27 +70,40 @@ def configure_optimizers(self): def training_step(self, batch, batch_idx): x, y = batch - y_hat = self(x) - + y_hat = self(x)["out"] + y = y.long() if self.make_vis: if np.random.random() < 0.01: self.visualize(x, y, y_hat, batch_idx) # Generally only care about the center x crop, so the model can take into account the clouds in the area without # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels + + # loss = 0 + # for f_step in range(self.forecast_steps): + # loss += self.criterion(y_hat, y[:,f_step,:,:]) + # loss /= self.forecast_steps loss = self.criterion(y_hat, y) self.log("train/loss", loss, on_step=True) return loss def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self(x) + y = y.long() + y_hat = self(x)["out"] + + # Loss is then the loss for each input timestep in the y_hat, so have to break it up per forecase_step + # val_loss = 0 + # for f_step in range(self.forecast_steps): + # val_loss += self.criterion(y_hat, y[:,f_step,:,:]) + # val_loss /= self.forecast_steps val_loss = self.criterion(y_hat, y) self.log("val/loss", val_loss, on_step=True, on_epoch=True) return val_loss def test_step(self, batch, batch_idx): x, y = batch - y_hat = self(x, self.forecast_steps) + y = y.long() + y_hat = self(x)["out"] loss = self.criterion(y_hat, y) return loss diff --git a/satflow/models/fcn.py b/satflow/models/fcn.py index ebbaec1b..4ac208b8 100644 --- a/satflow/models/fcn.py +++ b/satflow/models/fcn.py @@ -26,7 +26,7 @@ def __init__( if loss == "mse": self.criterion = F.mse_loss elif loss in ["bce", "binary_crossentropy", "crossentropy"]: - self.criterion = F.nll_loss + self.criterion = F.cross_entropy elif loss in ["focal"]: self.criterion = FocalLoss() else: @@ -45,7 +45,7 @@ def __init__( @classmethod def from_config(cls, config): - return DeeplabV3( + return FCN( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), hidden_dim=config.get("features", 64), diff --git a/satflow/models/losses.py b/satflow/models/losses.py index 8259cc63..4ff023bb 100644 --- a/satflow/models/losses.py +++ b/satflow/models/losses.py @@ -98,3 +98,63 @@ def forward(self, logit, target): else: loss = loss.sum() return loss + + +def _unbind_images(x, dim=1): + "only unstack images" + if isinstance(x, torch.Tensor): + if len(x.shape) >= 4: + return x.unbind(dim=dim) + return x + + +class StackUnstack(nn.Module): + "Stack together inputs, apply module, unstack output" + + def __init__(self, module, dim=1): + super().__init__() + self.dim = dim + self.module = module + + @staticmethod + def unbind_images(x, dim=1): + return _unbind_images(x, dim) + + def forward(self, *args): + inputs = [torch.stack(x, dim=self.dim) for x in args] + outputs = self.module(*inputs) + if isinstance(outputs, (tuple, list)): + return [self.unbind_images(output, dim=self.dim) for output in outputs] + else: + return outputs.unbind(dim=self.dim) + + +def StackLoss(loss_func=F.mse_loss, axis=-1): + def _inner_loss(x, y): + x = torch.cat(x, axis) + y = torch.cat(y, axis) + return loss_func(x, y) + + return _inner_loss + + +class MultiImageDice: + "Dice coefficient metric for binary target in segmentation" + + def __init__(self, axis=1): + self.axis = axis + + def reset(self): + self.inter, self.union = 0, 0 + + def accumulate(self, pred, y): + x = torch.cat(pred, -1) + y = torch.cat(y, -1) + pred = x.argmax(dim=self.axis).flatten() + targ = np.flatten(y) + self.inter += (pred * targ).float().sum().item() + self.union += (pred + targ).float().sum().item() + + @property + def value(self): + return 2.0 * self.inter / self.union if self.union > 0 else None diff --git a/satflow/models/unet.py b/satflow/models/unet.py index a31b36e4..41ad7c4c 100644 --- a/satflow/models/unet.py +++ b/satflow/models/unet.py @@ -34,6 +34,7 @@ def __init__( else: raise ValueError(f"loss {loss} not recognized") self.make_vis = make_vis + self.input_channels = input_channels self.model = UNet(forecast_steps, input_channels, num_layers, hidden_dim, bilinear) self.save_hyperparameters() @@ -61,7 +62,7 @@ def training_step(self, batch, batch_idx): y_hat = self(x) if self.make_vis: - if np.random.random() < 0.01: + if np.random.random() < 0.001: self.visualize(x, y, y_hat, batch_idx) # Generally only care about the center x crop, so the model can take into account the clouds in the area without # being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels @@ -95,7 +96,7 @@ def test_step(self, batch, batch_idx): def visualize(self, x, y, y_hat, batch_idx): # the logger you used (in this case tensorboard) - tensorboard = self.logger.experiment + tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time in_image = ( x[0].cpu().detach().numpy() @@ -105,15 +106,17 @@ def visualize(self, x, y, y_hat, batch_idx): if i % self.input_channels == 0: # First one j += 1 tensorboard.add_image( - f"Input_Image_{j}_Channel_{i}", in_slice, global_step=batch_idx + f"Input_Image_{j}_Channel_{i}", + np.expand_dims(in_slice, axis=0), + global_step=batch_idx, ) # Each Channel out_image = y_hat[0].cpu().detach().numpy() for i, out_slice in enumerate(out_image): tensorboard.add_image( - f"Output_Image_{i}", out_slice, global_step=batch_idx + f"Output_Image_{i}", np.expand_dims(out_slice, axis=0), global_step=batch_idx ) # Each Channel out_image = y[0].cpu().detach().numpy() for i, out_slice in enumerate(out_image): tensorboard.add_image( - f"Target_Image_{i}", out_slice, global_step=batch_idx + f"Target_Image_{i}", np.expand_dims(out_slice, axis=0), global_step=batch_idx ) # Each Channel