Skip to content

Commit

Permalink
Merge branch 'main' into text_fastfields
Browse files Browse the repository at this point in the history
  • Loading branch information
javitonino authored Sep 19, 2024
2 parents 1241ca9 + 292bf3e commit 88e6a95
Show file tree
Hide file tree
Showing 28 changed files with 452 additions and 111 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: Swatinem/rust-cache@v2
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.12.5"
- name: Restore venv
uses: actions/cache/restore@v4
with:
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ jobs:
push-docker:
name: Build and push nucliadb docker image
runs-on: ubuntu-latest
needs: build_wheels

steps:
- name: Checkout
Expand Down
8 changes: 8 additions & 0 deletions nucliadb/src/nucliadb/search/api/v1/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ async def ask_knowledgebox_endpoint(
)


def default(obj):
"Convert sets to lists when dumping json"
if isinstance(obj, set):
return list(obj)
raise TypeError


@handled_ask_exceptions
async def create_ask_response(
kbid: str,
Expand All @@ -87,6 +94,7 @@ async def create_ask_response(
origin=origin,
resource=resource,
)

headers = {
"NUCLIA-LEARNING-ID": ask_result.nuclia_learning_id or "unknown",
"Access-Control-Expose-Headers": "NUCLIA-LEARNING-ID",
Expand Down
1 change: 1 addition & 0 deletions nucliadb/src/nucliadb/search/api/v1/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ async def search(
autofilter=item.autofilter,
security=item.security,
rephrase=item.rephrase,
rephrase_prompt=item.rephrase_prompt,
)
pb_query, incomplete_results, autofilters = await query_parser.parse()

Expand Down
34 changes: 28 additions & 6 deletions nucliadb/src/nucliadb/search/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class RephraseMissingContextError(Exception):

DUMMY_REPHRASE_QUERY = "This is a rephrased query"
DUMMY_LEARNING_ID = "00"
DUMMY_LEARNING_MODEL = "chatgpt"


PUBLIC_PREDICT = "/api/v1/predict"
Expand All @@ -91,6 +92,7 @@ class RephraseMissingContextError(Exception):
FEEDBACK = "/feedback"

NUCLIA_LEARNING_ID_HEADER = "NUCLIA-LEARNING-ID"
NUCLIA_LEARNING_MODEL_HEADER = "NUCLIA-LEARNING-MODEL"


predict_observer = metrics.Observer(
Expand Down Expand Up @@ -325,7 +327,7 @@ async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:
@predict_observer.wrap({"type": "chat_ndjson"})
async def chat_query_ndjson(
self, kbid: str, item: ChatModel
) -> tuple[str, AsyncIterator[GenerativeChunk]]:
) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
"""
Chat query using the new stream format
Format specs: https://github.com/ndjson/ndjson-spec
Expand All @@ -350,7 +352,8 @@ async def chat_query_ndjson(
)
await self.check_response(resp, expected_status=200)
ident = resp.headers.get(NUCLIA_LEARNING_ID_HEADER)
return ident, get_chat_ndjson_generator(resp)
model = resp.headers.get(NUCLIA_LEARNING_MODEL_HEADER)
return ident, model, get_chat_ndjson_generator(resp)

@predict_observer.wrap({"type": "query"})
async def query(
Expand All @@ -359,8 +362,24 @@ async def query(
sentence: str,
semantic_model: Optional[str] = None,
generative_model: Optional[str] = None,
rephrase: Optional[bool] = False,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
) -> QueryInfo:
"""
Query endpoint: returns information to be used by NucliaDB at retrieval time, for instance:
- The embeddings
- The entities
- The stop words
- The semantic threshold
- etc.
:param kbid: KnowledgeBox ID
:param sentence: The query sentence
:param semantic_model: The semantic model to use to generate the embeddings
:param generative_model: The generative model that will be used to generate the answer
:param rephrase: If the query should be rephrased before calculating the embeddings for a better retrieval
:param rephrase_prompt: Custom prompt to use for rephrasing
"""
try:
self.check_nua_key_is_configured_for_onprem()
except NUAKeyMissingError:
Expand All @@ -372,6 +391,8 @@ async def query(
"text": sentence,
"rephrase": str(rephrase),
}
if rephrase_prompt is not None:
params["rephrase_prompt"] = rephrase_prompt
if semantic_model is not None:
params["semantic_models"] = [semantic_model]
if generative_model is not None:
Expand Down Expand Up @@ -473,22 +494,23 @@ async def rephrase_query(self, kbid: str, item: RephraseModel) -> str:

async def chat_query_ndjson(
self, kbid: str, item: ChatModel
) -> tuple[str, AsyncIterator[GenerativeChunk]]:
) -> tuple[str, str, AsyncIterator[GenerativeChunk]]:
self.calls.append(("chat_query_ndjson", item))

async def generate():
for item in self.ndjson_answer:
yield GenerativeChunk.model_validate_json(item)

return (DUMMY_LEARNING_ID, generate())
return (DUMMY_LEARNING_ID, DUMMY_LEARNING_MODEL, generate())

async def query(
self,
kbid: str,
sentence: str,
semantic_model: Optional[str] = None,
generative_model: Optional[str] = None,
rephrase: Optional[bool] = False,
rephrase: bool = False,
rephrase_prompt: Optional[str] = None,
) -> QueryInfo:
self.calls.append(("query", sentence))

Expand Down
97 changes: 93 additions & 4 deletions nucliadb/src/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,29 @@
CitationsAskResponseItem,
DebugAskResponseItem,
ErrorAskResponseItem,
FindRequest,
JSONAskResponseItem,
KnowledgeboxFindResults,
MetadataAskResponseItem,
MinScore,
NucliaDBClientType,
PrequeriesAskResponseItem,
PreQueriesStrategy,
PreQuery,
PreQueryResult,
PromptContext,
PromptContextOrder,
RagStrategyName,
Relations,
RelationsAskResponseItem,
RetrievalAskResponseItem,
SearchOptions,
StatusAskResponseItem,
SyncAskMetadata,
SyncAskResponse,
UserPrompt,
parse_custom_prompt,
parse_rephrase_prompt,
)
from nucliadb_telemetry import errors
from nucliadb_utils.exceptions import LimitsExceededError
Expand Down Expand Up @@ -132,7 +136,7 @@ def status_code(self) -> AnswerStatusCode:

@property
def status_error_details(self) -> Optional[str]:
if self._status is None:
if self._status is None: # pragma: no cover
return None
return self._status.details

Expand Down Expand Up @@ -459,9 +463,8 @@ async def ask(
prompt_context_images,
) = await prompt_context_builder.build()

custom_prompt = parse_custom_prompt(ask_request)

# Make the chat request to the predict API
custom_prompt = parse_custom_prompt(ask_request)
chat_model = ChatModel(
user_id=user_id,
system=custom_prompt.system,
Expand All @@ -472,14 +475,19 @@ async def ask(
question=user_query,
truncate=True,
citations=ask_request.citations,
citation_threshold=ask_request.citation_threshold,
generative_model=ask_request.generative_model,
max_tokens=query_parser.get_max_tokens_answer(),
query_context_images=prompt_context_images,
json_schema=ask_request.answer_json_schema,
)
with metrics.time("stream_start"):
predict = get_predict()
nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson(kbid, chat_model)
(
nuclia_learning_id,
nuclia_learning_model,
predict_answer_stream,
) = await predict.chat_query_ndjson(kbid, chat_model)

auditor = ChatAuditor(
kbid=kbid,
Expand All @@ -492,6 +500,7 @@ async def ask(
learning_id=nuclia_learning_id,
query_context=prompt_context,
query_context_order=prompt_context_order,
model=nuclia_learning_model,
)
return AskResult(
kbid=kbid,
Expand Down Expand Up @@ -670,6 +679,8 @@ async def retrieval_in_resource(
)

prequeries = parse_prequeries(ask_request)
if prequeries is None and ask_request.answer_json_schema is not None and main_query == "":
prequeries = calculate_prequeries_for_json_schema(ask_request)

# Make sure the retrieval is scoped to the resource if provided
ask_request.resource_filters = [resource]
Expand Down Expand Up @@ -703,3 +714,81 @@ async def retrieval_in_resource(
query_parser=query_parser,
main_query_weight=prequeries.main_query_weight if prequeries is not None else 1.0,
)


def calculate_prequeries_for_json_schema(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
"""
This function generates a PreQueriesStrategy with a query for each property in the JSON schema
found in ask_request.answer_json_schema.
This is useful for the use-case where the user is asking for a structured answer on a corpus
that is too big to send to the generative model.
For instance, a JSON schema like this:
{
"name": "book_ordering",
"description": "Structured answer for a book to order",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "The title of the book"
},
"author": {
"type": "string",
"description": "The author of the book"
},
},
"required": ["title", "author"]
}
}
Will generate a PreQueriesStrategy with 2 queries, one for each property in the JSON schema, with equal weights
[
PreQuery(request=FindRequest(query="The title of the book", ...), weight=1.0),
PreQuery(request=FindRequest(query="The author of the book", ...), weight=1.0),
]
"""
prequeries: list[PreQuery] = []
json_schema = ask_request.answer_json_schema or {}
features = []
if ChatOptions.SEMANTIC in ask_request.features:
features.append(SearchOptions.SEMANTIC)
if ChatOptions.KEYWORD in ask_request.features:
features.append(SearchOptions.KEYWORD)

properties = json_schema.get("parameters", {}).get("properties", {})
if len(properties) == 0: # pragma: no cover
return None
for prop_name, prop_def in properties.items():
query = prop_name
if prop_def.get("description"):
query += f": {prop_def['description']}"
req = FindRequest(
query=query,
features=features,
filters=[],
keyword_filters=[],
page_number=0,
page_size=10,
min_score=ask_request.min_score,
vectorset=ask_request.vectorset,
highlight=False,
debug=False,
show=[],
with_duplicates=False,
with_synonyms=False,
resource_filters=[], # to be filled with the resource filter
rephrase=ask_request.rephrase,
rephrase_prompt=parse_rephrase_prompt(ask_request),
security=ask_request.security,
autofilter=False,
)
prequery = PreQuery(
request=req,
weight=1.0,
)
prequeries.append(prequery)
strategy = PreQueriesStrategy(queries=prequeries)
ask_request.rag_strategies = [strategy]
return strategy
Loading

0 comments on commit 88e6a95

Please sign in to comment.