diff --git a/libs/aws/langchain_aws/retrievers/bedrock.py b/libs/aws/langchain_aws/retrievers/bedrock.py index f313f7fd..58906e0a 100644 --- a/libs/aws/langchain_aws/retrievers/bedrock.py +++ b/libs/aws/langchain_aws/retrievers/bedrock.py @@ -5,8 +5,9 @@ from botocore.exceptions import UnknownServiceError from langchain_core.callbacks import CallbackManagerForRetrieverRun from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.retrievers import BaseRetriever +from typing_extensions import Annotated class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg] @@ -59,6 +60,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): endpoint_url: Optional[str] = None client: Any retrieval_config: RetrievalConfig + min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)] @root_validator(pre=True) def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -103,6 +105,23 @@ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]: "profile name are valid." ) from e + def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: + """ + Filter out the records that have a score confidence + less than the required threshold. + """ + if not self.min_score_confidence: + return docs + filtered_docs = [ + item + for item in docs + if ( + item.metadata.get("score") is not None + and item.metadata.get("score", 0.0) >= self.min_score_confidence + ) + ] + return filtered_docs + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: @@ -127,4 +146,4 @@ def _get_relevant_documents( ) ) - return documents + return self._filter_by_score_confidence(docs=documents) diff --git a/libs/aws/langchain_aws/retrievers/kendra.py b/libs/aws/langchain_aws/retrievers/kendra.py index b4480cae..5e7b5fe1 100644 --- a/libs/aws/langchain_aws/retrievers/kendra.py +++ b/libs/aws/langchain_aws/retrievers/kendra.py @@ -444,7 +444,7 @@ def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]: def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]: """ Filter out the records that have a score confidence - greater than the required threshold. + less than the required threshold. """ if not self.min_score_confidence: return docs diff --git a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py index ae48ffef..54eb8bf3 100644 --- a/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py +++ b/libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py @@ -17,6 +17,7 @@ def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever: knowledge_base_id="test-knowledge-base", client=mock_client, retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type] + min_score_confidence=0.0, ) @@ -78,3 +79,44 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore knowledgeBaseId="test-knowledge-base", retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}}, ) + + +def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # type: ignore[no-untyped-def] + response = { + "retrievalResults": [ + { + "content": {"text": "This is the first result."}, + "location": "location1", + "score": 0.9, + }, + { + "content": {"text": "This is the second result."}, + "location": "location2", + "score": 0.8, + }, + {"content": {"text": "This is the third result."}, "location": "location3"}, + { + "content": {"text": "This is the fourth result."}, + "metadata": {"key1": "value1", "key2": "value2"}, + }, + ] + } + mock_client.retrieve.return_value = response + + query = "test query" + + expected_documents = [ + Document( + page_content="This is the first result.", + metadata={"location": "location1", "score": 0.9}, + ), + Document( + page_content="This is the second result.", + metadata={"location": "location2", "score": 0.8}, + ), + ] + + retriever.min_score_confidence = 0.80 + documents = retriever.invoke(query) + + assert documents == expected_documents diff --git a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py index 007ad139..b243db10 100644 --- a/libs/aws/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/aws/tests/unit_tests/retrievers/test_bedrock.py @@ -55,3 +55,31 @@ def test_retriever_invoke(amazon_retriever, mock_client): } assert documents[2].page_content == "result3" assert documents[2].metadata == {"score": 0} + + +def test_retriever_invoke_with_score(amazon_retriever, mock_client): + query = "test query" + mock_client.retrieve.return_value = { + "retrievalResults": [ + {"content": {"text": "result1"}, "metadata": {"key": "value1"}}, + { + "content": {"text": "result2"}, + "metadata": {"key": "value2"}, + "score": 1, + "location": "testLocation", + }, + {"content": {"text": "result3"}}, + ] + } + + amazon_retriever.min_score_confidence = 0.6 + documents = amazon_retriever.invoke(query, run_manager=None) + + assert len(documents) == 1 + assert isinstance(documents[0], Document) + assert documents[0].page_content == "result2" + assert documents[0].metadata == { + "score": 1, + "source_metadata": {"key": "value2"}, + "location": "testLocation", + }