diff --git a/nucliadb/nucliadb/common/cluster/manager.py b/nucliadb/nucliadb/common/cluster/manager.py index 13851f4152..ce2b0cea01 100644 --- a/nucliadb/nucliadb/common/cluster/manager.py +++ b/nucliadb/nucliadb/common/cluster/manager.py @@ -133,18 +133,18 @@ async def get_shards_by_kbid(self, kbid: str) -> list[writer_pb2.ShardObject]: async def apply_for_all_shards( self, kbid: str, - aw: Callable[[AbstractIndexNode, str, str], Awaitable[Any]], + aw: Callable[[AbstractIndexNode, str], Awaitable[Any]], timeout: float, ) -> list[Any]: shards = await self.get_shards_by_kbid(kbid) ops = [] for shard_obj in shards: - node, shard_id, node_id = choose_node(shard_obj) + node, shard_id = choose_node(shard_obj) if shard_id is None: raise ShardNotFound("Fount a node but not a shard") - ops.append(aw(node, shard_id, node_id)) + ops.append(aw(node, shard_id)) try: results = await asyncio.wait_for( @@ -512,7 +512,7 @@ def choose_node( *, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = False, -) -> tuple[AbstractIndexNode, str, str]: +) -> tuple[AbstractIndexNode, str]: """Choose an arbitrary node storing `shard` following these rules: - nodes containing a shard replica from `target_replicas` are the preferred - when enabled, read replica nodes are preferred over primaries @@ -548,7 +548,7 @@ def choose_node( top = ranked_nodes[max(ranked_nodes)] selected_node, shard_replica_id = random.choice(top) - return selected_node, shard_replica_id, selected_node.id + return selected_node, shard_replica_id def check_enough_nodes(): diff --git a/nucliadb/nucliadb/ingest/consumer/auditing.py b/nucliadb/nucliadb/ingest/consumer/auditing.py index 43a89f3b0a..835b3e4fd3 100644 --- a/nucliadb/nucliadb/ingest/consumer/auditing.py +++ b/nucliadb/nucliadb/ingest/consumer/auditing.py @@ -119,7 +119,7 @@ async def process_kb(self, kbid: str) -> None: total_paragraphs = 0 for shard_obj in shard_groups: - node, shard_id, _ = choose_node(shard_obj) + node, shard_id = choose_node(shard_obj) shard: nodereader_pb2.Shard = await node.reader.GetShard( nodereader_pb2.GetShardRequest(shard_id=noderesources_pb2.ShardId(id=shard_id)) # type: ignore ) diff --git a/nucliadb/nucliadb/ingest/consumer/shard_creator.py b/nucliadb/nucliadb/ingest/consumer/shard_creator.py index d08a3aecd2..8be1b84875 100644 --- a/nucliadb/nucliadb/ingest/consumer/shard_creator.py +++ b/nucliadb/nucliadb/ingest/consumer/shard_creator.py @@ -92,7 +92,7 @@ async def process_kb(self, kbid: str) -> None: kb_shards = await self.shard_manager.get_shards_by_kbid_inner(kbid) current_shard: writer_pb2.ShardObject = kb_shards.shards[kb_shards.actual] - node, shard_id, _ = choose_node(current_shard) + node, shard_id = choose_node(current_shard) shard: nodereader_pb2.Shard = await node.reader.GetShard( nodereader_pb2.GetShardRequest(shard_id=noderesources_pb2.ShardId(id=shard_id)) # type: ignore ) diff --git a/nucliadb/nucliadb/ingest/orm/entities.py b/nucliadb/nucliadb/ingest/orm/entities.py index 66325b1518..88e2deeb4f 100644 --- a/nucliadb/nucliadb/ingest/orm/entities.py +++ b/nucliadb/nucliadb/ingest/orm/entities.py @@ -195,7 +195,7 @@ async def get_indexed_entities_group(self, group: str) -> Optional[EntitiesGroup shard_manager = get_shard_manager() async def do_entities_search( - node: AbstractIndexNode, shard_id: str, node_id: str + node: AbstractIndexNode, shard_id: str ) -> RelationSearchResponse: request = RelationSearchRequest( shard_id=shard_id, @@ -288,7 +288,7 @@ async def get_indexed_entities_groups_names(self) -> set[str]: shard_manager = get_shard_manager() async def query_indexed_entities_group_names( - node: AbstractIndexNode, shard_id: str, node_id: str + node: AbstractIndexNode, shard_id: str ) -> TypeList: return await node.reader.RelationTypes(ShardId(id=shard_id)) # type: ignore diff --git a/nucliadb/nucliadb/ingest/settings.py b/nucliadb/nucliadb/ingest/settings.py index 8f2eade954..4b32bab3ae 100644 --- a/nucliadb/nucliadb/ingest/settings.py +++ b/nucliadb/nucliadb/ingest/settings.py @@ -49,7 +49,10 @@ class DriverSettings(BaseSettings): ) driver_tikv_url: Optional[list[str]] = Field( default=None, - description="TiKV PD (Placement Dricer) URL. The URL to the cluster manager of TiKV. Example: tikv-pd.svc:2379", + description=( + "TiKV PD (Placement Driver) URLs. The URL to the cluster manager of" + "TiKV. Example: '[\"tikv-pd.svc:2379\"]'" + ), ) driver_local_url: Optional[str] = Field( default=None, diff --git a/nucliadb/nucliadb/ingest/tests/unit/consumer/test_auditing.py b/nucliadb/nucliadb/ingest/tests/unit/consumer/test_auditing.py index 6276d2487f..c92b8376cb 100644 --- a/nucliadb/nucliadb/ingest/tests/unit/consumer/test_auditing.py +++ b/nucliadb/nucliadb/ingest/tests/unit/consumer/test_auditing.py @@ -52,7 +52,7 @@ def shard_manager(reader): "nucliadb.ingest.consumer.auditing.get_shard_manager", return_value=nm ), patch( "nucliadb.ingest.consumer.auditing.choose_node", - return_value=(node, "shard_id", None), + return_value=(node, "shard_id"), ): yield nm diff --git a/nucliadb/nucliadb/ingest/tests/unit/consumer/test_shard_creator.py b/nucliadb/nucliadb/ingest/tests/unit/consumer/test_shard_creator.py index d3b28ef0b2..0f2da98cdf 100644 --- a/nucliadb/nucliadb/ingest/tests/unit/consumer/test_shard_creator.py +++ b/nucliadb/nucliadb/ingest/tests/unit/consumer/test_shard_creator.py @@ -64,7 +64,7 @@ def shard_manager(reader): "nucliadb.ingest.consumer.shard_creator.get_shard_manager", return_value=sm ), patch( "nucliadb.ingest.consumer.shard_creator.choose_node", - return_value=(node, "shard_id", None), + return_value=(node, "shard_id"), ): yield sm diff --git a/nucliadb/nucliadb/search/api/v1/knowledgebox.py b/nucliadb/nucliadb/search/api/v1/knowledgebox.py index 9aa763ce6e..54c8b15ba6 100644 --- a/nucliadb/nucliadb/search/api/v1/knowledgebox.py +++ b/nucliadb/nucliadb/search/api/v1/knowledgebox.py @@ -103,7 +103,7 @@ async def knowledgebox_counters( queried_shards = [] for shard_object in shard_groups: try: - node, shard_id, _ = choose_node(shard_object) + node, shard_id = choose_node(shard_object) except KeyError: raise HTTPException( status_code=500, diff --git a/nucliadb/nucliadb/search/api/v1/resource/search.py b/nucliadb/nucliadb/search/api/v1/resource/search.py index fd8a846006..c66dbc77b9 100644 --- a/nucliadb/nucliadb/search/api/v1/resource/search.py +++ b/nucliadb/nucliadb/search/api/v1/resource/search.py @@ -26,7 +26,7 @@ from nucliadb.models.responses import HTTPClientError from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_PREFIX, api from nucliadb.search.api.v1.utils import fastapi_query -from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query from nucliadb.search.search.exceptions import InvalidQueryError from nucliadb.search.search.merge import merge_paragraphs_results from nucliadb.search.search.query import paragraph_query_to_pb @@ -118,7 +118,7 @@ async def resource_search( except InvalidQueryError as exc: return HTTPClientError(status_code=412, detail=str(exc)) - results, incomplete_results, queried_nodes, queried_shards = await node_query( + results, incomplete_results, queried_nodes = await node_query( kbid, Method.PARAGRAPH, pb_query, shards ) @@ -136,7 +136,8 @@ async def resource_search( response.status_code = 206 if incomplete_results else 200 if debug: - search_results.nodes = queried_nodes + search_results.nodes = debug_nodes_info(queried_nodes) + queried_shards = [shard_id for _, shard_id in queried_nodes] search_results.shards = queried_shards return search_results diff --git a/nucliadb/nucliadb/search/api/v1/search.py b/nucliadb/nucliadb/search/api/v1/search.py index 8b2fb46769..0909b204b9 100644 --- a/nucliadb/nucliadb/search/api/v1/search.py +++ b/nucliadb/nucliadb/search/api/v1/search.py @@ -31,7 +31,7 @@ from nucliadb.models.responses import HTTPClientError from nucliadb.search.api.v1.router import KB_PREFIX, api from nucliadb.search.api.v1.utils import fastapi_query -from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query from nucliadb.search.search.exceptions import InvalidQueryError from nucliadb.search.search.merge import merge_results from nucliadb.search.search.query import QueryParser @@ -237,7 +237,7 @@ async def catalog( ) pb_query, _, _ = await query_parser.parse() - (results, _, _, _) = await node_query( + (results, _, _) = await node_query( kbid, Method.SEARCH, pb_query, @@ -359,7 +359,7 @@ async def search( ) pb_query, incomplete_results, autofilters = await query_parser.parse() - results, query_incomplete_results, queried_nodes, queried_shards = await node_query( + results, query_incomplete_results, queried_nodes = await node_query( kbid, Method.SEARCH, pb_query, target_shard_replicas=item.shards ) @@ -391,8 +391,9 @@ async def search( len(search_results.resources), ) if item.debug: - search_results.nodes = queried_nodes + search_results.nodes = debug_nodes_info(queried_nodes) + queried_shards = [shard_id for _, shard_id in queried_nodes] search_results.shards = queried_shards search_results.autofilters = autofilters return search_results, incomplete_results diff --git a/nucliadb/nucliadb/search/api/v1/suggest.py b/nucliadb/nucliadb/search/api/v1/suggest.py index bb578277ad..d2eead9d76 100644 --- a/nucliadb/nucliadb/search/api/v1/suggest.py +++ b/nucliadb/nucliadb/search/api/v1/suggest.py @@ -148,7 +148,7 @@ async def suggest( range_modification_start, range_modification_end, ) - results, incomplete_results, _, queried_shards = await node_query( + results, incomplete_results, queried_nodes = await node_query( kbid, Method.SUGGEST, pb_query ) @@ -162,6 +162,8 @@ async def suggest( ) response.status_code = 206 if incomplete_results else 200 + + queried_shards = [shard_id for _, shard_id in queried_nodes] if debug and queried_shards: search_results.shards = queried_shards diff --git a/nucliadb/nucliadb/search/requesters/utils.py b/nucliadb/nucliadb/search/requesters/utils.py index d252490388..64c24acdc9 100644 --- a/nucliadb/nucliadb/search/requesters/utils.py +++ b/nucliadb/nucliadb/search/requesters/utils.py @@ -37,6 +37,7 @@ from nucliadb_protos.writer_pb2 import ShardObject as PBShardObject from nucliadb.common.cluster import manager as cluster_manager +from nucliadb.common.cluster.base import AbstractIndexNode from nucliadb.common.cluster.exceptions import ShardsNotFound from nucliadb.common.cluster.utils import get_shard_manager from nucliadb.search import logger @@ -86,7 +87,7 @@ async def node_query( pb_query: SuggestRequest, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = True, -) -> tuple[list[SuggestResponse], bool, list[tuple[str, str, str]], list[str]]: +) -> tuple[list[SuggestResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... @@ -97,7 +98,7 @@ async def node_query( pb_query: ParagraphSearchRequest, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = True, -) -> tuple[list[ParagraphSearchResponse], bool, list[tuple[str, str, str]], list[str]]: +) -> tuple[list[ParagraphSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... @@ -108,7 +109,7 @@ async def node_query( pb_query: SearchRequest, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = True, -) -> tuple[list[SearchResponse], bool, list[tuple[str, str, str]], list[str]]: +) -> tuple[list[SearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... @@ -119,7 +120,7 @@ async def node_query( pb_query: RelationSearchRequest, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = True, -) -> tuple[list[RelationSearchResponse], bool, list[tuple[str, str, str]], list[str]]: +) -> tuple[list[RelationSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ... @@ -129,7 +130,7 @@ async def node_query( pb_query: REQUEST_TYPE, target_shard_replicas: Optional[list[str]] = None, use_read_replica_nodes: bool = True, -) -> tuple[list[T], bool, list[tuple[str, str, str]], list[str]]: +) -> tuple[list[T], bool, list[tuple[AbstractIndexNode, str]]]: use_read_replica_nodes = use_read_replica_nodes and has_feature( const.Features.READ_REPLICA_SEARCHES, context={"kbid": kbid} ) @@ -144,13 +145,12 @@ async def node_query( ) ops = [] - queried_shards = [] queried_nodes = [] incomplete_results = False for shard_obj in shard_groups: try: - node, shard_id, node_id = cluster_manager.choose_node( + node, shard_id = cluster_manager.choose_node( shard_obj, use_read_replica_nodes=use_read_replica_nodes, target_shard_replicas=target_shard_replicas, @@ -163,8 +163,7 @@ async def node_query( # let's add it ot the query list if has a valid value func = METHODS[method] ops.append(func(node, shard_id, pb_query)) # type: ignore - queried_nodes.append((node.label, shard_id, node_id)) - queried_shards.append(shard_id) + queried_nodes.append((node, shard_id)) if not ops: logger.warning(f"No node found for any of this resources shards {kbid}") @@ -179,30 +178,39 @@ async def node_query( timeout=settings.search_timeout, ) except asyncio.TimeoutError as exc: # pragma: no cover - queried_nodes_details = [] - for _, shard_id, node_id in queried_nodes: - queried_node = cluster_manager.get_index_node(node_id) - if queried_node is None: - node_address = "unknown" - else: - node_address = node.address - queried_nodes_details.append( - { - "id": node_id, - "shard_id": shard_id, - "address": node_address, - } - ) logger.warning( - "Timeout while querying nodes", extra={"nodes": queried_nodes_details} + "Timeout while querying nodes", + extra={"nodes": debug_nodes_info(queried_nodes)}, ) results = [exc] error = validate_node_query_results(results or []) if error is not None: + if ( + error.status_code >= 500 + and use_read_replica_nodes + and any([node.is_read_replica() for node, _ in queried_nodes]) + ): + # We had an error querying a secondary node, instead of raising an + # error directly, retry query to primaries and hope it works + logger.warning( + "Query to read replica failed. Trying again with primary", + extra={"nodes": debug_nodes_info(queried_nodes)}, + ) + + results, incomplete_results, primary_queried_nodes = await node_query( # type: ignore + kbid, + method, + pb_query, + target_shard_replicas, + use_read_replica_nodes=False, + ) + queried_nodes.extend(primary_queried_nodes) + return results, incomplete_results, queried_nodes + raise error - return results, incomplete_results, queried_nodes, queried_shards + return results, incomplete_results, queried_nodes def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]: @@ -241,3 +249,19 @@ def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]: return HTTPException(status_code=status_code, detail=reason) return None + + +def debug_nodes_info( + nodes: list[tuple[AbstractIndexNode, str]] +) -> list[dict[str, str]]: + details: list[dict[str, str]] = [] + for node, shard_id in nodes: + info = { + "id": node.id, + "shard_id": shard_id, + "address": node.address, + } + if node.primary_id: + info["primary_id"] = node.primary_id + details.append(info) + return details diff --git a/nucliadb/nucliadb/search/search/chat/query.py b/nucliadb/nucliadb/search/search/chat/query.py index 3a7fa6f882..7feabd62e3 100644 --- a/nucliadb/nucliadb/search/search/chat/query.py +++ b/nucliadb/nucliadb/search/search/chat/query.py @@ -183,7 +183,6 @@ async def get_relations_results( relations_results, _, _, - _, ) = await node_query( kbid, Method.RELATIONS, diff --git a/nucliadb/nucliadb/search/search/find.py b/nucliadb/nucliadb/search/search/find.py index 6fa7ccce60..0ef4fa4302 100644 --- a/nucliadb/nucliadb/search/search/find.py +++ b/nucliadb/nucliadb/search/search/find.py @@ -19,7 +19,7 @@ # from time import time -from nucliadb.search.requesters.utils import Method, node_query +from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query from nucliadb.search.search.find_merge import find_merge_results from nucliadb.search.search.query import QueryParser from nucliadb.search.search.utils import should_disable_vector_search @@ -70,7 +70,7 @@ async def find( security=item.security, ) pb_query, incomplete_results, autofilters = await query_parser.parse() - results, query_incomplete_results, queried_nodes, queried_shards = await node_query( + results, query_incomplete_results, queried_nodes = await node_query( kbid, Method.SEARCH, pb_query, target_shard_replicas=item.shards ) incomplete_results = incomplete_results or query_incomplete_results @@ -100,8 +100,9 @@ async def find( len(search_results.resources), ) if item.debug: - search_results.nodes = queried_nodes + search_results.nodes = debug_nodes_info(queried_nodes) + queried_shards = [shard_id for _, shard_id in queried_nodes] search_results.shards = queried_shards search_results.autofilters = autofilters return search_results, incomplete_results diff --git a/nucliadb/nucliadb/search/tests/integration/requesters/test_utils.py b/nucliadb/nucliadb/search/tests/integration/requesters/test_utils.py index 7fb7ece760..b4a15ac341 100644 --- a/nucliadb/nucliadb/search/tests/integration/requesters/test_utils.py +++ b/nucliadb/nucliadb/search/tests/integration/requesters/test_utils.py @@ -50,6 +50,6 @@ async def test_vector_result_metadata( ), ).parse() - results, _, _, _ = await node_query(kbid, Method.SEARCH, pb_query) + results, _, _ = await node_query(kbid, Method.SEARCH, pb_query) assert len(results[0].vector.documents) > 0 assert results[0].vector.documents[0].HasField("metadata") diff --git a/nucliadb/nucliadb/search/tests/unit/search/requesters/test_utils.py b/nucliadb/nucliadb/search/tests/unit/search/requesters/test_utils.py index d1d1056e50..186da2bbb3 100644 --- a/nucliadb/nucliadb/search/tests/unit/search/requesters/test_utils.py +++ b/nucliadb/nucliadb/search/tests/unit/search/requesters/test_utils.py @@ -17,13 +17,146 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock +import pytest from fastapi import HTTPException from grpc import StatusCode from grpc.aio import AioRpcError # type: ignore +from nucliadb.common.cluster.base import AbstractIndexNode from nucliadb.search.requesters import utils +from nucliadb_protos import nodereader_pb2, writer_pb2 +from nucliadb_utils.utilities import Utility, clean_utility, get_utility, set_utility + + +@pytest.fixture +def fake_nodes(): + from nucliadb.common.cluster import manager + + original = manager.INDEX_NODES + manager.INDEX_NODES.clear() + + manager.add_index_node( + id="node-0", + address="nohost", + shard_count=0, + dummy=True, + ) + manager.add_index_node( + id="node-replica-0", + address="nohost", + shard_count=0, + dummy=True, + primary_id="node-0", + ) + + yield (["node-0"], ["node-replica-0"]) + + manager.INDEX_NODES = original + + +@pytest.fixture +def shard_manager(): + original = get_utility(Utility.SHARD_MANAGER) + + manager = AsyncMock() + manager.get_shards_by_kbid = AsyncMock( + return_value=[ + writer_pb2.ShardObject( + shard="shard-id", + replicas=[ + writer_pb2.ShardReplica( + shard=writer_pb2.ShardCreated(id="shard-id"), node="node-0" + ) + ], + ) + ] + ) + + set_utility(Utility.SHARD_MANAGER, manager) + + yield manager + + if original is None: + clean_utility(Utility.SHARD_MANAGER) + else: + set_utility(Utility.SHARD_MANAGER, original) + + +@pytest.fixture() +def search_methods(): + def fake_search( + node: AbstractIndexNode, shard: str, query: nodereader_pb2.SearchRequest + ): + if node.is_read_replica(): + raise Exception() + return nodereader_pb2.SearchResponse() + + original = utils.METHODS + utils.METHODS = { + utils.Method.SEARCH: AsyncMock(side_effect=fake_search), + utils.Method.PARAGRAPH: AsyncMock(), + } + + yield utils.METHODS + + utils.METHODS = original + + +@pytest.mark.asyncio +async def test_node_query_retries_primary_if_secondary_fails( + fake_nodes, + shard_manager, + search_methods, +): + """Setting up a node and a faulty replica, validate primary is queried if + secondary fails. + + """ + results, incomplete_results, queried_nodes = await utils.node_query( + kbid="my-kbid", + method=utils.Method.SEARCH, + pb_query=Mock(), + use_read_replica_nodes=True, + ) + # secondary fails, primary is called + assert search_methods[utils.Method.SEARCH].await_count == 2 + assert len(queried_nodes) == 2 + assert queried_nodes[0][0].is_read_replica() + assert not queried_nodes[1][0].is_read_replica() + + results, incomplete_results, queried_nodes = await utils.node_query( + kbid="my-kbid", + method=utils.Method.PARAGRAPH, + pb_query=Mock(), + use_read_replica_nodes=True, + ) + # secondary succeeds, no fallback call to primary + assert search_methods[utils.Method.PARAGRAPH].await_count == 1 + assert len(queried_nodes) == 1 + assert queried_nodes[0][0].is_read_replica() + + +def test_debug_nodes_info(fake_nodes: tuple[list[str], list[str]]): + from nucliadb.common.cluster import manager + + primary = manager.get_index_node(fake_nodes[0][0]) + assert primary is not None + secondary = manager.get_index_node(fake_nodes[1][0]) + assert secondary is not None + + info = utils.debug_nodes_info([(primary, "shard-a"), (secondary, "shard-b")]) + assert len(info) == 2 + + primary_keys = ["id", "shard_id", "address"] + secondary_keys = primary_keys + ["primary_id"] + + for key in primary_keys: + assert key in info[0] + + for key in secondary_keys: + assert key in info[1] def test_validate_node_query_results(): diff --git a/nucliadb/nucliadb/search/tests/unit/search/test_chat_prompt.py b/nucliadb/nucliadb/search/tests/unit/search/test_chat_prompt.py index b27f5ef44f..2f38dc1546 100644 --- a/nucliadb/nucliadb/search/tests/unit/search/test_chat_prompt.py +++ b/nucliadb/nucliadb/search/tests/unit/search/test_chat_prompt.py @@ -32,8 +32,6 @@ ) from nucliadb_protos import resources_pb2 -pytestmark = pytest.mark.asyncio - @pytest.fixture() def messages(): @@ -79,6 +77,7 @@ def kb(field_obj): yield mock +@pytest.mark.asyncio async def test_get_next_conversation_messages(field_obj, messages): assert ( len( @@ -107,18 +106,21 @@ async def test_get_next_conversation_messages(field_obj, messages): ) == [messages[3]] +@pytest.mark.asyncio async def test_find_conversation_message(field_obj, messages): assert await chat_prompt.find_conversation_message( field_obj=field_obj, mident="3" ) == (messages[2], 1, 2) +@pytest.mark.asyncio async def test_get_expanded_conversation_messages(kb, messages): assert await chat_prompt.get_expanded_conversation_messages( kb=kb, rid="rid", field_id="field_id", mident="3" ) == [messages[3]] +@pytest.mark.asyncio async def test_get_expanded_conversation_messages_question(kb, messages): assert ( await chat_prompt.get_expanded_conversation_messages( @@ -133,6 +135,7 @@ async def test_get_expanded_conversation_messages_question(kb, messages): ) +@pytest.mark.asyncio async def test_get_expanded_conversation_messages_missing(kb, messages): assert ( await chat_prompt.get_expanded_conversation_messages( @@ -163,6 +166,7 @@ def _create_find_result( ) +@pytest.mark.asyncio async def test_default_prompt_context(kb): result_text = " ".join(["text"] * 10) with patch("nucliadb.search.search.chat.prompt.get_read_only_transaction"), patch( @@ -215,6 +219,7 @@ def find_results(): ) +@pytest.mark.asyncio async def test_prompt_context_builder_prepends_user_context( find_results: KnowledgeboxFindResults, ): diff --git a/nucliadb/nucliadb/tests/integration/common/cluster/test_manager.py b/nucliadb/nucliadb/tests/integration/common/cluster/test_manager.py index 5e462a7a8d..1f1b726eda 100644 --- a/nucliadb/nucliadb/tests/integration/common/cluster/test_manager.py +++ b/nucliadb/nucliadb/tests/integration/common/cluster/test_manager.py @@ -140,8 +140,8 @@ async def test_choose_node(shards, shard_index: int, nodes: set): shard = shards.shards[shard_index] node_ids = set() for i in range(100): - _, _, node_id = manager.choose_node(shard) - node_ids.add(node_id) + node, _ = manager.choose_node(shard) + node_ids.add(node.id) assert node_ids == nodes, "Random numbers have defeat this test" @@ -152,16 +152,16 @@ async def test_choose_node_attempts_target_replicas_but_is_not_imperative(shards r1 = shard.replicas[1].shard.id n1 = shard.replicas[1].node - _, replica_id, node_id = manager.choose_node(shard, target_shard_replicas=[r0]) + node, replica_id = manager.choose_node(shard, target_shard_replicas=[r0]) assert replica_id == r0 - assert node_id == n0 + assert node.id == n0 # Change the node-0 to a non-existent node id in order to # test the target_shard_replicas logic is not imperative shard.replicas[0].node = "I-do-not-exist" - _, replica_id, node_id = manager.choose_node(shard, target_shard_replicas=[r0]) + node, replica_id = manager.choose_node(shard, target_shard_replicas=[r0]) assert replica_id == r1 - assert node_id == n1 + assert node.id == n1 async def test_choose_node_raises_if_no_nodes(shards): @@ -182,8 +182,8 @@ async def test_apply_for_all_shards(fake_kbid: str, shards, maindb_driver: Drive nodes = [] - async def fun(node: AbstractIndexNode, shard_id: str, node_id: str): - nodes.append((shard_id, node_id)) + async def fun(node: AbstractIndexNode, shard_id: str): + nodes.append((shard_id, node.id)) await shard_manager.apply_for_all_shards(kbid, fun, timeout=10) diff --git a/nucliadb/nucliadb/tests/unit/common/cluster/test_manager.py b/nucliadb/nucliadb/tests/unit/common/cluster/test_manager.py index 7bb4e4dc22..6d268eff75 100644 --- a/nucliadb/nucliadb/tests/unit/common/cluster/test_manager.py +++ b/nucliadb/nucliadb/tests/unit/common/cluster/test_manager.py @@ -111,7 +111,7 @@ def test_choose_node_with_two_primary_nodes(): add_index_node("node-0") add_index_node("node-1") - _, _, node_id = manager.choose_node( + node, _ = manager.choose_node( writer_pb2.ShardObject( replicas=[ writer_pb2.ShardReplica( @@ -120,8 +120,8 @@ def test_choose_node_with_two_primary_nodes(): ] ) ) - assert node_id == "node-0" - _, _, node_id = manager.choose_node( + assert node.id == "node-0" + node, _ = manager.choose_node( writer_pb2.ShardObject( replicas=[ writer_pb2.ShardReplica( @@ -130,7 +130,7 @@ def test_choose_node_with_two_primary_nodes(): ] ) ) - assert node_id == "node-1" + assert node.id == "node-1" manager.INDEX_NODES.clear() @@ -144,7 +144,7 @@ def test_choose_node_with_two_read_replicas(): add_read_replica_node("node-replica-0", primary_id="node-0") add_read_replica_node("node-replica-1", primary_id="node-1") - _, _, node_id = manager.choose_node( + node, _ = manager.choose_node( writer_pb2.ShardObject( replicas=[ writer_pb2.ShardReplica( @@ -154,8 +154,8 @@ def test_choose_node_with_two_read_replicas(): ), use_read_replica_nodes=True, ) - assert node_id == "node-replica-0" - _, _, node_id = manager.choose_node( + assert node.id == "node-replica-0" + node, _ = manager.choose_node( writer_pb2.ShardObject( replicas=[ writer_pb2.ShardReplica( @@ -165,7 +165,7 @@ def test_choose_node_with_two_read_replicas(): ), use_read_replica_nodes=True, ) - assert node_id == "node-replica-1" + assert node.id == "node-replica-1" manager.INDEX_NODES.clear() @@ -201,9 +201,9 @@ def repeated_choose_node( node_ids = [] for _ in range(count): - _, shard_id, node_id = manager.choose_node(shard, **kwargs) + node, shard_id = manager.choose_node(shard, **kwargs) shard_ids.append(shard_id) - node_ids.append(node_id) + node_ids.append(node.id) return shard_ids, node_ids diff --git a/nucliadb/nucliadb/train/nodes.py b/nucliadb/nucliadb/train/nodes.py index f98b4b5bca..33871cfdff 100644 --- a/nucliadb/nucliadb/train/nodes.py +++ b/nucliadb/nucliadb/train/nodes.py @@ -55,7 +55,7 @@ async def get_reader(self, kbid: str, shard: str) -> tuple[AbstractIndexNode, st except StopIteration: raise KeyError("Shard not found") - node_obj, shard_id, _ = manager.choose_node(shard_object) + node_obj, shard_id = manager.choose_node(shard_object) return node_obj, shard_id async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]: diff --git a/nucliadb_models/nucliadb_models/search.py b/nucliadb_models/nucliadb_models/search.py index eb4dffae39..8463e9bbdc 100644 --- a/nucliadb_models/nucliadb_models/search.py +++ b/nucliadb_models/nucliadb_models/search.py @@ -20,7 +20,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Type, TypeVar, Union from google.protobuf.json_format import MessageToDict from nucliadb_protos.audit_pb2 import ClientType @@ -244,7 +244,7 @@ class ResourceSearchResults(JsonBaseModel): sentences: Optional[Sentences] = None paragraphs: Optional[Paragraphs] = None relations: Optional[Relations] = None - nodes: Optional[List[Tuple[str, str, str]]] = None + nodes: Optional[List[Dict[str, str]]] = None shards: Optional[List[str]] = None @@ -256,7 +256,7 @@ class KnowledgeboxSearchResults(JsonBaseModel): paragraphs: Optional[Paragraphs] = None fulltext: Optional[Resources] = None relations: Optional[Relations] = None - nodes: Optional[List[Tuple[str, str, str]]] = None + nodes: Optional[List[Dict[str, str]]] = None shards: Optional[List[str]] = None autofilters: List[str] = ModelParamDefaults.applied_autofilters.to_pydantic_field() @@ -779,7 +779,8 @@ def fields_validator(cls, values): ] ) raise ValueError( - f"Field '{field}' does not have a valid field type. Valid field types are: {allowed_field_types_part}." + f"Field '{field}' does not have a valid field type. " + f"Valid field types are: {allowed_field_types_part}." ) return values @@ -1009,7 +1010,7 @@ class KnowledgeboxFindResults(JsonBaseModel): page_number: int = 0 page_size: int = 20 next_page: bool = False - nodes: Optional[List[Tuple[str, str, str]]] = None + nodes: Optional[List[Dict[str, str]]] = None shards: Optional[List[str]] = None autofilters: List[str] = ModelParamDefaults.applied_autofilters.to_pydantic_field() min_score: float = ModelParamDefaults.min_score.to_pydantic_field()