Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Model updates/CoordConv/Anti-aliased training #38

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions satflow/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:
- model: convlstm_model.yaml
- datamodule: satflow_datamodule.yaml
- callbacks: default.yaml # set this to null if you don't want to use callbacks
- logger: tensorboard # set logger here or use command line (e.g. `python run.py logger=wandb`)
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)

- experiment: null
- hparams_search: null
Expand All @@ -24,7 +24,7 @@ defaults:
work_dir: ${hydra:runtime.cwd}

# path to folder with data
data_dir: /run/media/jacob/T7/
data_dir: /run/media/bieker/data/EUMETSAT/

# use `python run.py debug=true` for easy debugging!
# this will run 1 train, val and test loop with only 1 batch
Expand Down
24 changes: 13 additions & 11 deletions satflow/configs/datamodule/segmentation_datamodule.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
# @package _group_
_target_: satflow.data.datamodules.SatFlowDataModule
_target_: satflow.data.datamodules.SegFlowDataModule

batch_size: 1
batch_size: 16
data_dir: ${data_dir} # data_dir is specified in config.yaml
shuffle: 0
sources:
train: "satflow-flow-144-tiled-{00001..00105}.tar"
val: "satflow-flow-144-tiled-{00106..00129}.tar"
test: "satflow-flow-144-tiled-{00130..00149}.tar"
train: "satflow-flow-144-tiled-{00001..00105}.tar" # 2020
val: "satflow-flow-144-tiled-{00106..00129}.tar" # 2021
test: "satflow-flow-144-tiled-{00130..00149}.tar" # 2021
num_workers: 12
pin_memory: True
config:
visualize: False
num_timesteps: 6
num_timesteps: 0
skip_timesteps: 1
forecast_times: 6
output_shape: 128
forecast_times: 2
output_shape: 256
target_type: "cloudmask"
num_crops: 10
num_times: 10
num_crops: 5
num_times: 20
use_topo: True
use_latlon: True
use_latlon: False
use_time: False
time_aux: False
use_mask: False
use_image: False
time_as_channels: True
add_pixel_coords: False
add_polar_coords: False
bands:
[
"HRV",
Expand Down
32 changes: 32 additions & 0 deletions satflow/configs/datamodule/unet_datamodule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# @package _group_
_target_: satflow.data.datamodules.SatFlowDataModule

batch_size: 8
data_dir: ${data_dir} # data_dir is specified in config.yaml
shuffle: 0
sources:
train: "satflow-flow-144-tiled-{00001..00105}.tar"
val: "satflow-flow-144-tiled-{00106..00129}.tar"
test: "satflow-flow-144-tiled-{00130..00149}.tar"
num_workers: 12
pin_memory: True
config:
visualize: False
num_timesteps: 10
skip_timesteps: 1
forecast_times: 20
output_shape: 128
target_type: "cloudmask"
num_crops: 10
num_times: 15
use_topo: False
use_latlon: False
use_time: False
time_aux: False
use_mask: True
use_image: False
time_as_channels: True
add_pixel_coords: True
add_polar_coords: False
bands: ["IR016"]
transforms: {}
2 changes: 1 addition & 1 deletion satflow/configs/model/deeplabv3_r50_model.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# @package _group_
_target_: satflow.models.deeplabv3.DeepLabV3
forecast_steps: 6
forecast_steps: 2
input_channels: 12
lr: 0.001
make_vis: False
Expand Down
3 changes: 1 addition & 2 deletions satflow/configs/model/fcn_r50_model.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# @package _group_
_target_: satflow.models.fcn.FCN
forecast_steps: 6
forecast_steps: 1
input_channels: 12
lr: 0.001
make_vis: False
loss: "bce"
backbone: "resnet50"
pretrained: False
aux_loss: False
4 changes: 4 additions & 0 deletions satflow/configs/trainer/minimal.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ auto_lr_find: True
auto_scale_batch_size: False
reload_dataloaders_every_epoch: True

accumulate_grad_batches: 1
precision: 32
# stochastic_weight_avg: True

weights_summary: null
progress_bar_refresh_rate: 10
# resume_from_checkpoint: null
67 changes: 66 additions & 1 deletion satflow/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.utils.data import DataLoader
from typing import Optional
import webdataset as wds
from satflow.data.datasets import SatFlowDataset, CloudFlowDataset
from satflow.data.datasets import SatFlowDataset, CloudFlowDataset, SegFlowDataset
import os


Expand Down Expand Up @@ -152,3 +152,68 @@ def test_dataloader(self):
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)


class SegFlowDataModule(pl.LightningDataModule):
def __init__(
self,
config: dict,
sources: dict,
batch_size: int = 2,
shuffle: int = 0,
data_dir: str = "./",
num_workers: int = 1,
pin_memory: bool = True,
):
super().__init__()
self.data_dir = data_dir
self.config = config
self.batch_size = batch_size
self.shuffle = shuffle
self.sources = sources
self.num_workers = num_workers
self.pin_memory = pin_memory

def prepare_data(self):
# download
pass

def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
train_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["train"]))
val_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["val"]))
if self.shuffle > 0:
# Add shuffling, each sample is still quite large, so too many examples ends up running out of ram
train_dset = train_dset.shuffle(self.shuffle)
self.train_dataset = SegFlowDataset([train_dset], config=self.config, train=True)
self.val_dataset = SegFlowDataset([val_dset], config=self.config, train=False)

# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
test_dset = wds.WebDataset(os.path.join(self.data_dir, self.sources["test"]))
self.test_dataset = SegFlowDataset([test_dset], config=self.config, train=False)

def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)

def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)

def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
pin_memory=self.pin_memory,
num_workers=self.num_workers,
)
10 changes: 6 additions & 4 deletions satflow/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ def create_pixel_coord_layers(x_dim: int, y_dim: int, with_r: bool = False) -> n

xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
ret = np.stack([xx_channel, yy_channel], axis=0)

ret = np.concatenate((xx_channel, yy_channel), axis=-1)
if with_r:
rr = np.sqrt(np.square(xx_channel - 0.5) + np.square(yy_channel - 0.5))
ret = np.concatenate([ret, np.expand_dims(rr, axis=0)], axis=0)
Expand Down Expand Up @@ -582,10 +581,13 @@ def __iter__(self) -> Iterator[T_co]:
# but could be interpolated between the previous step and next one by weighting by time difference
# Topographic is same of course, just need to resize to 1km x 1km?
# grid by taking the mean value of the interior ones
sources = [iter(ds) for ds in self.datasets]
while True:
sources = [iter(ds) for ds in self.datasets]
for source in sources:
sample = next(source)
try:
sample = next(source)
except StopIteration:
continue
timesteps = pickle.loads(sample["time.pyd"])
available_steps = len(timesteps) # number of available timesteps
# Check to make sure all timesteps exist
Expand Down
10 changes: 7 additions & 3 deletions satflow/examples/create_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
eumetsat_dir = "/run/media/bieker/Round1/EUMETSAT/"


def make_day(data, flow=True, batch=144, tile=True):
def make_day(data, flow=True, batch=4, tile=True):
root_dir, shard_num = data
# reset optical flow samples
flow_sample = {}
Expand All @@ -78,11 +78,11 @@ def make_day(data, flow=True, batch=144, tile=True):
shard_num += 1
interday_frame = 0
if os.path.exists(
f"/run/media/bieker/data/EUMETSAT/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar"
f"/home/bieker/Development/satflow/datasets/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar"
):
return
sink_flow = wds.TarWriter(
f"/run/media/bieker/data/EUMETSAT/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar",
f"/home/bieker/Development/satflow/datasets/satflow{'-' if not flow else '-flow'}{'-' if not flow and batch > 0 else f'-{batch}-'}{'tiled-' if tile else ''}{shard_num:05d}.tar",
compress=True,
)
for root, dirs, files in os.walk(root_dir):
Expand Down Expand Up @@ -184,6 +184,9 @@ def make_day(data, flow=True, batch=144, tile=True):
for i in range(8)
]
batch_num += 1
if batch_num >= 1:
sink_flow.close()
return
else:
flow_sample["time.pyd"] = datetime_object
sink_flow.write(flow_sample)
Expand Down Expand Up @@ -229,3 +232,4 @@ def make_day(data, flow=True, batch=144, tile=True):
# exit()
for data in all_dates:
make_day(data)
exit()
2 changes: 2 additions & 0 deletions satflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .conv_lstm import EncoderDecoderConvLSTM
from .metnet import MetNet
from .predrnn import PredRNN
from .deeplabv3 import DeepLabV3
from .fcn import FCN
8 changes: 6 additions & 2 deletions satflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,22 @@ def create_model(model_name, pretrained=False, checkpoint_path="", **kwargs):
global_pool (str): global pool type (default: 'avg')
input_channels (int): number of input channels (default: 12)
forecast_steps (int): number of steps to forecast (default: 48)
optimizer (str): optimizer (default: 'adam')
lr (float): learning rate (default: 0.001)
lr_scheduler (str): learning rate scheduler (default: '')
lr_scheduler_**: lr_scheduler specific arguments
**: other kwargs are model specific
"""
source_name, model_name = split_model_name(model_name)

# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
lr_kwargs = {k: v for k, v in kwargs.items() if v is not None and "lr_scheduler_" in k}
kwargs = {k: v for k, v in kwargs.items() if v is not None and "lr_scheduler_" not in k}

if model_name in REGISTERED_MODELS:
model = get_model(model_name)(pretrained=pretrained, **kwargs)
model = get_model(model_name)(pretrained=pretrained, lr_scheduler=lr_kwargs, **kwargs)
else:
raise RuntimeError("Unknown model (%s)" % model_name)

Expand Down
30 changes: 22 additions & 8 deletions satflow/models/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@register_model
class DeeplabV3(pl.LightningModule):
class DeepLabV3(pl.LightningModule):
def __init__(
self,
forecast_steps: int = 48,
Expand All @@ -21,13 +21,14 @@ def __init__(
pretrained: bool = False,
aux_loss: bool = False,
):
super(DeeplabV3, self).__init__()
super(DeepLabV3, self).__init__()
self.lr = lr
self.forecast_steps = forecast_steps
assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"]
if loss == "mse":
self.criterion = F.mse_loss
elif loss in ["bce", "binary_crossentropy", "crossentropy"]:
self.criterion = F.nll_loss
self.criterion = F.cross_entropy
elif loss in ["focal"]:
self.criterion = FocalLoss()
else:
Expand All @@ -50,7 +51,7 @@ def __init__(

@classmethod
def from_config(cls, config):
return DeeplabV3(
return DeepLabV3(
forecast_steps=config.get("forecast_steps", 12),
input_channels=config.get("in_channels", 12),
hidden_dim=config.get("features", 64),
Expand All @@ -69,27 +70,40 @@ def configure_optimizers(self):

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)

y_hat = self(x)["out"]
y = y.long()
if self.make_vis:
if np.random.random() < 0.01:
self.visualize(x, y, y_hat, batch_idx)
# Generally only care about the center x crop, so the model can take into account the clouds in the area without
# being penalized for that, but for now, just do general MSE loss, also only care about first 12 channels

# loss = 0
# for f_step in range(self.forecast_steps):
# loss += self.criterion(y_hat, y[:,f_step,:,:])
# loss /= self.forecast_steps
loss = self.criterion(y_hat, y)
self.log("train/loss", loss, on_step=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
y = y.long()
y_hat = self(x)["out"]

# Loss is then the loss for each input timestep in the y_hat, so have to break it up per forecase_step
# val_loss = 0
# for f_step in range(self.forecast_steps):
# val_loss += self.criterion(y_hat, y[:,f_step,:,:])
# val_loss /= self.forecast_steps
val_loss = self.criterion(y_hat, y)
self.log("val/loss", val_loss, on_step=True, on_epoch=True)
return val_loss

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x, self.forecast_steps)
y = y.long()
y_hat = self(x)["out"]
loss = self.criterion(y_hat, y)
return loss

Expand Down
4 changes: 2 additions & 2 deletions satflow/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
if loss == "mse":
self.criterion = F.mse_loss
elif loss in ["bce", "binary_crossentropy", "crossentropy"]:
self.criterion = F.nll_loss
self.criterion = F.cross_entropy
elif loss in ["focal"]:
self.criterion = FocalLoss()
else:
Expand All @@ -45,7 +45,7 @@ def __init__(

@classmethod
def from_config(cls, config):
return DeeplabV3(
return FCN(
forecast_steps=config.get("forecast_steps", 12),
input_channels=config.get("in_channels", 12),
hidden_dim=config.get("features", 64),
Expand Down
Loading