Skip to content

Commit

Permalink
cleanup examples
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Oct 16, 2023
1 parent a3ad3d5 commit a98d7e0
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 181 deletions.
52 changes: 19 additions & 33 deletions examples/local_llm/local_llm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,19 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "3020f7f6-8a50-4eaf-b9bf-b320fda5d67f",
"metadata": {},
"outputs": [],
"source": [
"from ragna.core import PackageRequirement\n",
"\n",
"for requirement in [\n",
" PackageRequirement(\"torch\"),\n",
" PackageRequirement(\"optimum\"),\n",
" PackageRequirement(\"auto-gptq\"),\n",
"]:\n",
" assert requirement.is_available(), requirement\n",
"\n",
"import torch\n",
"\n",
"assert torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "68349291-a5dd-41cb-b686-0f9101018ef0",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"platform/c++/implementation/internal.cpp:205:reinit_singlethreaded(): Reinitialising as single-threaded.\n"
]
}
],
"source": [
"from ragna.core import Assistant\n",
"from ragna.core import PackageRequirement, Assistant, Source\n",
"\n",
"\n",
"class AiroborosAssistant(Assistant):\n",
Expand Down Expand Up @@ -71,11 +58,13 @@
" )\n",
"\n",
" @property\n",
" def max_input_size(self):\n",
" def max_input_size(self) -> int:\n",
" # FIXME\n",
" return 1024\n",
"\n",
" def answer(self, prompt, sources):\n",
" def answer(\n",
" self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256\n",
" ) -> str:\n",
" template = \"\"\"\n",
" A chat about the content of documents.\n",
" Only use the content listed below to answer any questions from the user.\n",
Expand All @@ -96,11 +85,8 @@
" ).input_ids.cuda()\n",
" output_ids = self.model.generate(\n",
" inputs=input_ids,\n",
" temperature=0.7,\n",
" do_sample=True,\n",
" top_p=0.95,\n",
" top_k=40,\n",
" max_new_tokens=512,\n",
" do_sample=False,\n",
" max_new_tokens=max_new_tokens,\n",
" )\n",
" output = self.tokenizer.decode(output_ids[0])\n",
" return output.rsplit(\"ASSISTANT:\", 1)[-1].replace(\"</s>\", \"\").strip()\n",
Expand All @@ -111,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "7f4a5263-6d5c-4505-bcbe-d74ce96144ed",
"metadata": {},
"outputs": [
Expand All @@ -127,13 +113,13 @@
"output_type": "stream",
"text": [
"User: What is Ragna?\n",
"Assistant: Ragna is an OSS app for RAG workflows that offers a Python and REST API as well as web UI.\n"
"Assistant: Ragna is an open-source application for RAG workflows. It offers a Python and REST API as well as a web UI.\n"
]
}
],
"source": [
"from ragna.core import Rag\n",
"from ragna.source_storage import RagnaDemoSourceStorage\n",
"from ragna.source_storages import RagnaDemoSourceStorage\n",
"\n",
"rag = Rag()\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/rest_api/rest_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"USER = \"Ragna\"\n",
"\n",
"response = await client.get(\"/chats\", params={\"user\": USER})\n",
"pprint(response.json())"
"pprint(response.json(), sort_dicts=False)"
]
},
{
Expand Down Expand Up @@ -239,7 +239,7 @@
"response = await client.get(\"/document\", params={\"user\": USER, \"name\": path.name})\n",
"document_info = response.json()\n",
"document = document_info[\"document\"]\n",
"pprint(document_info)"
"pprint(document_info, sort_dicts=False)"
]
},
{
Expand Down
16 changes: 16 additions & 0 deletions examples/s3_documents/ragna-s3.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
local_cache_root = "/home/philip/.cache/ragna"

[rag]
queue_url = "memory"
document = "ragna_s3_document.S3Document"
source_storages = ["ragna.source_storages.RagnaDemoSourceStorage"]
assistants = ["ragna.assistants.RagnaDemoAssistant"]

[api]
url = "http://127.0.0.1:31476"
database_url = "memory"
upload_token_secret = "-34DVeiUKh1CiZLpz0io3c5ZniUIQKlQ"
upload_token_ttl = 300

[ui]
url = "http://127.0.0.1:31477"
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import os

import uuid
from typing import Any

from ragna.assistants import RagnaDemoAssistant

from ragna.core import Config, Document, PackageRequirement, RagnaException
from ragna.source_storages import RagnaDemoSourceStorage
from ragna.core import (
Config,
Document,
EnvVarRequirement,
PackageRequirement,
RagnaException,
Requirement,
)


class S3Document(Document):
@classmethod
def requirements(cls) -> list[Requirement]:
return [
PackageRequirement("boto3"),
EnvVarRequirement("AWS_ACCESS_KEY_ID"),
EnvVarRequirement("AWS_SECRET_ACCESS_KEY"),
EnvVarRequirement("AWS_REGION"),
EnvVarRequirement("AWS_S3_BUCKET"),
]

@classmethod
def _session(cls):
import boto3
Expand All @@ -34,7 +47,7 @@ async def get_upload_info(
response = s3.generate_presigned_post(
Bucket=bucket,
Key=str(id),
ExpiresIn=config.upload_token_ttl,
ExpiresIn=config.api.upload_token_ttl,
)

url = response["url"]
Expand All @@ -43,7 +56,7 @@ async def get_upload_info(

return url, data, metadata

def is_available(self) -> bool:
def is_readable(self) -> bool:
session = self._session()
s3 = session.resource("s3")

Expand All @@ -63,11 +76,3 @@ def read(self) -> bytes:
session = self._session()
s3 = session.resource("s3")
return s3.Object(self.metadata["bucket"], str(self.id)).get()["Body"].read()


config = Config(
state_database_url="sqlite://",
document_class=S3Document,
)
config.register_component(RagnaDemoSourceStorage)
config.register_component(RagnaDemoAssistant)
Loading

0 comments on commit a98d7e0

Please sign in to comment.