From aef0b237513d2a274d59c5e222b9de3105c24734 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 21:36:48 +0000 Subject: [PATCH 1/6] fix SeCo transforms --- torchgeo/models/resnet.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index bd785e85f85..4d275ad71a1 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -8,6 +8,7 @@ import kornia.augmentation as K import timm import torch +from kornia.contrib import Lambda from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum @@ -37,6 +38,8 @@ K.CenterCrop(224), K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), + Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=["image"], ) From 4233b0b6864543e5a9d49b204cd78365bad0bded Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 21:39:20 +0000 Subject: [PATCH 2/6] fix SeCo transforms --- torchgeo/models/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 4d275ad71a1..c45c2f180ed 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -39,6 +39,7 @@ K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), + Lambda(lambda x: x.to(torch.uint8)) K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=["image"], From 1228885d1dea64b6b95d9abaab8251fa0851edb6 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 21:39:34 +0000 Subject: [PATCH 3/6] fix SeCo transforms x 2 --- torchgeo/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index c45c2f180ed..83e8d0c364a 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -39,7 +39,7 @@ K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), - Lambda(lambda x: x.to(torch.uint8)) + Lambda(lambda x: x.to(torch.uint8)), K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=["image"], From e552b1a82c9b85235117e3cd4ac0df721ccc9303 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 21:59:19 +0000 Subject: [PATCH 4/6] convert back to float or else kornia complains --- torchgeo/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 83e8d0c364a..71c7a2f38eb 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -39,7 +39,7 @@ K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), - Lambda(lambda x: x.to(torch.uint8)), + Lambda(lambda x: x.to(torch.uint8).to(torch.float)), K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=["image"], From 7bb32ccab50c9989781d3025fb45af3f819b3735 Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 22:20:47 +0000 Subject: [PATCH 5/6] fix mypy --- torchgeo/models/resnet.py | 2 +- torchgeo/transforms/transforms.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 71c7a2f38eb..ee0c73fb326 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -39,7 +39,7 @@ K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), - Lambda(lambda x: x.to(torch.uint8).to(torch.float)), + Lambda(lambda x: x.to(torch.uint8).to(torch.float)), # type: ignore[no-any-return] K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.Normalize(mean=_mean, std=_std), data_keys=["image"], diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index c45f20127e5..b0b03b56092 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,6 +8,7 @@ import kornia.augmentation as K import torch from einops import rearrange +from kornia.contrib import Lambda from kornia.geometry import crop_by_indices from torch import Tensor from torch.nn.modules import Module @@ -23,7 +24,7 @@ class AugmentationSequential(Module): def __init__( self, - *args: Union[K.base._AugmentationBase, K.ImageSequential], + *args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda], data_keys: list[str], ) -> None: """Initialize a new augmentation sequential instance. From b63051b88e40ab4a5fe28afd6a29cf31adf27eaf Mon Sep 17 00:00:00 2001 From: isaaccorley <22203655+isaaccorley@users.noreply.github.com> Date: Thu, 11 May 2023 22:58:11 +0000 Subject: [PATCH 6/6] fix mypy try #2 --- torchgeo/transforms/transforms.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index b0b03b56092..8a03a59c6a8 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -45,7 +45,9 @@ def __init__( else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys) + self.augs = K.AugmentationSequential( + *args, data_keys=keys # type: ignore[arg-type] + ) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Perform augmentations and update data dict.