From 2e2b9c76d9eed80698124488a822bcbe563fdc46 Mon Sep 17 00:00:00 2001 From: Manuel Rech <63170478+manuelrech@users.noreply.github.com> Date: Fri, 3 Nov 2023 12:15:02 +1100 Subject: [PATCH] Keep also original query - multi_query.py (#12696) When you use a MultiQuery it might be useful to use the original query as well as the newly generated ones to maximise the changes to retriever the correct document. I haven't created an issue, it seems a very small and easy thing. --------- Co-authored-by: Bagatur --- libs/langchain/langchain/retrievers/multi_query.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index 5abdddc2e2c78..3ac2beaa445be 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -61,6 +61,8 @@ class MultiQueryRetriever(BaseRetriever): llm_chain: LLMChain verbose: bool = True parser_key: str = "lines" + include_original: bool = False + """Whether to include the original query in the list of generated queries.""" @classmethod def from_llm( @@ -69,12 +71,15 @@ def from_llm( llm: BaseLLM, prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, parser_key: str = "lines", + include_original: bool = False, ) -> "MultiQueryRetriever": """Initialize from llm using default template. Args: retriever: retriever to query documents from llm: llm for query generation using DEFAULT_QUERY_PROMPT + include_original: Whether to include the original query in the list of + generated queries. Returns: MultiQueryRetriever @@ -85,6 +90,7 @@ def from_llm( retriever=retriever, llm_chain=llm_chain, parser_key=parser_key, + include_original=include_original, ) async def _aget_relevant_documents( @@ -102,6 +108,8 @@ async def _aget_relevant_documents( Unique union of relevant documents from all generated queries """ queries = await self.agenerate_queries(query, run_manager) + if self.include_original: + queries.append(query) documents = await self.aretrieve_documents(queries, run_manager) return self.unique_union(documents) @@ -160,6 +168,8 @@ def _get_relevant_documents( Unique union of relevant documents from all generated queries """ queries = self.generate_queries(query, run_manager) + if self.include_original: + queries.append(query) documents = self.retrieve_documents(queries, run_manager) return self.unique_union(documents)