diff --git a/nucliadb/tests/nucliadb/integration/test_vectorsets.py b/nucliadb/tests/nucliadb/integration/test_vectorsets.py index 482d2fab7b..d9767cecbd 100644 --- a/nucliadb/tests/nucliadb/integration/test_vectorsets.py +++ b/nucliadb/tests/nucliadb/integration/test_vectorsets.py @@ -18,15 +18,40 @@ # along with this program. If not, see . # +import functools +import random +import uuid +from typing import Any, Optional from unittest.mock import AsyncMock, patch import pytest from httpx import AsyncClient +from pytest_mock import MockerFixture from nucliadb.common.cluster import manager -from nucliadb_protos import nodereader_pb2 +from nucliadb.common.cluster.base import AbstractIndexNode +from nucliadb.common.maindb.driver import Driver +from nucliadb.ingest.orm.knowledgebox import KnowledgeBox +from nucliadb.search.predict import DummyPredictEngine +from nucliadb.search.requesters import utils +from nucliadb_protos import ( + nodereader_pb2, + resources_pb2, + utils_pb2, + writer_pb2, +) +from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata from nucliadb_protos.writer_pb2_grpc import WriterStub +from nucliadb_utils.indexing import IndexingUtility +from nucliadb_utils.storages.storage import Storage +from nucliadb_utils.utilities import ( + Utility, + clean_utility, + set_utility, +) +from tests.ingest.fixtures import make_extracted_text from tests.nucliadb.knowledgeboxes.vectorsets import KbSpecs +from tests.utils import inject_message DEFAULT_VECTOR_DIMENSION = 512 VECTORSET_DIMENSION = 12 @@ -134,3 +159,248 @@ async def mock_node_query(kbid: str, method, pb_query: nodereader_pb2.SearchRequ ) assert resp.status_code == 200 assert calls[-1].vectorset == expected + + +async def test_querying_kb_with_vectorsets( + mocker: MockerFixture, + storage: Storage, + maindb_driver: Driver, + shard_manager, + learning_config, + dummy_indexing_utility, + nucliadb_grpc: WriterStub, + nucliadb_reader: AsyncClient, + dummy_predict: DummyPredictEngine, +): + """This tests validates a KB with 1 or 2 vectorsets have functional search + using or not `vectorset` parameter in search. The point here is not the + result, but checking the index response. + + """ + query: tuple[Any, Optional[nodereader_pb2.SearchResponse], Optional[Exception]] = (None, None, None) + + async def query_shard_wrapper( + node: AbstractIndexNode, shard: str, pb_query: nodereader_pb2.SearchRequest + ): + nonlocal query + + from nucliadb.search.search.shards import query_shard + + # this avoids problems with spying an object twice + if not hasattr(node.reader.Search, "spy_return"): + spy = mocker.spy(node.reader, "Search") + else: + spy = node.reader.Search # type: ignore + + try: + result = await query_shard(node, shard, pb_query) + except Exception as exc: + query = (spy, None, exc) + raise + else: + query = (spy, result, None) + return result + + def predict_query_wrapper(original, dimension): + @functools.wraps(original) + async def inner(*args, **kwargs): + query_info = await original(*args, **kwargs) + query_info.sentence.data = [1.0] * dimension + return query_info + + return inner + + # KB with one vectorset + + kbid = KnowledgeBox.new_unique_kbid() + kbslug = "kb-with-one-vectorset" + kbid, _ = await KnowledgeBox.create( + maindb_driver, + kbid=kbid, + slug=kbslug, + semantic_models={ + "model": SemanticModelMetadata( + similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768 + ), + }, + ) + rid = uuid.uuid4().hex + field_id = "my-field" + bm = create_broker_message_with_vectorsets(kbid, rid, field_id, [("model", 768)]) + await inject_message(nucliadb_grpc, bm) + + with ( + patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True), + ): + with ( + patch.object( + dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768) + ), + ): + resp = await nucliadb_reader.post( + f"/kb/{kbid}/find", + json={ + "query": "foo", + }, + ) + assert resp.status_code == 200 + + node_search_spy, result, error = query + assert error is None + + request = node_search_spy.call_args[0][0] + assert request.vectorset == "" + assert len(request.vector) == 768 + + resp = await nucliadb_reader.post( + f"/kb/{kbid}/find", + json={ + "query": "foo", + "vectorset": "model", + }, + ) + assert resp.status_code == 200 + + node_search_spy, result, error = query + assert error is None + + request = node_search_spy.call_args[0][0] + assert request.vectorset == "model" + assert len(request.vector) == 768 + + # KB with 2 vectorsets + + kbid = KnowledgeBox.new_unique_kbid() + kbslug = "kb-with-vectorsets" + kbid, _ = await KnowledgeBox.create( + maindb_driver, + kbid=kbid, + slug=kbslug, + semantic_models={ + "model-A": SemanticModelMetadata( + similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768 + ), + "model-B": SemanticModelMetadata( + similarity_function=utils_pb2.VectorSimilarity.DOT, vector_dimension=1024 + ), + }, + ) + rid = uuid.uuid4().hex + field_id = "my-field" + bm = create_broker_message_with_vectorsets( + kbid, rid, field_id, [("model-A", 768), ("model-B", 1024)] + ) + await inject_message(nucliadb_grpc, bm) + + with ( + patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True), + ): + with ( + patch.object( + dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768) + ), + ): + resp = await nucliadb_reader.post( + f"/kb/{kbid}/find", + json={ + "query": "foo", + "vectorset": "model-A", + }, + ) + assert resp.status_code == 200 + + node_search_spy, result, error = query + assert error is None + + request = node_search_spy.call_args[0][0] + assert request.vectorset == "model-A" + assert len(request.vector) == 768 + + with ( + patch.object( + dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 1024) + ), + ): + resp = await nucliadb_reader.post( + f"/kb/{kbid}/find", + json={ + "query": "foo", + "vectorset": "model-B", + }, + ) + assert resp.status_code == 200 + + node_search_spy, result, error = query + assert error is None + + request = node_search_spy.call_args[0][0] + assert request.vectorset == "model-B" + assert len(request.vector) == 1024 + + resp = await nucliadb_reader.get( + f"/kb/{kbid}/find", + params={ + "query": "foo", + }, + ) + assert resp.status_code == 500 + node_search_spy, result, error = query + request = node_search_spy.call_args[0][0] + assert result is None + assert request.vectorset == "" + assert "Query without vectorset but shard has multiple vector indexes" in str(error) + + +@pytest.fixture(scope="function") +def dummy_predict(): + predict = DummyPredictEngine() + set_utility(Utility.PREDICT, predict) + yield predict + clean_utility(Utility.PREDICT) + + +# +# TODO: replace for the one in ndbfixtures when it's ready +@pytest.fixture(scope="function") +async def dummy_indexing_utility(): + # as it's a dummy utility, we don't need to provide real nats servers or + # creds + indexing_utility = IndexingUtility(nats_creds=None, nats_servers=[], dummy=True) + await indexing_utility.initialize() + set_utility(Utility.INDEXING, indexing_utility) + + yield + + clean_utility(Utility.INDEXING) + await indexing_utility.finalize() + + +def create_broker_message_with_vectorsets( + kbid: str, + rid: str, + field_id: str, + vectorsets: list[tuple[str, int]], +): + bm = writer_pb2.BrokerMessage(kbid=kbid, uuid=rid, type=writer_pb2.BrokerMessage.AUTOCOMMIT) + + body = "Lorem ipsum dolor sit amet..." + bm.texts[field_id].body = body + + bm.extracted_text.append(make_extracted_text(field_id, body)) + + for vectorset_id, vectorset_dimension in vectorsets: + # custom vectorset + field_vectors = resources_pb2.ExtractedVectorsWrapper() + field_vectors.field.field = field_id + field_vectors.field.field_type = resources_pb2.FieldType.TEXT + field_vectors.vectorset_id = vectorset_id + for i in range(0, 100, 10): + field_vectors.vectors.vectors.vectors.append( + utils_pb2.Vector( + start=i, + end=i + 10, + vector=[random.random()] * vectorset_dimension, + ) + ) + bm.field_vectors.append(field_vectors) + return bm diff --git a/nucliadb_node/src/shards/shard_reader.rs b/nucliadb_node/src/shards/shard_reader.rs index 054d3dea82..f45ae1990f 100644 --- a/nucliadb_node/src/shards/shard_reader.rs +++ b/nucliadb_node/src/shards/shard_reader.rs @@ -201,10 +201,16 @@ impl ShardReader { let info = info_span!(parent: &span, "vector count"); let vector_task = || { run_with_telemetry(info, || { - read_rw_lock(&self.vector_readers) - .get(DEFAULT_VECTORS_INDEX_NAME) - .expect("Default vectors index should never be deleted (yet)") - .count() + let vector_readers = read_rw_lock(&self.vector_readers); + if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) { + return reader.count(); + } + + let mut count = 0; + for reader in vector_readers.values() { + count += reader.count()?; + } + Ok(count) }) }; @@ -673,10 +679,17 @@ impl ShardReader { ) -> NodeResult { let vectorset = &request.vector_set; if vectorset.is_empty() { - read_rw_lock(&self.vector_readers) - .get(DEFAULT_VECTORS_INDEX_NAME) - .expect("Default vectors index should never be deleted (yet)") - .search(request, context) + let vector_readers = read_rw_lock(&self.vector_readers); + if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) { + reader.search(request, context) + } else if vector_readers.len() == 1 { + // no default vectorset but only one exist, consider it the + // default + let reader = vector_readers.values().next().unwrap(); + reader.search(request, context) + } else { + Err(node_error!("Query without vectorset but shard has multiple vector indexes")) + } } else { let vector_readers = read_rw_lock(&self.vector_readers); let reader = vector_readers.get(vectorset);