Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix SeCo transforms #1324

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import kornia.augmentation as K
import timm
import torch
from kornia.contrib import Lambda
from timm.models import ResNet
from torchvision.models._api import Weights, WeightsEnum

Expand Down Expand Up @@ -37,6 +38,9 @@
K.CenterCrop(224),
K.Normalize(mean=_min, std=_max - _min),
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)),
Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a lot of value in this. All we're doing is losing information, right? We're already scaling to the same range, do we really care if there are values less/greater than what they trained on?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I'm not agreeing with this, but these are the transforms used for SeCo. @calebrob6 and I have done some KNN experiments with SeCo and it's actually very sensitive to these transforms (significant changes in downstream performance if you don't use these). https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py#L51

Lambda(lambda x: x.to(torch.uint8).to(torch.float)), # type: ignore[no-any-return]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually do anything?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it converts from float to uint8 to reduce bit resolution. The conversion back to float is so that the other kornia augmentations don't complain about it being float

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like the same situation as the clip above. It reduces information from the image and makes the job harder for downstream tasks. This may help during SSL because image comparison is harder, but I have no a priori reason to believe that this would help with downstream tasks. If you perform an ablation study and find that it helps, I'm fine with adding it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that it's bad to throw away information, but we should make another transform function if we want to remove certain pieces. Our current seco transform isn't correct which is what we are advertising by attaching it to the weights.

K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
Copy link
Collaborator

@adamjstewart adamjstewart May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

divide by 255

Where do you see this?

Copy link
Collaborator Author

@isaaccorley isaaccorley May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a multiply by 255, I don't see a divide by 255

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transforms.ToTensor(), right before imagenet_normalization()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. In that case, I would remove all 4 lines, so we neither multiply by 255 nor divide by 255. This will make the transform faster.

K.Normalize(mean=_mean, std=_std),
data_keys=["image"],
)
Expand Down
7 changes: 5 additions & 2 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.contrib import Lambda
from kornia.geometry import crop_by_indices
from torch import Tensor
from torch.nn.modules import Module
Expand All @@ -23,7 +24,7 @@ class AugmentationSequential(Module):

def __init__(
self,
*args: Union[K.base._AugmentationBase, K.ImageSequential],
*args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda],
data_keys: list[str],
) -> None:
"""Initialize a new augmentation sequential instance.
Expand All @@ -44,7 +45,9 @@ def __init__(
else:
keys.append(key)

self.augs = K.AugmentationSequential(*args, data_keys=keys)
self.augs = K.AugmentationSequential(
*args, data_keys=keys # type: ignore[arg-type]
)

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.
Expand Down