Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor document registering and upload #441

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ragna/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from ._document import (
Document,
DocumentHandler,
DocumentUploadParameters,
DocxDocumentHandler,
LocalDocument,
Page,
Expand Down
103 changes: 34 additions & 69 deletions ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,17 @@

import abc
import io
import os
import secrets
import time
import uuid
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union
from typing import Any, AsyncIterator, Iterator, Optional, Type, TypeVar, Union

import jwt
import aiofiles
from pydantic import BaseModel

from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin

if TYPE_CHECKING:
from ragna.deploy import Config
import ragna


class DocumentUploadParameters(BaseModel):
method: str
url: str
data: dict
from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin


class Document(RequirementsMixin, abc.ABC):
Expand Down Expand Up @@ -62,16 +53,6 @@ def get_handler(name: str) -> DocumentHandler:

return handler

@classmethod
@abc.abstractmethod
async def get_upload_info(
cls, *, config: Config, user: str, id: uuid.UUID, name: str
) -> tuple[dict[str, Any], DocumentUploadParameters]:
pass

@abc.abstractmethod
def is_readable(self) -> bool: ...

@abc.abstractmethod
def read(self) -> bytes: ...

Expand All @@ -88,12 +69,25 @@ class LocalDocument(Document):
[ragna.core.LocalDocument.from_path][].
"""

def __init__(
self,
*,
id: Optional[uuid.UUID] = None,
name: str,
metadata: dict[str, Any],
handler: Optional[DocumentHandler] = None,
):
super().__init__(id=id, name=name, metadata=metadata, handler=handler)
if "path" not in self.metadata:
metadata["path"] = str(ragna.local_root() / "documents" / str(self.id))

@classmethod
def from_path(
cls,
path: Union[str, Path],
*,
id: Optional[uuid.UUID] = None,
name: Optional[str] = None,
metadata: Optional[dict[str, Any]] = None,
handler: Optional[DocumentHandler] = None,
) -> LocalDocument:
Expand All @@ -102,6 +96,7 @@ def from_path(
Args:
path: Local path to the file.
id: ID of the document. If omitted, one is generated.
name: Name of the document. If omitted, defaults to the name of the `path`.
metadata: Optional metadata of the document.
handler: Document handler. If omitted, a builtin handler is selected based
on the suffix of the `path`.
Expand All @@ -118,60 +113,30 @@ def from_path(
)

path = Path(path).expanduser().resolve()
if name is None:
name = path.name
metadata["path"] = str(path)

return cls(id=id, name=path.name, metadata=metadata, handler=handler)
return cls(id=id, name=name, metadata=metadata, handler=handler)

@property
@cached_property
def path(self) -> Path:
return Path(self.metadata["path"])

def is_readable(self) -> bool:
return self.path.exists()
async def _write(self, stream: AsyncIterator[bytes]) -> None:
if self.path.exists():
raise RagnaException("ADDME")
pmeier marked this conversation as resolved.
Show resolved Hide resolved

def read(self) -> bytes:
with open(self.path, "rb") as stream:
return stream.read()

_JWT_SECRET = os.environ.get(
"RAGNA_API_DOCUMENT_UPLOAD_SECRET", secrets.token_urlsafe(32)[:32]
)
_JWT_ALGORITHM = "HS256"
async with aiofiles.open(self.path, "wb") as file:
async for content in stream:
await file.write(content)

@classmethod
async def get_upload_info(
cls, *, config: Config, user: str, id: uuid.UUID, name: str
) -> tuple[dict[str, Any], DocumentUploadParameters]:
url = f"{config._url}/api/document"
data = {
"token": jwt.encode(
payload={
"user": user,
"id": str(id),
"exp": time.time() + 5 * 60,
},
key=cls._JWT_SECRET,
algorithm=cls._JWT_ALGORITHM,
)
}
metadata = {"path": str(config.local_root / "documents" / str(id))}
return metadata, DocumentUploadParameters(method="PUT", url=url, data=data)
def read(self) -> bytes:
if not self.path.is_file():
raise RagnaException("ADDME")
pmeier marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def decode_upload_token(cls, token: str) -> tuple[str, uuid.UUID]:
try:
payload = jwt.decode(
token, key=cls._JWT_SECRET, algorithms=[cls._JWT_ALGORITHM]
)
except jwt.InvalidSignatureError:
raise RagnaException(
"Token invalid", http_status_code=401, http_detail=RagnaException.EVENT
)
except jwt.ExpiredSignatureError:
raise RagnaException(
"Token expired", http_status_code=401, http_detail=RagnaException.EVENT
)
return payload["user"], uuid.UUID(payload["id"])
with open(self.path, "rb") as file:
return file.read()


class Page(BaseModel):
Expand Down
22 changes: 8 additions & 14 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,20 +286,14 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
return answer

def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:
documents_ = []
for document in documents:
if not isinstance(document, Document):
document = LocalDocument.from_path(document)

if not document.is_readable():
raise RagnaException(
"Document not readable",
document=document,
http_status_code=404,
)

documents_.append(document)
return documents_
return [
(
document
if isinstance(document, Document)
else LocalDocument.from_path(document)
)
for document in documents
]

def _unpack_chat_params(
self, params: dict[str, Any]
Expand Down
109 changes: 29 additions & 80 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,56 @@
import uuid
from typing import Annotated, AsyncIterator, cast
from typing import Annotated, AsyncIterator

import aiofiles
import pydantic
from fastapi import (
APIRouter,
Body,
Depends,
Form,
HTTPException,
UploadFile,
)
from fastapi.responses import StreamingResponse

import ragna
import ragna.core
from ragna._compat import anext
from ragna.core._utils import default_user
from ragna.deploy import Config

from . import _schemas as schemas
from ._engine import Engine


def make_router(config: Config, engine: Engine) -> APIRouter:
def make_router(engine: Engine) -> APIRouter:
router = APIRouter(tags=["API"])

def get_user() -> str:
return default_user()

UserDependency = Annotated[str, Depends(get_user)]

# TODO: the document endpoints do not go through the engine, because they'll change
# quite drastically when the UI no longer depends on the API

_database = engine._database

@router.post("/document")
async def create_document_upload_info(
user: UserDependency,
name: Annotated[str, Body(..., embed=True)],
) -> schemas.DocumentUpload:
with _database.get_session() as session:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
_database.add_document(
session, user=user, document=document, metadata=metadata
)
return schemas.DocumentUpload(parameters=parameters, document=document)

# TODO: Add UI support and documentation for this endpoint (#406)
@router.post("/documents")
async def create_documents_upload_info(
user: UserDependency,
names: Annotated[list[str], Body(..., embed=True)],
) -> list[schemas.DocumentUpload]:
with _database.get_session() as session:
document_metadata_collection = []
document_upload_collection = []
for name in names:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
document_metadata_collection.append((document, metadata))
document_upload_collection.append(
schemas.DocumentUpload(parameters=parameters, document=document)
)

_database.add_documents(
session,
user=user,
document_metadata_collection=document_metadata_collection,
)
return document_upload_collection

# TODO: Add new endpoint for batch uploading documents (#407)
@router.put("/document")
async def upload_document(
token: Annotated[str, Form()], file: UploadFile
) -> schemas.Document:
if not issubclass(config.document, ragna.core.LocalDocument):
raise HTTPException(
status_code=400,
detail="Ragna configuration does not support local upload",
)
with _database.get_session() as session:
user, id = ragna.core.LocalDocument.decode_upload_token(token)
document = _database.get_document(session, user=user, id=id)

core_document = cast(
ragna.core.LocalDocument, engine._to_core.document(document)
)
core_document.path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(core_document.path, "wb") as document_file:
while content := await file.read(1024):
await document_file.write(content)

return document
def register_documents(
user: UserDependency, document_registrations: list[schemas.DocumentRegistration]
) -> list[schemas.Document]:
return engine.register_documents(
user=user, document_registrations=document_registrations
)

@router.put("/documents")
async def upload_documents(
user: UserDependency, documents: list[UploadFile]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Does this load all files into memory on the server?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. The UploadFile from FastAPI actually wraps a SpooledTemporaryFile, i.e. it will only be kept in memory for small files and otherwise temporarily be stored on disk.

) -> None:
def make_content_stream(file: UploadFile) -> AsyncIterator[bytes]:
async def content_stream() -> AsyncIterator[bytes]:
while content := await file.read(16 * 1024):
yield content

return content_stream()

await engine.store_documents(
user=user,
ids_and_streams=[
(uuid.UUID(document.filename), make_content_stream(document))
for document in documents
],
)

@router.get("/components")
def get_components(_: UserDependency) -> schemas.Components:
Expand All @@ -110,9 +59,9 @@ def get_components(_: UserDependency) -> schemas.Components:
@router.post("/chats")
async def create_chat(
user: UserDependency,
chat_metadata: schemas.ChatMetadata,
chat_creation: schemas.ChatCreation,
) -> schemas.Chat:
return engine.create_chat(user=user, chat_metadata=chat_metadata)
return engine.create_chat(user=user, chat_creation=chat_creation)

@router.get("/chats")
async def get_chats(user: UserDependency) -> list[schemas.Chat]:
Expand Down
5 changes: 2 additions & 3 deletions ragna/deploy/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def make_app(
ignore_unavailable_components: bool,
open_browser: bool,
) -> FastAPI:
ragna.local_root(config.local_root)
set_redirect_root_path(config.root_path)

lifespan: Optional[Callable[[FastAPI], AsyncContextManager]]
Expand All @@ -38,7 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
def target() -> None:
client = httpx.Client(base_url=config._url)

def server_available():
def server_available() -> bool:
try:
return client.get("/health").is_success
except httpx.ConnectError:
Expand Down Expand Up @@ -77,7 +76,7 @@ def server_available():
)

if api:
app.include_router(make_api_router(config, engine), prefix="/api")
app.include_router(make_api_router(engine), prefix="/api")

if ui:
panel_app = make_ui_app(config=config)
Expand Down
Loading
Loading