Skip to content

Commit

Permalink
Remove *Search gRPC methods and cleanup search (#2626)
Browse files Browse the repository at this point in the history
* Convert ParagraphSearchRequest to SearchRequest

* Return appropiate paragraph response

* No more RelationSearch either

* Not ready to replace paragraph search request

* Don't call RelationSearch anymore (py)

* Remove RelationSearch

* Let's try again with paragraph search :)

* Typo in type

* Replace last direct call to ParagraphSearch (tests)

* No more ParagraphSearch (grpc)

* No more DocumentSearch

* No more VectorSearch

* Remove test using document search

* Update protos

* Fix test

* Cleanup query_paragraph_shard

* Remove ignored params from resource search endpoint

* Fix test

* Fix unit tests

* Cleanup relations_shard

* Remove paragraph and relation search methods
  • Loading branch information
jotare authored Nov 13, 2024
1 parent 785e912 commit ab7e349
Show file tree
Hide file tree
Showing 25 changed files with 194 additions and 986 deletions.
4 changes: 0 additions & 4 deletions nidx/nidx_protos/src/nidx.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ service NidxApi {
}

service NidxSearcher {
rpc DocumentSearch(nodereader.DocumentSearchRequest) returns (nodereader.DocumentSearchResponse) {}
rpc ParagraphSearch(nodereader.ParagraphSearchRequest) returns (nodereader.ParagraphSearchResponse) {}
rpc VectorSearch(nodereader.VectorSearchRequest) returns (nodereader.VectorSearchResponse) {}
rpc RelationSearch(nodereader.RelationSearchRequest) returns (nodereader.RelationSearchResponse) {}
rpc DocumentIds(noderesources.ShardId) returns (nodereader.IdCollection) {}
rpc ParagraphIds(noderesources.ShardId) returns (nodereader.IdCollection) {}
rpc VectorIds(noderesources.VectorSetID) returns (nodereader.IdCollection) {}
Expand Down
25 changes: 0 additions & 25 deletions nidx/src/searcher/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,6 @@ impl SearchServer {

#[tonic::async_trait]
impl NidxSearcher for SearchServer {
async fn document_search(
&self,
_request: Request<DocumentSearchRequest>,
) -> Result<Response<DocumentSearchResponse>> {
todo!()
}

async fn paragraph_search(
&self,
_request: Request<ParagraphSearchRequest>,
) -> Result<Response<ParagraphSearchResponse>> {
todo!()
}

async fn vector_search(&self, _request: Request<VectorSearchRequest>) -> Result<Response<VectorSearchResponse>> {
todo!()
}

async fn relation_search(
&self,
_request: Request<RelationSearchRequest>,
) -> Result<Response<RelationSearchResponse>> {
todo!()
}

async fn document_ids(&self, _request: Request<noderesources::ShardId>) -> Result<Response<IdCollection>> {
todo!()
}
Expand Down
6 changes: 0 additions & 6 deletions nucliadb/src/nucliadb/common/cluster/grpc_node_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nucliadb_protos.nodereader_pb2 import (
EdgeList,
RelationEdge,
RelationSearchResponse,
)
from nucliadb_protos.noderesources_pb2 import (
EmptyResponse,
Expand Down Expand Up @@ -90,11 +89,6 @@ async def GetShard(self, data): # pragma: no cover
self.calls.setdefault("GetShard", []).append(data)
return NodeResourcesShard(shard_id="shard", fields=2, paragraphs=2, sentences=2)

async def RelationSearch(self, data): # pragma: no cover
self.calls.setdefault("RelationSearch", []).append(data)
result = RelationSearchResponse()
return result

async def RelationEdges(self, data): # pragma: no cover
self.calls.setdefault("RelationEdges", []).append(data)
result = EdgeList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,6 @@ async def Search(self, request: SearchRequest, retry: bool = False) -> SearchRes
else:
raise

async def ParagraphSearch(self, request: ParagraphSearchRequest) -> ParagraphSearchResponse:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
self.executor, self.reader.paragraph_search, request.SerializeToString()
)
pb_bytes = bytes(result)
pb = ParagraphSearchResponse()
pb.ParseFromString(pb_bytes)
return pb

async def RelationSearch(self, request: RelationSearchRequest) -> RelationSearchResponse:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
self.executor, self.reader.relation_search, request.SerializeToString()
)
pb_bytes = bytes(result)
pb = RelationSearchResponse()
pb.ParseFromString(pb_bytes)
return pb

async def GetShard(self, request: GetShardRequest) -> NodeResourcesShard:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
Expand Down
10 changes: 5 additions & 5 deletions nucliadb/src/nucliadb/ingest/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
Faceted,
RelationNodeFilter,
RelationPrefixSearchRequest,
RelationSearchRequest,
RelationSearchResponse,
SearchRequest,
SearchResponse,
Expand Down Expand Up @@ -208,16 +207,17 @@ async def get_indexed_entities_group(self, group: str) -> Optional[EntitiesGroup
shard_manager = get_shard_manager()

async def do_entities_search(node: AbstractIndexNode, shard_id: str) -> RelationSearchResponse:
request = RelationSearchRequest(
shard_id=shard_id,
prefix=RelationPrefixSearchRequest(
request = SearchRequest(
shard=shard_id,
relation_prefix=RelationPrefixSearchRequest(
prefix="",
node_filters=[
RelationNodeFilter(node_type=RelationNode.NodeType.ENTITY, node_subtype=group)
],
),
)
return await node.reader.RelationSearch(request) # type: ignore
response = await node.reader.Search(request) # type: ignore
return response.relation

results = await shard_manager.apply_for_all_shards(
self.kbid,
Expand Down
18 changes: 2 additions & 16 deletions nucliadb/src/nucliadb/search/api/v1/resource/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,10 @@
from nucliadb.search.search.exceptions import InvalidQueryError
from nucliadb.search.search.merge import merge_paragraphs_results
from nucliadb.search.search.query import paragraph_query_to_pb
from nucliadb_models.common import FieldTypeName
from nucliadb_models.resource import ExtractedDataTypeName, NucliaDBRoles
from nucliadb_models.resource import NucliaDBRoles
from nucliadb_models.search import (
NucliaDBClientType,
ResourceProperties,
ResourceSearchResults,
SearchOptions,
SearchParamDefaults,
SortField,
SortOrder,
Expand Down Expand Up @@ -79,13 +76,6 @@ async def resource_search(
SearchParamDefaults.range_modification_end
),
highlight: bool = fastapi_query(SearchParamDefaults.highlight),
show: list[ResourceProperties] = fastapi_query(
SearchParamDefaults.show, default=list(ResourceProperties)
),
field_type_filter: list[FieldTypeName] = fastapi_query(
SearchParamDefaults.field_type_filter, alias="field_type"
),
extracted: list[ExtractedDataTypeName] = fastapi_query(SearchParamDefaults.extracted),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
debug: bool = fastapi_query(SearchParamDefaults.debug),
shards: list[str] = fastapi_query(SearchParamDefaults.shards),
Expand All @@ -97,7 +87,6 @@ async def resource_search(
try:
pb_query = await paragraph_query_to_pb(
kbid,
[SearchOptions.KEYWORD],
rid,
query,
fields,
Expand All @@ -116,7 +105,7 @@ async def resource_search(
return HTTPClientError(status_code=412, detail=str(exc))

results, incomplete_results, queried_nodes = await node_query(
kbid, Method.PARAGRAPH, pb_query, shards
kbid, Method.SEARCH, pb_query, shards
)

# We need to merge
Expand All @@ -125,9 +114,6 @@ async def resource_search(
count=page_size,
page=page_number,
kbid=kbid,
show=show,
field_type_filter=field_type_filter,
extracted=extracted,
highlight_split=highlight,
min_score=0.0,
)
Expand Down
44 changes: 4 additions & 40 deletions nucliadb/src/nucliadb/search/requesters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import asyncio
import json
from enum import Enum
from enum import Enum, auto
from typing import Any, Optional, Sequence, TypeVar, Union, overload

from fastapi import HTTPException
Expand All @@ -34,17 +34,11 @@
from nucliadb.common.nidx import get_nidx_fake_node
from nucliadb.search import logger
from nucliadb.search.search.shards import (
query_paragraph_shard,
query_shard,
relations_shard,
suggest_shard,
)
from nucliadb.search.settings import settings
from nucliadb_protos.nodereader_pb2 import (
ParagraphSearchRequest,
ParagraphSearchResponse,
RelationSearchRequest,
RelationSearchResponse,
SearchRequest,
SearchResponse,
SuggestRequest,
Expand All @@ -57,27 +51,21 @@


class Method(Enum):
SEARCH = 1
PARAGRAPH = 2
SUGGEST = 3
RELATIONS = 4
SEARCH = auto()
SUGGEST = auto()


METHODS = {
Method.SEARCH: query_shard,
Method.PARAGRAPH: query_paragraph_shard,
Method.SUGGEST: suggest_shard,
Method.RELATIONS: relations_shard,
}

REQUEST_TYPE = Union[SuggestRequest, ParagraphSearchRequest, SearchRequest, RelationSearchRequest]
REQUEST_TYPE = Union[SuggestRequest, SearchRequest]

T = TypeVar(
"T",
SuggestResponse,
ParagraphSearchResponse,
SearchResponse,
RelationSearchResponse,
)


Expand All @@ -93,18 +81,6 @@ async def node_query(
) -> tuple[list[SuggestResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...


@overload
async def node_query(
kbid: str,
method: Method,
pb_query: ParagraphSearchRequest,
target_shard_replicas: Optional[list[str]] = None,
use_read_replica_nodes: bool = True,
timeout: Optional[float] = None,
retry_on_primary: bool = True,
) -> tuple[list[ParagraphSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...


@overload
async def node_query(
kbid: str,
Expand All @@ -117,18 +93,6 @@ async def node_query(
) -> tuple[list[SearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...


@overload
async def node_query(
kbid: str,
method: Method,
pb_query: RelationSearchRequest,
target_shard_replicas: Optional[list[str]] = None,
use_read_replica_nodes: bool = True,
timeout: Optional[float] = None,
retry_on_primary: bool = True,
) -> tuple[list[RelationSearchResponse], bool, list[tuple[AbstractIndexNode, str]]]: ...


async def node_query(
kbid: str,
method: Method,
Expand Down
19 changes: 10 additions & 9 deletions nucliadb/src/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
parse_rephrase_prompt,
)
from nucliadb_protos import audit_pb2
from nucliadb_protos.nodereader_pb2 import RelationSearchRequest, RelationSearchResponse
from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
from nucliadb_telemetry.errors import capture_exception
from nucliadb_utils.utilities import get_audit

Expand Down Expand Up @@ -214,25 +214,26 @@ async def get_relations_results(
try:
predict = get_predict()
detected_entities = await predict.detect_entities(kbid, text_answer)
relation_request = RelationSearchRequest()
relation_request.subgraph.entry_points.extend(detected_entities)
relation_request.subgraph.depth = 1
request = SearchRequest()
request.relation_subgraph.entry_points.extend(detected_entities)
request.relation_subgraph.depth = 1

relations_results: list[RelationSearchResponse]
results: list[SearchResponse]
(
relations_results,
results,
_,
_,
) = await node_query(
kbid,
Method.RELATIONS,
relation_request,
Method.SEARCH,
request,
target_shard_replicas=target_shard_replicas,
timeout=timeout,
use_read_replica_nodes=True,
retry_on_primary=False,
)
return await merge_relations_results(relations_results, relation_request.subgraph)
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
return await merge_relations_results(relations_results, request.relation_subgraph)
except Exception as exc:
capture_exception(exc)
logger.exception("Error getting relations results")
Expand Down
11 changes: 4 additions & 7 deletions nucliadb/src/nucliadb/search/search/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ async def merge_paragraph_results(
highlight: bool,
sort: SortOptions,
min_score: float,
):
) -> Paragraphs:
raw_paragraph_list: list[tuple[ParagraphResult, SortValue]] = []
facets: dict[str, Any] = {}
query = None
Expand Down Expand Up @@ -545,19 +545,16 @@ async def merge_results(


async def merge_paragraphs_results(
paragraph_responses: list[ParagraphSearchResponse],
responses: list[SearchResponse],
count: int,
page: int,
kbid: str,
show: list[ResourceProperties],
field_type_filter: list[FieldTypeName],
extracted: list[ExtractedDataTypeName],
highlight_split: bool,
min_score: float,
) -> ResourceSearchResults:
paragraphs = []
for result in paragraph_responses:
paragraphs.append(result)
for result in responses:
paragraphs.append(result.paragraph)

api_results = ResourceSearchResults()

Expand Down
Loading

0 comments on commit ab7e349

Please sign in to comment.