Skip to content

Commit

Permalink
erge branch 'corpus-dev' into metadata-translate
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 25, 2024
2 parents 6dd7ede + 1b51d27 commit 4faa1bb
Show file tree
Hide file tree
Showing 37 changed files with 911 additions and 386 deletions.
5 changes: 2 additions & 3 deletions .github/actions/setup-env/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ runs:
mamba env update --file environment-dev.yml
git checkout -- environment-dev.yml
- name: Install redis-server if necessary
if: (steps.cache.outputs.cache-hit != 'true') && (runner.os != 'Windows')
- name: Install playwright
shell: bash -el {0}
run: mamba install --yes --channel conda-forge redis-server
run: playwright install

- name: Install ragna
shell: bash -el {0}
Expand Down
68 changes: 67 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,76 @@ jobs:

- name: Run unit tests
id: tests
run: pytest --junit-xml=test-results.xml --durations=25
run: |
pytest \
--ignore tests/deploy/ui \
--junit-xml=test-results.xml \
--durations=25
- name: Surface failing tests
if: steps.tests.outcome != 'success'
uses: pmeier/[email protected]
with:
path: test-results.xml

pytest-ui:
strategy:
matrix:
os:
- ubuntu-latest
- windows-latest
- macos-latest
browser:
- chromium
- firefox
python-version:
- "3.9"
- "3.10"
- "3.11"
exclude:
- python-version: "3.10"
os: windows-latest
- python-version: "3.11"
os: windows-latest
- python-version: "3.10"
os: macos-latest
- python-version: "3.11"
os: macos-latest
include:
- browser: webkit
os: macos-latest
python-version: "3.9"

fail-fast: false

runs-on: ${{ matrix.os }}

defaults:
run:
shell: bash -el {0}

steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Setup environment
uses: ./.github/actions/setup-env
with:
python-version: ${{ matrix.python-version }}

- name: Run unit tests
id: tests
run: |
pytest tests/deploy/ui \
--browser ${{ matrix.browser }} \
--video=retain-on-failure
- name: Upload playwright video
if: failure()
uses: actions/upload-artifact@v4
with:
name:
playwright-${{ matrix.os }}-${{ matrix.python-version}}-${{ github.run_id }}
path: test-results
12 changes: 10 additions & 2 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@
# - [ragna.assistants.Gpt4][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]

from ragna import assistants


class DemoStreamingAssistant(assistants.RagnaDemoAssistant):
def answer(self, prompt, sources):
content = next(super().answer(prompt, sources))
def answer(self, messages):
content = next(super().answer(messages))
for chunk in content.split(" "):
yield f"{chunk} "

Expand Down
20 changes: 10 additions & 10 deletions docs/tutorials/gallery_custom_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import uuid

from ragna.core import Document, Source, SourceStorage
from ragna.core import Document, Source, SourceStorage, Message


class TutorialSourceStorage(SourceStorage):
Expand Down Expand Up @@ -61,9 +61,9 @@ def retrieve(
# %%
# ### Assistant
#
# [ragna.core.Assistant][]s are objects that take a user prompt and relevant
# [ragna.core.Source][]s and generate a response form that. Usually, assistants are
# LLMs.
# [ragna.core.Assistant][]s are objects that take the chat history as list of
# [ragna.core.Message][]s and their relevant [ragna.core.Source][]s and generate a
# response from that. Usually, assistants are LLMs.
#
# In this tutorial, we define a minimal `TutorialAssistant` that is similar to
# [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user
Expand All @@ -82,8 +82,11 @@ def retrieve(


class TutorialAssistant(Assistant):
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
def answer(self, messages: list[Message]) -> Iterator[str]:
print(f"Running {type(self).__name__}().answer()")
# For simplicity, we only deal with the last message here, i.e. the latest user
# prompt.
prompt, sources = (message := messages[-1]).content, message.sources
yield (
f"To answer the user prompt '{prompt}', "
f"I was given {len(sources)} source(s)."
Expand Down Expand Up @@ -254,8 +257,7 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
class ElaborateTutorialAssistant(Assistant):
def answer(
self,
prompt: str,
sources: list[Source],
messages: list[Message],
*,
my_required_parameter: int,
my_optional_parameter: str = "foo",
Expand Down Expand Up @@ -393,9 +395,7 @@ def answer(


class AsyncAssistant(Assistant):
async def answer(
self, prompt: str, sources: list[Source]
) -> AsyncIterator[str]:
async def answer(self, messages: list[Message]) -> AsyncIterator[str]:
print(f"Running {type(self).__name__}().answer()")
start = time.perf_counter()
await asyncio.sleep(0.3)
Expand Down
8 changes: 8 additions & 0 deletions docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@
# - [ragna.assistants.Jurassic2Ultra][]
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
# - [ragna.assistants.LlamafileAssistant][]
# - [Ollama](https://ollama.com/)
# - [ragna.assistants.OllamaGemma2B][]
# - [ragna.assistants.OllamaLlama2][]
# - [ragna.assistants.OllamaLlava][]
# - [ragna.assistants.OllamaMistral][]
# - [ragna.assistants.OllamaMixtral][]
# - [ragna.assistants.OllamaOrcaMini][]
# - [ragna.assistants.OllamaPhi2][]
#
# !!! note
#
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
- pytest >=6
- pytest-mock
- pytest-asyncio
- pytest-playwright
- mypy ==1.10.0
- pre-commit
- types-aiofiles
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ dependencies = [
"httpx",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel==1.4.2",
"panel==1.4.4",
"pydantic>=2",
"pydantic-core",
"pydantic-settings>=2",
"PyJWT",
"python-multipart",
"redis",
"questionary",
"rich",
"sqlalchemy>=2",
Expand Down
22 changes: 14 additions & 8 deletions ragna-docker.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@ authentication = "ragna.deploy.RagnaDemoAuthentication"
document = "ragna.core.LocalDocument"
source_storages = [
"ragna.source_storages.Chroma",
"ragna.source_storages.RagnaDemoSourceStorage",
"ragna.source_storages.LanceDB"
]
assistants = [
"ragna.assistants.Jurassic2Ultra",
"ragna.assistants.Claude",
"ragna.assistants.ClaudeInstant",
"ragna.assistants.ClaudeHaiku",
"ragna.assistants.ClaudeOpus",
"ragna.assistants.ClaudeSonnet",
"ragna.assistants.Command",
"ragna.assistants.CommandLight",
"ragna.assistants.RagnaDemoAssistant",
"ragna.assistants.GeminiPro",
"ragna.assistants.GeminiUltra",
"ragna.assistants.Mpt7bInstruct",
"ragna.assistants.Mpt30bInstruct",
"ragna.assistants.Gpt4",
"ragna.assistants.OllamaGemma2B",
"ragna.assistants.OllamaPhi2",
"ragna.assistants.OllamaLlama2",
"ragna.assistants.OllamaLlava",
"ragna.assistants.OllamaMistral",
"ragna.assistants.OllamaMixtral",
"ragna.assistants.OllamaOrcaMini",
"ragna.assistants.Gpt35Turbo16k",
"ragna.assistants.Gpt4",
"ragna.assistants.Jurassic2Ultra",
"ragna.assistants.LlamafileAssistant",
"ragna.assistants.RagnaDemoAssistant",
]

[api]
Expand Down
16 changes: 16 additions & 0 deletions ragna/assistants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
"CommandLight",
"GeminiPro",
"GeminiUltra",
"OllamaGemma2B",
"OllamaPhi2",
"OllamaLlama2",
"OllamaLlava",
"OllamaMistral",
"OllamaMixtral",
"OllamaOrcaMini",
"Gpt35Turbo16k",
"Gpt4",
"Jurassic2Ultra",
Expand All @@ -19,6 +26,15 @@
from ._demo import RagnaDemoAssistant
from ._google import GeminiPro, GeminiUltra
from ._llamafile import LlamafileAssistant
from ._ollama import (
OllamaGemma2B,
OllamaLlama2,
OllamaLlava,
OllamaMistral,
OllamaMixtral,
OllamaOrcaMini,
OllamaPhi2,
)
from ._openai import Gpt4, Gpt35Turbo16k

# isort: split
Expand Down
15 changes: 8 additions & 7 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import AsyncIterator, cast

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant


class Ai21LabsAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "AI21_API_KEY"
_STREAMING_PROTOCOL = None
_MODEL_TYPE: str

@classmethod
Expand All @@ -22,12 +23,14 @@ def _make_system_content(self, sources: list[Source]) -> str:
return instruction + "\n\n".join(source.content for source in sources)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
response = await self._client.post(
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
headers={
"accept": "application/json",
Expand All @@ -46,10 +49,8 @@ async def answer(
],
"system": self._make_system_content(sources),
},
)
await self._assert_api_call_is_success(response)

yield cast(str, response.json()["outputs"][0]["text"])
):
yield cast(str, data["outputs"][0]["text"])


# The Jurassic2Mid assistant receives a 500 internal service error from the remote
Expand Down
10 changes: 6 additions & 4 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import AsyncIterator, cast

from ragna.core import PackageRequirement, RagnaException, Requirement, Source
from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class AnthropicAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
_MODEL: str

@classmethod
Expand Down Expand Up @@ -36,11 +37,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
async for data in self._stream_sse(
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
headers={
Expand Down
10 changes: 6 additions & 4 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Message, RagnaException, Source

from ._http_api import HttpApiAssistant
from ._http_api import HttpApiAssistant, HttpStreamingProtocol


class CohereAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = "COHERE_API_KEY"
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL
_MODEL: str

@classmethod
Expand All @@ -24,12 +25,13 @@ def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
async for event in self._stream_jsonl(
prompt, sources = (message := messages[-1]).content, message.sources
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
headers={
Expand Down
Loading

0 comments on commit 4faa1bb

Please sign in to comment.