diff --git a/haystack/pipelines/standard_pipelines.py b/haystack/pipelines/standard_pipelines.py index 203cf5e29d..de18bd8ac5 100644 --- a/haystack/pipelines/standard_pipelines.py +++ b/haystack/pipelines/standard_pipelines.py @@ -717,27 +717,43 @@ def __init__(self, document_store: BaseDocumentStore): self.pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Query"]) self.document_store = document_store - def run(self, document_ids: List[str], top_k: int = 5): + def run( + self, + document_ids: List[str], + filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, + top_k: int = 5, + index: Optional[str] = None, + ): """ :param document_ids: document ids + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions :param top_k: How many documents id to return against single document + :param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used. """ similar_documents: list = [] self.document_store.return_embedding = True # type: ignore - for document in self.document_store.get_documents_by_id(ids=document_ids): + for document in self.document_store.get_documents_by_id(ids=document_ids, index=index): similar_documents.append( self.document_store.query_by_embedding( - query_emb=document.embedding, return_embedding=False, top_k=top_k + query_emb=document.embedding, filters=filters, return_embedding=False, top_k=top_k, index=index ) ) self.document_store.return_embedding = False # type: ignore return similar_documents - def run_batch(self, document_ids: List[str], top_k: int = 5): # type: ignore + def run_batch( # type: ignore + self, + document_ids: List[str], + filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, + top_k: int = 5, + index: Optional[str] = None, + ): """ :param document_ids: document ids + :param filters: Optional filters to narrow down the search space to documents whose metadata fulfill certain conditions :param top_k: How many documents id to return against single document + :param index: Optionally specify the name of index to query the document from. If None, the DocumentStore's default index (self.index) will be used. """ - return self.run(document_ids=document_ids, top_k=top_k) + return self.run(document_ids=document_ids, filters=filters, top_k=top_k, index=index) diff --git a/test/pipelines/test_standard_pipelines.py b/test/pipelines/test_standard_pipelines.py index 21034f22e7..958f52361a 100644 --- a/test/pipelines/test_standard_pipelines.py +++ b/test/pipelines/test_standard_pipelines.py @@ -200,6 +200,39 @@ def test_most_similar_documents_pipeline(retriever, document_store): assert isinstance(document.content, str) +@pytest.mark.parametrize( + "retriever,document_store", [("embedding", "milvus1"), ("embedding", "elasticsearch")], indirect=True +) +def test_most_similar_documents_pipeline_with_filters(retriever, document_store): + documents = [ + {"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}}, + {"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}}, + {"content": "Sample text for document-3", "meta": {"source": "wiki3"}}, + {"content": "Sample text for document-4", "meta": {"source": "wiki4"}}, + {"content": "Sample text for document-5", "meta": {"source": "wiki5"}}, + ] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + docs_id: list = ["a", "b"] + filters = {"source": ["wiki3", "wiki4", "wiki5"]} + pipeline = MostSimilarDocumentsPipeline(document_store=document_store) + list_of_documents = pipeline.run(document_ids=docs_id, filters=filters) + + assert len(list_of_documents[0]) > 1 + assert isinstance(list_of_documents, list) + assert len(list_of_documents) == len(docs_id) + + for another_list in list_of_documents: + assert isinstance(another_list, list) + for document in another_list: + assert isinstance(document, Document) + assert isinstance(document.id, str) + assert isinstance(document.content, str) + assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"] + + @pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True) def test_most_similar_documents_pipeline_batch(retriever, document_store): documents = [ @@ -229,6 +262,37 @@ def test_most_similar_documents_pipeline_batch(retriever, document_store): assert isinstance(document.content, str) +@pytest.mark.parametrize("retriever,document_store", [("embedding", "memory")], indirect=True) +def test_most_similar_documents_pipeline_with_filters_batch(retriever, document_store): + documents = [ + {"id": "a", "content": "Sample text for document-1", "meta": {"source": "wiki1"}}, + {"id": "b", "content": "Sample text for document-2", "meta": {"source": "wiki2"}}, + {"content": "Sample text for document-3", "meta": {"source": "wiki3"}}, + {"content": "Sample text for document-4", "meta": {"source": "wiki4"}}, + {"content": "Sample text for document-5", "meta": {"source": "wiki5"}}, + ] + + document_store.write_documents(documents) + document_store.update_embeddings(retriever) + + docs_id: list = ["a", "b"] + filters = {"source": ["wiki3", "wiki4", "wiki5"]} + pipeline = MostSimilarDocumentsPipeline(document_store=document_store) + list_of_documents = pipeline.run_batch(document_ids=docs_id, filters=filters) + + assert len(list_of_documents[0]) > 1 + assert isinstance(list_of_documents, list) + assert len(list_of_documents) == len(docs_id) + + for another_list in list_of_documents: + assert isinstance(another_list, list) + for document in another_list: + assert isinstance(document, Document) + assert isinstance(document.id, str) + assert isinstance(document.content, str) + assert document.meta["source"] in ["wiki3", "wiki4", "wiki5"] + + @pytest.mark.integration @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) def test_most_similar_documents_pipeline_save(tmpdir, document_store_with_docs):