You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently torchvision.models.get_model() doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).
Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?
model=torchvision.models.get_model(self.model_type, weights=self.weights_backbone)
# fix the in/out features of the final layer of the classifier to match num_classes. # We have to do this after get_model() so we can retain the pre-trained weights, but # modify the model architecture for our use case.classifier_layer=model.classifierlast_layer_index=len(classifier_layer) -1original_linear_layer=classifier_layer[last_layer_index]
new_linear_layer=torch.nn.Linear(in_features=original_linear_layer.in_features, out_features=self.num_classes)
classifier_layer[last_layer_index] =new_linear_layer
Motivation, pitch
Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
that kind of model surgery is probably too specific to each model for it to be reliably implemented within get_model().
Note that some model builders allow num_classes to be passed.
🚀 The feature
Currently
torchvision.models.get_model()
doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?
Motivation, pitch
Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: