Skip to content

Commit

Permalink
fix demo assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 22, 2024
1 parent bae4fc4 commit 6dd7ede
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
2 changes: 1 addition & 1 deletion ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _markdown_answer(self) -> str:
def _default_answer(self, prompt: str, sources: list[Source]) -> str:
sources_display = []
for source in sources:
source_display = f"- {source.document.name}"
source_display = f"- {source.document_name}"
if source.location:
source_display += f", {source.location}"
source_display += f": {textwrap.shorten(source.content, width=100)}"
Expand Down
1 change: 1 addition & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class Source(pydantic.BaseModel):

id: str
document_id: uuid.UUID
document_name: str
location: str
content: str
num_tokens: int
Expand Down
13 changes: 6 additions & 7 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def _load_component(

def chat(
self,
*,
input: Any,
*,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
**params: Any,
Expand Down Expand Up @@ -153,7 +153,8 @@ def __init__(
) -> None:
self._rag = rag

self.documents, self.metadata_filter = self._parse_input(input)
self.documents, self.metadata_filter, self._prepared = self._parse_input(input)

self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
Expand All @@ -165,7 +166,6 @@ def __init__(
self.params = params
self._unpacked_params = self._unpack_chat_params(params)

self._prepared = False
self._messages: list[Message] = []

async def prepare(self) -> Message:
Expand Down Expand Up @@ -237,10 +237,9 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:

def _parse_input(
self, input: Iterable[Any]
) -> tuple[Optional[list[Document]], MetadataFilter]:
) -> tuple[Optional[list[Document]], MetadataFilter, bool]:
if isinstance(input, MetadataFilter):
self._prepared = True
return None, input
return None, input, True

documents = []
for document in input:
Expand All @@ -262,7 +261,7 @@ def _parse_input(
for document in documents
]
)
return documents, metadata_filter
return documents, metadata_filter, False

def _unpack_chat_params(
self, params: dict[str, Any]
Expand Down
75 changes: 45 additions & 30 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import textwrap
import uuid
from typing import Any

from ragna.core import (
Document,
Expand Down Expand Up @@ -28,21 +29,23 @@ def display_name(cls) -> str:
return "Ragna/DemoSourceStorage"

def __init__(self) -> None:
self._storage: list[Source] = []
self._storage: list[dict[str, Any]] = []

def store(self, documents: list[Document]) -> None:
self._storage.extend(
[
Source(
id=str(uuid.uuid4()),
document=document,
location=(
dict(
document_id=document.id,
document_name=document.name,
**document.metadata,
__id__=str(uuid.uuid4()),
__location__=(
f"page {page.number}"
if (page := next(document.extract_pages())).number
else ""
),
content=(content := textwrap.shorten(page.text, width=100)),
num_tokens=len(content.split()),
__content__=(content := textwrap.shorten(page.text, width=100)),
__num_tokens__=len(content.split()),
)
for document in documents
]
Expand All @@ -61,39 +64,51 @@ def store(self, documents: list[Document]) -> None:

def _apply_filter(
self, metadata_filter: MetadataFilter
) -> list[tuple[int, Source]]:
) -> list[tuple[int, dict[str, Any]]]:
if metadata_filter.operator is MetadataOperator.RAW:
raise RagnaException
elif metadata_filter.operator in {MetadataOperator.AND, MetadataOperator.OR}:
return sorted(
functools.reduce(
(
set.intersection
if metadata_filter.operator is MetadataOperator.AND
else set.union
),
(set(self._apply_filter(child)) for child in metadata_filter.value),
idcs_groups = []
rows_map = {}
for child in metadata_filter.value:
idcs_group = set()
for idx, row in self._apply_filter(child):
idcs_group.add(idx)
if idx not in rows_map:
rows_map[idx] = row
idcs_groups.append(idcs_group)
idcs = functools.reduce(
(
set.intersection
if metadata_filter.operator is MetadataOperator.AND
else set.union
),
key=lambda source_with_idx: source_with_idx[0],
idcs_groups,
)
return [(idx, rows_map[idx]) for idx in sorted(idcs)]
else:
sources_with_idx = []
for idx, source in enumerate(self._storage):
if metadata_filter.key == "document_id":
value = source.document.id
elif metadata_filter.key == "document_name":
value = source.document.name
else:
value = source.document.metadata.get(metadata_filter.key)
if value is None:
continue
rows_with_idx = []
for idx, row in enumerate(self._storage):
value = row.get(metadata_filter.key)
if value is None:
continue

if self._METADATA_OPERATOR_MAP[metadata_filter.operator](
value, metadata_filter.value
):
sources_with_idx.append((idx, source))
rows_with_idx.append((idx, row))

return sources_with_idx
return rows_with_idx

def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source]:
return [source for _, source in self._apply_filter(metadata_filter)]
return [
Source(
id=row["__id__"],
document_id=row["document_id"],
document_name=row["document_name"],
location=row["__location__"],
content=row["__content__"],
num_tokens=row["__num_tokens__"],
)
for _, row in self._apply_filter(metadata_filter)
]

0 comments on commit 6dd7ede

Please sign in to comment.