Skip to content

Commit

Permalink
Merge pull request #49 from ihmaws/dev/ihm/add-score-filtering-for-be…
Browse files Browse the repository at this point in the history
…drock-kb

Add min_score_confidence support for the Bedrock KB retriver
  • Loading branch information
3coins authored May 21, 2024
2 parents 4a88b7f + 351f17a commit 3bc0c39
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 3 deletions.
23 changes: 21 additions & 2 deletions libs/aws/langchain_aws/retrievers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from botocore.exceptions import UnknownServiceError
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from typing_extensions import Annotated


class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
Expand Down Expand Up @@ -59,6 +60,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever):
endpoint_url: Optional[str] = None
client: Any
retrieval_config: RetrievalConfig
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]

@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -103,6 +105,23 @@ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"profile name are valid."
) from e

def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
"""
Filter out the records that have a score confidence
less than the required threshold.
"""
if not self.min_score_confidence:
return docs
filtered_docs = [
item
for item in docs
if (
item.metadata.get("score") is not None
and item.metadata.get("score", 0.0) >= self.min_score_confidence
)
]
return filtered_docs

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
Expand All @@ -127,4 +146,4 @@ def _get_relevant_documents(
)
)

return documents
return self._filter_by_score_confidence(docs=documents)
2 changes: 1 addition & 1 deletion libs/aws/langchain_aws/retrievers/kendra.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
"""
Filter out the records that have a score confidence
greater than the required threshold.
less than the required threshold.
"""
if not self.min_score_confidence:
return docs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever:
knowledge_base_id="test-knowledge-base",
client=mock_client,
retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type]
min_score_confidence=0.0,
)


Expand Down Expand Up @@ -78,3 +79,44 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore
knowledgeBaseId="test-knowledge-base",
retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}},
)


def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # type: ignore[no-untyped-def]
response = {
"retrievalResults": [
{
"content": {"text": "This is the first result."},
"location": "location1",
"score": 0.9,
},
{
"content": {"text": "This is the second result."},
"location": "location2",
"score": 0.8,
},
{"content": {"text": "This is the third result."}, "location": "location3"},
{
"content": {"text": "This is the fourth result."},
"metadata": {"key1": "value1", "key2": "value2"},
},
]
}
mock_client.retrieve.return_value = response

query = "test query"

expected_documents = [
Document(
page_content="This is the first result.",
metadata={"location": "location1", "score": 0.9},
),
Document(
page_content="This is the second result.",
metadata={"location": "location2", "score": 0.8},
),
]

retriever.min_score_confidence = 0.80
documents = retriever.invoke(query)

assert documents == expected_documents
28 changes: 28 additions & 0 deletions libs/aws/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,31 @@ def test_retriever_invoke(amazon_retriever, mock_client):
}
assert documents[2].page_content == "result3"
assert documents[2].metadata == {"score": 0}


def test_retriever_invoke_with_score(amazon_retriever, mock_client):
query = "test query"
mock_client.retrieve.return_value = {
"retrievalResults": [
{"content": {"text": "result1"}, "metadata": {"key": "value1"}},
{
"content": {"text": "result2"},
"metadata": {"key": "value2"},
"score": 1,
"location": "testLocation",
},
{"content": {"text": "result3"}},
]
}

amazon_retriever.min_score_confidence = 0.6
documents = amazon_retriever.invoke(query, run_manager=None)

assert len(documents) == 1
assert isinstance(documents[0], Document)
assert documents[0].page_content == "result2"
assert documents[0].metadata == {
"score": 1,
"source_metadata": {"key": "value2"},
"location": "testLocation",
}

0 comments on commit 3bc0c39

Please sign in to comment.