Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for using the ImageNet pretrained model #146

Open
LightingMc opened this issue Jan 23, 2024 · 4 comments
Open

Code for using the ImageNet pretrained model #146

LightingMc opened this issue Jan 23, 2024 · 4 comments

Comments

@LightingMc
Copy link

LightingMc commented Jan 23, 2024

I thought would be helpful for other people. I had issues with getting the resnet used in this repo running properly, but the given weights work well with Pytorch's default resnet.

Loading weights

state_dict=torch.load("supcon_official.pth",'cpu')

Correcting the terms properly.

state_dict=state_dict['model']
new_state_dict = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
new_state_dict[k] = v
state_dict = new_state_dict
new_state_dict = {}

for k, v in state_dict.items():
k = k.replace("encoder.", "")
new_state_dict[k] = v
state_dict = new_state_dict

Using the standard pytorch resnet50

model = resnet50()
del model.fc
model.fc = nn.Identity()

Dont need this

state_dict.pop("head.0.weight", None)
state_dict.pop("head.0.bias", None)
state_dict.pop("head.2.weight", None)
state_dict.pop("head.2.bias", None)

This should do the trick

model.load_state_dict(state_dict,strict=True)

@DruncBread
Copy link

I have this problem too. Thanks for tip !
I wonder if there is difference between author's resnet and pytorch's default resnet
hope the performance is the same

@HobbitLong
Copy link
Owner

Hi,

Sorry for the confusion. The resnet (nn.Module file) used in this repo was only for CIFAR input, i.e., 32x32. The weights for ImageNet we provided here is for input of 224x224, so it can only be loaded with pytorch official definition of ResNet, which takes 224x224 as input.

Historically, we firstly release this repo for CIFAR-10/100, so we define resnet for only 32x32 input. Later on, I trained a SupCon ImageNet model with my other code base, and shared the weights in this repo. So it caused this confusion.

@LightingMc
Copy link
Author

@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.

@HobbitLong
Copy link
Owner

@HobbitLong thats what I thought as well hahaha. This code was for cifar but the weights were for imagenet.

Glad you figured it out much earlier, and thank you for sharing it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants