Skip to content

Commit

Permalink
add LanceDB as SourceStorage
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Oct 9, 2023
1 parent c2963c3 commit 0805fec
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
22 changes: 15 additions & 7 deletions examples/python_api/python_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
{
"data": {
"text/plain": [
"'0.1.dev21+g7f8ec2d.d20230925071852'"
"'0.1.dev29+g03721c7.d20231009084220'"
]
},
"execution_count": 2,
Expand All @@ -44,7 +44,7 @@
{
"data": {
"text/plain": [
"Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=<class 'ragna.core.LocalDocument'>, upload_token_secret='245be1b4c5656eefec1ac16d4a856f189e9e35d35aa014eb7426573e5b86c03f', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': <class 'ragna.source_storage._demo.RagnaDemoSourceStorage'>}, registered_assistant_classes={'Ragna/DemoAssistant': <class 'ragna.assistant._demo.RagnaDemoAssistant'>})"
"Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=<class 'ragna.core.LocalDocument'>, upload_token_secret='d9d5c32fcb2d4f3a3a36cb5d95c8147ab2e9c664ad2a3976f3e2eb9ef80b53c9', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': <class 'ragna.source_storage._demo.RagnaDemoSourceStorage'>}, registered_assistant_classes={'Ragna/DemoAssistant': <class 'ragna.assistant._demo.RagnaDemoAssistant'>})"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -98,7 +98,11 @@
" OpenaiGpt35Turbo16kAssistant,\n",
" OpenaiGpt4Assistant,\n",
")\n",
"from ragna.source_storage import ChromaSourceStorage, RagnaDemoSourceStorage\n",
"from ragna.source_storage import (\n",
" ChromaSourceStorage,\n",
" RagnaDemoSourceStorage,\n",
" LanceDBSourceStorage,\n",
")\n",
"\n",
"rag = Rag(demo_config)\n",
"\n",
Expand All @@ -123,8 +127,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): <coroutine object answer_prompt at 0x7f73bfd15f40>,\n",
" ('Chroma', 'OpenAI/gpt-4'): <coroutine object answer_prompt at 0x7f740774a040>}\n"
"{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): <coroutine object answer_prompt at 0x7f690f91d640>,\n",
" ('Chroma', 'OpenAI/gpt-4'): <coroutine object answer_prompt at 0x7f690f7c04c0>,\n",
" ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): <coroutine object answer_prompt at 0x7f690f7c0840>,\n",
" ('LanceDB', 'OpenAI/gpt-4'): <coroutine object answer_prompt at 0x7f690f7c09c0>}\n"
]
}
],
Expand All @@ -133,7 +139,7 @@
"import asyncio\n",
"from pprint import pprint\n",
"\n",
"source_storages = [ChromaSourceStorage]\n",
"source_storages = [ChromaSourceStorage, LanceDBSourceStorage]\n",
"assistants = [OpenaiGpt35Turbo16kAssistant, OpenaiGpt4Assistant]\n",
"\n",
"\n",
Expand Down Expand Up @@ -168,7 +174,9 @@
"output_type": "stream",
"text": [
"{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source RAG (Response Analysis Graph) orchestration app. It is designed to help users create conversational AI applications by providing a framework for managing and orchestrating the flow of conversations. Ragna allows developers to define conversation flows, handle user inputs, and generate dynamic responses based on predefined rules and logic. It is built on top of the Rasa framework and provides additional features and functionalities to simplify the development process.,\n",
" ('Chroma', 'OpenAI/gpt-4'): Ragna is an open-source RAG orchestration app.}\n"
" ('Chroma', 'OpenAI/gpt-4'): Ragna is an open-source RAG orchestration app.,\n",
" ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source rag orchestration app. It is a software application that allows users to create and arrange musical compositions using ragtime music. It is designed to be accessible and customizable for musicians and composers.,\n",
" ('LanceDB', 'OpenAI/gpt-4'): Ragna is an open-source rag orchestration app.}\n"
]
}
],
Expand Down
1 change: 1 addition & 0 deletions ragna/source_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ._chroma import ChromaSourceStorage
from ._demo import RagnaDemoSourceStorage
from ._lancedb import LanceDBSourceStorage
117 changes: 117 additions & 0 deletions ragna/source_storage/_lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from ragna.core import (
Document,
PackageRequirement,
RagnaId,
Requirement,
Source,
SourceStorage,
)

from ragna.utils import chunk_pages, page_numbers_to_str, take_sources_up_to_max_tokens


class LanceDBSourceStorage(SourceStorage):
@classmethod
def display_name(cls) -> str:
return "LanceDB"

@classmethod
def requirements(cls) -> list[Requirement]:
return [
PackageRequirement("lancedb>=0.2"),
# FIXME: re-add this after https://github.com/apache/arrow/issues/38167 is
# resolved.
# PackageRequirement("pyarrow"),
PackageRequirement("sentence_transformers"),
]

def __init__(self, config):
super().__init__(config)

import lancedb
import pyarrow as pa
from sentence_transformers import SentenceTransformer

self._db = lancedb.connect(config.local_cache_root / "lancedb")
self._model = SentenceTransformer("paraphrase-albert-small-v2")
self._schema = pa.schema(
[
pa.field("document_id", pa.string()),
pa.field("document_name", pa.string()),
pa.field("page_numbers", pa.string()),
pa.field("text", pa.string()),
pa.field(
self._VECTOR_COLUMN_NAME,
pa.list_(pa.float32(), self._model[-1].word_embedding_dimension),
),
pa.field("num_tokens", pa.int32()),
]
)

def _embed(self, batch):
return [self._model.encode(sentence) for sentence in batch]

_VECTOR_COLUMN_NAME = "embedded_text"

def store(
self,
documents: list[Document],
*,
chat_id: RagnaId,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
table = self._db.create_table(name=str(chat_id), schema=self._schema)

for document in documents:
for chunk in chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
tokenizer=self._model.tokenizer,
):
table.add(
[
{
"document_id": str(document.id),
"document_name": document.name,
"page_numbers": page_numbers_to_str(chunk.page_numbers),
"text": chunk.text,
self._VECTOR_COLUMN_NAME: self._model.encode(chunk.text),
"num_tokens": chunk.num_tokens,
}
]
)

def retrieve(
self,
prompt: str,
*,
chat_id: RagnaId,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
table = self._db.open_table(str(chat_id))

# We cannot retrieve source by a maximum number of tokens. Thus, we estimate how
# many sources we have to query. We overestimate by a factor of two to avoid
# retrieving to few sources and needed to query again.
limit = int(num_tokens * 2 / chunk_size)
results = table.search().limit(limit).to_arrow()

return list(
take_sources_up_to_max_tokens(
(
Source(
id=RagnaId.make(),
document_id=RagnaId(result["document_id"]),
document_name=result["document_name"],
location=result["page_numbers"],
content=result["text"],
num_tokens=result["num_tokens"],
)
for result in results.to_pylist()
),
max_tokens=num_tokens,
)
)

0 comments on commit 0805fec

Please sign in to comment.