Skip to content

Commit

Permalink
Implement a more relaxed vectorset querying (#2361)
Browse files Browse the repository at this point in the history
* Implement a more relaxed vectorset querying

* Add test to validate search
  • Loading branch information
jotare authored Jul 31, 2024
1 parent 6d901ec commit 563cfa1
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 9 deletions.
272 changes: 271 additions & 1 deletion nucliadb/tests/nucliadb/integration/test_vectorsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,40 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

import functools
import random
import uuid
from typing import Any, Optional
from unittest.mock import AsyncMock, patch

import pytest
from httpx import AsyncClient
from pytest_mock import MockerFixture

from nucliadb.common.cluster import manager
from nucliadb_protos import nodereader_pb2
from nucliadb.common.cluster.base import AbstractIndexNode
from nucliadb.common.maindb.driver import Driver
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
from nucliadb.search.predict import DummyPredictEngine
from nucliadb.search.requesters import utils
from nucliadb_protos import (
nodereader_pb2,
resources_pb2,
utils_pb2,
writer_pb2,
)
from nucliadb_protos.knowledgebox_pb2 import SemanticModelMetadata
from nucliadb_protos.writer_pb2_grpc import WriterStub
from nucliadb_utils.indexing import IndexingUtility
from nucliadb_utils.storages.storage import Storage
from nucliadb_utils.utilities import (
Utility,
clean_utility,
set_utility,
)
from tests.ingest.fixtures import make_extracted_text
from tests.nucliadb.knowledgeboxes.vectorsets import KbSpecs
from tests.utils import inject_message

DEFAULT_VECTOR_DIMENSION = 512
VECTORSET_DIMENSION = 12
Expand Down Expand Up @@ -134,3 +159,248 @@ async def mock_node_query(kbid: str, method, pb_query: nodereader_pb2.SearchRequ
)
assert resp.status_code == 200
assert calls[-1].vectorset == expected


async def test_querying_kb_with_vectorsets(
mocker: MockerFixture,
storage: Storage,
maindb_driver: Driver,
shard_manager,
learning_config,
dummy_indexing_utility,
nucliadb_grpc: WriterStub,
nucliadb_reader: AsyncClient,
dummy_predict: DummyPredictEngine,
):
"""This tests validates a KB with 1 or 2 vectorsets have functional search
using or not `vectorset` parameter in search. The point here is not the
result, but checking the index response.
"""
query: tuple[Any, Optional[nodereader_pb2.SearchResponse], Optional[Exception]] = (None, None, None)

async def query_shard_wrapper(
node: AbstractIndexNode, shard: str, pb_query: nodereader_pb2.SearchRequest
):
nonlocal query

from nucliadb.search.search.shards import query_shard

# this avoids problems with spying an object twice
if not hasattr(node.reader.Search, "spy_return"):
spy = mocker.spy(node.reader, "Search")
else:
spy = node.reader.Search # type: ignore

try:
result = await query_shard(node, shard, pb_query)
except Exception as exc:
query = (spy, None, exc)
raise
else:
query = (spy, result, None)
return result

def predict_query_wrapper(original, dimension):
@functools.wraps(original)
async def inner(*args, **kwargs):
query_info = await original(*args, **kwargs)
query_info.sentence.data = [1.0] * dimension
return query_info

return inner

# KB with one vectorset

kbid = KnowledgeBox.new_unique_kbid()
kbslug = "kb-with-one-vectorset"
kbid, _ = await KnowledgeBox.create(
maindb_driver,
kbid=kbid,
slug=kbslug,
semantic_models={
"model": SemanticModelMetadata(
similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768
),
},
)
rid = uuid.uuid4().hex
field_id = "my-field"
bm = create_broker_message_with_vectorsets(kbid, rid, field_id, [("model", 768)])
await inject_message(nucliadb_grpc, bm)

with (
patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True),
):
with (
patch.object(
dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768)
),
):
resp = await nucliadb_reader.post(
f"/kb/{kbid}/find",
json={
"query": "foo",
},
)
assert resp.status_code == 200

node_search_spy, result, error = query
assert error is None

request = node_search_spy.call_args[0][0]
assert request.vectorset == ""
assert len(request.vector) == 768

resp = await nucliadb_reader.post(
f"/kb/{kbid}/find",
json={
"query": "foo",
"vectorset": "model",
},
)
assert resp.status_code == 200

node_search_spy, result, error = query
assert error is None

request = node_search_spy.call_args[0][0]
assert request.vectorset == "model"
assert len(request.vector) == 768

# KB with 2 vectorsets

kbid = KnowledgeBox.new_unique_kbid()
kbslug = "kb-with-vectorsets"
kbid, _ = await KnowledgeBox.create(
maindb_driver,
kbid=kbid,
slug=kbslug,
semantic_models={
"model-A": SemanticModelMetadata(
similarity_function=utils_pb2.VectorSimilarity.COSINE, vector_dimension=768
),
"model-B": SemanticModelMetadata(
similarity_function=utils_pb2.VectorSimilarity.DOT, vector_dimension=1024
),
},
)
rid = uuid.uuid4().hex
field_id = "my-field"
bm = create_broker_message_with_vectorsets(
kbid, rid, field_id, [("model-A", 768), ("model-B", 1024)]
)
await inject_message(nucliadb_grpc, bm)

with (
patch.dict(utils.METHODS, {utils.Method.SEARCH: query_shard_wrapper}, clear=True),
):
with (
patch.object(
dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 768)
),
):
resp = await nucliadb_reader.post(
f"/kb/{kbid}/find",
json={
"query": "foo",
"vectorset": "model-A",
},
)
assert resp.status_code == 200

node_search_spy, result, error = query
assert error is None

request = node_search_spy.call_args[0][0]
assert request.vectorset == "model-A"
assert len(request.vector) == 768

with (
patch.object(
dummy_predict, "query", side_effect=predict_query_wrapper(dummy_predict.query, 1024)
),
):
resp = await nucliadb_reader.post(
f"/kb/{kbid}/find",
json={
"query": "foo",
"vectorset": "model-B",
},
)
assert resp.status_code == 200

node_search_spy, result, error = query
assert error is None

request = node_search_spy.call_args[0][0]
assert request.vectorset == "model-B"
assert len(request.vector) == 1024

resp = await nucliadb_reader.get(
f"/kb/{kbid}/find",
params={
"query": "foo",
},
)
assert resp.status_code == 500
node_search_spy, result, error = query
request = node_search_spy.call_args[0][0]
assert result is None
assert request.vectorset == ""
assert "Query without vectorset but shard has multiple vector indexes" in str(error)


@pytest.fixture(scope="function")
def dummy_predict():
predict = DummyPredictEngine()
set_utility(Utility.PREDICT, predict)
yield predict
clean_utility(Utility.PREDICT)


#
# TODO: replace for the one in ndbfixtures when it's ready
@pytest.fixture(scope="function")
async def dummy_indexing_utility():
# as it's a dummy utility, we don't need to provide real nats servers or
# creds
indexing_utility = IndexingUtility(nats_creds=None, nats_servers=[], dummy=True)
await indexing_utility.initialize()
set_utility(Utility.INDEXING, indexing_utility)

yield

clean_utility(Utility.INDEXING)
await indexing_utility.finalize()


def create_broker_message_with_vectorsets(
kbid: str,
rid: str,
field_id: str,
vectorsets: list[tuple[str, int]],
):
bm = writer_pb2.BrokerMessage(kbid=kbid, uuid=rid, type=writer_pb2.BrokerMessage.AUTOCOMMIT)

body = "Lorem ipsum dolor sit amet..."
bm.texts[field_id].body = body

bm.extracted_text.append(make_extracted_text(field_id, body))

for vectorset_id, vectorset_dimension in vectorsets:
# custom vectorset
field_vectors = resources_pb2.ExtractedVectorsWrapper()
field_vectors.field.field = field_id
field_vectors.field.field_type = resources_pb2.FieldType.TEXT
field_vectors.vectorset_id = vectorset_id
for i in range(0, 100, 10):
field_vectors.vectors.vectors.vectors.append(
utils_pb2.Vector(
start=i,
end=i + 10,
vector=[random.random()] * vectorset_dimension,
)
)
bm.field_vectors.append(field_vectors)
return bm
29 changes: 21 additions & 8 deletions nucliadb_node/src/shards/shard_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,16 @@ impl ShardReader {
let info = info_span!(parent: &span, "vector count");
let vector_task = || {
run_with_telemetry(info, || {
read_rw_lock(&self.vector_readers)
.get(DEFAULT_VECTORS_INDEX_NAME)
.expect("Default vectors index should never be deleted (yet)")
.count()
let vector_readers = read_rw_lock(&self.vector_readers);
if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) {
return reader.count();
}

let mut count = 0;
for reader in vector_readers.values() {
count += reader.count()?;
}
Ok(count)
})
};

Expand Down Expand Up @@ -673,10 +679,17 @@ impl ShardReader {
) -> NodeResult<VectorSearchResponse> {
let vectorset = &request.vector_set;
if vectorset.is_empty() {
read_rw_lock(&self.vector_readers)
.get(DEFAULT_VECTORS_INDEX_NAME)
.expect("Default vectors index should never be deleted (yet)")
.search(request, context)
let vector_readers = read_rw_lock(&self.vector_readers);
if let Some(reader) = vector_readers.get(DEFAULT_VECTORS_INDEX_NAME) {
reader.search(request, context)
} else if vector_readers.len() == 1 {
// no default vectorset but only one exist, consider it the
// default
let reader = vector_readers.values().next().unwrap();
reader.search(request, context)
} else {
Err(node_error!("Query without vectorset but shard has multiple vector indexes"))
}
} else {
let vector_readers = read_rw_lock(&self.vector_readers);
let reader = vector_readers.get(vectorset);
Expand Down

0 comments on commit 563cfa1

Please sign in to comment.