From 24785d57eb2578bb77d6972080712626cf5f7991 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jul 2024 10:25:46 +0200 Subject: [PATCH 1/3] refactor document upload --- ragna/_cli/core.py | 16 +-- ragna/core/__init__.py | 1 - ragna/core/_document.py | 113 ++++++++------------- ragna/core/_rag.py | 22 ++--- ragna/deploy/_api.py | 109 ++++++--------------- ragna/deploy/_core.py | 5 +- ragna/deploy/_database.py | 119 ++++++++--------------- ragna/deploy/_engine.py | 101 ++++++++++++++----- ragna/deploy/_schemas.py | 22 +++-- scripts/add_chats.py | 95 +++++++++--------- tests/deploy/api/test_batch_endpoints.py | 86 ---------------- tests/deploy/api/test_components.py | 19 +--- tests/deploy/api/test_e2e.py | 31 +++--- 13 files changed, 282 insertions(+), 457 deletions(-) delete mode 100644 tests/deploy/api/test_batch_endpoints.py diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py index 2e9030df..34638e4f 100644 --- a/ragna/_cli/core.py +++ b/ragna/_cli/core.py @@ -1,7 +1,6 @@ from pathlib import Path from typing import Annotated, Optional -import httpx import rich import typer import uvicorn @@ -74,13 +73,13 @@ def deploy( *, config: ConfigOption = "./ragna.toml", # type: ignore[assignment] api: Annotated[ - Optional[bool], + bool, typer.Option( "--api/--no-api", help="Deploy the Ragna REST API.", show_default="True if UI is not deployed and otherwise check availability", ), - ] = None, + ] = True, ui: Annotated[ bool, typer.Option( @@ -101,19 +100,8 @@ def deploy( typer.Option(help="Open a browser when Ragna is deployed."), ] = None, ) -> None: - def api_available() -> bool: - try: - return httpx.get(f"{config._url}/health").is_success - except httpx.ConnectError: - return False - - if api is None: - api = not api_available() if ui else True - if not (api or ui): raise Exception - elif ui and not api and not api_available(): - raise Exception if open_browser is None: open_browser = ui diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 0f4b4bdf..44449775 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -34,7 +34,6 @@ from ._document import ( Document, DocumentHandler, - DocumentUploadParameters, DocxDocumentHandler, LocalDocument, Page, diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 436344b4..31a5ba67 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -2,26 +2,25 @@ import abc import io -import os -import secrets -import time import uuid +from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union - -import jwt +from typing import ( + Any, + AsyncIterator, + Iterator, + Optional, + Type, + TypeVar, + Union, +) + +import aiofiles from pydantic import BaseModel -from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin - -if TYPE_CHECKING: - from ragna.deploy import Config - +import ragna -class DocumentUploadParameters(BaseModel): - method: str - url: str - data: dict +from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin class Document(RequirementsMixin, abc.ABC): @@ -62,16 +61,6 @@ def get_handler(name: str) -> DocumentHandler: return handler - @classmethod - @abc.abstractmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - pass - - @abc.abstractmethod - def is_readable(self) -> bool: ... - @abc.abstractmethod def read(self) -> bytes: ... @@ -88,12 +77,25 @@ class LocalDocument(Document): [ragna.core.LocalDocument.from_path][]. """ + def __init__( + self, + *, + id: Optional[uuid.UUID] = None, + name: str, + metadata: dict[str, Any], + handler: Optional[DocumentHandler] = None, + ): + super().__init__(id=id, name=name, metadata=metadata, handler=handler) + if "path" not in self.metadata: + metadata["path"] = str(ragna.local_root() / "documents" / str(self.id)) + @classmethod def from_path( cls, path: Union[str, Path], *, id: Optional[uuid.UUID] = None, + name: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, handler: Optional[DocumentHandler] = None, ) -> LocalDocument: @@ -102,6 +104,7 @@ def from_path( Args: path: Local path to the file. id: ID of the document. If omitted, one is generated. + name: Name of the document. If omitted, defaults to the name of the `path`. metadata: Optional metadata of the document. handler: Document handler. If omitted, a builtin handler is selected based on the suffix of the `path`. @@ -118,60 +121,30 @@ def from_path( ) path = Path(path).expanduser().resolve() + if name is None: + name = path.name metadata["path"] = str(path) - return cls(id=id, name=path.name, metadata=metadata, handler=handler) + return cls(id=id, name=name, metadata=metadata, handler=handler) - @property + @cached_property def path(self) -> Path: return Path(self.metadata["path"]) - def is_readable(self) -> bool: - return self.path.exists() + async def _write(self, stream: AsyncIterator[bytes]) -> None: + if self.path.exists(): + raise RagnaException("ADDME") - def read(self) -> bytes: - with open(self.path, "rb") as stream: - return stream.read() - - _JWT_SECRET = os.environ.get( - "RAGNA_API_DOCUMENT_UPLOAD_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" + async with aiofiles.open(self.path, "wb") as file: + async for content in stream: + await file.write(content) - @classmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - url = f"{config._url}/api/document" - data = { - "token": jwt.encode( - payload={ - "user": user, - "id": str(id), - "exp": time.time() + 5 * 60, - }, - key=cls._JWT_SECRET, - algorithm=cls._JWT_ALGORITHM, - ) - } - metadata = {"path": str(config.local_root / "documents" / str(id))} - return metadata, DocumentUploadParameters(method="PUT", url=url, data=data) + def read(self) -> bytes: + if not self.path.is_file(): + raise RagnaException("ADDME") - @classmethod - def decode_upload_token(cls, token: str) -> tuple[str, uuid.UUID]: - try: - payload = jwt.decode( - token, key=cls._JWT_SECRET, algorithms=[cls._JWT_ALGORITHM] - ) - except jwt.InvalidSignatureError: - raise RagnaException( - "Token invalid", http_status_code=401, http_detail=RagnaException.EVENT - ) - except jwt.ExpiredSignatureError: - raise RagnaException( - "Token expired", http_status_code=401, http_detail=RagnaException.EVENT - ) - return payload["user"], uuid.UUID(payload["id"]) + with open(self.path, "rb") as file: + return file.read() class Page(BaseModel): diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 98c32a42..c3da0c76 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -286,20 +286,14 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: return answer def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: - documents_ = [] - for document in documents: - if not isinstance(document, Document): - document = LocalDocument.from_path(document) - - if not document.is_readable(): - raise RagnaException( - "Document not readable", - document=document, - http_status_code=404, - ) - - documents_.append(document) - return documents_ + return [ + ( + document + if isinstance(document, Document) + else LocalDocument.from_path(document) + ) + for document in documents + ] def _unpack_chat_params( self, params: dict[str, Any] diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index d3194064..4de2737c 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -1,29 +1,23 @@ import uuid -from typing import Annotated, AsyncIterator, cast +from typing import Annotated, AsyncIterator -import aiofiles import pydantic from fastapi import ( APIRouter, Body, Depends, - Form, - HTTPException, UploadFile, ) from fastapi.responses import StreamingResponse -import ragna -import ragna.core from ragna._compat import anext from ragna.core._utils import default_user -from ragna.deploy import Config from . import _schemas as schemas from ._engine import Engine -def make_router(config: Config, engine: Engine) -> APIRouter: +def make_router(engine: Engine) -> APIRouter: router = APIRouter(tags=["API"]) def get_user() -> str: @@ -31,77 +25,32 @@ def get_user() -> str: UserDependency = Annotated[str, Depends(get_user)] - # TODO: the document endpoints do not go through the engine, because they'll change - # quite drastically when the UI no longer depends on the API - - _database = engine._database - - @router.post("/document") - async def create_document_upload_info( - user: UserDependency, - name: Annotated[str, Body(..., embed=True)], - ) -> schemas.DocumentUpload: - with _database.get_session() as session: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document.metadata = metadata - _database.add_document( - session, user=user, document=document, metadata=metadata - ) - return schemas.DocumentUpload(parameters=parameters, document=document) - - # TODO: Add UI support and documentation for this endpoint (#406) @router.post("/documents") - async def create_documents_upload_info( - user: UserDependency, - names: Annotated[list[str], Body(..., embed=True)], - ) -> list[schemas.DocumentUpload]: - with _database.get_session() as session: - document_metadata_collection = [] - document_upload_collection = [] - for name in names: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document.metadata = metadata - document_metadata_collection.append((document, metadata)) - document_upload_collection.append( - schemas.DocumentUpload(parameters=parameters, document=document) - ) - - _database.add_documents( - session, - user=user, - document_metadata_collection=document_metadata_collection, - ) - return document_upload_collection - - # TODO: Add new endpoint for batch uploading documents (#407) - @router.put("/document") - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - with _database.get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) - document = _database.get_document(session, user=user, id=id) - - core_document = cast( - ragna.core.LocalDocument, engine._to_core.document(document) - ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(core_document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return document + def register_documents( + user: UserDependency, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + return engine.register_documents( + user=user, document_registrations=document_registrations + ) + + @router.put("/documents") + async def upload_documents( + user: UserDependency, documents: list[UploadFile] + ) -> None: + def make_content_stream(file: UploadFile) -> AsyncIterator[bytes]: + async def content_stream() -> AsyncIterator[bytes]: + while content := await file.read(16 * 1024): + yield content + + return content_stream() + + await engine.store_documents( + user=user, + ids_and_streams=[ + (uuid.UUID(document.filename), make_content_stream(document)) + for document in documents + ], + ) @router.get("/components") def get_components(_: UserDependency) -> schemas.Components: @@ -110,9 +59,9 @@ def get_components(_: UserDependency) -> schemas.Components: @router.post("/chats") async def create_chat( user: UserDependency, - chat_metadata: schemas.ChatMetadata, + chat_creation: schemas.ChatCreation, ) -> schemas.Chat: - return engine.create_chat(user=user, chat_metadata=chat_metadata) + return engine.create_chat(user=user, chat_creation=chat_creation) @router.get("/chats") async def get_chats(user: UserDependency) -> list[schemas.Chat]: diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py index 67f067e0..6df4b71b 100644 --- a/ragna/deploy/_core.py +++ b/ragna/deploy/_core.py @@ -27,7 +27,6 @@ def make_app( ignore_unavailable_components: bool, open_browser: bool, ) -> FastAPI: - ragna.local_root(config.local_root) set_redirect_root_path(config.root_path) lifespan: Optional[Callable[[FastAPI], AsyncContextManager]] @@ -38,7 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: def target() -> None: client = httpx.Client(base_url=config._url) - def server_available(): + def server_available() -> bool: try: return client.get("/health").is_success except httpx.ConnectError: @@ -77,7 +76,7 @@ def server_available(): ) if api: - app.include_router(make_api_router(config, engine), prefix="/api") + app.include_router(make_api_router(engine), prefix="/api") if ui: panel_app = make_ui_app(config=config) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 323ccd21..529fa3b6 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -1,7 +1,7 @@ from __future__ import annotations import uuid -from typing import Any, Optional +from typing import Any, Collection, Optional from urllib.parse import urlsplit from sqlalchemy import create_engine, select @@ -42,83 +42,50 @@ def _get_user(self, session: Session, *, username: str) -> orm.User: return user - def add_document( - self, - session: Session, - *, - user: str, - document: schemas.Document, - metadata: dict[str, Any], - ) -> None: - session.add( - orm.Document( - id=document.id, - user_id=self._get_user(session, username=user).id, - name=document.name, - metadata_=metadata, - ) - ) - session.commit() - def add_documents( self, session: Session, *, user: str, - document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], + documents: list[schemas.Document], ) -> None: - """ - Add multiple documents to the database. - - This function allows adding multiple documents at once by calling `add_all`. This is - important when there is non-negligible latency attached to each database operation. - """ - documents = [ - orm.Document( - id=document.id, - user_id=self._get_user(session, username=user).id, - name=document.name, - metadata_=metadata, - ) - for document, metadata in document_metadata_collection - ] - session.add_all(documents) + user_id = self._get_user(session, username=user).id + session.add_all( + [self._to_orm.document(document, user_id=user_id) for document in documents] + ) session.commit() - def get_document( - self, session: Session, *, user: str, id: uuid.UUID - ) -> schemas.Document: - document = session.execute( - select(orm.Document).where( - (orm.Document.user_id == self._get_user(session, username=user).id) - & (orm.Document.id == id) - ) - ).scalar_one_or_none() - return self._to_schema.document(document) - - def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: - document_ids = {document.id for document in chat.metadata.documents} - # FIXME also check if the user is allowed to access the documents? + def _get_orm_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[orm.Document]: + # FIXME also check if the user is allowed to access the documents + # FIXME: maybe just take the user id to avoid getting it twice in add_chat? documents = ( - session.execute( - select(orm.Document).where(orm.Document.id.in_(document_ids)) - ) + session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) .scalars() .all() ) - if len(documents) != len(document_ids): + if len(documents) != len(ids): raise RagnaException( - str(document_ids - {document.id for document in documents}) + str(set(ids) - {document.id for document in documents}) ) + return documents # type: ignore[no-any-return] + + def get_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[schemas.Document]: + return [ + self._to_schema.document(document) + for document in self._get_orm_documents(session, user=user, ids=ids) + ] + + def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: orm_chat = self._to_orm.chat( - chat, - user_id=self._get_user(session, username=user).id, - # We have to pass the documents here, because SQLAlchemy does not allow a - # second instance of orm.Document with the same primary key in the session. - documents=documents, + chat, user_id=self._get_user(session, username=user).id ) - session.add(orm_chat) + # We need to merge and not add here, because the documents are already in the DB + session.merge(orm_chat) session.commit() def _select_chat(self, *, eager: bool = False) -> Any: @@ -213,21 +180,17 @@ def chat( chat: schemas.Chat, *, user_id: uuid.UUID, - documents: Optional[list[orm.Document]] = None, ) -> orm.Chat: - if documents is None: - documents = [ - self.document(document, user_id=user_id) - for document in chat.metadata.documents - ] return orm.Chat( id=chat.id, user_id=user_id, - name=chat.metadata.name, - documents=documents, - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, - params=chat.metadata.params, + name=chat.name, + documents=[ + self.document(document, user_id=user_id) for document in chat.documents + ], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, messages=[ self.message(message, chat_id=chat.id) for message in chat.messages ], @@ -262,13 +225,11 @@ def message(self, message: orm.Message) -> schemas.Message: def chat(self, chat: orm.Chat) -> schemas.Chat: return schemas.Chat( id=chat.id, - metadata=schemas.ChatMetadata( - name=chat.name, - documents=[self.document(document) for document in chat.documents], - source_storage=chat.source_storage, - assistant=chat.assistant, - params=chat.params, - ), + name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, messages=[self.message(message) for message in chat.messages], prepared=chat.prepared, ) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 847f7a93..f732df1e 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,8 +1,13 @@ import uuid -from typing import Any, AsyncIterator, Optional, Type +from typing import Any, AsyncIterator, Optional, Type, cast +from fastapi import status as http_status_code + +import ragna from ragna import Rag, core from ragna._compat import aiter, anext +from ragna._utils import make_directory +from ragna.core import RagnaException from ragna.core._rag import SpecialChatParams from ragna.deploy import Config @@ -13,15 +18,20 @@ class Engine: def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None: self._config = config + ragna.local_root(config.local_root) + self._documents_root = make_directory(config.local_root / "documents") + self.supports_store_documents = issubclass( + self._config.document, ragna.core.LocalDocument + ) self._database = Database(url=config.database_url) - self._rag: Rag = Rag( + self._rag = Rag( # type: ignore[var-annotated] config=config, ignore_unavailable_components=ignore_unavailable_components, ) - self._to_core = SchemaToCoreConverter(config=config, rag=self._rag) + self._to_core = SchemaToCoreConverter(self._rag) self._to_schema = CoreToSchemaConverter() def _get_component_json_schema( @@ -56,12 +66,59 @@ def get_components(self) -> schemas.Components: ], ) + def register_documents( + self, *, user: str, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + # We create core.Document's first, because they might update the metadata + core_documents = [ + self._config.document( + name=registration.name, metadata=registration.metadata + ) + for registration in document_registrations + ] + documents = [self._to_schema.document(document) for document in core_documents] + + with self._database.get_session() as session: + self._database.add_documents(session, user=user, documents=documents) + + return documents + + async def store_documents( + self, + *, + user: str, + ids_and_streams: list[tuple[uuid.UUID, AsyncIterator[bytes]]], + ) -> None: + if not self.supports_store_documents: + raise RagnaException( + "Ragna configuration does not support local upload", + http_status_code=http_status_code.HTTP_400_BAD_REQUEST, + ) + + ids, streams = zip(*ids_and_streams) + + with self._database.get_session() as session: + documents = self._database.get_documents(session, user=user, ids=ids) + + for document, stream in zip(documents, streams): + core_document = cast( + ragna.core.LocalDocument, self._to_core.document(document) + ) + await core_document._write(stream) + def create_chat( - self, *, user: str, chat_metadata: schemas.ChatMetadata + self, *, user: str, chat_creation: schemas.ChatCreation ) -> schemas.Chat: - chat = schemas.Chat(metadata=chat_metadata) + params = chat_creation.model_dump() + document_ids = params.pop("document_ids") + with self._database.get_session() as session: + documents = self._database.get_documents( + session, user=user, ids=document_ids + ) + + chat = schemas.Chat(documents=documents, **params) - # Although we don't need the actual core.Chat here, this just performs the input + # Although we don't need the actual core.Chat here, this performs the input # validation. self._to_core.chat(chat, user=user) @@ -117,12 +174,12 @@ def delete_chat(self, *, user: str, id: uuid.UUID) -> None: class SchemaToCoreConverter: - def __init__(self, *, config: Config, rag: Rag) -> None: - self._config = config + def __init__(self, rag: Rag) -> None: self._rag = rag def document(self, document: schemas.Document) -> core.Document: - return self._config.document( + # FIXME: config + return core.LocalDocument( id=document.id, name=document.name, metadata=document.metadata, @@ -146,13 +203,13 @@ def message(self, message: schemas.Message) -> core.Message: def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat: core_chat = self._rag.chat( - documents=[self.document(document) for document in chat.metadata.documents], - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, user=user, chat_id=chat.id, - chat_name=chat.metadata.name, - **chat.metadata.params, + chat_name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + **chat.params, ) core_chat._messages = [self.message(message) for message in chat.messages] core_chat._prepared = chat.prepared @@ -182,7 +239,9 @@ def message( ) -> schemas.Message: return schemas.Message( id=message.id, - content=content_override or message.content, + content=( + content_override if content_override is not None else message.content + ), role=message.role, sources=[self.source(source) for source in message.sources], timestamp=message.timestamp, @@ -193,13 +252,11 @@ def chat(self, chat: core.Chat) -> schemas.Chat: del params["user"] return schemas.Chat( id=params.pop("chat_id"), - metadata=schemas.ChatMetadata( - name=params.pop("chat_name"), - source_storage=chat.source_storage.display_name(), - assistant=chat.assistant.display_name(), - params=params, - documents=[self.document(document) for document in chat.documents], - ), + name=params.pop("chat_name"), + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), + params=params, messages=[self.message(message) for message in chat._messages], prepared=chat._prepared, ) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 55ae333f..cc5490b7 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -15,15 +15,15 @@ class Components(BaseModel): assistants: list[dict[str, Any]] -class Document(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4) +class DocumentRegistration(BaseModel): name: str metadata: dict[str, Any] = Field(default_factory=dict) -class DocumentUpload(BaseModel): - parameters: ragna.core.DocumentUploadParameters - document: Document +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + metadata: dict[str, Any] class Source(BaseModel): @@ -43,16 +43,20 @@ class Message(BaseModel): timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) -class ChatMetadata(BaseModel): +class ChatCreation(BaseModel): name: str + document_ids: list[uuid.UUID] source_storage: str assistant: str - params: dict - documents: list[Document] + params: dict[str, Any] = Field(default_factory=dict) class Chat(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) - metadata: ChatMetadata + name: str + documents: list[Document] + source_storage: str + assistant: str + params: dict[str, Any] messages: list[Message] = Field(default_factory=list) prepared: bool = False diff --git a/scripts/add_chats.py b/scripts/add_chats.py index b8c15194..5f550289 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -1,71 +1,70 @@ import datetime import json -import os import httpx -from ragna.core._utils import default_user - def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") - client.get("/").raise_for_status() + client.get("/health").raise_for_status() + + # ## authentication + # + # username = default_user() + # token = ( + # client.post( + # "/token", + # data={ + # "username": username, + # "password": os.environ.get( + # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + # ), + # }, + # ) + # .raise_for_status() + # .json() + # ) + # client.headers["Authorization"] = f"Bearer {token}" + + print() - ## authentication + ## documents - username = default_user() - token = ( + documents = ( client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, + "/api/documents", json=[{"name": f"document{i}.txt"} for i in range(5)] ) .raise_for_status() .json() ) - client.headers["Authorization"] = f"Bearer {token}" - ## documents - - documents = [] - for i in range(5): - name = f"document{i}.txt" - document_upload = ( - client.post("/document", json={"name": name}).raise_for_status().json() - ) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": f"Content of {name}".encode()}, - ).raise_for_status() - documents.append(document_upload["document"]) + client.put( + "/api/documents", + files=[ + ("documents", (document["id"], f"Content of {document['name']}".encode())) + for document in documents + ], + ).raise_for_status() ## chat 1 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Test chat", - "documents": documents[:2], + "document_ids": [document["id"] for document in documents[:2]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() @@ -73,55 +72,53 @@ def main(): chat = ( client.post( - "/chats", + "/api/chats", json={ "name": f"Chat {datetime.datetime.now():%x %X}", - "documents": documents[2:4], + "document_ids": [document["id"] for document in documents[2:]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() for _ in range(3): client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna? Please, I need to know!"}, ).raise_for_status() - ## chat 3 + # ## chat 3 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": ( "Really long chat name that likely needs to be truncated somehow. " "If you can read this, truncating failed :boom:" ), - "documents": [documents[i] for i in [0, 2, 4]], + "document_ids": [documents[i]["id"] for i in [0, 2, 4]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Ok, in that case show me some pretty markdown!"}, ).raise_for_status() - chats = client.get("/chats").raise_for_status().json() + chats = client.get("/api/chats").raise_for_status().json() print(json.dumps(chats)) diff --git a/tests/deploy/api/test_batch_endpoints.py b/tests/deploy/api/test_batch_endpoints.py deleted file mode 100644 index 2736df24..00000000 --- a/tests/deploy/api/test_batch_endpoints.py +++ /dev/null @@ -1,86 +0,0 @@ -from fastapi import status -from fastapi.testclient import TestClient - -from ragna.deploy import Config - -from .utils import authenticate, make_api_app - - -def test_batch_sequential_upload_equivalence(tmp_local_root): - "Check that uploading documents sequentially and in batch gives the same result" - config = Config(local_root=tmp_local_root) - - document_root = config.local_root / "documents" - document_root.mkdir() - document_path1 = document_root / "test1.txt" - with open(document_path1, "w") as file: - file.write("!\n") - document_path2 = document_root / "test2.txt" - with open(document_path2, "w") as file: - file.write("?\n") - - with TestClient( - make_api_app(config=Config(), ignore_unavailable_components=False) - ) as client: - authenticate(client) - - document1_upload = ( - client.post("/api/document", json={"name": document_path1.name}) - .raise_for_status() - .json() - ) - document2_upload = ( - client.post("/api/document", json={"name": document_path2.name}) - .raise_for_status() - .json() - ) - - documents_upload = ( - client.post( - "/api/documents", - json={"names": [document_path1.name, document_path2.name]}, - ) - .raise_for_status() - .json() - ) - - assert ( - document1_upload["parameters"]["url"] - == documents_upload[0]["parameters"]["url"] - ) - assert ( - document2_upload["parameters"]["url"] - == documents_upload[1]["parameters"]["url"] - ) - - assert ( - document1_upload["document"]["name"] - == documents_upload[0]["document"]["name"] - ) - assert ( - document2_upload["document"]["name"] - == documents_upload[1]["document"]["name"] - ) - - # assuming that if test passes for first document it will also pass for the other - with open(document_path1, "rb") as file: - response_sequential_upload1 = client.request( - document1_upload["parameters"]["method"], - document1_upload["parameters"]["url"], - data=document1_upload["parameters"]["data"], - files={"file": file}, - ) - response_batch_upload1 = client.request( - documents_upload[0]["parameters"]["method"], - documents_upload[0]["parameters"]["url"], - data=documents_upload[0]["parameters"]["data"], - files={"file": file}, - ) - - assert response_sequential_upload1.status_code == status.HTTP_200_OK - assert response_batch_upload1.status_code == status.HTTP_200_OK - - assert ( - response_sequential_upload1.json()["name"] - == response_batch_upload1.json()["name"] - ) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index b459f12e..0d44790c 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -67,31 +67,22 @@ def test_unknown_component(tmp_local_root): ) as client: authenticate(client) - document_upload = ( - client.post("/api/document", json={"name": document_path.name}) + document = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() - .json() + .json()[0] ) - document = document_upload["document"] - assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) response = client.post( "/api/chats", json={ "name": "test-chat", + "document_ids": [document["id"]], "source_storage": "unknown_source_storage", "assistant": "unknown_assistant", - "params": {}, - "documents": [document], }, ) diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index c1a80ad5..61632251 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -43,26 +43,21 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): assert client.get("/api/chats").raise_for_status().json() == [] - document_upload = ( - client.post("/api/document", json={"name": document_path.name}) + documents = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() .json() ) - document = document_upload["document"] + assert len(documents) == 1 + document = documents[0] assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) components = client.get("/api/components").raise_for_status().json() - documents = components["documents"] - assert set(documents) == config.document.supported_suffixes() + supported_documents = components["documents"] + assert set(supported_documents) == config.document.supported_suffixes() source_storages = [ json_schema["title"] for json_schema in components["source_storages"] ] @@ -77,15 +72,19 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): source_storage = source_storages[0] assistant = assistants[0] - chat_metadata = { + chat_creation = { "name": "test-chat", + "document_ids": [document["id"]], "source_storage": source_storage, "assistant": assistant, "params": {"multiple_answer_chunks": multiple_answer_chunks}, - "documents": [document], } - chat = client.post("/api/chats", json=chat_metadata).raise_for_status().json() - assert chat["metadata"] == chat_metadata + chat = client.post("/api/chats", json=chat_creation).raise_for_status().json() + for field in ["name", "source_storage", "assistant", "params"]: + assert chat[field] == chat_creation[field] + assert [document["id"] for document in chat["documents"]] == chat_creation[ + "document_ids" + ] assert not chat["prepared"] assert chat["messages"] == [] From e45b305dbfbb3b8501fa8f61bf385000753233d5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 8 Jul 2024 10:56:50 +0200 Subject: [PATCH 2/3] cleanup --- ragna/_cli/core.py | 16 ++++++++++++++-- ragna/core/_document.py | 10 +--------- ragna/deploy/_engine.py | 8 ++++---- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py index 34638e4f..2e9030df 100644 --- a/ragna/_cli/core.py +++ b/ragna/_cli/core.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Annotated, Optional +import httpx import rich import typer import uvicorn @@ -73,13 +74,13 @@ def deploy( *, config: ConfigOption = "./ragna.toml", # type: ignore[assignment] api: Annotated[ - bool, + Optional[bool], typer.Option( "--api/--no-api", help="Deploy the Ragna REST API.", show_default="True if UI is not deployed and otherwise check availability", ), - ] = True, + ] = None, ui: Annotated[ bool, typer.Option( @@ -100,8 +101,19 @@ def deploy( typer.Option(help="Open a browser when Ragna is deployed."), ] = None, ) -> None: + def api_available() -> bool: + try: + return httpx.get(f"{config._url}/health").is_success + except httpx.ConnectError: + return False + + if api is None: + api = not api_available() if ui else True + if not (api or ui): raise Exception + elif ui and not api and not api_available(): + raise Exception if open_browser is None: open_browser = ui diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 31a5ba67..e6dbe8a2 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -5,15 +5,7 @@ import uuid from functools import cached_property from pathlib import Path -from typing import ( - Any, - AsyncIterator, - Iterator, - Optional, - Type, - TypeVar, - Union, -) +from typing import Any, AsyncIterator, Iterator, Optional, Type, TypeVar, Union import aiofiles from pydantic import BaseModel diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index f732df1e..2209a61f 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -31,7 +31,7 @@ def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> No ignore_unavailable_components=ignore_unavailable_components, ) - self._to_core = SchemaToCoreConverter(self._rag) + self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) self._to_schema = CoreToSchemaConverter() def _get_component_json_schema( @@ -174,12 +174,12 @@ def delete_chat(self, *, user: str, id: uuid.UUID) -> None: class SchemaToCoreConverter: - def __init__(self, rag: Rag) -> None: + def __init__(self, *, config: Config, rag: Rag) -> None: + self._config = config self._rag = rag def document(self, document: schemas.Document) -> core.Document: - # FIXME: config - return core.LocalDocument( + return self._config.document( id=document.id, name=document.name, metadata=document.metadata, From eb25b28093d213cd4cca7b2479f1e4edfc89781c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 9 Jul 2024 17:09:18 +0200 Subject: [PATCH 3/3] fix error messages --- ragna/core/_document.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index e6dbe8a2..7a1cef7f 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -125,7 +125,9 @@ def path(self) -> Path: async def _write(self, stream: AsyncIterator[bytes]) -> None: if self.path.exists(): - raise RagnaException("ADDME") + raise RagnaException( + "File already exists", path=self.path, http_detail=RagnaException.EVENT + ) async with aiofiles.open(self.path, "wb") as file: async for content in stream: @@ -133,7 +135,9 @@ async def _write(self, stream: AsyncIterator[bytes]) -> None: def read(self) -> bytes: if not self.path.is_file(): - raise RagnaException("ADDME") + raise RagnaException( + "File does not exist", path=self.path, http_detail=RagnaException.EVENT + ) with open(self.path, "rb") as file: return file.read()