Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

How to load ijepa checkpoints? #50

Open
namrahrehman opened this issue Oct 10, 2023 · 6 comments
Open

How to load ijepa checkpoints? #50

namrahrehman opened this issue Oct 10, 2023 · 6 comments

Comments

@namrahrehman
Copy link

I am trying to use this model for classification of cifar10 in Google Colab. I was trying to load the model to study its layers so I cloned this repo and I am using it as follows:

from vision_transformer import vit_huge
# Initialize the ViT-H model with the specified patch size and resolution
model = vit_huge(patch_size=14, num_classes=1000)  # Adjust num_classes if needed
import torch
# Load the state dictionary from the file
state_dict = torch.load('/content/drive/MyDrive/IN1K-vit.h.14-300e.pth.tar')

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Print the layers/modules of the model for inspection
def print_model_layers(model, prefix=""):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Module):
            module_name = prefix + "." + name if prefix else name
            print(module_name)
            print_model_layers(module, prefix=module_name)

print_model_layers(model)

but I get the following error:

`RuntimeError Traceback (most recent call last)
in <cell line: 6>()
4
5 # Load the state dictionary into the model
----> 6 model.load_state_dict(state_dict)
7
8 # Print the layers/modules of the model

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
2039
2040 if len(error_msgs) > 0:
-> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2042 self.class.name, "\n\t".join(error_msgs)))
2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for VisionTransformer:
Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "blocks.4.norm1.weight", "blocks.4.norm1.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "bl...
Unexpected key(s) in state_dict: "encoder", "predictor", "opt", "scaler", "target_encoder", "epoch", "loss", "batch_size", "world_size", "lr".`

I do not understand which vit from the vision_tranformer.py I am supposed to use for the checkpoint (IN1K-vit.h.14-300e.pth.tar) because using vit_huge gives the error above.

@CUN-bjy
Copy link

CUN-bjy commented Oct 17, 2023

It seems that the checkpoint is saved as a DDP module, but you tried to load it into a pure encoder.
This can be the solution.

ckpt = torch.load(load_path, map_location=torch.device('cpu'))
pretrained_dict = ckpt['encoder']

# -- loading encoder
for k, v in pretrained_dict.items():
  encoder.state_dict()[k[len("module."):]].copy_(v)

@namrahrehman
Copy link
Author

Thank you so much @CUN-bjy, it worked!
However, I couldn't classify the images due to limited computing resources.
Thanks again for your help!

@lazarosgogos
Copy link

Hello everyone.
Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

@lange4531
Copy link

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

This would be great to have a solution on if someone has managed to get it working!

@VimukthiRandika1997
Copy link

VimukthiRandika1997 commented Apr 10, 2024

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

You can take pretrained Target Encoder and finetune on your custom datasets. But finetuning would be costly as you can see from the size of encoder: It has 32 blocks as Vit based models require lot of data to be tuned for the task at hand. Also GPU requirement is higher. One possibility would be training a MLP (1 layer, 2 layers, ....N layers) on top the encoder for task of interest.

Possible downstream tasks would be image similarity, classification, etc. Feature extraction is the main component, you can use it anywhere!

@FalsoMoralista
Copy link

FalsoMoralista commented Apr 11, 2024

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

* Take the state_dict

* Initialize the corresponding ViT (e.g. ViT-H with the `init_model` function from `src.helper.py`)

* Initialize an optimizer with `init_opt`

* Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

I have developed a fine-tuning code for the I-JEPA here very based on the ViT-MAE in order to reproduce the experiments conducted here right now it's seeming to work, as the loss is decreasing, but I'm not managing to get much reduction on the test error so I am currently investigating that. If you need help contact me on discord (at falsomoralista) or something.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants