diff --git a/nucliadb/tests/nucliadb/integration/search/test_filters.py b/nucliadb/tests/nucliadb/integration/search/test_filters.py index 01ef6b5e86..a9ffa3afde 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_filters.py +++ b/nucliadb/tests/nucliadb/integration/search/test_filters.py @@ -24,7 +24,7 @@ from nucliadb.common.context import ApplicationContext from nucliadb.tests.vectors import V1, V2, Q from nucliadb_models.labels import Label, LabelSetKind -from nucliadb_models.search import MinScore +from nucliadb_models.search import MinScore, SearchOptions from nucliadb_protos.resources_pb2 import ( Classification, ExtractedTextWrapper, @@ -351,7 +351,7 @@ async def _test_filtering(nucliadb_reader: AsyncClient, kbid: str, filters): json=dict( query="", filters=filters, - features=["paragraph", "vector"], + features=[SearchOptions.KEYWORD, SearchOptions.SEMANTIC], vector=Q, min_score=MinScore(semantic=-1).model_dump(), ), diff --git a/nucliadb/tests/nucliadb/integration/search/test_search.py b/nucliadb/tests/nucliadb/integration/search/test_search.py index 0a6af55e8f..e6320122a8 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_search.py +++ b/nucliadb/tests/nucliadb/integration/search/test_search.py @@ -35,6 +35,7 @@ from nucliadb.ingest.consumer import shard_creator from nucliadb.search.predict import SendToPredictError from nucliadb.tests.vectors import V1 +from nucliadb_models.search import SearchOptions from nucliadb_protos import resources_pb2 as rpb from nucliadb_protos.audit_pb2 import AuditRequest, ClientType from nucliadb_protos.utils_pb2 import RelationNode @@ -973,8 +974,8 @@ async def test_search_pagination( page_size = 5 for feature, result_key in [ - ("paragraph", "paragraphs"), - ("document", "fulltext"), + (SearchOptions.KEYWORD.value, "paragraphs"), + (SearchOptions.FULLTEXT.value, "fulltext"), ]: total_pages = math.floor(total / page_size) for page_number in range(0, total_pages): @@ -1069,7 +1070,7 @@ async def test_resource_search_pagination( f"/kb/{kbid}/resource/{rid}/search", params={ "query": query, - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], "page_number": page_number, "page_size": page_size, }, @@ -1083,7 +1084,7 @@ async def test_resource_search_pagination( f"/kb/{kbid}/resource/{rid}/search", params={ "query": query, - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], "page_number": page_number + 1, "page_size": page_size, }, @@ -1109,7 +1110,7 @@ async def test_search_endpoints_handle_predict_errors( resp = await nucliadb_reader.post( f"/kb/{kbid}/{endpoint}", json={ - "features": ["vector"], + "features": [SearchOptions.SEMANTIC], "query": "something", }, ) diff --git a/nucliadb/tests/nucliadb/integration/search/test_search_date_ranges_filter.py b/nucliadb/tests/nucliadb/integration/search/test_search_date_ranges_filter.py index 22a540d044..e96d1f0f6d 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_search_date_ranges_filter.py +++ b/nucliadb/tests/nucliadb/integration/search/test_search_date_ranges_filter.py @@ -23,6 +23,7 @@ from httpx import AsyncClient from nucliadb.tests.vectors import V1 +from nucliadb_models.search import SearchOptions from tests.nucliadb.integration.search.test_search import get_resource_with_a_sentence from tests.utils import inject_message @@ -76,8 +77,8 @@ async def resource(nucliadb_grpc, knowledgebox): @pytest.mark.parametrize( "feature", [ - "paragraph", - "vector", + SearchOptions.KEYWORD, + SearchOptions.SEMANTIC, ], ) async def test_search_with_date_range_filters_nucliadb_dates( @@ -133,8 +134,8 @@ async def test_search_with_date_range_filters_nucliadb_dates( @pytest.mark.parametrize( "feature", [ - "paragraph", - "vector", + SearchOptions.KEYWORD, + SearchOptions.SEMANTIC, ], ) async def test_search_with_date_range_filters_origin_dates( @@ -188,7 +189,7 @@ async def _test_find_date_ranges( found, ): payload = {"query": "Ramon", "features": features} - if "vector" in features: + if SearchOptions.SEMANTIC in features: payload["vector"] = V1 if creation_start is not None: payload["range_creation_start"] = creation_start.isoformat() diff --git a/nucliadb/tests/nucliadb/integration/search/test_search_sorting.py b/nucliadb/tests/nucliadb/integration/search/test_search_sorting.py index e660c4a357..e259b25856 100644 --- a/nucliadb/tests/nucliadb/integration/search/test_search_sorting.py +++ b/nucliadb/tests/nucliadb/integration/search/test_search_sorting.py @@ -22,6 +22,8 @@ import pytest from httpx import AsyncClient +from nucliadb_models.search import SearchOptions + @pytest.mark.asyncio async def test_search_sort_by_score( @@ -201,7 +203,7 @@ async def test_list_all_resources_by_creation_and_modification_dates_with_empty_ f"/kb/{kbid}/search", params={ "query": "", - "features": ["document"], + "features": [SearchOptions.FULLTEXT.value], "fields": ["a/title"], "page_number": page_number, "page_size": page_size, diff --git a/nucliadb/tests/nucliadb/integration/test_api.py b/nucliadb/tests/nucliadb/integration/test_api.py index 53c8652af1..01f8aa9ee5 100644 --- a/nucliadb/tests/nucliadb/integration/test_api.py +++ b/nucliadb/tests/nucliadb/integration/test_api.py @@ -30,6 +30,7 @@ ) from nucliadb_models import common, metadata from nucliadb_models.resource import Resource +from nucliadb_models.search import SearchOptions from nucliadb_protos import resources_pb2 as rpb from nucliadb_protos import writer_pb2 as wpb from nucliadb_protos.dataset_pb2 import TaskType, TrainSet @@ -920,7 +921,7 @@ async def test_pagination_limits( f"/kb/kbid/find", json={ "query": "foo", - "features": ["vector"], + "features": [SearchOptions.SEMANTIC], "page_size": 1000, }, ) @@ -933,7 +934,7 @@ async def test_pagination_limits( f"/kb/kbid/find", json={ "query": "foo", - "features": ["vector"], + "features": [SearchOptions.SEMANTIC], "page_number": 30, "page_size": 100, }, diff --git a/nucliadb/tests/nucliadb/integration/test_deletion.py b/nucliadb/tests/nucliadb/integration/test_deletion.py index a193ba2db2..c93905fe3d 100644 --- a/nucliadb/tests/nucliadb/integration/test_deletion.py +++ b/nucliadb/tests/nucliadb/integration/test_deletion.py @@ -24,6 +24,7 @@ import pytest from httpx import AsyncClient +from nucliadb_models.search import SearchOptions from nucliadb_protos.resources_pb2 import ( ExtractedTextWrapper, ExtractedVectorsWrapper, @@ -141,7 +142,7 @@ class FieldData: f"/kb/{knowledgebox}/find", json={ "query": "Original", - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], "min_score": {"bm25": 0.0}, }, timeout=None, @@ -155,7 +156,7 @@ class FieldData: f"/kb/{knowledgebox}/find", json={ "query": "Extracted", - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], }, timeout=None, ) @@ -233,7 +234,7 @@ class FieldData: f"/kb/{knowledgebox}/find", json={ "query": "Extracted", - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], "min_score": {"bm25": 0.0}, }, timeout=None, @@ -252,7 +253,7 @@ class FieldData: f"/kb/{knowledgebox}/find", json={ "query": "Modified", - "features": ["paragraph"], + "features": [SearchOptions.KEYWORD], }, timeout=None, ) diff --git a/nucliadb/tests/nucliadb/integration/test_find.py b/nucliadb/tests/nucliadb/integration/test_find.py index 1855ec0d86..8694ee0cb1 100644 --- a/nucliadb/tests/nucliadb/integration/test_find.py +++ b/nucliadb/tests/nucliadb/integration/test_find.py @@ -24,6 +24,7 @@ import pytest from httpx import AsyncClient +from nucliadb_models.search import SearchOptions from nucliadb_protos.writer_pb2_grpc import WriterStub from nucliadb_utils.exceptions import LimitsExceededError @@ -105,14 +106,14 @@ async def test_find_does_not_support_fulltext_search( knowledgebox, ): resp = await nucliadb_reader.get( - f"/kb/{knowledgebox}/find?query=title&features=document&features=paragraph", + f"/kb/{knowledgebox}/find?query=title&features=fulltext&features=keyword", ) assert resp.status_code == 422 assert "fulltext search not supported" in resp.json()["detail"][0]["msg"] resp = await nucliadb_reader.post( f"/kb/{knowledgebox}/find", - json={"query": "title", "features": ["document", "paragraph"]}, + json={"query": "title", "features": [SearchOptions.FULLTEXT, SearchOptions.KEYWORD]}, ) assert resp.status_code == 422 assert "fulltext search not supported" in resp.json()["detail"][0]["msg"] @@ -244,7 +245,7 @@ async def test_story_7286( f"/kb/{knowledgebox}/find", json={ "query": "title", - "features": ["paragraph", "vector", "relations"], + "features": [SearchOptions.KEYWORD, SearchOptions.SEMANTIC, SearchOptions.RELATIONS], "shards": [], "highlight": True, "autofilter": False, diff --git a/nucliadb/tests/nucliadb/integration/test_matryoshka_embeddings.py b/nucliadb/tests/nucliadb/integration/test_matryoshka_embeddings.py index c9bacd9252..2cd001a400 100644 --- a/nucliadb/tests/nucliadb/integration/test_matryoshka_embeddings.py +++ b/nucliadb/tests/nucliadb/integration/test_matryoshka_embeddings.py @@ -25,6 +25,7 @@ from nucliadb.common.maindb.driver import Driver from nucliadb.learning_proxy import LearningConfiguration +from nucliadb_models.search import SearchOptions from nucliadb_protos import knowledgebox_pb2, resources_pb2, utils_pb2, writer_pb2 from nucliadb_protos.writer_pb2_grpc import WriterStub from tests.utils import inject_message @@ -127,7 +128,7 @@ async def test_matryoshka_embeddings( f"/kb/{kbid}/search", params={ "query": "matryoshka", - "features": ["vector"], + "features": [SearchOptions.SEMANTIC.value], "min_score": 0.99999, "with_duplicates": True, }, diff --git a/nucliadb/tests/nucliadb/integration/test_synonyms.py b/nucliadb/tests/nucliadb/integration/test_synonyms.py index e25e80b07f..a56183f428 100644 --- a/nucliadb/tests/nucliadb/integration/test_synonyms.py +++ b/nucliadb/tests/nucliadb/integration/test_synonyms.py @@ -19,6 +19,8 @@ # import pytest +from nucliadb_models.search import SearchOptions + @pytest.mark.asyncio async def test_custom_synonyms_api( @@ -197,7 +199,7 @@ async def test_search_errors_if_vectors_or_relations_requested( resp = await nucliadb_reader.post( f"/kb/{kbid}/search", json=dict( - features=["paragraph", "vector", "relations"], + features=[SearchOptions.KEYWORD, SearchOptions.SEMANTIC, SearchOptions.RELATIONS], query="planet", with_synonyms=True, ), diff --git a/nucliadb_models/src/nucliadb_models/search.py b/nucliadb_models/src/nucliadb_models/search.py index c9f7c730f7..7af29e47c1 100644 --- a/nucliadb_models/src/nucliadb_models/search.py +++ b/nucliadb_models/src/nucliadb_models/search.py @@ -105,37 +105,12 @@ class SearchOptions(str, Enum): RELATIONS = "relations" SEMANTIC = "semantic" - # DEPRECATED: use keyword, fulltext and semantic instead - PARAGRAPH = "paragraph" - DOCUMENT = "document" - VECTOR = "vector" - - def normalized(self): - if self.value == SearchOptions.PARAGRAPH: - return SearchOptions.KEYWORD - elif self.value == SearchOptions.DOCUMENT: - return SearchOptions.FULLTEXT - elif self.value == SearchOptions.VECTOR: - return SearchOptions.SEMANTIC - return self - class ChatOptions(str, Enum): KEYWORD = "keyword" RELATIONS = "relations" SEMANTIC = "semantic" - # DEPRECATED: use keyword, and semantic instead - VECTORS = "vectors" - PARAGRAPHS = "paragraphs" - - def normalized(self): - if self.value == ChatOptions.PARAGRAPHS: - return ChatOptions.KEYWORD - elif self.value == ChatOptions.VECTORS: - return ChatOptions.SEMANTIC - return self - class SuggestOptions(str, Enum): PARAGRAPH = "paragraph" @@ -201,7 +176,7 @@ class Sentences(BaseModel): page_size: int = 20 min_score: float = Field( title="Minimum score", - description="Minimum similarity score used to filter vector index search. Results with a lower score have been ignored.", # noqa + description="Minimum similarity score used to filter vector index search. Results with a lower score have been ignored.", # noqa: E501 ) @@ -228,7 +203,7 @@ class Paragraphs(BaseModel): next_page: bool = False min_score: float = Field( title="Minimum score", - description="Minimum bm25 score used to filter bm25 index search. Results with a lower score have been ignored.", # noqa + description="Minimum bm25 score used to filter bm25 index search. Results with a lower score have been ignored.", # noqa: E501 ) @@ -250,7 +225,7 @@ class Resources(BaseModel): next_page: bool = False min_score: float = Field( title="Minimum score", - description="Minimum bm25 score used to filter bm25 index search. Results with a lower score have been ignored.", # noqa + description="Minimum bm25 score used to filter bm25 index search. Results with a lower score have been ignored.", # noqa: E501 ) @@ -560,7 +535,7 @@ class SearchParamDefaults: with_synonyms = ParamDefault( default=False, title="With custom synonyms", - description="Whether to return matches for custom knowledge box synonyms of the query terms. Note: only supported for `paragraph` and `document` search options.", # noqa: E501 + description="Whether to return matches for custom knowledge box synonyms of the query terms. Note: only supported for `keyword` and `fulltext` search options.", # noqa: E501 ) sort_order = ParamDefault( default=SortOrder.DESC, @@ -586,7 +561,7 @@ class SearchParamDefaults: search_features = ParamDefault( default=None, title="Search features", - description="List of search features to use. Each value corresponds to a lookup into on of the different indexes. `document`, `paragraph` and `vector` are deprecated, please use `fulltext`, `keyword` and `semantic` instead", # noqa + description="List of search features to use. Each value corresponds to a lookup into on of the different indexes", ) rank_fusion = ParamDefault( default=RankFusionName.LEGACY, @@ -601,7 +576,7 @@ class SearchParamDefaults: debug = ParamDefault( default=False, title="Debug mode", - description="If set, the response will include some extra metadata for debugging purposes, like the list of queried nodes.", # noqa + description="If set, the response will include some extra metadata for debugging purposes, like the list of queried nodes.", # noqa: E501 ) show = ParamDefault( default=[ResourceProperties.BASIC], @@ -622,27 +597,27 @@ class SearchParamDefaults: range_creation_start = ParamDefault( default=None, title="Resource creation range start", - description="Resources created before this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa + description="Resources created before this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa: E501 ) range_creation_end = ParamDefault( default=None, title="Resource creation range end", - description="Resources created after this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa + description="Resources created after this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa: E501 ) range_modification_start = ParamDefault( default=None, title="Resource modification range start", - description="Resources modified before this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa + description="Resources modified before this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa: E501 ) range_modification_end = ParamDefault( default=None, title="Resource modification range end", - description="Resources modified after this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa + description="Resources modified after this date will be filtered out of search results. Datetime are represented as a str in ISO 8601 format, like: 2008-09-15T15:53:00+05:00.", # noqa: E501 ) vector = ParamDefault( default=None, title="Search Vector", - description="The vector to perform the search with. If not provided, NucliaDB will use Nuclia Predict API to create the vector off from the query.", # noqa + description="The vector to perform the search with. If not provided, NucliaDB will use Nuclia Predict API to create the vector off from the query.", # noqa: E501 ) vectorset = ParamDefault( default=None, @@ -652,12 +627,12 @@ class SearchParamDefaults: chat_context = ParamDefault( default=None, title="Chat history", - description="Use to rephrase the new LLM query by taking into account the chat conversation history", # noqa + description="Use to rephrase the new LLM query by taking into account the chat conversation history", # noqa: E501 ) chat_features = ParamDefault( default=[ChatOptions.SEMANTIC, ChatOptions.KEYWORD], title="Chat features", - description="Features enabled for the chat endpoint. Semantic search is done if `semantic` (or `vectors`) is included. If `keyword` (or `paragraphs`) is included, the results will include matching paragraphs from the bm25 index. If `relations` is included, a graph of entities related to the answer is returned. `paragraphs` and `vectors` are deprecated, please use `keyword` and `semantic` instead", # noqa + description="Features enabled for the chat endpoint. Semantic search is done if `semantic` is included. If `keyword` is included, the results will include matching paragraphs from the bm25 index. If `relations` is included, a graph of entities related to the answer is returned. `paragraphs` and `vectors` are deprecated, please use `keyword` and `semantic` instead", # noqa: E501 ) suggest_features = ParamDefault( default=[ @@ -670,17 +645,17 @@ class SearchParamDefaults: security = ParamDefault( default=None, title="Security", - description="Security metadata for the request. If not provided, the search request is done without the security lookup phase.", # noqa + description="Security metadata for the request. If not provided, the search request is done without the security lookup phase.", # noqa: E501 ) security_groups = ParamDefault( default=[], title="Security groups", - description="List of security groups to filter search results for. Only resources matching the query and containing the specified security groups will be returned. If empty, all resources will be considered for the search.", # noqa + description="List of security groups to filter search results for. Only resources matching the query and containing the specified security groups will be returned. If empty, all resources will be considered for the search.", # noqa: E501 ) rephrase = ParamDefault( default=False, title="Rephrase query consuming LLMs", - description="Rephrase query consuming LLMs - it will make the query slower", # noqa + description="Rephrase query consuming LLMs - it will make the query slower", # noqa: E501 ) prefer_markdown = ParamDefault( default=False, @@ -803,7 +778,7 @@ class BaseSearchRequest(AuditMetadataBase): min_score: Optional[Union[float, MinScore]] = Field( default=None, title="Minimum score", - description="Minimum score to filter search results. Results with a lower score will be ignored. Accepts either a float or a dictionary with the minimum scores for the bm25 and vector indexes. If a float is provided, it is interpreted as the minimum score for vector index search.", # noqa + description="Minimum score to filter search results. Results with a lower score will be ignored. Accepts either a float or a dictionary with the minimum scores for the bm25 and vector indexes. If a float is provided, it is interpreted as the minimum score for vector index search.", # noqa: E501 ) range_creation_start: Optional[DateTime] = ( SearchParamDefaults.range_creation_start.to_pydantic_field() @@ -864,11 +839,6 @@ class BaseSearchRequest(AuditMetadataBase): ], ) - @field_validator("features", mode="after") - @classmethod - def normalize_features(cls, features: list[SearchOptions]): - return [feature.normalized() for feature in features] - @model_validator(mode="after") def top_k_overwrites_pagination(self): """This method adds support for `top_k` attribute, overwriting @@ -953,7 +923,7 @@ class ChatModel(BaseModel): ) query_context_order: Optional[dict[str, int]] = Field( default=None, - description="The order of the query context elements. This is used to sort the context elements by relevance before sending them to the generative model", # noqa + description="The order of the query context elements. This is used to sort the context elements by relevance before sending them to the generative model", # noqa: E501 ) chat_history: list[ChatContextMessage] = Field( default=[], description="The chat conversation history" @@ -968,7 +938,7 @@ class ChatModel(BaseModel): citations: bool = Field(default=False, description="Whether to include the citations in the answer") citation_threshold: Optional[float] = Field( default=None, - description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.", # noqa + description="If citations is True, this sets the similarity threshold (0 to 1) for paragraphs to be included as citations. Lower values result in more citations. If not provided, Nuclia's default threshold is used.", # noqa: E501 ge=0.0, le=1.0, ) @@ -1048,7 +1018,7 @@ class FieldExtensionStrategy(RagStrategy): name: Literal["field_extension"] = "field_extension" fields: list[str] = Field( title="Fields", - description="List of field ids to extend the context with. It will try to extend the retrieval context with the specified fields in the matching resources. The field ids have to be in the format `{field_type}/{field_name}`, like 'a/title', 'a/summary' for title and summary fields or 't/amend' for a text field named 'amend'.", # noqa + description="List of field ids to extend the context with. It will try to extend the retrieval context with the specified fields in the matching resources. The field ids have to be in the format `{field_type}/{field_name}`, like 'a/title', 'a/summary' for title and summary fields or 't/amend' for a text field named 'amend'.", # noqa: E501 min_length=1, ) @@ -1249,7 +1219,7 @@ class CustomPrompt(BaseModel): system: Optional[str] = Field( default=None, title="System prompt", - description="System prompt given to the generative model responsible of generating the answer. This can help customize the behavior of the model when generating the answer. If not specified, the default model provider's prompt is used.", # noqa + description="System prompt given to the generative model responsible of generating the answer. This can help customize the behavior of the model when generating the answer. If not specified, the default model provider's prompt is used.", # noqa: E501 min_length=1, examples=[ "You are a medical assistant, use medical terminology", @@ -1261,7 +1231,7 @@ class CustomPrompt(BaseModel): user: Optional[str] = Field( default=None, title="User prompt", - description="User prompt given to the generative model responsible of generating the answer. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa + description="User prompt given to the generative model responsible of generating the answer. Use the words {context} and {question} in brackets where you want those fields to be placed, in case you want them in your prompt. Context will be the data returned by the retrieval step and question will be the user's query.", # noqa: E501 min_length=1, examples=[ "Taking into account our previous conversation, and this context: {context} answer this {question}", @@ -1313,7 +1283,7 @@ class AskRequest(AuditMetadataBase): "List of keyword filter expressions to apply to the retrieval step. " "The text block search will only be performed on the documents that contain the specified keywords. " "The filters are case-insensitive, and only alphanumeric characters and spaces are allowed. " - "Filtering examples can be found here: https://docs.nuclia.dev/docs/rag/advanced/search/#filters" # noqa + "Filtering examples can be found here: https://docs.nuclia.dev/docs/rag/advanced/search/#filters" # noqa: E501 ), examples=[ ["NLP", "BERT"], @@ -1325,7 +1295,7 @@ class AskRequest(AuditMetadataBase): min_score: Optional[Union[float, MinScore]] = Field( default=None, title="Minimum score", - description="Minimum score to filter search results. Results with a lower score will be ignored. Accepts either a float or a dictionary with the minimum scores for the bm25 and vector indexes. If a float is provided, it is interpreted as the minimum score for vector index search.", # noqa + description="Minimum score to filter search results. Results with a lower score will be ignored. Accepts either a float or a dictionary with the minimum scores for the bm25 and vector indexes. If a float is provided, it is interpreted as the minimum score for vector index search.", # noqa: E501 ) features: list[ChatOptions] = SearchParamDefaults.chat_features.to_pydantic_field() range_creation_start: Optional[DateTime] = ( @@ -1355,7 +1325,7 @@ class AskRequest(AuditMetadataBase): prompt: Optional[Union[str, CustomPrompt]] = Field( default=None, title="Prompts", - description="Use to customize the prompts given to the generative model. Both system and user prompts can be customized. If a string is provided, it is interpreted as the user prompt.", # noqa + description="Use to customize the prompts given to the generative model. Both system and user prompts can be customized. If a string is provided, it is interpreted as the user prompt.", # noqa: E501 ) rank_fusion: SkipJsonSchema[Union[RankFusionName, RankFusion]] = ( SearchParamDefaults.rank_fusion.to_pydantic_field() @@ -1440,7 +1410,7 @@ class AskRequest(AuditMetadataBase): max_tokens: Optional[Union[int, MaxTokens]] = Field( default=None, title="Maximum LLM tokens to use for the request", - description="Use to limit the amount of tokens used in the LLM context and/or for generating the answer. If not provided, the default maximum tokens of the generative model will be used. If an integer is provided, it is interpreted as the maximum tokens for the answer.", # noqa + description="Use to limit the amount of tokens used in the LLM context and/or for generating the answer. If not provided, the default maximum tokens of the generative model will be used. If an integer is provided, it is interpreted as the maximum tokens for the answer.", # noqa: E501 ) rephrase: bool = Field( @@ -1499,11 +1469,6 @@ def validate_rag_strategies(cls, rag_strategies: list[RagStrategies]) -> list[Ra ) return rag_strategies - @field_validator("features", mode="after") - @classmethod - def normalize_features(cls, features: list[ChatOptions]): - return [feature.normalized() for feature in features] - # Alias (for backwards compatiblity with testbed) class ChatRequest(AskRequest): @@ -1599,7 +1564,7 @@ class FindRequest(BaseSearchRequest): "List of keyword filter expressions to apply to the retrieval step. " "The text block search will only be performed on the documents that contain the specified keywords. " "The filters are case-insensitive, and only alphanumeric characters and spaces are allowed. " - "Filtering examples can be found here: https://docs.nuclia.dev/docs/rag/advanced/search/#filters" # noqa + "Filtering examples can be found here: https://docs.nuclia.dev/docs/rag/advanced/search/#filters" # noqa: E501 ), examples=[ ["NLP", "BERT"], @@ -1708,7 +1673,7 @@ class KnowledgeboxFindResults(JsonBaseModel): best_matches: list[str] = Field( default=[], title="Best matches", - description="List of ids of best matching paragraphs. The list is sorted by decreasing relevance (most relevant first).", # noqa + description="List of ids of best matching paragraphs. The list is sorted by decreasing relevance (most relevant first).", # noqa: E501 ) @@ -1722,7 +1687,7 @@ def to_proto(self) -> int: class FeedbackRequest(BaseModel): ident: str = Field( title="Request identifier", - description="Id of the request to provide feedback for. This id is returned in the response header `Nuclia-Learning-Id` of the chat endpoint.", # noqa + description="Id of the request to provide feedback for. This id is returned in the response header `Nuclia-Learning-Id` of the chat endpoint.", # noqa: E501 ) good: bool = Field(title="Good", description="Whether the result was good or not") task: FeedbackTasks = Field( @@ -1817,11 +1782,11 @@ class SyncAskResponse(BaseModel): answer_json: Optional[dict[str, Any]] = Field( default=None, title="Answer JSON", - description="The generative JSON answer to the query. This is returned only if the answer_json_schema parameter is provided in the request.", # noqa + description="The generative JSON answer to the query. This is returned only if the answer_json_schema parameter is provided in the request.", # noqa: E501 ) status: str = Field( title="Status", - description="The status of the query execution. It can be 'success', 'error' or 'no_context'", # noqa + description="The status of the query execution. It can be 'success', 'error' or 'no_context'", # noqa: E501 ) retrieval_results: KnowledgeboxFindResults = Field( title="Retrieval results", @@ -1840,7 +1805,7 @@ class SyncAskResponse(BaseModel): learning_id: str = Field( default="", title="Learning id", - description="The id of the learning request. This id can be used to provide feedback on the learning process.", # noqa + description="The id of the learning request. This id can be used to provide feedback on the learning process.", # noqa: E501 ) relations: Optional[Relations] = Field( default=None, @@ -1860,7 +1825,7 @@ class SyncAskResponse(BaseModel): metadata: Optional[SyncAskMetadata] = Field( default=None, title="Metadata", - description="Metadata of the query execution. This includes the number of tokens used in the LLM context and answer, and the timings of the generative model.", # noqa + description="Metadata of the query execution. This includes the number of tokens used in the LLM context and answer, and the timings of the generative model.", # noqa: E501 ) error_details: Optional[str] = Field( default=None, diff --git a/nucliadb_models/tests/test_search.py b/nucliadb_models/tests/test_search.py index 7fe7bdd89f..5050637046 100644 --- a/nucliadb_models/tests/test_search.py +++ b/nucliadb_models/tests/test_search.py @@ -84,62 +84,11 @@ def test_base_search_request_top_k(): assert request.page_size == 100 -def test_search_request_features_normalization(): - request = search.SearchRequest( - features=[ - search.SearchOptions.VECTOR, - search.SearchOptions.PARAGRAPH, - search.SearchOptions.DOCUMENT, - search.SearchOptions.RELATIONS, - ] - ) - assert request.features == [ - search.SearchOptions.SEMANTIC, - search.SearchOptions.KEYWORD, - search.SearchOptions.FULLTEXT, - search.SearchOptions.RELATIONS, - ] - - -def test_find_request_features_normalization(): - request = search.FindRequest( - features=[ - search.SearchOptions.VECTOR, - search.SearchOptions.PARAGRAPH, - search.SearchOptions.RELATIONS, - ] - ) - assert request.features == [ - search.SearchOptions.SEMANTIC, - search.SearchOptions.KEYWORD, - search.SearchOptions.RELATIONS, - ] - - def test_find_request_fulltext_feature_not_allowed(): - with pytest.raises(ValidationError): - search.FindRequest(features=[search.SearchOptions.DOCUMENT]) - with pytest.raises(ValidationError): search.FindRequest(features=[search.SearchOptions.FULLTEXT]) -def test_chat_request_features_normalization(): - request = search.AskRequest( - query="my-query", - features=[ - search.ChatOptions.VECTORS, - search.ChatOptions.PARAGRAPHS, - search.ChatOptions.RELATIONS, - ], - ) - assert request.features == [ - search.ChatOptions.SEMANTIC, - search.ChatOptions.KEYWORD, - search.ChatOptions.RELATIONS, - ] - - # Rank fusion diff --git a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py index c0a57f31ee..1cea1ef279 100644 --- a/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py +++ b/nucliadb_sdk/src/nucliadb_sdk/v2/docstrings.py @@ -200,7 +200,7 @@ class Docstring(BaseModel): description="Search on the full text index", code=""">>> from nucliadb_sdk import * >>> sdk = NucliaDBSDK(api_key="api-key") ->>> resp = sdk.search(kbid="mykbid", query="Site Reliability", features=["document"]) +>>> resp = sdk.search(kbid="mykbid", query="Site Reliability", features=["fulltext"]) >>> rid = resp.fulltext.results[0].rid >>> resp.resources[rid].title 'The Site Reliability Workbook.pdf' diff --git a/nucliadb_sdk/tests/test_search.py b/nucliadb_sdk/tests/test_search.py index 10aded2c7c..142a631d4c 100644 --- a/nucliadb_sdk/tests/test_search.py +++ b/nucliadb_sdk/tests/test_search.py @@ -22,6 +22,7 @@ import nucliadb_sdk from nucliadb_models.resource import KnowledgeBoxObj +from nucliadb_models.search import SearchOptions TESTING_IN_CI = os.environ.get("CI") == "true" @@ -252,7 +253,7 @@ def test_search_resource(kb: KnowledgeBoxObj, sdk: nucliadb_sdk.NucliaDB): results = sdk.search( kbid=kb.uuid, - features=["document"], + features=[SearchOptions.FULLTEXT], faceted=["/classification.labels"], page_size=0, ) @@ -266,7 +267,7 @@ def test_search_resource(kb: KnowledgeBoxObj, sdk: nucliadb_sdk.NucliaDB): resources = sdk.search( kbid=kb.uuid, - features=["document"], + features=[SearchOptions.FULLTEXT], faceted=["/classification.labels/emoji"], page_size=0, )