-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
194 changes: 194 additions & 0 deletions
194
tests/models/paligemma/colpali_2/test_modeling_colpali_2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import logging | ||
from typing import Generator, List, cast | ||
|
||
import pytest | ||
import torch | ||
from PIL import Image | ||
from transformers import BatchFeature | ||
from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig | ||
|
||
from colpali_engine.models import ColPali2, ColPali2Config, ColPali2Processor | ||
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def colpali_2_config() -> Generator[ColPali2Config, None, None]: | ||
yield ColPali2Config( | ||
vlm_config=cast( | ||
PaliGemmaConfig, | ||
PaliGemmaConfig.from_pretrained("google/paligemma-3b-mix-448"), | ||
), | ||
single_vector_projector_dim=128, | ||
single_vector_pool_strategy="mean", | ||
multi_vector_projector_dim=128, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def colpali_2_from_config(colpali_2_config: ColPali2Config) -> Generator[ColPali2, None, None]: | ||
device = get_torch_device("auto") | ||
logger.info(f"Device used: {device}") | ||
|
||
yield ColPali2(config=colpali_2_config) | ||
tear_down_torch() | ||
|
||
|
||
@pytest.skip("No model available in the hub yet") | ||
@pytest.fixture(scope="module") | ||
def colpali_2_model_path() -> str: | ||
raise NotImplementedError("Please provide the path to the model in the hub") | ||
|
||
|
||
@pytest.skip("No model available in the hub yet") | ||
@pytest.fixture(scope="module") | ||
def colpali_2_from_pretrained(colpali_2_model_path: str) -> Generator[ColPali2, None, None]: | ||
device = get_torch_device("auto") | ||
logger.info(f"Device used: {device}") | ||
|
||
yield cast( | ||
ColPali2, | ||
ColPali2.from_pretrained( | ||
colpali_2_model_path, | ||
torch_dtype=torch.bfloat16, | ||
device_map=device, | ||
), | ||
) | ||
tear_down_torch() | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def processor() -> Generator[ColPali2Processor, None, None]: | ||
yield cast(ColPali2Processor, ColPali2Processor.from_pretrained("google/paligemma-3b-mix-448")) | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def images() -> Generator[List[Image.Image], None, None]: | ||
yield [ | ||
Image.new("RGB", (32, 32), color="white"), | ||
Image.new("RGB", (16, 16), color="black"), | ||
] | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def queries() -> Generator[List[str], None, None]: | ||
yield [ | ||
"Does Manu like to play football?", | ||
"Are Benjamin, Antoine, Merve, and Jo friends?", | ||
"Is byaldi a dish or an awesome repository for RAG?", | ||
] | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def batch_queries(processor: ColPali2Processor, queries: List[str]) -> Generator[BatchFeature, None, None]: | ||
yield processor.process_queries(queries) | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def batch_images(processor: ColPali2Processor, images: List[Image.Image]) -> Generator[BatchFeature, None, None]: | ||
yield processor.process_images(images) | ||
|
||
|
||
class TestLoadColPali2: | ||
""" | ||
Test the different ways to load ColPali2. | ||
""" | ||
|
||
@pytest.mark.slow | ||
def test_load_colpali_2_from_config(self, colpali_2_config: ColPali2Config): | ||
device = get_torch_device("auto") | ||
logger.info(f"Device used: {device}") | ||
|
||
model = ColPali2(config=colpali_2_config) | ||
|
||
assert isinstance(model, ColPali2) | ||
assert model.single_vector_projector_dim == colpali_2_config.single_vector_projector_dim | ||
assert model.multi_vector_pooler.pooling_strategy == colpali_2_config.single_vector_pool_strategy | ||
assert model.multi_vector_projector_dim == colpali_2_config.multi_vector_projector_dim | ||
|
||
tear_down_torch() | ||
|
||
@pytest.mark.slow | ||
def test_load_colpali_2_from_pretrained(self, colpali_2_from_config: ColPali2): | ||
assert isinstance(colpali_2_from_config, ColPali2) | ||
|
||
|
||
class TestForwardSingleVector: | ||
""" | ||
Test the forward pass of ColPali2 for single-vector embeddings. | ||
""" | ||
|
||
@pytest.mark.slow | ||
def test_colpali_2_forward_images( | ||
self, | ||
colpali_2_from_config: ColPali2, | ||
batch_images: BatchFeature, | ||
): | ||
# Forward pass | ||
with torch.no_grad(): | ||
outputs = colpali_2_from_config.forward_single_vector(**batch_images) | ||
|
||
# Assertions | ||
assert isinstance(outputs, torch.Tensor) | ||
assert outputs.dim() == 2 | ||
batch_size, emb_dim = outputs.shape | ||
assert batch_size == len(batch_images["input_ids"]) | ||
assert emb_dim == colpali_2_from_config.single_vector_projector_dim | ||
|
||
@pytest.mark.slow | ||
def test_colpali_2_forward_queries( | ||
self, | ||
colpali_2_from_config: ColPali2, | ||
batch_queries: BatchFeature, | ||
): | ||
# Forward pass | ||
with torch.no_grad(): | ||
outputs = colpali_2_from_config.forward_single_vector(**batch_queries) | ||
|
||
# Assertions | ||
assert isinstance(outputs, torch.Tensor) | ||
assert outputs.dim() == 2 | ||
batch_size, emb_dim = outputs.shape | ||
assert batch_size == len(batch_queries["input_ids"]) | ||
assert emb_dim == colpali_2_from_config.single_vector_projector_dim | ||
|
||
|
||
class TestForwardMultiVector: | ||
""" | ||
Test the forward pass of ColPali2 for multi-vector embeddings. | ||
""" | ||
|
||
@pytest.mark.slow | ||
def test_colpali_2_forward_images( | ||
self, | ||
colpali_2_from_config: ColPali2, | ||
batch_images: BatchFeature, | ||
): | ||
# Forward pass | ||
with torch.no_grad(): | ||
outputs = colpali_2_from_config.forward_multi_vector(**batch_images) | ||
|
||
# Assertions | ||
assert isinstance(outputs, torch.Tensor) | ||
assert outputs.dim() == 3 | ||
batch_size, n_visual_tokens, emb_dim = outputs.shape | ||
assert batch_size == len(batch_images["input_ids"]) | ||
assert emb_dim == colpali_2_from_config.multi_vector_projector_dim | ||
|
||
@pytest.mark.slow | ||
def test_colpali_2_forward_queries( | ||
self, | ||
colpali_2_from_config: ColPali2, | ||
batch_queries: BatchFeature, | ||
): | ||
# Forward pass | ||
with torch.no_grad(): | ||
outputs = colpali_2_from_config.forward_multi_vector(**batch_queries) | ||
|
||
# Assertions | ||
assert isinstance(outputs, torch.Tensor) | ||
assert outputs.dim() == 3 | ||
batch_size, n_query_tokens, emb_dim = outputs.shape | ||
assert batch_size == len(batch_queries["input_ids"]) | ||
assert emb_dim == colpali_2_from_config.multi_vector_projector_dim |
50 changes: 50 additions & 0 deletions
50
tests/models/paligemma/colpali_2/test_processing_colpali_2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Generator, cast | ||
|
||
import pytest | ||
import torch | ||
from PIL import Image | ||
|
||
from colpali_engine.models import ColPali2Processor | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def colpali_model_path() -> str: | ||
return "google/paligemma-3b-mix-448" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def processor_from_pretrained(colpali_model_path: str) -> Generator[ColPali2Processor, None, None]: | ||
yield cast(ColPali2Processor, ColPali2Processor.from_pretrained(colpali_model_path)) | ||
|
||
|
||
def test_load_processor_from_pretrained(processor_from_pretrained: ColPali2Processor): | ||
assert isinstance(processor_from_pretrained, ColPali2Processor) | ||
|
||
|
||
def test_process_images(processor_from_pretrained: ColPali2Processor): | ||
# Create a dummy image | ||
image = Image.new("RGB", (16, 16), color="black") | ||
images = [image] | ||
|
||
# Process the image | ||
batch_feature = processor_from_pretrained.process_images(images) | ||
|
||
# Assertions | ||
assert "pixel_values" in batch_feature | ||
assert batch_feature["pixel_values"].shape == torch.Size([1, 3, 448, 448]) | ||
|
||
|
||
def test_process_queries(processor_from_pretrained: ColPali2Processor): | ||
queries = [ | ||
"Does Manu like to play football?", | ||
"Are Benjamin, Antoine, Merve, and Jo friends?", | ||
"Is byaldi a dish or a nice repository for RAG?", | ||
] | ||
|
||
# Process the queries | ||
batch_encoding = processor_from_pretrained.process_queries(queries) | ||
|
||
# Assertions | ||
assert "input_ids" in batch_encoding | ||
assert isinstance(batch_encoding["input_ids"], torch.Tensor) | ||
assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) |