Skip to content

Commit

Permalink
Merge pull request #40 from rozerarenu/main
Browse files Browse the repository at this point in the history
Add source metadata from API response to Document metadata
  • Loading branch information
3coins authored May 8, 2024
2 parents 123c720 + 6b49042 commit 4b2761e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 8 deletions.
4 changes: 2 additions & 2 deletions libs/aws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ retriever = AmazonKendraRetriever(
retriever.get_relevant_documents(query="What is the meaning of life?")
```

`AmazonKnowlegeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases.
`AmazonKnowledgeBasesRetriever` class provides a retriever to connect with Amazon Knowledge Bases.

```python
from langchain_aws import AmazonKnowledgeBasesRetriever
Expand All @@ -65,4 +65,4 @@ retriever = AmazonKnowledgeBasesRetriever(
)

retriever.get_relevant_documents(query="What is the meaning of life?")
```
```
13 changes: 8 additions & 5 deletions libs/aws/langchain_aws/retrievers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,16 @@ def _get_relevant_documents(
results = response["retrievalResults"]
documents = []
for result in results:
content = result["content"]["text"]
result.pop("content")
if "score" not in result:
result["score"] = 0
if "metadata" in result:
result["source_metadata"] = result.pop("metadata")
documents.append(
Document(
page_content=result["content"]["text"],
metadata={
"location": result["location"],
"score": result["score"] if "score" in result else 0,
},
page_content=content,
metadata=result,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore
"score": 0.8,
},
{"content": {"text": "This is the third result."}, "location": "location3"},
{
"content": {"text": "This is the fourth result."},
"metadata": {"key1": "value1", "key2": "value2"},
},
]
}
mock_client.retrieve.return_value = response
Expand All @@ -53,9 +57,19 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore
page_content="This is the third result.",
metadata={"location": "location3", "score": 0.0},
),
Document(
page_content="This is the fourth result.",
metadata={
"score": 0.0,
"source_metadata": {
"key1": "value1",
"key2": "value2",
},
},
),
]

documents = retriever.get_relevant_documents(query)
documents = retriever.invoke(query)

assert documents == expected_documents

Expand Down
57 changes: 57 additions & 0 deletions libs/aws/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# type: ignore

from unittest.mock import MagicMock

import pytest
from langchain_core.documents import Document

from langchain_aws.retrievers import AmazonKnowledgeBasesRetriever


@pytest.fixture
def mock_client():
return MagicMock()


@pytest.fixture
def mock_retriever_config():
return {"vectorSearchConfiguration": {"numberOfResults": 4}}


@pytest.fixture
def amazon_retriever(mock_client, mock_retriever_config):
return AmazonKnowledgeBasesRetriever(
knowledge_base_id="test_kb_id",
retrieval_config=mock_retriever_config,
client=mock_client,
)


def test_retriever_invoke(amazon_retriever, mock_client):
query = "test query"
mock_client.retrieve.return_value = {
"retrievalResults": [
{"content": {"text": "result1"}, "metadata": {"key": "value1"}},
{
"content": {"text": "result2"},
"metadata": {"key": "value2"},
"score": 1,
"location": "testLocation",
},
{"content": {"text": "result3"}},
]
}
documents = amazon_retriever.invoke(query, run_manager=None)

assert len(documents) == 3
assert isinstance(documents[0], Document)
assert documents[0].page_content == "result1"
assert documents[0].metadata == {"score": 0, "source_metadata": {"key": "value1"}}
assert documents[1].page_content == "result2"
assert documents[1].metadata == {
"score": 1,
"source_metadata": {"key": "value2"},
"location": "testLocation",
}
assert documents[2].page_content == "result3"
assert documents[2].metadata == {"score": 0}

0 comments on commit 4b2761e

Please sign in to comment.