diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 7deb55c2..21a3742a 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -53,10 +53,9 @@ runs: mamba env update --file environment-dev.yml git checkout -- environment-dev.yml - - name: Install redis-server if necessary - if: (steps.cache.outputs.cache-hit != 'true') && (runner.os != 'Windows') + - name: Install playwright shell: bash -el {0} - run: mamba install --yes --channel conda-forge redis-server + run: playwright install - name: Install ragna shell: bash -el {0} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b74f3907..5ff4147d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,10 +68,76 @@ jobs: - name: Run unit tests id: tests - run: pytest --junit-xml=test-results.xml --durations=25 + run: | + pytest \ + --ignore tests/deploy/ui \ + --junit-xml=test-results.xml \ + --durations=25 - name: Surface failing tests if: steps.tests.outcome != 'success' uses: pmeier/pytest-results-action@v0.3.0 with: path: test-results.xml + + pytest-ui: + strategy: + matrix: + os: + - ubuntu-latest + - windows-latest + - macos-latest + browser: + - chromium + - firefox + python-version: + - "3.9" + - "3.10" + - "3.11" + exclude: + - python-version: "3.10" + os: windows-latest + - python-version: "3.11" + os: windows-latest + - python-version: "3.10" + os: macos-latest + - python-version: "3.11" + os: macos-latest + include: + - browser: webkit + os: macos-latest + python-version: "3.9" + + fail-fast: false + + runs-on: ${{ matrix.os }} + + defaults: + run: + shell: bash -el {0} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup environment + uses: ./.github/actions/setup-env + with: + python-version: ${{ matrix.python-version }} + + - name: Run unit tests + id: tests + run: | + pytest tests/deploy/ui \ + --browser ${{ matrix.browser }} \ + --video=retain-on-failure + + - name: Upload playwright video + if: failure() + uses: actions/upload-artifact@v4 + with: + name: + playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }} + path: test-results diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index f80f499e..e67ba38a 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -31,13 +31,21 @@ # - [ragna.assistants.Gpt4][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] from ragna import assistants class DemoStreamingAssistant(assistants.RagnaDemoAssistant): - def answer(self, prompt, sources): - content = next(super().answer(prompt, sources)) + def answer(self, messages): + content = next(super().answer(messages)) for chunk in content.split(" "): yield f"{chunk} " diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py index 3b6af0a2..4e7674da 100644 --- a/docs/tutorials/gallery_custom_components.py +++ b/docs/tutorials/gallery_custom_components.py @@ -30,7 +30,7 @@ import uuid -from ragna.core import Document, Source, SourceStorage +from ragna.core import Document, Source, SourceStorage, Message class TutorialSourceStorage(SourceStorage): @@ -61,9 +61,9 @@ def retrieve( # %% # ### Assistant # -# [ragna.core.Assistant][]s are objects that take a user prompt and relevant -# [ragna.core.Source][]s and generate a response form that. Usually, assistants are -# LLMs. +# [ragna.core.Assistant][]s are objects that take the chat history as list of +# [ragna.core.Message][]s and their relevant [ragna.core.Source][]s and generate a +# response from that. Usually, assistants are LLMs. # # In this tutorial, we define a minimal `TutorialAssistant` that is similar to # [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user @@ -82,8 +82,11 @@ def retrieve( class TutorialAssistant(Assistant): - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: + def answer(self, messages: list[Message]) -> Iterator[str]: print(f"Running {type(self).__name__}().answer()") + # For simplicity, we only deal with the last message here, i.e. the latest user + # prompt. + prompt, sources = (message := messages[-1]).content, message.sources yield ( f"To answer the user prompt '{prompt}', " f"I was given {len(sources)} source(s)." @@ -254,8 +257,7 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: class ElaborateTutorialAssistant(Assistant): def answer( self, - prompt: str, - sources: list[Source], + messages: list[Message], *, my_required_parameter: int, my_optional_parameter: str = "foo", @@ -393,9 +395,7 @@ def answer( class AsyncAssistant(Assistant): - async def answer( - self, prompt: str, sources: list[Source] - ) -> AsyncIterator[str]: + async def answer(self, messages: list[Message]) -> AsyncIterator[str]: print(f"Running {type(self).__name__}().answer()") start = time.perf_counter() await asyncio.sleep(0.3) diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index eb11410f..b7d17dc3 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -87,6 +87,14 @@ # - [ragna.assistants.Jurassic2Ultra][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] # # !!! note # diff --git a/environment-dev.yml b/environment-dev.yml index 2a7b6a03..9ee47a1b 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -10,6 +10,7 @@ dependencies: - pytest >=6 - pytest-mock - pytest-asyncio + - pytest-playwright - mypy ==1.10.0 - pre-commit - types-aiofiles diff --git a/pyproject.toml b/pyproject.toml index 9ceb1569..ccdc84e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,12 @@ dependencies = [ "httpx", "importlib_metadata>=4.6; python_version<'3.10'", "packaging", - "panel==1.4.2", + "panel==1.4.4", "pydantic>=2", "pydantic-core", "pydantic-settings>=2", "PyJWT", "python-multipart", - "redis", "questionary", "rich", "sqlalchemy>=2", diff --git a/ragna-docker.toml b/ragna-docker.toml index 1874fb3c..239ef8a2 100644 --- a/ragna-docker.toml +++ b/ragna-docker.toml @@ -3,22 +3,28 @@ authentication = "ragna.deploy.RagnaDemoAuthentication" document = "ragna.core.LocalDocument" source_storages = [ "ragna.source_storages.Chroma", - "ragna.source_storages.RagnaDemoSourceStorage", "ragna.source_storages.LanceDB" ] assistants = [ - "ragna.assistants.Jurassic2Ultra", - "ragna.assistants.Claude", - "ragna.assistants.ClaudeInstant", + "ragna.assistants.ClaudeHaiku", + "ragna.assistants.ClaudeOpus", + "ragna.assistants.ClaudeSonnet", "ragna.assistants.Command", "ragna.assistants.CommandLight", - "ragna.assistants.RagnaDemoAssistant", "ragna.assistants.GeminiPro", "ragna.assistants.GeminiUltra", - "ragna.assistants.Mpt7bInstruct", - "ragna.assistants.Mpt30bInstruct", - "ragna.assistants.Gpt4", + "ragna.assistants.OllamaGemma2B", + "ragna.assistants.OllamaPhi2", + "ragna.assistants.OllamaLlama2", + "ragna.assistants.OllamaLlava", + "ragna.assistants.OllamaMistral", + "ragna.assistants.OllamaMixtral", + "ragna.assistants.OllamaOrcaMini", "ragna.assistants.Gpt35Turbo16k", + "ragna.assistants.Gpt4", + "ragna.assistants.Jurassic2Ultra", + "ragna.assistants.LlamafileAssistant", + "ragna.assistants.RagnaDemoAssistant", ] [api] diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index d583e7a0..bcf5ead6 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -6,6 +6,13 @@ "CommandLight", "GeminiPro", "GeminiUltra", + "OllamaGemma2B", + "OllamaPhi2", + "OllamaLlama2", + "OllamaLlava", + "OllamaMistral", + "OllamaMixtral", + "OllamaOrcaMini", "Gpt35Turbo16k", "Gpt4", "Jurassic2Ultra", @@ -19,6 +26,15 @@ from ._demo import RagnaDemoAssistant from ._google import GeminiPro, GeminiUltra from ._llamafile import LlamafileAssistant +from ._ollama import ( + OllamaGemma2B, + OllamaLlama2, + OllamaLlava, + OllamaMistral, + OllamaMixtral, + OllamaOrcaMini, + OllamaPhi2, +) from ._openai import Gpt4, Gpt35Turbo16k # isort: split diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 1c61a213..3d6da65d 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,12 +1,13 @@ from typing import AsyncIterator, cast -from ragna.core import Source +from ragna.core import Message, Source from ._http_api import HttpApiAssistant class Ai21LabsAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" + _STREAMING_PROTOCOL = None _MODEL_TYPE: str @classmethod @@ -22,12 +23,14 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response - response = await self._client.post( + prompt, sources = (message := messages[-1]).content, message.sources + async for data in self._call_api( + "POST", f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat", headers={ "accept": "application/json", @@ -46,10 +49,8 @@ async def answer( ], "system": self._make_system_content(sources), }, - ) - await self._assert_api_call_is_success(response) - - yield cast(str, response.json()["outputs"][0]["text"]) + ): + yield cast(str, data["outputs"][0]["text"]) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 37f132b5..5a618f66 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,12 +1,13 @@ from typing import AsyncIterator, cast -from ragna.core import PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class AnthropicAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL: str @classmethod @@ -36,11 +37,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: ) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - async for data in self._stream_sse( + prompt, sources = (message := messages[-1]).content, message.sources + async for data in self._call_api( "POST", "https://api.anthropic.com/v1/messages", headers={ diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index b47737f8..f3920770 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,12 +1,13 @@ from typing import AsyncIterator, cast -from ragna.core import RagnaException, Source +from ragna.core import Message, RagnaException, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class CohereAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL _MODEL: str @classmethod @@ -24,12 +25,13 @@ def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - async for event in self._stream_jsonl( + prompt, sources = (message := messages[-1]).content, message.sources + async for event in self._call_api( "POST", "https://api.cohere.ai/v1/chat", headers={ diff --git a/ragna/assistants/_demo.py b/ragna/assistants/_demo.py index cf27a893..f9bf644e 100644 --- a/ragna/assistants/_demo.py +++ b/ragna/assistants/_demo.py @@ -1,8 +1,7 @@ -import re import textwrap from typing import Iterator -from ragna.core import Assistant, Source +from ragna.core import Assistant, Message, MessageRole class RagnaDemoAssistant(Assistant): @@ -22,11 +21,11 @@ class RagnaDemoAssistant(Assistant): def display_name(cls) -> str: return "Ragna/DemoAssistant" - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: - if re.search("markdown", prompt, re.IGNORECASE): + def answer(self, messages: list[Message]) -> Iterator[str]: + if "markdown" in messages[-1].content.lower(): yield self._markdown_answer() else: - yield self._default_answer(prompt, sources) + yield self._default_answer(messages) def _markdown_answer(self) -> str: return textwrap.dedent( @@ -39,7 +38,8 @@ def _markdown_answer(self) -> str: """ ).strip() - def _default_answer(self, prompt: str, sources: list[Source]) -> str: + def _default_answer(self, messages: list[Message]) -> str: + prompt, sources = (message := messages[-1]).content, message.sources sources_display = [] for source in sources: source_display = f"- {source.document_name}" @@ -50,13 +50,16 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str: if len(sources) > 3: sources_display.append("[...]") + n_messages = len([m for m in messages if m.role == MessageRole.USER]) return ( textwrap.dedent( """ - I'm a demo assistant and can be used to try Ragnas workflow. + I'm a demo assistant and can be used to try Ragna's workflow. I will only mirror back my inputs. + + So far I have received {n_messages} messages. - Your prompt was: + Your last prompt was: > {prompt} @@ -66,5 +69,10 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str: """ ) .strip() - .format(name=str(self), prompt=prompt, sources="\n".join(sources_display)) + .format( + name=str(self), + n_messages=n_messages, + prompt=prompt, + sources="\n".join(sources_display), + ) ) diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 8e1caf1e..7069565a 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,38 +1,15 @@ from typing import AsyncIterator -from ragna._compat import anext -from ragna.core import PackageRequirement, Requirement, Source +from ragna.core import Message, Source -from ._http_api import HttpApiAssistant - - -# ijson does not support reading from an (async) iterator, but only from file-like -# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. -# See https://github.com/ICRAR/ijson/issues/44 for details. -# ijson actually doesn't care about most of the file interface and only requires the -# read() method to be present. -class AsyncIteratorReader: - def __init__(self, ait: AsyncIterator[bytes]) -> None: - self._ait = ait - - async def read(self, n: int) -> bytes: - # n is usually used to indicate how many bytes to read, but since we want to - # return a chunk as soon as it is available, we ignore the value of n. The only - # exception is n == 0, which is used by ijson to probe the return type and - # set up decoding. - if n == 0: - return b"" - return await anext(self._ait, b"") # type: ignore[call-arg] +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class GoogleAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSON _MODEL: str - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("ijson")] - @classmethod def display_name(cls) -> str: return f"Google/{cls._MODEL}" @@ -49,11 +26,10 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ) async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import ijson - - async with self._client.stream( + prompt, sources = (message := messages[-1]).content, message.sources + async for chunk in self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", params={"key": self._api_key}, @@ -64,7 +40,10 @@ async def answer( ], # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ - {"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"} + { + "category": f"HARM_CATEGORY_{category}", + "threshold": "BLOCK_NONE", + } for category in [ "HARASSMENT", "HATE_SPEECH", @@ -78,14 +57,9 @@ async def answer( "maxOutputTokens": max_new_tokens, }, }, - ) as response: - await self._assert_api_call_is_success(response) - - async for chunk in ijson.items( - AsyncIteratorReader(response.aiter_bytes(1024)), - "item.candidates.item.content.parts.item.text", - ): - yield chunk + parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + ): + yield chunk class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index 1151a62a..d6f48a26 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -1,65 +1,83 @@ import contextlib +import enum import json import os from typing import Any, AsyncIterator, Optional import httpx -from httpx import Response import ragna -from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement +from ragna._compat import anext +from ragna.core import ( + Assistant, + EnvVarRequirement, + PackageRequirement, + RagnaException, + Requirement, +) -class HttpApiAssistant(Assistant): - _API_KEY_ENV_VAR: Optional[str] +class HttpStreamingProtocol(enum.Enum): + SSE = enum.auto() + JSONL = enum.auto() + JSON = enum.auto() - @classmethod - def requirements(cls) -> list[Requirement]: - requirements: list[Requirement] = ( - [EnvVarRequirement(cls._API_KEY_ENV_VAR)] - if cls._API_KEY_ENV_VAR is not None - else [] - ) - requirements.extend(cls._extra_requirements()) - return requirements +class HttpApiCaller: @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [] + def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: + streaming_requirements: dict[HttpStreamingProtocol, list[Requirement]] = { + HttpStreamingProtocol.SSE: [PackageRequirement("httpx_sse")], + HttpStreamingProtocol.JSON: [PackageRequirement("ijson")], + } + return streaming_requirements.get(protocol, []) - def __init__(self) -> None: - self._client = httpx.AsyncClient( - headers={"User-Agent": f"{ragna.__version__}/{self}"}, - timeout=60, - ) - self._api_key: Optional[str] = ( - os.environ[self._API_KEY_ENV_VAR] - if self._API_KEY_ENV_VAR is not None - else None - ) - - async def _assert_api_call_is_success(self, response: Response) -> None: - if response.is_success: - return + def __init__( + self, + client: httpx.AsyncClient, + protocol: Optional[HttpStreamingProtocol] = None, + ) -> None: + self._client = client + self._protocol = protocol - content = await response.aread() - with contextlib.suppress(Exception): - content = json.loads(content) + def __call__( + self, + method: str, + url: str, + *, + parse_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + if self._protocol is None: + call_method = self._no_stream + else: + call_method = { + HttpStreamingProtocol.SSE: self._stream_sse, + HttpStreamingProtocol.JSONL: self._stream_jsonl, + HttpStreamingProtocol.JSON: self._stream_json, + }[self._protocol] + return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs) - raise RagnaException( - "API call failed", - request_method=response.request.method, - request_url=str(response.request.url), - response_status_code=response.status_code, - response_content=content, - ) + async def _no_stream( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + response = await self._client.request(method, url, **kwargs) + await self._assert_api_call_is_success(response) + yield response.json() async def _stream_sse( self, method: str, url: str, + *, + parse_kwargs: dict[str, Any], **kwargs: Any, - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[Any]: import httpx_sse async with httpx_sse.aconnect_sse( @@ -71,10 +89,103 @@ async def _stream_sse( yield json.loads(sse.data) async def _stream_jsonl( - self, method: str, url: str, **kwargs: Any - ) -> AsyncIterator[dict[str, Any]]: + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: async with self._client.stream(method, url, **kwargs) as response: await self._assert_api_call_is_success(response) async for chunk in response.aiter_lines(): yield json.loads(chunk) + + # ijson does not support reading from an (async) iterator, but only from file-like + # objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. + # See https://github.com/ICRAR/ijson/issues/44 for details. + # ijson actually doesn't care about most of the file interface and only requires the + # read() method to be present. + class _AsyncIteratorReader: + def __init__(self, ait: AsyncIterator[bytes]) -> None: + self._ait = ait + + async def read(self, n: int) -> bytes: + # n is usually used to indicate how many bytes to read, but since we want to + # return a chunk as soon as it is available, we ignore the value of n. The + # only exception is n == 0, which is used by ijson to probe the return type + # and set up decoding. + if n == 0: + return b"" + return await anext(self._ait, b"") # type: ignore[call-arg] + + async def _stream_json( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + import ijson + + item = parse_kwargs["item"] + chunk_size = parse_kwargs.get("chunk_size", 16) + + async with self._client.stream(method, url, **kwargs) as response: + await self._assert_api_call_is_success(response) + + async for chunk in ijson.items( + self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item + ): + yield chunk + + async def _assert_api_call_is_success(self, response: httpx.Response) -> None: + if response.is_success: + return + + content = await response.aread() + with contextlib.suppress(Exception): + content = json.loads(content) + + raise RagnaException( + "API call failed", + request_method=response.request.method, + request_url=str(response.request.url), + response_status_code=response.status_code, + response_content=content, + ) + + +class HttpApiAssistant(Assistant): + _API_KEY_ENV_VAR: Optional[str] + _STREAMING_PROTOCOL: Optional[HttpStreamingProtocol] + + @classmethod + def requirements(cls) -> list[Requirement]: + requirements: list[Requirement] = ( + [EnvVarRequirement(cls._API_KEY_ENV_VAR)] + if cls._API_KEY_ENV_VAR is not None + else [] + ) + if cls._STREAMING_PROTOCOL is not None: + requirements.extend(HttpApiCaller.requirements(cls._STREAMING_PROTOCOL)) + requirements.extend(cls._extra_requirements()) + return requirements + + @classmethod + def _extra_requirements(cls) -> list[Requirement]: + return [] + + def __init__(self) -> None: + self._client = httpx.AsyncClient( + headers={"User-Agent": f"{ragna.__version__}/{self}"}, + timeout=60, + ) + self._api_key: Optional[str] = ( + os.environ[self._API_KEY_ENV_VAR] + if self._API_KEY_ENV_VAR is not None + else None + ) + self._call_api = HttpApiCaller(self._client, self._STREAMING_PROTOCOL) diff --git a/ragna/assistants/_llamafile.py b/ragna/assistants/_llamafile.py index 3e78a625..5c7cc1da 100644 --- a/ragna/assistants/_llamafile.py +++ b/ragna/assistants/_llamafile.py @@ -1,9 +1,11 @@ import os +from functools import cached_property -from ._openai import OpenaiCompliantHttpApiAssistant +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant -class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): +class LlamafileAssistant(OpenaiLikeHttpApiAssistant): """[llamafile](https://github.com/Mozilla-Ocho/llamafile) To use this assistant, start the llamafile server manually. By default, the server @@ -16,10 +18,14 @@ class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): """ _API_KEY_ENV_VAR = None - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL = None - @property + @classmethod + def display_name(cls) -> str: + return "llamafile" + + @cached_property def _url(self) -> str: base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") return f"{base_url}/v1/chat/completions" diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py new file mode 100644 index 00000000..7a92b998 --- /dev/null +++ b/ragna/assistants/_ollama.py @@ -0,0 +1,84 @@ +import os +from functools import cached_property +from typing import AsyncIterator, cast + +from ragna.core import Message, RagnaException + +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant + + +class OllamaAssistant(OpenaiLikeHttpApiAssistant): + """[Ollama](https://ollama.com/) + + To use this assistant, start the Ollama server manually. By default, the server + is expected at `http://localhost:11434`. This can be changed with the + `RAGNA_OLLAMA_BASE_URL` environment variable. + """ + + _API_KEY_ENV_VAR = None + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL + _MODEL: str + + @classmethod + def display_name(cls) -> str: + return f"Ollama/{cls._MODEL}" + + @cached_property + def _url(self) -> str: + base_url = os.environ.get("RAGNA_OLLAMA_BASE_URL", "http://localhost:11434") + return f"{base_url}/api/chat" + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + prompt, sources = (message := messages[-1]).content, message.sources + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): + # Modeled after + # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 + if "error" in data: + raise RagnaException(data["error"]) + if not data["done"]: + yield cast(str, data["message"]["content"]) + + +class OllamaGemma2B(OllamaAssistant): + """[Gemma:2B](https://ollama.com/library/gemma)""" + + _MODEL = "gemma:2b" + + +class OllamaLlama2(OllamaAssistant): + """[Llama 2](https://ollama.com/library/llama2)""" + + _MODEL = "llama2" + + +class OllamaLlava(OllamaAssistant): + """[Llava](https://ollama.com/library/llava)""" + + _MODEL = "llava" + + +class OllamaMistral(OllamaAssistant): + """[Mistral](https://ollama.com/library/mistral)""" + + _MODEL = "mistral" + + +class OllamaMixtral(OllamaAssistant): + """[Mixtral](https://ollama.com/library/mixtral)""" + + _MODEL = "mixtral" + + +class OllamaOrcaMini(OllamaAssistant): + """[Orca Mini](https://ollama.com/library/orca-mini)""" + + _MODEL = "orca-mini" + + +class OllamaPhi2(OllamaAssistant): + """[Phi-2](https://ollama.com/library/phi)""" + + _MODEL = "phi" diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 37957be2..b004b595 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,23 +1,15 @@ import abc -from typing import Any, AsyncIterator, Literal, Optional, cast +from functools import cached_property +from typing import Any, AsyncIterator, Optional, cast -from ragna.core import PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Message, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol -class OpenaiCompliantHttpApiAssistant(HttpApiAssistant): - _STREAMING_METHOD: Literal["sse", "jsonl"] +class OpenaiLikeHttpApiAssistant(HttpApiAssistant): _MODEL: Optional[str] - @classmethod - def requirements(cls) -> list[Requirement]: - requirements = super().requirements() - requirements.extend( - {"sse": [PackageRequirement("httpx_sse")]}.get(cls._STREAMING_METHOD, []) - ) - return requirements - @property @abc.abstractmethod def _url(self) -> str: ... @@ -32,23 +24,8 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) def _stream( - self, - method: str, - url: str, - **kwargs: Any, + self, prompt: str, sources: list[Source], *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: - stream = { - "sse": self._stream_sse, - "jsonl": self._stream_jsonl, - }.get(self._STREAMING_METHOD) - if stream is None: - raise RagnaException - - return stream(method, url, **kwargs) - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming headers = { @@ -75,7 +52,13 @@ async def answer( if self._MODEL is not None: json_["model"] = self._MODEL - async for data in self._stream("POST", self._url, headers=headers, json=json_): + return self._call_api("POST", self._url, headers=headers, json=json_) + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + prompt, sources = (message := messages[-1]).content, message.sources + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): choice = data["choices"][0] if choice["finish_reason"] is not None: break @@ -83,15 +66,15 @@ async def answer( yield cast(str, choice["delta"]["content"]) -class OpenaiAssistant(OpenaiCompliantHttpApiAssistant): +class OpenaiAssistant(OpenaiLikeHttpApiAssistant): _API_KEY_ENV_VAR = "OPENAI_API_KEY" - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE @classmethod def display_name(cls) -> str: return f"OpenAI/{cls._MODEL}" - @property + @cached_property def _url(self) -> str: return "https://api.openai.com/v1/chat/completions" diff --git a/ragna/core/_components.py b/ragna/core/_components.py index 5d233fba..e592698c 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -52,6 +52,16 @@ def __repr__(self) -> str: def _protocol_models( cls, ) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]: + # This method dynamically builds a pydantic.BaseModel for the extra parameters + # of each method that is listed in the __ragna_protocol_methods__ class + # variable. These models are used by ragna.core.Chat._unpack_chat_params to + # validate and distribute the **params passed by the user. + + # Walk up the MRO until we find the __ragna_protocol_methods__ variable. This is + # the indicator that we found the protocol class. We use this as a reference for + # which params of a protocol method are part of the protocol (think positional + # parameters) and which are requested by the concrete class (think keyword + # parameters). protocol_cls, protocol_methods = next( (cls_, cls_.__ragna_protocol_methods__) # type: ignore[attr-defined] for cls_ in cls.__mro__ @@ -65,10 +75,14 @@ def _protocol_models( method = getattr(cls, method_name) params = iter(inspect.signature(method).parameters.values()) annotations = get_type_hints(method) + # Skip over the protocol parameters in order for the model below to only + # comprise concrete parameters. + for _ in range(num_protocol_params): next(params) - models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload] + models[(cls, method_name)] = pydantic.create_model( + # type: ignore[call-overload] f"{cls.__name__}.{method_name}", **{ param.name: ( @@ -138,7 +152,7 @@ def retrieve(self, metadata_filter: MetadataFilter, prompt: str) -> list[Source] ... -class MessageRole(enum.Enum): +class MessageRole(str, enum.Enum): """Message role Attributes: @@ -229,12 +243,12 @@ class Assistant(Component, abc.ABC): __ragna_protocol_methods__ = ["answer"] @abc.abstractmethod - def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: - """Answer a prompt given some sources. + def answer(self, messages: list[Message]) -> Iterator[str]: + """Answer a prompt given the chat history. Args: - prompt: Prompt to be answered. - sources: Sources to use when answering answer the prompt. + messages: List of messages in the chat history. The last item is the current + user prompt and has the relevant sources attached to it. Returns: Answer. diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 58fe24b5..16b629dc 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -1,8 +1,11 @@ from __future__ import annotations +import contextlib import datetime import inspect +import itertools import uuid +from collections import defaultdict from typing import ( Any, AsyncIterator, @@ -19,6 +22,7 @@ ) import pydantic +import pydantic_core from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage @@ -217,14 +221,15 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: detail=RagnaException.EVENT, ) - self._messages.append(Message(content=prompt, role=MessageRole.USER)) - sources = await self._run( self.source_storage.retrieve, self.metadata_filter, prompt ) + question = Message(content=prompt, role=MessageRole.USER, sources=sources) + self._messages.append(question) + answer = Message( - content=self._run_gen(self.assistant.answer, prompt, sources), + content=self._run_gen(self.assistant.answer, self._messages.copy()), role=MessageRole.ASSISTANT, sources=sources, ) @@ -266,6 +271,15 @@ def _parse_input( def _unpack_chat_params( self, params: dict[str, Any] ) -> dict[Callable, dict[str, Any]]: + # This method does two things: + # 1. Validate the **params against the signatures of the protocol methods of the + # used components. This makes sure that + # - No parameter is passed that isn't used by at least one component + # - No parameter is missing that is needed by at least one component + # - No parameter is passed in the wrong type + # 2. Prepare the distribution of the parameters to the protocol method that + # requested them. The actual distribution happens in self._run and + # self._run_gen, but is only a dictionary lookup by then. component_models = { getattr(component, name): model for component in [self.source_storage, self.assistant] @@ -273,20 +287,104 @@ def _unpack_chat_params( } ChatModel = merge_models( - str(self.params["chat_id"]), + f"{self.__module__}.{type(self).__name__}-{self.params['chat_id']}", SpecialChatParams, *component_models.values(), config=pydantic.ConfigDict(extra="forbid"), ) - chat_params = ChatModel.model_validate(params, strict=True).model_dump( - exclude_none=True - ) + with self._format_validation_error(ChatModel): + chat_model = ChatModel.model_validate(params, strict=True) + + chat_params = chat_model.model_dump(exclude_none=True) return { fn: model(**chat_params).model_dump() for fn, model in component_models.items() } + @contextlib.contextmanager + def _format_validation_error( + self, model_cls: type[pydantic.BaseModel] + ) -> Iterator[None]: + try: + yield + except pydantic.ValidationError as validation_error: + errors = defaultdict(list) + for error in validation_error.errors(): + errors[error["type"]].append(error) + + parts = [ + f"Validating the Chat parameters resulted in {validation_error.error_count()} errors:", + "", + ] + + def format_error( + error: pydantic_core.ErrorDetails, + *, + annotation: bool = False, + value: bool = False, + ) -> str: + param = cast(str, error["loc"][0]) + + formatted_error = f"- {param}" + if annotation: + annotation_ = cast( + type, model_cls.model_fields[param].annotation + ).__name__ + formatted_error += f": {annotation_}" + + if value: + value_ = error["input"] + formatted_error += ( + f" = {value_!r}" if annotation else f"={value_!r}" + ) + + return formatted_error + + if "extra_forbidden" in errors: + parts.extend( + [ + "The following parameters are unknown:", + "", + *[ + format_error(error, value=True) + for error in errors["extra_forbidden"] + ], + "", + ] + ) + + if "missing" in errors: + parts.extend( + [ + "The following parameters are missing:", + "", + *[ + format_error(error, annotation=True) + for error in errors["missing"] + ], + "", + ] + ) + + type_errors = ["string_type", "int_type", "float_type", "bool_type"] + if any(type_error in errors for type_error in type_errors): + parts.extend( + [ + "The following parameters have the wrong type:", + "", + *[ + format_error(error, annotation=True, value=True) + for error in itertools.chain.from_iterable( + errors[type_error] for type_error in type_errors + ) + ], + "", + ] + ) + + raise RagnaException("\n".join(parts)) + async def _run( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any ) -> T: diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 5d008f4e..3540f011 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -233,11 +233,7 @@ def schema_to_core_chat( chat_name=chat.metadata.name, **chat.metadata.params, ) - # FIXME: We need to reconstruct the previous messages here. Right now this is - # not needed, because the chat itself never accesses past messages. However, - # if we implement a chat history feature, i.e. passing past messages to - # the assistant, this becomes crucial. - core_chat._messages = [] + core_chat._messages = [message.to_core() for message in chat.messages] core_chat._prepared = chat.prepared return core_chat @@ -291,12 +287,15 @@ async def answer( ) -> schemas.Message: with get_session() as session: chat = database.get_chat(session, user=user, id=id) - chat.messages.append( - schemas.Message(content=prompt, role=ragna.core.MessageRole.USER) - ) core_chat = schema_to_core_chat(session, user=user, chat=chat) core_answer = await core_chat.answer(prompt, stream=stream) + sources = [schemas.Source.from_core(source) for source in core_answer.sources] + chat.messages.append( + schemas.Message( + content=prompt, role=ragna.core.MessageRole.USER, sources=sources + ) + ) if stream: @@ -307,10 +306,7 @@ async def message_chunks() -> AsyncIterator[BaseModel]: answer = schemas.Message( content=content_chunk, role=core_answer.role, - sources=[ - schemas.Source.from_core(source) - for source in core_answer.sources - ], + sources=sources, ) yield answer diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py index 53957a74..37471c69 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_api/schemas.py @@ -26,6 +26,16 @@ def from_core(cls, document: ragna.core.Document) -> Document: name=document.name, ) + def to_core(self) -> ragna.core.Document: + return ragna.core.LocalDocument( + id=self.id, + name=self.name, + # TEMP: setting an empty metadata dict for now. + # Will be resolved as part of the "managed ragna" work: + # https://github.com/Quansight/ragna/issues/256 + metadata={}, + ) + class DocumentUpload(BaseModel): parameters: ragna.core.DocumentUploadParameters @@ -50,6 +60,15 @@ def from_core(cls, source: ragna.core.Source) -> Source: num_tokens=source.num_tokens, ) + def to_core(self) -> ragna.core.Source: + return ragna.core.Source( + id=self.id, + document=self.document.to_core(), + location=self.location, + content=self.content, + num_tokens=self.num_tokens, + ) + class Message(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) @@ -66,6 +85,13 @@ def from_core(cls, message: ragna.core.Message) -> Message: sources=[Source.from_core(source) for source in message.sources], ) + def to_core(self) -> ragna.core.Message: + return ragna.core.Message( + content=self.content, + role=self.role, + sources=[source.to_core() for source in self.sources], + ) + class ChatMetadata(BaseModel): name: str diff --git a/ragna/deploy/_ui/css/modal_welcome/button.css b/ragna/deploy/_ui/css/modal_welcome/button.css deleted file mode 100644 index 5c98c041..00000000 --- a/ragna/deploy/_ui/css/modal_welcome/button.css +++ /dev/null @@ -1,4 +0,0 @@ -:host(.modal_welcome_close_button) { - width: 35%; - margin-left: 60%; -} diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index ab8bc1c0..379acef5 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -1,3 +1,5 @@ +from datetime import datetime + import panel as pn import param @@ -59,6 +61,14 @@ def refresh(self): @pn.depends("refresh_counter", "chats", "current_chat_id", on_init=True) def __panel__(self): + epoch = datetime(1970, 1, 1) + self.chats.sort( + key=lambda chat: ( + epoch if not chat["messages"] else chat["messages"][-1]["timestamp"] + ), + reverse=True, + ) + self.chat_buttons = [] for chat in self.chats: button = pn.widgets.Button( diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index c8610e7b..4ba5ba94 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -3,12 +3,9 @@ import panel as pn import param -from . import js -from . import styles as ui from .central_view import CentralView from .left_sidebar import LeftSidebar from .modal_configuration import ModalConfiguration -from .modal_welcome import ModalWelcome from .right_sidebar import RightSidebar @@ -71,14 +68,6 @@ def open_modal(self): self.template.modal.objects[0].objects = [self.modal] self.template.open_modal() - def open_welcome_modal(self, event): - self.modal = ModalWelcome( - close_button_callback=lambda: self.template.close_modal(), - ) - - self.template.modal.objects[0].objects = [self.modal] - self.template.open_modal() - async def open_new_chat(self, new_chat_id): # called after creating a new chat. self.current_chat_id = new_chat_id @@ -111,59 +100,9 @@ def update_subviews_current_chat_id(self, avoid_senders=[]): def __panel__(self): asyncio.ensure_future(self.refresh_data()) - objects = [self.left_sidebar, self.central_view, self.right_sidebar] - - if self.chats is not None and len(self.chats) == 0: - """I haven't found a better way to open the modal when the pages load, - than simulating a click on the "New chat" button. - - calling self.template.open_modal() doesn't work - - calling self.on_click_new_chat doesn't work either - - trying to schedule a call to on_click_new_chat with pn.state.schedule_task - could have worked but my tests were yielding an unstable result. - """ - - new_chat_button_name = "open welcome modal" - open_welcome_modal = pn.widgets.Button( - name=new_chat_button_name, - button_type="primary", - ) - open_welcome_modal.on_click(self.open_welcome_modal) - - hack_open_modal = pn.pane.HTML( - """ - - """.replace( - "{new_chat_btn_name}", new_chat_button_name - ).strip(), - # This is not really styling per say, it's just a way to hide from the page the HTML item of this hack. - # It's not worth moving this to a separate file. - stylesheets=[ - ui.css( - ":host", - {"position": "absolute", "z-index": "-999"}, - ) - ], - ) - - objects.append( - pn.Row( - open_welcome_modal, - pn.pane.HTML(js.SHADOWROOT_INDEXING), - hack_open_modal, - visible=False, - ) - ) - - main_page = pn.Row( - *objects, + return pn.Row( + self.left_sidebar, + self.central_view, + self.right_sidebar, css_classes=["main_page_main_row"], ) - - return main_page diff --git a/ragna/deploy/_ui/modal_welcome.py b/ragna/deploy/_ui/modal_welcome.py deleted file mode 100644 index 71b6ad7f..00000000 --- a/ragna/deploy/_ui/modal_welcome.py +++ /dev/null @@ -1,42 +0,0 @@ -import panel as pn -import param - -from . import js -from . import styles as ui - - -class ModalWelcome(pn.viewable.Viewer): - close_button_callback = param.Callable() - - def __init__(self, **params): - super().__init__(**params) - - def did_click_on_close_button(self, event): - if self.close_button_callback is not None: - self.close_button_callback() - - def __panel__(self): - close_button = pn.widgets.Button( - name="Okay, let's go", - button_type="primary", - css_classes=["modal_welcome_close_button"], - ) - close_button.on_click(self.did_click_on_close_button) - - return pn.Column( - pn.pane.HTML( - f"""""" - + """

Welcome !


- Ragna is a RAG Orchestration Framework.
- With its UI, select and configure LLMs, upload documents, and chat with the LLM.
-
- Use Ragna UI out-of-the-box, as a daily-life interface with your favorite AI,
- or as a reference to build custom web applications. -


- """ - ), - close_button, - width=ui.WELCOME_MODAL_WIDTH, - height=ui.WELCOME_MODAL_HEIGHT, - sizing_mode="fixed", - ) diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 7f994eeb..213e6e1a 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -46,7 +46,6 @@ "right_sidebar": [pn.widgets.Button, pn.Column, pn.pane.Markdown], "left_sidebar": [pn.widgets.Button, pn.pane.HTML, pn.Column], "main_page": [pn.Row], - "modal_welcome": [pn.widgets.Button], "modal_configuration": [ pn.widgets.IntSlider, pn.layout.Card, @@ -103,9 +102,6 @@ def css(selector: Union[str, Iterable[str]], declarations: dict[str, str]) -> st CONFIG_MODAL_MAX_HEIGHT = 850 CONFIG_MODAL_WIDTH = 800 -WELCOME_MODAL_HEIGHT = 275 -WELCOME_MODAL_WIDTH = 530 - CSS_VARS = css( ":root", diff --git a/ragna/source_storages/_chroma.py b/ragna/source_storages/_chroma.py index 6b49251a..ad4651cc 100644 --- a/ragna/source_storages/_chroma.py +++ b/ragna/source_storages/_chroma.py @@ -123,6 +123,7 @@ def retrieve( ) -> list[Source]: collection = self._get_collection() + include = ["distances", "metadatas", "documents"] result = collection.query( query_texts=prompt, where=self._translate_metadata_filter(metadata_filter), @@ -141,22 +142,19 @@ def retrieve( max(int(num_tokens * 2 / chunk_size), 100), collection.count(), ), - include=["distances", "metadatas", "documents"], + include=include, # type: ignore[arg-type] ) num_results = len(result["ids"][0]) - result = { - key: [None] * num_results if value is None else value[0] # type: ignore[index] - for key, value in result.items() - } + result = {key: result[key][0] for key in ["ids", *include]} # type: ignore[literal-required] # dict of lists -> list of dicts results = [ - {key[:-1]: value[idx] for key, value in result.items()} + {key: value[idx] for key, value in result.items()} for idx in range(num_results) ] # That should be the default, but let's make extra sure here - results = sorted(results, key=lambda r: r["distance"]) + results = sorted(results, key=lambda r: r["distances"]) # TODO: we should have some functionality here to remove results with a high # distance to keep only "valid" sources. However, there are two issues: @@ -167,13 +165,13 @@ def retrieve( return self._take_sources_up_to_max_tokens( ( Source( - id=result["id"], + id=result["ids"], # FIXME: We no longer have access to the document here # maybe reflect the same in the demo component - document_id=result["metadata"]["document_id"], - location=result["metadata"]["page_numbers"], - content=result["document"], - num_tokens=result["metadata"]["num_tokens"], + document_id=result["metadatas"]["document_id"], + location=result["metadatas"]["page_numbers"], + content=result["documents"], + num_tokens=result["metadatas"]["num_tokens"], ) for result in results ), diff --git a/requirements-docker.lock b/requirements-docker.lock index 21d52e6c..b954077d 100644 --- a/requirements-docker.lock +++ b/requirements-docker.lock @@ -213,9 +213,9 @@ pandas==2.1.4 # via # bokeh # panel -panel==1.4.2 +panel==1.4.4 # via Ragna (pyproject.toml) -param==2.0.1 +param==2.1.1 # via # panel # pyviz-comms @@ -304,8 +304,6 @@ questionary==2.0.1 # via Ragna (pyproject.toml) ratelimiter==1.2.0.post0 # via lancedb -redis==5.0.1 - # via Ragna (pyproject.toml) regex==2023.12.25 # via tiktoken requests==2.31.0 diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 02b964b5..de852b0b 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -5,7 +5,7 @@ from ragna import assistants from ragna._compat import anext from ragna.assistants._http_api import HttpApiAssistant -from ragna.core import RagnaException +from ragna.core import Message, RagnaException from tests.utils import skip_on_windows HTTP_API_ASSISTANTS = [ @@ -25,7 +25,8 @@ async def test_api_call_error_smoke(mocker, assistant): mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"}) - chunks = assistant().answer(prompt="?", sources=[]) + messages = [Message(content="?", sources=[])] + chunks = assistant().answer(messages) with pytest.raises(RagnaException, match="API call failed"): await anext(chunks) diff --git a/tests/core/test_rag.py b/tests/core/test_rag.py index 77050146..b1766a6f 100644 --- a/tests/core/test_rag.py +++ b/tests/core/test_rag.py @@ -1,8 +1,7 @@ -import pydantic import pytest from ragna import Rag, assistants, source_storages -from ragna.core import LocalDocument +from ragna.core import Assistant, LocalDocument, RagnaException @pytest.fixture() @@ -14,20 +13,82 @@ def demo_document(tmp_path, request): class TestChat: - def chat(self, documents, **params): + def chat( + self, + documents, + source_storage=source_storages.RagnaDemoSourceStorage, + assistant=assistants.RagnaDemoAssistant, + **params, + ): return Rag().chat( input=documents, - source_storage=source_storages.RagnaDemoSourceStorage, - assistant=assistants.RagnaDemoAssistant, + source_storage=source_storage, + assistant=assistant, **params, ) - def test_extra_params(self, demo_document): - with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"): + def test_params_validation_unknown(self, demo_document): + params = { + "bool_param": False, + "int_param": 1, + "float_param": 0.5, + "string_param": "arbitrary_value", + } + with pytest.raises(RagnaException, match="unknown") as exc_info: + self.chat(documents=[demo_document], **params) + + msg = str(exc_info.value) + for param, value in params.items(): + assert f"{param}={value!r}" in msg + + def test_params_validation_missing(self, demo_document): + class ValidationAssistant(Assistant): + def answer( + self, + messages, + bool_param: bool, + int_param: int, + float_param: float, + string_param: str, + ): + pass + + with pytest.raises(RagnaException, match="missing") as exc_info: + self.chat(documents=[demo_document], assistant=ValidationAssistant) + + msg = str(exc_info.value) + for param, annotation in ValidationAssistant.answer.__annotations__.items(): + assert f"{param}: {annotation.__name__}" in msg + + def test_params_validation_wrong_type(self, demo_document): + class ValidationAssistant(Assistant): + def answer( + self, + messages, + bool_param: bool, + int_param: int, + float_param: float, + string_param: str, + ): + pass + + params = { + "bool_param": 1, + "int_param": 0.5, + "float_param": "arbitrary_value", + "string_param": False, + } + + with pytest.raises(RagnaException, match="wrong type") as exc_info: self.chat( - documents=[demo_document], not_supported_parameter="arbitrary_value" + documents=[demo_document], assistant=ValidationAssistant, **params ) + msg = str(exc_info.value) + for param, value in params.items(): + annotation = ValidationAssistant.answer.__annotations__[param] + assert f"{param}: {annotation.__name__} = {value!r}" in msg + def test_document_path(self, demo_document): chat = self.chat(documents=[demo_document.path]) diff --git a/tests/deploy/api/test_batch_endpoints.py b/tests/deploy/api/test_batch_endpoints.py index 94740750..3c85c77c 100644 --- a/tests/deploy/api/test_batch_endpoints.py +++ b/tests/deploy/api/test_batch_endpoints.py @@ -3,8 +3,7 @@ from ragna.deploy import Config from ragna.deploy._api import app - -from .utils import authenticate +from tests.deploy.utils import authenticate_with_api def test_batch_sequential_upload_equivalence(tmp_local_root): @@ -23,7 +22,7 @@ def test_batch_sequential_upload_equivalence(tmp_local_root): with TestClient( app(config=Config(), ignore_unavailable_components=False) ) as client: - authenticate(client) + authenticate_with_api(client) document1_upload = ( client.post("/document", json={"name": document_path1.name}) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 65f02209..b7fe464c 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -6,8 +6,7 @@ from ragna.core import RagnaException from ragna.deploy import Config from ragna.deploy._api import app - -from .utils import authenticate +from tests.deploy.utils import authenticate_with_api @pytest.mark.parametrize("ignore_unavailable_components", [True, False]) @@ -27,7 +26,7 @@ def test_ignore_unavailable_components(ignore_unavailable_components): ignore_unavailable_components=ignore_unavailable_components, ) ) as client: - authenticate(client) + authenticate_with_api(client) components = client.get("/components").raise_for_status().json() assert [assistant["title"] for assistant in components["assistants"]] == [ @@ -66,7 +65,7 @@ def test_unknown_component(tmp_local_root): with TestClient( app(config=Config(), ignore_unavailable_components=False) ) as client: - authenticate(client) + authenticate_with_api(client) document_upload = ( client.post("/document", json={"name": document_path.name}) diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 41b154db..4abbf7cf 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -1,31 +1,15 @@ import json -import time import pytest from fastapi.testclient import TestClient -from ragna.assistants import RagnaDemoAssistant from ragna.deploy import Config from ragna.deploy._api import app - -from .utils import authenticate - - -class TestAssistant(RagnaDemoAssistant): - def answer(self, prompt, sources, *, multiple_answer_chunks: bool): - # Simulate a "real" assistant through a small delay. See - # https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440 - # for why this is needed. - time.sleep(1e-3) - content = next(super().answer(prompt, sources)) - - if multiple_answer_chunks: - for chunk in content.split(" "): - yield f"{chunk} " - else: - yield content +from tests.deploy.utils import TestAssistant, authenticate_with_api +from tests.utils import skip_on_windows +@skip_on_windows @pytest.mark.parametrize("multiple_answer_chunks", [True, False]) @pytest.mark.parametrize("stream_answer", [True, False]) def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): @@ -38,7 +22,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): file.write("!\n") with TestClient(app(config=config, ignore_unavailable_components=False)) as client: - authenticate(client) + authenticate_with_api(client) assert client.get("/chats").raise_for_status().json() == [] @@ -125,12 +109,12 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() assert len(chat["messages"]) == 3 + assert chat["messages"][-1] == message assert ( chat["messages"][-2]["role"] == "user" - and chat["messages"][-2]["sources"] == [] + and chat["messages"][-2]["sources"] == message["sources"] and chat["messages"][-2]["content"] == prompt ) - assert chat["messages"][-1] == message client.delete(f"/chats/{chat['id']}").raise_for_status() assert client.get("/chats").raise_for_status().json() == [] diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py deleted file mode 100644 index abcf1411..00000000 --- a/tests/deploy/api/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from fastapi.testclient import TestClient - -from ragna.core._utils import default_user - - -def authenticate(client: TestClient) -> None: - username = default_user() - token = ( - client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, - ) - .raise_for_status() - .json() - ) - client.headers["Authorization"] = f"Bearer {token}" diff --git a/tests/deploy/ui/test_ui.py b/tests/deploy/ui/test_ui.py new file mode 100644 index 00000000..85699278 --- /dev/null +++ b/tests/deploy/ui/test_ui.py @@ -0,0 +1,157 @@ +import socket +import subprocess +import sys +import time + +import httpx +import panel as pn +import pytest +from playwright.sync_api import Page, expect + +from ragna._utils import timeout_after +from ragna.deploy import Config +from tests.deploy.utils import TestAssistant + + +def get_available_port(): + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def config( + tmp_local_root, +): + config = Config( + local_root=tmp_local_root, + assistants=[TestAssistant], + ui=dict(port=get_available_port()), + api=dict(port=get_available_port()), + ) + path = tmp_local_root / "ragna.toml" + config.to_file(path) + return config + + +class Server: + def __init__(self, config): + self.config = config + self.base_url = f"http://{config.ui.hostname}:{config.ui.port}" + + def server_up(self): + try: + return httpx.get(self.base_url).is_success + except httpx.ConnectError: + return False + + @timeout_after(60) + def start(self): + self.proc = subprocess.Popen( + [ + sys.executable, + "-m", + "ragna", + "ui", + "--config", + self.config.local_root / "ragna.toml", + "--start-api", + "--ignore-unavailable-components", + "--no-open-browser", + ], + stdout=sys.stdout, + stderr=sys.stderr, + ) + + while not self.server_up(): + time.sleep(1) + + def stop(self): + self.proc.kill() + pn.state.kill_all_servers() + + def __enter__(self): + self.start() + return self + + def __exit__(self, *args): + self.stop() + + +def test_health(config, page: Page) -> None: + with Server(config) as server: + health_url = f"{server.base_url}/health" + response = page.goto(health_url) + assert response.ok + + +def test_start_chat(config, page: Page) -> None: + with Server(config) as server: + # Index page, no auth + index_url = server.base_url + page.goto(index_url) + expect(page.get_by_role("button", name="Sign In")).to_be_visible() + + # Authorize with no credentials + page.get_by_role("button", name="Sign In").click() + expect(page.get_by_role("button", name=" New Chat")).to_be_visible() + + # expect auth token to be set + cookies = page.context.cookies() + assert len(cookies) == 1 + cookie = cookies[0] + assert cookie.get("name") == "auth_token" + auth_token = cookie.get("value") + assert auth_token is not None + + # New page button + new_chat_button = page.get_by_role("button", name=" New Chat") + expect(new_chat_button).to_be_visible() + new_chat_button.click() + + document_root = config.local_root / "documents" + document_root.mkdir() + document_name = "test.txt" + document_path = document_root / document_name + with open(document_path, "w") as file: + file.write("!\n") + + # File upload selector + with page.expect_file_chooser() as fc_info: + page.locator(".fileUpload").click() + file_chooser = fc_info.value + file_chooser.set_files(document_path) + + # Upload document and expect to see it listed + file_list = page.locator(".fileListContainer") + expect(file_list.first).to_have_text(str(document_name)) + + chat_dialog = page.get_by_role("dialog") + expect(chat_dialog).to_be_visible() + start_chat_button = page.get_by_role("button", name="Start Conversation") + expect(start_chat_button).to_be_visible() + time.sleep(0.5) # hack while waiting for button to be fully clickable + start_chat_button.click(delay=5) + + chat_box_row = page.locator(".chat-interface-input-row") + expect(chat_box_row).to_be_visible() + + chat_box = chat_box_row.get_by_role("textbox") + expect(chat_box).to_be_visible() + + # Document should be in the database + chats_url = f"http://{config.api.hostname}:{config.api.port}/chats" + chats = httpx.get( + chats_url, headers={"Authorization": f"Bearer {auth_token}"} + ).json() + assert len(chats) == 1 + chat = chats[0] + chat_documents = chat["metadata"]["documents"] + assert len(chat_documents) == 1 + assert chat_documents[0]["name"] == document_name + + chat_box.fill("Tell me about the documents") + + chat_button = chat_box_row.get_by_role("button") + expect(chat_button).to_be_visible() + chat_button.click() diff --git a/tests/deploy/utils.py b/tests/deploy/utils.py new file mode 100644 index 00000000..48b8c2ae --- /dev/null +++ b/tests/deploy/utils.py @@ -0,0 +1,44 @@ +import os +import time + +from fastapi.testclient import TestClient + +from ragna.assistants import RagnaDemoAssistant +from ragna.core._utils import default_user + + +class TestAssistant(RagnaDemoAssistant): + def answer(self, messages, *, multiple_answer_chunks: bool = True): + # Simulate a "real" assistant through a small delay. See + # https://github.com/Quansight/ragna/pull/401#issuecomment-2095851440 + # for why this is needed. + # + # Note: multiple_answer_chunks is given a default value here to satisfy + # the tests in deploy/ui/test_ui.py. This can be removed if TestAssistant + # is ever removed from that file. + time.sleep(1e-3) + content = next(super().answer(messages)) + + if multiple_answer_chunks: + for chunk in content.split(" "): + yield f"{chunk} " + else: + yield content + + +def authenticate_with_api(client: TestClient) -> None: + username = default_user() + token = ( + client.post( + "/token", + data={ + "username": username, + "password": os.environ.get( + "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + ), + }, + ) + .raise_for_status() + .json() + ) + client.headers["Authorization"] = f"Bearer {token}"