Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify transfer learning by modifying get_model() #8631

Open
david-csnmedia opened this issue Sep 3, 2024 · 1 comment
Open

Simplify transfer learning by modifying get_model() #8631

david-csnmedia opened this issue Sep 3, 2024 · 1 comment

Comments

@david-csnmedia
Copy link

🚀 The feature

Currently torchvision.models.get_model() doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).

Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?

model = torchvision.models.get_model(self.model_type, weights=self.weights_backbone)

# fix the in/out features of the final layer of the classifier to match num_classes. 
# We have to do this after get_model() so we can retain the pre-trained weights, but 
# modify the model architecture for our use case.

classifier_layer = model.classifier
last_layer_index = len(classifier_layer) - 1

original_linear_layer = classifier_layer[last_layer_index]

new_linear_layer = torch.nn.Linear(in_features=original_linear_layer.in_features, out_features=self.num_classes)
classifier_layer[last_layer_index] = new_linear_layer

Motivation, pitch

Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.

Alternatives

No response

Additional context

No response

@NicolasHug
Copy link
Member

Hi @david-csnmedia ,

that kind of model surgery is probably too specific to each model for it to be reliably implemented within get_model().
Note that some model builders allow num_classes to be passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants