Skip to content

Commit

Permalink
Add TorchHubModel model wrapper (#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Nov 29, 2024
1 parent a69f9c2 commit 68eddbe
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/eva/core/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from eva.core.models.wrappers.base import BaseModel
from eva.core.models.wrappers.from_function import ModelFromFunction
from eva.core.models.wrappers.from_torchhub import TorchHubModel
from eva.core.models.wrappers.huggingface import HuggingFaceModel
from eva.core.models.wrappers.onnx import ONNXModel

__all__ = [
"BaseModel",
"ModelFromFunction",
"HuggingFaceModel",
"ModelFromFunction",
"ONNXModel",
"TorchHubModel",
]
87 changes: 87 additions & 0 deletions src/eva/core/models/wrappers/from_torchhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Model wrapper for torch.hub models."""

from typing import Any, Callable, Dict, Tuple

import torch
import torch.nn as nn
from typing_extensions import override

from eva.core.models import wrappers
from eva.core.models.wrappers import _utils


class TorchHubModel(wrappers.BaseModel):
"""Model wrapper for `torch.hub` models."""

def __init__(
self,
model_name: str,
repo_or_dir: str,
pretrained: bool = True,
checkpoint_path: str = "",
out_indices: int | Tuple[int, ...] | None = None,
norm: bool = False,
trust_repo: bool = True,
model_kwargs: Dict[str, Any] | None = None,
tensor_transforms: Callable | None = None,
) -> None:
"""Initializes the encoder.
Args:
model_name: Name of model to instantiate.
repo_or_dir: The torch.hub repository or local directory to load the model from.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
checkpoint_path: Path of checkpoint to load.
out_indices: Returns last n blocks if `int`, all if `None`, select
matching indices if sequence.
norm: Wether to apply norm layer to all intermediate features. Only
used when `out_indices` is not `None`.
trust_repo: If set to `False`, a prompt will ask the user whether the
repo should be trusted.
model_kwargs: Extra model arguments.
tensor_transforms: The transforms to apply to the output tensor
produced by the model.
"""
super().__init__(tensor_transforms=tensor_transforms)

self._model_name = model_name
self._repo_or_dir = repo_or_dir
self._pretrained = pretrained
self._checkpoint_path = checkpoint_path
self._out_indices = out_indices
self._norm = norm
self._trust_repo = trust_repo
self._model_kwargs = model_kwargs or {}

self.load_model()

@override
def load_model(self) -> None:
"""Builds and loads the torch.hub model."""
self._model: nn.Module = torch.hub.load(
repo_or_dir=self._repo_or_dir,
model=self._model_name,
trust_repo=self._trust_repo,
pretrained=self._pretrained,
**self._model_kwargs,
) # type: ignore

if self._checkpoint_path:
_utils.load_model_weights(self._model, self._checkpoint_path)

TorchHubModel.__name__ = self._model_name

@override
def model_forward(self, tensor: torch.Tensor) -> torch.Tensor:
if self._out_indices is not None:
if not hasattr(self._model, "get_intermediate_layers"):
raise ValueError(
"Only models with `get_intermediate_layers` are supported "
"when using `out_indices`."
)

return self._model.get_intermediate_layers(
tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm
)

return self._model(tensor)
4 changes: 2 additions & 2 deletions src/eva/vision/models/networks/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Vision Model Backbones API."""

from eva.vision.models.networks.backbones import pathology, timm, universal
from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model

__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"]
__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"]
5 changes: 5 additions & 0 deletions src/eva/vision/models/networks/backbones/torchhub/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""torch.hub backbones API."""

from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model

__all__ = ["torch_hub_model"]
61 changes: 61 additions & 0 deletions src/eva/vision/models/networks/backbones/torchhub/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""torch.hub backbones."""

import functools
from typing import Tuple

import torch
from loguru import logger
from torch import nn

from eva.core.models import wrappers
from eva.vision.models.networks.backbones.registry import BackboneModelRegistry

HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"]
"""List of torch.hub repositories for which to add the models to the registry."""


def torch_hub_model(
model_name: str,
repo_or_dir: str,
checkpoint_path: str | None = None,
pretrained: bool = False,
out_indices: int | Tuple[int, ...] | None = None,
**kwargs,
) -> nn.Module:
"""Initializes any ViT model from torch.hub with weights from a specified checkpoint.
Args:
model_name: The name of the model to load.
repo_or_dir: The torch.hub repository or local directory to load the model from.
checkpoint_path: The path to the checkpoint file.
pretrained: If set to `True`, load pretrained model weights if available.
out_indices: Whether and which multi-level patch embeddings to return.
**kwargs: Additional arguments to pass to the model
Returns:
The VIT model instance.
"""
logger.info(
f"Loading torch.hub model {model_name} from {repo_or_dir}"
+ (f"using checkpoint {checkpoint_path}" if checkpoint_path else "")
)

return wrappers.TorchHubModel(
model_name=model_name,
repo_or_dir=repo_or_dir,
pretrained=pretrained,
checkpoint_path=checkpoint_path or "",
out_indices=out_indices,
model_kwargs=kwargs,
)


BackboneModelRegistry._registry.update(
{
f"torchhub/{repo}:{model_name}": functools.partial(
torch_hub_model, model_name=model_name, repo_or_dir=repo
)
for repo in HUB_REPOS
for model_name in torch.hub.list(repo, verbose=False)
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc
"""
if isinstance(features, torch.Tensor):
features = [features]
if not isinstance(features, list) or features[0].ndim != 4:
if not isinstance(features, (list, tuple)) or features[0].ndim != 4:
raise ValueError(
"Input features should be a list of four (4) dimensional inputs of "
"shape (batch_size, hidden_size, n_patches_height, n_patches_width)."
Expand Down
2 changes: 1 addition & 1 deletion src/eva/vision/models/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from eva.vision.models.wrappers.from_registry import ModelFromRegistry
from eva.vision.models.wrappers.from_timm import TimmModel

__all__ = ["TimmModel", "ModelFromRegistry"]
__all__ = ["ModelFromRegistry", "TimmModel"]
76 changes: 76 additions & 0 deletions tests/eva/core/models/wrappers/test_from_torchub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""TorchHubModel tests."""

from typing import Any, Dict, Tuple

import pytest
import torch

from eva.core.models import wrappers


@pytest.mark.parametrize(
"model_name, repo_or_dir, out_indices, model_kwargs, "
"input_tensor, expected_len, expected_shape",
[
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
None,
None,
torch.Tensor(2, 3, 224, 224),
None,
torch.Size([2, 384]),
),
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
1,
None,
torch.Tensor(2, 3, 224, 224),
1,
torch.Size([2, 384, 16, 16]),
),
(
"dinov2_vits14",
"facebookresearch/dinov2:main",
3,
None,
torch.Tensor(2, 3, 224, 224),
3,
torch.Size([2, 384, 16, 16]),
),
],
)
def test_torchhub_model(
torchhub_model: wrappers.TorchHubModel,
input_tensor: torch.Tensor,
expected_len: int | None,
expected_shape: torch.Size,
) -> None:
"""Tests the torch.hub model wrapper."""
outputs = torchhub_model(input_tensor)
if torchhub_model._out_indices is not None:
assert isinstance(outputs, list) or isinstance(outputs, tuple)
assert len(outputs) == expected_len
assert isinstance(outputs[0], torch.Tensor)
assert outputs[0].shape == expected_shape
else:
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == expected_shape


@pytest.fixture(scope="function")
def torchhub_model(
model_name: str,
repo_or_dir: str,
out_indices: int | Tuple[int, ...] | None,
model_kwargs: Dict[str, Any] | None,
) -> wrappers.TorchHubModel:
"""TorchHubModel fixture."""
return wrappers.TorchHubModel(
model_name=model_name,
repo_or_dir=repo_or_dir,
out_indices=out_indices,
model_kwargs=model_kwargs,
pretrained=False,
)

0 comments on commit 68eddbe

Please sign in to comment.