From fe7ad44c3e349fd57d96e823fb8a7970255897c5 Mon Sep 17 00:00:00 2001 From: Renu Rozera Date: Tue, 7 May 2024 14:47:38 -0700 Subject: [PATCH 1/2] Add source metadata from API response to Document metadata --- libs/aws/README.md | 4 +- libs/aws/langchain_aws/retrievers/bedrock.py | 13 +++-- .../test_amazon_knowledgebases_retriever.py | 16 +++++- .../unit_tests/retrievers/test_bedrock.py | 54 +++++++++++++++++++ 4 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 libs/aws/tests/unit_tests/retrievers/test_bedrock.py diff --git a/libs/aws/README.md b/libs/aws/README.md index 62d0bb59..1ad6b4e1 100644 --- a/libs/aws/README.md +++ b/libs/aws/README.md @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever( retriever.get_relevant_documents(query="What is the meaning of life?") ``` -`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. +`AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases. ```python from langchain_aws import AmazonKnowledgeBasesRetriever @@ -65,4 +65,4 @@ retriever = AmazonKnowledgeBasesRetriever( ) retriever.get_relevant_documents(query="What is the meaning of life?") -``` \ No newline at end of file +``` diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index 068b904c..f313f7fd 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -114,13 +114,16 @@ def _get_relevant_documents( results = response["retrievalResults"] documents = [] for result in results: + content = result["content"]["text"] + result.pop("content") + if "score" not in result: + result["score"] = 0 + if "metadata" in result: + result["source_metadata"] = result.pop("metadata") documents.append( Document( - page_content=result["content"]["text"], - metadata={ - "location": result["location"], - "score": result["score"] if "score" in result else 0, - }, + page_content=content, + metadata=result, ) ) diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py index f66a158f..ae48ffef 100644 --- a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -34,6 +34,10 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore "score": 0.8, }, {"content": {"text": "This is the third result."}, "location": "location3"}, + { + "content": {"text": "This is the fourth result."}, + "metadata": {"key1": "value1", "key2": "value2"}, + }, ] } mock_client.retrieve.return_value = response @@ -53,9 +57,19 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore page_content="This is the third result.", metadata={"location": "location3", "score": 0.0}, ), + Document( + page_content="This is the fourth result.", + metadata={ + "score": 0.0, + "source_metadata": { + "key1": "value1", + "key2": "value2", + }, + }, + ), ] - documents = retriever.get_relevant_documents(query) + documents = retriever.invoke(query) assert documents == expected_documents diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py new file mode 100644 index 00000000..b205ee53 --- /dev/null +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock + +import pytest +from langchain_community.retrievers import AmazonKnowledgeBasesRetriever +from langchain_core.documents import Document + + +@pytest.fixture +def mock_client(): + return MagicMock() + + +@pytest.fixture +def mock_retriever_config(): + return {"vectorSearchConfiguration": {"numberOfResults": 4}} + + +@pytest.fixture +def amazon_retriever(mock_client, mock_retriever_config): + return AmazonKnowledgeBasesRetriever( + knowledge_base_id="test_kb_id", + retrieval_config=mock_retriever_config, + client=mock_client, + ) + + +def test_retriever_invoke(amazon_retriever, mock_client): + query = "test query" + mock_client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + documents = amazon_retriever.invoke(query, run_manager=None) + + assert len(documents) == 3 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result1" + assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}} + assert documents[1].page_content == "result2" + assert documents[1].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + } + assert documents[2].page_content == "result3" + assert documents[2].metadata == {"score": 0} From 6b490424f5ffdf10ad315b838d7c781650d6b1d8 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 7 May 2024 18:01:28 -0700 Subject: [PATCH 2/2] Fixed liniting. --- libs/aws/tests/unit_tests/retrievers/test_bedrock.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index b205ee53..007ad139 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -1,9 +1,12 @@ +# type: ignore + from unittest.mock import MagicMock import pytest -from langchain_community.retrievers import AmazonKnowledgeBasesRetriever from langchain_core.documents import Document +from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever + @pytest.fixture def mock_client():