From 68eddbe9f5a778010f4d10cf3df95dcd6404f05f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20K=C3=A4nzig?= <36882833+nkaenzig@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:17:54 +0100 Subject: [PATCH] Add `TorchHubModel` model wrapper (#721) --- src/eva/core/models/wrappers/__init__.py | 4 +- src/eva/core/models/wrappers/from_torchhub.py | 87 +++++++++++++++++++ .../models/networks/backbones/__init__.py | 4 +- .../networks/backbones/torchhub/__init__.py | 5 ++ .../networks/backbones/torchhub/backbones.py | 61 +++++++++++++ .../decoders/segmentation/decoder2d.py | 2 +- src/eva/vision/models/wrappers/__init__.py | 2 +- .../core/models/wrappers/test_from_torchub.py | 76 ++++++++++++++++ 8 files changed, 236 insertions(+), 5 deletions(-) create mode 100644 src/eva/core/models/wrappers/from_torchhub.py create mode 100644 src/eva/vision/models/networks/backbones/torchhub/__init__.py create mode 100644 src/eva/vision/models/networks/backbones/torchhub/backbones.py create mode 100644 tests/eva/core/models/wrappers/test_from_torchub.py diff --git a/src/eva/core/models/wrappers/__init__.py b/src/eva/core/models/wrappers/__init__.py index 95ab6101..979577bd 100644 --- a/src/eva/core/models/wrappers/__init__.py +++ b/src/eva/core/models/wrappers/__init__.py @@ -2,12 +2,14 @@ from eva.core.models.wrappers.base import BaseModel from eva.core.models.wrappers.from_function import ModelFromFunction +from eva.core.models.wrappers.from_torchhub import TorchHubModel from eva.core.models.wrappers.huggingface import HuggingFaceModel from eva.core.models.wrappers.onnx import ONNXModel __all__ = [ "BaseModel", - "ModelFromFunction", "HuggingFaceModel", + "ModelFromFunction", "ONNXModel", + "TorchHubModel", ] diff --git a/src/eva/core/models/wrappers/from_torchhub.py b/src/eva/core/models/wrappers/from_torchhub.py new file mode 100644 index 00000000..cb424d01 --- /dev/null +++ b/src/eva/core/models/wrappers/from_torchhub.py @@ -0,0 +1,87 @@ +"""Model wrapper for torch.hub models.""" + +from typing import Any, Callable, Dict, Tuple + +import torch +import torch.nn as nn +from typing_extensions import override + +from eva.core.models import wrappers +from eva.core.models.wrappers import _utils + + +class TorchHubModel(wrappers.BaseModel): + """Model wrapper for `torch.hub` models.""" + + def __init__( + self, + model_name: str, + repo_or_dir: str, + pretrained: bool = True, + checkpoint_path: str = "", + out_indices: int | Tuple[int, ...] | None = None, + norm: bool = False, + trust_repo: bool = True, + model_kwargs: Dict[str, Any] | None = None, + tensor_transforms: Callable | None = None, + ) -> None: + """Initializes the encoder. + + Args: + model_name: Name of model to instantiate. + repo_or_dir: The torch.hub repository or local directory to load the model from. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + checkpoint_path: Path of checkpoint to load. + out_indices: Returns last n blocks if `int`, all if `None`, select + matching indices if sequence. + norm: Wether to apply norm layer to all intermediate features. Only + used when `out_indices` is not `None`. + trust_repo: If set to `False`, a prompt will ask the user whether the + repo should be trusted. + model_kwargs: Extra model arguments. + tensor_transforms: The transforms to apply to the output tensor + produced by the model. + """ + super().__init__(tensor_transforms=tensor_transforms) + + self._model_name = model_name + self._repo_or_dir = repo_or_dir + self._pretrained = pretrained + self._checkpoint_path = checkpoint_path + self._out_indices = out_indices + self._norm = norm + self._trust_repo = trust_repo + self._model_kwargs = model_kwargs or {} + + self.load_model() + + @override + def load_model(self) -> None: + """Builds and loads the torch.hub model.""" + self._model: nn.Module = torch.hub.load( + repo_or_dir=self._repo_or_dir, + model=self._model_name, + trust_repo=self._trust_repo, + pretrained=self._pretrained, + **self._model_kwargs, + ) # type: ignore + + if self._checkpoint_path: + _utils.load_model_weights(self._model, self._checkpoint_path) + + TorchHubModel.__name__ = self._model_name + + @override + def model_forward(self, tensor: torch.Tensor) -> torch.Tensor: + if self._out_indices is not None: + if not hasattr(self._model, "get_intermediate_layers"): + raise ValueError( + "Only models with `get_intermediate_layers` are supported " + "when using `out_indices`." + ) + + return self._model.get_intermediate_layers( + tensor, self._out_indices, reshape=True, return_class_token=False, norm=self._norm + ) + + return self._model(tensor) diff --git a/src/eva/vision/models/networks/backbones/__init__.py b/src/eva/vision/models/networks/backbones/__init__.py index 0fdf2963..1ef7bc85 100644 --- a/src/eva/vision/models/networks/backbones/__init__.py +++ b/src/eva/vision/models/networks/backbones/__init__.py @@ -1,6 +1,6 @@ """Vision Model Backbones API.""" -from eva.vision.models.networks.backbones import pathology, timm, universal +from eva.vision.models.networks.backbones import pathology, timm, torchhub, universal from eva.vision.models.networks.backbones.registry import BackboneModelRegistry, register_model -__all__ = ["pathology", "timm", "universal", "BackboneModelRegistry", "register_model"] +__all__ = ["pathology", "timm", "torchhub", "universal", "BackboneModelRegistry", "register_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/__init__.py b/src/eva/vision/models/networks/backbones/torchhub/__init__.py new file mode 100644 index 00000000..6acd9797 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/__init__.py @@ -0,0 +1,5 @@ +"""torch.hub backbones API.""" + +from eva.vision.models.networks.backbones.torchhub.backbones import torch_hub_model + +__all__ = ["torch_hub_model"] diff --git a/src/eva/vision/models/networks/backbones/torchhub/backbones.py b/src/eva/vision/models/networks/backbones/torchhub/backbones.py new file mode 100644 index 00000000..d1503a80 --- /dev/null +++ b/src/eva/vision/models/networks/backbones/torchhub/backbones.py @@ -0,0 +1,61 @@ +"""torch.hub backbones.""" + +import functools +from typing import Tuple + +import torch +from loguru import logger +from torch import nn + +from eva.core.models import wrappers +from eva.vision.models.networks.backbones.registry import BackboneModelRegistry + +HUB_REPOS = ["facebookresearch/dinov2:main", "kaiko-ai/towards_large_pathology_fms"] +"""List of torch.hub repositories for which to add the models to the registry.""" + + +def torch_hub_model( + model_name: str, + repo_or_dir: str, + checkpoint_path: str | None = None, + pretrained: bool = False, + out_indices: int | Tuple[int, ...] | None = None, + **kwargs, +) -> nn.Module: + """Initializes any ViT model from torch.hub with weights from a specified checkpoint. + + Args: + model_name: The name of the model to load. + repo_or_dir: The torch.hub repository or local directory to load the model from. + checkpoint_path: The path to the checkpoint file. + pretrained: If set to `True`, load pretrained model weights if available. + out_indices: Whether and which multi-level patch embeddings to return. + **kwargs: Additional arguments to pass to the model + + Returns: + The VIT model instance. + """ + logger.info( + f"Loading torch.hub model {model_name} from {repo_or_dir}" + + (f"using checkpoint {checkpoint_path}" if checkpoint_path else "") + ) + + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + pretrained=pretrained, + checkpoint_path=checkpoint_path or "", + out_indices=out_indices, + model_kwargs=kwargs, + ) + + +BackboneModelRegistry._registry.update( + { + f"torchhub/{repo}:{model_name}": functools.partial( + torch_hub_model, model_name=model_name, repo_or_dir=repo + ) + for repo in HUB_REPOS + for model_name in torch.hub.list(repo, verbose=False) + } +) diff --git a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py index c43b351c..ce242713 100644 --- a/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py +++ b/src/eva/vision/models/networks/decoders/segmentation/decoder2d.py @@ -52,7 +52,7 @@ def _forward_features(self, features: torch.Tensor | List[torch.Tensor]) -> torc """ if isinstance(features, torch.Tensor): features = [features] - if not isinstance(features, list) or features[0].ndim != 4: + if not isinstance(features, (list, tuple)) or features[0].ndim != 4: raise ValueError( "Input features should be a list of four (4) dimensional inputs of " "shape (batch_size, hidden_size, n_patches_height, n_patches_width)." diff --git a/src/eva/vision/models/wrappers/__init__.py b/src/eva/vision/models/wrappers/__init__.py index 14d63b68..d2f84de4 100644 --- a/src/eva/vision/models/wrappers/__init__.py +++ b/src/eva/vision/models/wrappers/__init__.py @@ -3,4 +3,4 @@ from eva.vision.models.wrappers.from_registry import ModelFromRegistry from eva.vision.models.wrappers.from_timm import TimmModel -__all__ = ["TimmModel", "ModelFromRegistry"] +__all__ = ["ModelFromRegistry", "TimmModel"] diff --git a/tests/eva/core/models/wrappers/test_from_torchub.py b/tests/eva/core/models/wrappers/test_from_torchub.py new file mode 100644 index 00000000..bf275234 --- /dev/null +++ b/tests/eva/core/models/wrappers/test_from_torchub.py @@ -0,0 +1,76 @@ +"""TorchHubModel tests.""" + +from typing import Any, Dict, Tuple + +import pytest +import torch + +from eva.core.models import wrappers + + +@pytest.mark.parametrize( + "model_name, repo_or_dir, out_indices, model_kwargs, " + "input_tensor, expected_len, expected_shape", + [ + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + None, + None, + torch.Tensor(2, 3, 224, 224), + None, + torch.Size([2, 384]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 1, + None, + torch.Tensor(2, 3, 224, 224), + 1, + torch.Size([2, 384, 16, 16]), + ), + ( + "dinov2_vits14", + "facebookresearch/dinov2:main", + 3, + None, + torch.Tensor(2, 3, 224, 224), + 3, + torch.Size([2, 384, 16, 16]), + ), + ], +) +def test_torchhub_model( + torchhub_model: wrappers.TorchHubModel, + input_tensor: torch.Tensor, + expected_len: int | None, + expected_shape: torch.Size, +) -> None: + """Tests the torch.hub model wrapper.""" + outputs = torchhub_model(input_tensor) + if torchhub_model._out_indices is not None: + assert isinstance(outputs, list) or isinstance(outputs, tuple) + assert len(outputs) == expected_len + assert isinstance(outputs[0], torch.Tensor) + assert outputs[0].shape == expected_shape + else: + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == expected_shape + + +@pytest.fixture(scope="function") +def torchhub_model( + model_name: str, + repo_or_dir: str, + out_indices: int | Tuple[int, ...] | None, + model_kwargs: Dict[str, Any] | None, +) -> wrappers.TorchHubModel: + """TorchHubModel fixture.""" + return wrappers.TorchHubModel( + model_name=model_name, + repo_or_dir=repo_or_dir, + out_indices=out_indices, + model_kwargs=model_kwargs, + pretrained=False, + )