-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix SeCo transforms #1324
base: main
Are you sure you want to change the base?
fix SeCo transforms #1324
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would actually disagree with all changes in this PR. You might be able to convince me to keep the clamp if you can show evidence that it actually improves downstream accuracy.
@@ -37,6 +38,9 @@ | |||
K.CenterCrop(224), | |||
K.Normalize(mean=_min, std=_max - _min), | |||
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), | |||
Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a lot of value in this. All we're doing is losing information, right? We're already scaling to the same range, do we really care if there are values less/greater than what they trained on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, I'm not agreeing with this, but these are the transforms used for SeCo. @calebrob6 and I have done some KNN experiments with SeCo and it's actually very sensitive to these transforms (significant changes in downstream performance if you don't use these). https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py#L51
@@ -37,6 +38,9 @@ | |||
K.CenterCrop(224), | |||
K.Normalize(mean=_min, std=_max - _min), | |||
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), | |||
Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), | |||
Lambda(lambda x: x.to(torch.uint8).to(torch.float)), # type: ignore[no-any-return] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually do anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it converts from float to uint8 to reduce bit resolution. The conversion back to float is so that the other kornia augmentations don't complain about it being float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like the same situation as the clip above. It reduces information from the image and makes the job harder for downstream tasks. This may help during SSL because image comparison is harder, but I have no a priori reason to believe that this would help with downstream tasks. If you perform an ablation study and find that it helps, I'm fine with adding it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you that it's bad to throw away information, but we should make another transform function if we want to remove certain pieces. Our current seco transform isn't correct which is what we are advertising by attaching it to the weights.
@@ -37,6 +38,9 @@ | |||
K.CenterCrop(224), | |||
K.Normalize(mean=_min, std=_max - _min), | |||
K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), | |||
Lambda(lambda x: torch.clamp(x, min=0.0, max=255.0)), | |||
Lambda(lambda x: x.to(torch.uint8).to(torch.float)), # type: ignore[no-any-return] | |||
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
divide by 255
Where do you see this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a multiply by 255, I don't see a divide by 255
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transforms.ToTensor(),
right before imagenet_normalization()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. In that case, I would remove all 4 lines, so we neither multiply by 255 nor divide by 255. This will make the transform faster.
Whether you agree or not with the changes, our current seco transforms are incorrect. I'm just taking what they have from their script. |
I don't actually know how to fix the mypy errors other than ignoring them.
|
In my mind, the transforms supplied with the weights aren't designed to reproduce SeCo pre-training, they're designed to use the SeCo weights for downstream tasks. If these changes don't improve performance on downstream tasks and only make transforms slower, or even harm performance on downstream tasks, then we shouldn't add them. That's why an ablation study is needed to ensure that these changes are beneficial. |
In that case then the transforms are completely wrong, because SeCo doesn't use these for downstream tasks at all. They only divide eurosat by 255, they use different BigEarthNet mean and std. They also don't resize to 224x224. I would argue that you would need to show an ablation study of why you shouldn't use the same transforms as those used during training since that's less intuitive. It's trained to take 8-bit inputs so I would expect performance would degrade if you feed it 12-bit inputs. |
Let's hold off on this PR until after the deadline so we have time to do a proper ablation study on no transforms, current transforms, and proposed transforms. |
I'm good with this just wanted to use it for comparisons to SeCo |
Added this to the 0.4.2 milestone so we don't forget about it. At the bare minimum, we should remove the multiply by 255 stuff. The rest is more controversial and will need ablation studies that we don't have time to work on until after the deadline. |
Is this still a work in progress? |
@adamjstewart I can finish this today. It fell off my radar. |
No rush. If we can squeeze it into 0.5.2, great. If not, we can continue to punt it. Just trying to finish up old PRs. |
I'm no longer opposed to the use of |
The current SeCo transforms do the following:
However in the script they actually do the following: