Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load the torch models from internet if there is bad network connection? #47

Open
elenacliu opened this issue Sep 12, 2023 · 4 comments

Comments

@elenacliu
Copy link

Thank you for your great work!

When trying to extract dino features, we need to load model from github through the code:

if 'dino' in model_type:
model = torch.hub.load('facebookresearch/dino:main', model_type)
else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
temp_model = timm.create_model(model_type, pretrained=True)
model_type_dict = {
'vit_small_patch16_224': 'dino_vits16',
'vit_small_patch8_224': 'dino_vits8',
'vit_base_patch16_224': 'dino_vitb16',
'vit_base_patch8_224': 'dino_vitb8'
}
model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type])
temp_state_dict = temp_model.state_dict()
del temp_state_dict['head.weight']
del temp_state_dict['head.bias']
model.load_state_dict(temp_state_dict)
return model

but I cannot find a way to do that without available network. Do you have any alternative methods?

@elenacliu
Copy link
Author

I have tried to download the corresponding weights from github to my local machine, and scp it to the server and modify your code to:

from torchvision.models import vit_b_16

if 'dino' in model_type:
            model_path = '/path/to/ckpt/dino_vitbase16_pretrain_full_checkpoint.pth'
            if os.path.exists(model_path):
                 model = vit_b_16(pretrained=True)
                 state_dict = torch.load(model_path)
                 model.load_state_dict(state_dict)
                 model.eval()
            else: 
                 model = torch.hub.load('facebookresearch/dino:main', model_type)

But it seems that the backbone cannot match with the weights, with some keys missing.

image

@elenacliu
Copy link
Author

I also tried to use torch.hub.load(source='local'):

model_dir = '~/.cache/torch/checkpoints'  # I have a hubconf.py and the weights dino_vitbase16_pretrain_full_checkpoint.pth under the directory
model_name = 'vit_b_16'
model = torch.hub.load(model_dir, model=model_name, source='local')

which also encountered into error in dino_extractor.py:

patch_size = model.patch_embed.patch_size

AttributeError: 'VisionTransformer' object has no attribute 'patch_embed'

@kxwangzju
Copy link

Hi, I am having the same problem. Have you fixed that?

@elenacliu
Copy link
Author

Sorry about that, but I haven't fixed that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants