Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Sep 6, 2024
1 parent 8e120c6 commit 6288bb1
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]:
else:
messages = prompt
return [
{"role": message.role.value, "content": message.content}
{"text": message.content, "role": message.role.value}
for message in messages
if message.role is not MessageRole.SYSTEM
]
Expand Down
3 changes: 2 additions & 1 deletion ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def generate(
Returns:
async streamed inference response string chunks
"""
# See https://ai.google.dev/api/generate-content#v1beta.models.streamGenerateContent
async with self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
Expand All @@ -76,7 +77,7 @@ async def generate(
"maxOutputTokens": max_new_tokens,
},
},
parse_kwargs=dict(item="item"), # .candidates.item.content.parts.item.text
parse_kwargs=dict(item="item"),
) as stream:
async for data in stream:
yield data
Expand Down
8 changes: 1 addition & 7 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import abc
from functools import cached_property
from typing import (
Any,
AsyncIterator,
Optional,
Union,
cast,
)
from typing import Any, AsyncIterator, Optional, Union, cast

from ragna.core import Message, MessageRole, Source

Expand Down
10 changes: 5 additions & 5 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ async def generate(self, messages):
content=messages[-1].content,
parse_kwargs=parse_kwargs,
) as stream:
async for chunk in stream:
yield chunk
async for data in stream:
yield data

async def answer(self, messages):
async for chunk in self.generate(messages):
if chunk.get("break"):
async for data in self.generate(messages):
if data.get("break"):
break

yield chunk
yield data


@skip_on_windows
Expand Down

0 comments on commit 6288bb1

Please sign in to comment.