Skip to content

Commit

Permalink
Add Virchow2 model to registry (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored Oct 23, 2024
1 parent 095ef7b commit 19d0043
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 32 deletions.
17 changes: 16 additions & 1 deletion docs/user-guide/advanced/replicate_evaluations.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ IN_FEATURES=1024 \
eva predict_fit --config configs/vision/pathology/offline/<task>.yaml
```

### Virchow2 (paige.ai) - DINOv2 ViT-H14 (3.1M Slides) [[8]](#references)
To evaluate [paige.ai's](https://www.paige.ai/) FM with DINOv2 ViT-H14 backbone, pretrained on
a proprietary dataset of 3.1M million slides, available for download on
[HuggingFace](https://huggingface.co/paige-ai/Virchow2), run:

```
MODEL_NAME=paige/virchow2 \
NORMALIZE_MEAN="[0.485,0.456,0.406]" \
NORMALIZE_STD="[0.229,0.224,0.225]" \
IN_FEATURES=1280 \
eva predict_fit --config configs/vision/pathology/offline/<task>.yaml
```


## References

Expand All @@ -219,4 +232,6 @@ eva predict_fit --config configs/vision/pathology/offline/<task>.yaml

[6]: Xu, Hanwen, et al. "A whole-slide foundation model for digital pathology from real-world data." Nature (2024): 1-8.

[7]: Nechaev, Dmitry, Alexey Pchelnikov, and Ekaterina Ivanova. "Hibou: A Family of Foundational Vision Transformers for Pathology." arXiv preprint arXiv:2406.05074 (2024).
[7]: Nechaev, Dmitry, Alexey Pchelnikov, and Ekaterina Ivanova. "Hibou: A Family of Foundational Vision Transformers for Pathology." arXiv preprint arXiv:2406.05074 (2024).

[8]: Zimmermann, Eric, et al. "Virchow 2: Scaling Self-Supervised Mixed Magnification Models in Pathology." arXiv preprint arXiv:2408.00738 (2024).
26 changes: 17 additions & 9 deletions src/eva/core/models/transforms/extract_cls_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
class ExtractCLSFeatures:
"""Extracts the CLS token from a ViT model output."""

def __init__(self, cls_index: int = 0) -> None:
def __init__(
self, cls_index: int = 0, num_register_tokens: int = 0, include_patch_tokens: bool = False
) -> None:
"""Initializes the transformation.
Args:
cls_index: The index of the CLS token in the output tensor.
num_register_tokens: The number of register tokens in the model output.
include_patch_tokens: Whether to concat the mean aggregated patch tokens with
the cls token.
"""
self._cls_index = cls_index
self._num_register_tokens = num_register_tokens
self._include_patch_tokens = include_patch_tokens

def __call__(
self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling
Expand All @@ -23,11 +30,12 @@ def __call__(
Args:
tensor: The tensor representing the model output.
"""
if isinstance(tensor, torch.Tensor):
transformed_tensor = tensor[:, self._cls_index, :]
elif isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
transformed_tensor = tensor.last_hidden_state[:, self._cls_index, :]
else:
raise ValueError(f"Unsupported type {type(tensor)}")

return transformed_tensor
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
tensor = tensor.last_hidden_state

cls_token = tensor[:, self._cls_index, :]
if self._include_patch_tokens:
patch_tokens = tensor[:, 1 + self._num_register_tokens :, :]
return torch.cat([cls_token, patch_tokens.mean(1)], dim=-1)

return cls_token
34 changes: 23 additions & 11 deletions src/eva/core/models/transforms/extract_patch_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,23 @@
class ExtractPatchFeatures:
"""Extracts the patch features from a ViT model output."""

def __init__(self, ignore_remaining_dims: bool = False) -> None:
def __init__(
self,
has_cls_token: bool = True,
num_register_tokens: int = 0,
ignore_remaining_dims: bool = False,
) -> None:
"""Initializes the transformation.
Args:
has_cls_token: If set to `True`, the model output is expected to have
a classification token.
num_register_tokens: The number of register tokens in the model output.
ignore_remaining_dims: If set to `True`, ignore the remaining dimensions
of the patch grid if it is not a square number.
"""
self._has_cls_token = has_cls_token
self._num_register_tokens = num_register_tokens
self._ignore_remaining_dims = ignore_remaining_dims

def __call__(
Expand All @@ -31,17 +41,19 @@ def __call__(
A tensor (batch_size, hidden_size, n_patches_height, n_patches_width)
representing the model output.
"""
num_skip = int(self._has_cls_token) + self._num_register_tokens
if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling):
features = tensor.last_hidden_state[:, 1:, :].permute(0, 2, 1)
batch_size, hidden_size, patch_grid = features.shape
height = width = int(math.sqrt(patch_grid))
if height * width != patch_grid:
if self._ignore_remaining_dims:
features = features[:, :, -height * width :]
else:
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
patch_embeddings = features.view(batch_size, hidden_size, height, width)
features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1)
else:
raise ValueError(f"Unsupported type {type(tensor)}")
features = tensor[:, num_skip:, :].permute(0, 2, 1)

batch_size, hidden_size, patch_grid = features.shape
height = width = int(math.sqrt(patch_grid))
if height * width != patch_grid:
if self._ignore_remaining_dims:
features = features[:, :, -height * width :]
else:
raise ValueError(f"Patch grid size must be a square number {patch_grid}.")
patch_embeddings = features.view(batch_size, hidden_size, height, width)

return [patch_embeddings]
12 changes: 12 additions & 0 deletions src/eva/vision/models/networks/backbones/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Utilis for backbone networks."""

import os
from typing import Any, Dict, Tuple

import huggingface_hub
from torch import nn

from eva import models
Expand Down Expand Up @@ -37,3 +39,13 @@ def load_hugingface_model(
tensor_transforms=tensor_transforms,
model_kwargs=model_kwargs,
)


def huggingface_login(hf_token: str | None = None):
token = hf_token or os.environ.get("HF_TOKEN")
if not token:
raise ValueError(
"Please provide a HuggingFace token to download the model. "
"You can either pass it as an argument or set the env variable HF_TOKEN."
)
huggingface_hub.login(token=token)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from eva.vision.models.networks.backbones.pathology.lunit import lunit_vits8, lunit_vits16
from eva.vision.models.networks.backbones.pathology.mahmood import mahmood_uni
from eva.vision.models.networks.backbones.pathology.owkin import owkin_phikon
from eva.vision.models.networks.backbones.pathology.paige import paige_virchow2

__all__ = [
"kaiko_vitb16",
Expand All @@ -28,4 +29,5 @@
"prov_gigapath",
"histai_hibou_b",
"histai_hibou_l",
"paige_virchow2",
]
10 changes: 8 additions & 2 deletions src/eva/vision/models/networks/backbones/pathology/histai.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
"""Initializes the hibou-B pathology FM by hist.ai (https://huggingface.co/histai/hibou-B).
Uses a customized implementation of the DINOv2 architecture from the transformers
library to add support for registers, which requires the trust_remote_code=True flag.
Args:
out_indices: Whether and which multi-level patch embeddings to return.
Currently only out_indices=1 is supported.
Expand All @@ -23,14 +26,17 @@ def histai_hibou_b(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
model_name="histai/hibou-B",
out_indices=out_indices,
model_kwargs={"trust_remote_code": True},
transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
)


@register_model("pathology/histai_hibou_l")
def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Module:
"""Initializes the hibou-L pathology FM by hist.ai (https://huggingface.co/histai/hibou-L).
Uses a customized implementation of the DINOv2 architecture from the transformers
library to add support for registers, which requires the trust_remote_code=True flag.
Args:
out_indices: Whether and which multi-level patch embeddings to return.
Currently only out_indices=1 is supported.
Expand All @@ -42,5 +48,5 @@ def histai_hibou_l(out_indices: int | Tuple[int, ...] | None = None) -> nn.Modul
model_name="histai/hibou-L",
out_indices=out_indices,
model_kwargs={"trust_remote_code": True},
transform_args={"ignore_remaining_dims": True} if out_indices is not None else None,
transform_args={"num_register_tokens": 4} if out_indices is not None else None,
)
11 changes: 2 additions & 9 deletions src/eva/vision/models/networks/backbones/pathology/mahmood.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn

from eva.vision.models import wrappers
from eva.vision.models.networks.backbones import _utils
from eva.vision.models.networks.backbones.registry import register_model


Expand All @@ -31,19 +32,11 @@ def mahmood_uni(
Returns:
The model instance.
"""
token = hf_token or os.environ.get("HF_TOKEN")
if not token:
raise ValueError(
"Please provide a HuggingFace token to download the model. "
"You can either pass it as an argument or set the env variable HF_TOKEN."
)

checkpoint_path = os.path.join(download_dir, "pytorch_model.bin")

if not os.path.exists(checkpoint_path):
logger.info(f"Downloading the model checkpoint to {download_dir} ...")
os.makedirs(download_dir, exist_ok=True)
huggingface_hub.login(token=token)
_utils.huggingface_login(hf_token)
huggingface_hub.hf_hub_download(
"MahmoodLab/UNI",
filename="pytorch_model.bin",
Expand Down
51 changes: 51 additions & 0 deletions src/eva/vision/models/networks/backbones/pathology/paige.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Pathology FMs from paige.ai.
Source: https://huggingface.co/paige-ai/
"""

from typing import Tuple

import timm
import torch.nn as nn

from eva.core.models import transforms
from eva.vision.models import wrappers
from eva.vision.models.networks.backbones import _utils
from eva.vision.models.networks.backbones.registry import register_model


@register_model("paige/virchow2")
def paige_virchow2(
dynamic_img_size: bool = True,
out_indices: int | Tuple[int, ...] | None = None,
hf_token: str | None = None,
include_patch_tokens: bool = False,
) -> nn.Module:
"""Initializes the Virchow2 pathology FM by paige.ai.
Args:
dynamic_img_size: Support different input image sizes by allowing to change
the grid size (interpolate abs and/or ROPE pos) in the forward pass.
out_indices: Whether and which multi-level patch embeddings to return.
include_patch_tokens: Whether to combine the mean aggregated patch tokens with cls token.
hf_token: HuggingFace token to download the model.
Returns:
The model instance.
"""
_utils.huggingface_login(hf_token)
return wrappers.TimmModel(
model_name="hf-hub:paige-ai/Virchow2",
out_indices=out_indices,
pretrained=True,
model_kwargs={
"dynamic_img_size": dynamic_img_size,
"mlp_layer": timm.layers.SwiGLUPacked,
"act_layer": nn.SiLU,
},
tensor_transforms=(
transforms.ExtractCLSFeatures(include_patch_tokens=include_patch_tokens)
if out_indices is None
else None
),
)
5 changes: 5 additions & 0 deletions tests/eva/core/models/wrappers/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
[
("hf-internal-testing/tiny-random-ViTModel", None, (16, 226, 32)),
("hf-internal-testing/tiny-random-ViTModel", transforms.ExtractCLSFeatures(), (16, 32)),
(
"hf-internal-testing/tiny-random-ViTModel",
transforms.ExtractCLSFeatures(include_patch_tokens=True),
(16, 64),
),
],
)
def test_huggingface_model(
Expand Down

0 comments on commit 19d0043

Please sign in to comment.