diff --git a/nucliadb/tests/nucliadb/integration/test_vectorsets.py b/nucliadb/tests/nucliadb/integration/test_vectorsets.py
index 482d2fab7b..d9767cecbd 100644
--- a/nucliadb/tests/nucliadb/integration/test_vectorsets.py
+++ b/nucliadb/tests/nucliadb/integration/test_vectorsets.py
@@ -18,15 +18,40 @@
# along with this program. If not, see .
#
+import functools
+import random
+import uuid
+from typing import Any, Optional
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
+from pytest_mock import MockerFixture
from nucliadb.common.cluster import manager
-from nucliadb_protos import nodereader_pb2
+from nucliadb.common.cluster.base import AbstractIndexNode
+from nucliadb.common.maindb.driver import Driver
+from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
+from nucliadb.search.predict import DummyPredictEngine
+from nucliadb.search.requesters import utils
+from nucliadb_protos import (
+ nodereader_pb2,
+ resources_pb2,
+ utils_pb2,
+ writer_pb2,
+)
+from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata
from nucliadb_protos.writer_pb2_grpc import WriterStub
+from nucliadb_utils.indexing import IndexingUtility
+from nucliadb_utils.storages.storage import Storage
+from nucliadb_utils.utilities import (
+ Utility,
+ clean_utility,
+ set_utility,
+)
+from tests.ingest.fixtures import make_extracted_text
from tests.nucliadb.knowledgeboxes.vectorsets import KbSpecs
+from tests.utils import inject_message
DEFAULT_VECTOR_DIMENSION = 512
VECTORSET_DIMENSION = 12
@@ -134,3 +159,248 @@ async def mock_node_query(kbid: str, method, pb_query: nodereader_pb2.SearchRequ
)
assert resp.status_code == 200
assert calls[-1].vectorset == expected
+
+
+async def test_querying_kb_with_vectorsets(
+ mocker: MockerFixture,
+ storage: Storage,
+ maindb_driver: Driver,
+ shard_manager,
+ learning_config,
+ dummy_indexing_utility,
+ nucliadb_grpc: WriterStub,
+ nucliadb_reader: AsyncClient,
+ dummy_predict: DummyPredictEngine,
+):
+ """This tests validates a KB with 1 or 2 vectorsets have functional search
+ using or not `vectorset` parameter in search. The point here is not the
+ result, but checking the index response.
+
+ """
+ query: tuple[Any, Optional[nodereader_pb2.SearchResponse], Optional[Exception]] = (None, None, None)
+
+ async def query_shard_wrapper(
+ node: AbstractIndexNode, shard: str, pb_query: nodereader_pb2.SearchRequest
+ ):
+ nonlocal query
+
+ from nucliadb.search.search.shards import query_shard
+
+ # this avoids problems with spying an object twice
+ if not hasattr(node.reader.Search, "spy_return"):
+ spy = mocker.spy(node.reader, "Search")
+ else:
+ spy = node.reader.Search # type: ignore
+
+ try:
+ result = await query_shard(node, shard, pb_query)
+ except Exception as exc:
+ query = (spy, None, exc)
+ raise
+ else:
+ query = (spy, result, None)
+ return result
+
+ def predict_query_wrapper(original, dimension):
+ @functools.wraps(original)
+ async def inner(*args, **kwargs):
+ query_info = await original(*args, **kwargs)
+ query_info.sentence.data = [1.0] * dimension
+ return query_info
+
+ return inner
+
+ # KB with one vectorset
+
+ kbid = KnowledgeBox.new_unique_kbid()
+ kbslug = "kb-with-one-vectorset"
+ kbid, _ = await KnowledgeBox.create(
+ maindb_driver,
+ kbid=kbid,
+ slug=kbslug,
+ semantic_models={
+ "model": SemanticModelMetadata(
+ similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768
+ ),
+ },
+ )
+ rid = uuid.uuid4().hex
+ field_id = "my-field"
+ bm = create_broker_message_with_vectorsets(kbid, rid, field_id, [("model", 768)])
+ await inject_message(nucliadb_grpc, bm)
+
+ with (
+ patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True),
+ ):
+ with (
+ patch.object(
+ dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768)
+ ),
+ ):
+ resp = await nucliadb_reader.post(
+ f"/kb/{kbid}/find",
+ json={
+ "query": "foo",
+ },
+ )
+ assert resp.status_code == 200
+
+ node_search_spy, result, error = query
+ assert error is None
+
+ request = node_search_spy.call_args[0][0]
+ assert request.vectorset == ""
+ assert len(request.vector) == 768
+
+ resp = await nucliadb_reader.post(
+ f"/kb/{kbid}/find",
+ json={
+ "query": "foo",
+ "vectorset": "model",
+ },
+ )
+ assert resp.status_code == 200
+
+ node_search_spy, result, error = query
+ assert error is None
+
+ request = node_search_spy.call_args[0][0]
+ assert request.vectorset == "model"
+ assert len(request.vector) == 768
+
+ # KB with 2 vectorsets
+
+ kbid = KnowledgeBox.new_unique_kbid()
+ kbslug = "kb-with-vectorsets"
+ kbid, _ = await KnowledgeBox.create(
+ maindb_driver,
+ kbid=kbid,
+ slug=kbslug,
+ semantic_models={
+ "model-A": SemanticModelMetadata(
+ similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768
+ ),
+ "model-B": SemanticModelMetadata(
+ similarity_function=utils_pb2.VectorSimilarity.DOT, vector_dimension=1024
+ ),
+ },
+ )
+ rid = uuid.uuid4().hex
+ field_id = "my-field"
+ bm = create_broker_message_with_vectorsets(
+ kbid, rid, field_id, [("model-A", 768), ("model-B", 1024)]
+ )
+ await inject_message(nucliadb_grpc, bm)
+
+ with (
+ patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True),
+ ):
+ with (
+ patch.object(
+ dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768)
+ ),
+ ):
+ resp = await nucliadb_reader.post(
+ f"/kb/{kbid}/find",
+ json={
+ "query": "foo",
+ "vectorset": "model-A",
+ },
+ )
+ assert resp.status_code == 200
+
+ node_search_spy, result, error = query
+ assert error is None
+
+ request = node_search_spy.call_args[0][0]
+ assert request.vectorset == "model-A"
+ assert len(request.vector) == 768
+
+ with (
+ patch.object(
+ dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 1024)
+ ),
+ ):
+ resp = await nucliadb_reader.post(
+ f"/kb/{kbid}/find",
+ json={
+ "query": "foo",
+ "vectorset": "model-B",
+ },
+ )
+ assert resp.status_code == 200
+
+ node_search_spy, result, error = query
+ assert error is None
+
+ request = node_search_spy.call_args[0][0]
+ assert request.vectorset == "model-B"
+ assert len(request.vector) == 1024
+
+ resp = await nucliadb_reader.get(
+ f"/kb/{kbid}/find",
+ params={
+ "query": "foo",
+ },
+ )
+ assert resp.status_code == 500
+ node_search_spy, result, error = query
+ request = node_search_spy.call_args[0][0]
+ assert result is None
+ assert request.vectorset == ""
+ assert "Query without vectorset but shard has multiple vector indexes" in str(error)
+
+
+@pytest.fixture(scope="function")
+def dummy_predict():
+ predict = DummyPredictEngine()
+ set_utility(Utility.PREDICT, predict)
+ yield predict
+ clean_utility(Utility.PREDICT)
+
+
+#
+# TODO: replace for the one in ndbfixtures when it's ready
+@pytest.fixture(scope="function")
+async def dummy_indexing_utility():
+ # as it's a dummy utility, we don't need to provide real nats servers or
+ # creds
+ indexing_utility = IndexingUtility(nats_creds=None, nats_servers=[], dummy=True)
+ await indexing_utility.initialize()
+ set_utility(Utility.INDEXING, indexing_utility)
+
+ yield
+
+ clean_utility(Utility.INDEXING)
+ await indexing_utility.finalize()
+
+
+def create_broker_message_with_vectorsets(
+ kbid: str,
+ rid: str,
+ field_id: str,
+ vectorsets: list[tuple[str, int]],
+):
+ bm = writer_pb2.BrokerMessage(kbid=kbid, uuid=rid, type=writer_pb2.BrokerMessage.AUTOCOMMIT)
+
+ body = "Lorem ipsum dolor sit amet..."
+ bm.texts[field_id].body = body
+
+ bm.extracted_text.append(make_extracted_text(field_id, body))
+
+ for vectorset_id, vectorset_dimension in vectorsets:
+ # custom vectorset
+ field_vectors = resources_pb2.ExtractedVectorsWrapper()
+ field_vectors.field.field = field_id
+ field_vectors.field.field_type = resources_pb2.FieldType.TEXT
+ field_vectors.vectorset_id = vectorset_id
+ for i in range(0, 100, 10):
+ field_vectors.vectors.vectors.vectors.append(
+ utils_pb2.Vector(
+ start=i,
+ end=i + 10,
+ vector=[random.random()] * vectorset_dimension,
+ )
+ )
+ bm.field_vectors.append(field_vectors)
+ return bm
diff --git a/nucliadb_node/src/shards/shard_reader.rs b/nucliadb_node/src/shards/shard_reader.rs
index 054d3dea82..f45ae1990f 100644
--- a/nucliadb_node/src/shards/shard_reader.rs
+++ b/nucliadb_node/src/shards/shard_reader.rs
@@ -201,10 +201,16 @@ impl ShardReader {
let info = info_span!(parent: &span, "vector count");
let vector_task = || {
run_with_telemetry(info, || {
- read_rw_lock(&self.vector_readers)
- .get(DEFAULT_VECTORS_INDEX_NAME)
- .expect("Default vectors index should never be deleted (yet)")
- .count()
+ let vector_readers = read_rw_lock(&self.vector_readers);
+ if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) {
+ return reader.count();
+ }
+
+ let mut count = 0;
+ for reader in vector_readers.values() {
+ count += reader.count()?;
+ }
+ Ok(count)
})
};
@@ -673,10 +679,17 @@ impl ShardReader {
) -> NodeResult {
let vectorset = &request.vector_set;
if vectorset.is_empty() {
- read_rw_lock(&self.vector_readers)
- .get(DEFAULT_VECTORS_INDEX_NAME)
- .expect("Default vectors index should never be deleted (yet)")
- .search(request, context)
+ let vector_readers = read_rw_lock(&self.vector_readers);
+ if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) {
+ reader.search(request, context)
+ } else if vector_readers.len() == 1 {
+ // no default vectorset but only one exist, consider it the
+ // default
+ let reader = vector_readers.values().next().unwrap();
+ reader.search(request, context)
+ } else {
+ Err(node_error!("Query without vectorset but shard has multiple vector indexes"))
+ }
} else {
let vector_readers = read_rw_lock(&self.vector_readers);
let reader = vector_readers.get(vectorset);