Skip to content

Commit

Permalink
[6/N] refactor async_generate with request (#578)
Browse files Browse the repository at this point in the history
* [6/N] refactor async_generate with request

* resolve comments

---------

Co-authored-by: guoli-yin <[email protected]>
  • Loading branch information
gyin94 and guoli-yin authored Jul 9, 2024
1 parent fea9ba8 commit e13d41a
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 48 deletions.
19 changes: 9 additions & 10 deletions axlearn/open_api/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

# isort: off
from axlearn.open_api.common import BaseClient, ClientRateLimitError, ValidationError
Expand Down Expand Up @@ -39,32 +39,31 @@ def _create_client(self) -> AsyncAnthropic:
async def async_generate(
self,
*,
messages: Optional[List[Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
request: Dict[str, Any],
**kwargs,
) -> str:
"""Generates response asynchronously from the client.
Args:
messages: OpenAI requests style messages.
tools: OpenAI tools definitions.
prompt: OpenAI prompt style.
request: OpenAI style request.
**kwargs: API request keyword arguments.
Returns:
Response in string format.
Raises:
ClientRateLimitError: Hits rate limiting for retries.
ValidationError: Field messages must be in request.
"""
if "messages" not in request:
raise ValidationError("Field messages must be in request.")
cfg: AnthropicClient.Config = self.config
client: AsyncAnthropic = self._client
request_kwargs = copy.deepcopy(kwargs)
anthropic_tools = None
if tools is not None:
anthropic_tools = _convert_openai_tools_to_anthropic(tools=tools)
anthropic_messages = _convert_openai_messages_to_anthropic(messages=messages)
if request.get("tools", None) is not None:
anthropic_tools = _convert_openai_tools_to_anthropic(tools=request["tools"])
anthropic_messages = _convert_openai_messages_to_anthropic(messages=request["messages"])
try:
# A temporary solution to encourage claude models to generate parallel tool calls.
if request_kwargs is not None and request_kwargs.get(
Expand Down
6 changes: 4 additions & 2 deletions axlearn/open_api/anthropic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ async def test_async_generate(self, mock_convert_tools, mock_convert_messages):
mock_response.model_dump_json.return_value = json.dumps({"response": "test_response"})
client._client.messages.create = AsyncMock(return_value=mock_response)
result = await client.async_generate(
messages=[{"role": "user", "content": "Hello"}],
tools=[{"name": "tool1"}],
request={
"messages": [{"role": "user", "content": "Hello"}],
"tools": [{"name": "tool1"}],
},
add_system_parallel_tools=True,
)
# Assert the expected result.
Expand Down
13 changes: 4 additions & 9 deletions axlearn/open_api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,16 @@ def parse_generation(cls, response: Dict[str, Any]) -> Sequence[ChatCompletionMe
async def async_generate(
self,
*,
messages: Optional[Sequence[Dict[str, Any]]] = None,
tools: Optional[Sequence[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
request: Dict[str, Any],
**kwargs,
) -> str:
"""Generates response asynchronously from the client.
Args:
client: Endpoint client.
messages: OpenAI requests style messages. Ref:
request: OpenAI style request. Ref:
https://github.com/openai/openai-python/blob/50371bf3151ebb1a43017abfe205d4d9b2e5faac/src/openai/resources/chat/completions.py#L237
https://github.com/openai/openai-python/blob/f3e6e634a86d5789ab1274ae27f43adc842f4ba8/src/openai/types/chat/chat_completion_message.py#L25
tools: OpenAI tools definitions.
prompt: OpenAI prompt style.
**kwargs: API request keyword arguments.
Returns:
Expand Down Expand Up @@ -188,9 +185,7 @@ async def _async_generate_from_request(
while True:
try:
response = await client.async_generate(
messages=request.get("messages", None),
tools=request.get("tools", None),
prompt=request.get("prompt", None),
request=request,
**kwargs,
)
break
Expand Down
40 changes: 28 additions & 12 deletions axlearn/open_api/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,27 @@ def _create_client(self) -> GenerativeModel:
async def async_generate(
self,
*,
messages: Optional[List[Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
request: Dict[str, Any],
**kwargs,
) -> str:
"""Generates response asynchronously from the client.
Args:
messages: OpenAI requests style messages.
tools: OpenAI tools definitions.
prompt: OpenAI prompt style.
request: OpenAI style request.
**kwargs: API request keyword arguments.
Returns:
Response in string format.
"""
contents = _convert_openai_messages_to_gemini(messages=messages)
if tools is not None:
gemini_tools = _convert_openai_tools_to_gemini(tools=tools)
Raises:
ValidationError: Field messages must be in request.
"""
if "messages" not in request:
raise ValidationError("Field messages must be in request.")
_format_request(request=request)
contents = _convert_openai_messages_to_gemini(messages=request["messages"])
if request.get("tools", None) is not None:
gemini_tools = _convert_openai_tools_to_gemini(tools=request["tools"])
else:
gemini_tools = None
client: GenerativeModel = self._client
Expand Down Expand Up @@ -176,8 +177,6 @@ def _aggregate_tool_role_messages(messages: List[Dict[str, Any]]) -> List[Dict[s
if message["role"] != "tool":
aggregated_messages.append(message)
continue
# Reduce tool name length which is smaller than OpenAI models.
message = _format_tool_message(message=message)
if len(aggregated_messages) > 0 and aggregated_messages[-1]["role"] == "tool":
tool_messages: list = aggregated_messages[-1]["tool_messages"]
aggregated_messages[-1]["tool_messages"] = tool_messages.append(message)
Expand All @@ -188,6 +187,23 @@ def _aggregate_tool_role_messages(messages: List[Dict[str, Any]]) -> List[Dict[s
return aggregated_messages


def _format_request(request: Dict[str, Any]):
"""Formats request to follow Gemini request rules."""
if "messages" in request:
request["messages"] = [
_format_tool_message(message=message) for message in request["messages"]
]
if "target_message" in request:
message = request["target_message"]
request["target_message"] = _format_tool_message(message=message)
if "tools" in request:
new_tools = []
for tool in request["tools"]:
tool["function"]["name"] = tool["function"]["name"][-_max_tool_name_length:]
new_tools.append(tool)
request["tools"] = new_tools


def _convert_openai_messages_to_gemini(messages: List[Dict[str, Any]]) -> List[Content]:
"""Converts OpenAI messages to Gemini Content.
Expand Down
6 changes: 4 additions & 2 deletions axlearn/open_api/gemini_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ async def test_async_generate(self, mock_convert_tools, mock_convert_messages):
client._client.generate_content_async = AsyncMock(return_value=mock_response)

result = await client.async_generate(
messages=[{"role": "user", "content": "Hello"}],
tools=[{"name": "tool1"}],
request={
"messages": [{"role": "user", "content": "Hello"}],
"tools": [{"type": "function", "function": {"name": "func1"}}],
},
temperature=0.7,
top_k=50,
top_p=0.9,
Expand Down
21 changes: 9 additions & 12 deletions axlearn/open_api/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import os
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

# isort: off
from axlearn.open_api.common import BaseClient, ClientRateLimitError, ValidationError
Expand Down Expand Up @@ -47,17 +47,13 @@ def _create_client(self) -> AsyncOpenAI:
async def async_generate(
self,
*,
messages: Optional[List[Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
request: Dict[str, Any],
**kwargs,
) -> str:
"""Generates response asynchronously from the client.
Args:
messages: OpenAI requests style messages.
tools: OpenAI tools definitions.
prompt: OpenAI prompt style.
request: OpenAI style request.
**kwargs: API request keyword arguments.
Returns:
Expand All @@ -69,20 +65,21 @@ async def async_generate(
"""
cfg: OpenAIClient.Config = self.config
client: AsyncOpenAI = self._client
assert prompt is not None or messages is not None, ValidationError(
"Either prompt or messages must be not None."
)
prompt = request.get("prompt", None)
messages = request.get("messages", None)
if prompt is None and messages is None:
raise ValidationError("Both prompt and messages are None.")
try:
if prompt is not None:
response: Completion = await client.completions.create(
prompt=prompt,
prompt=request["prompt"],
extra_body=cfg.extra_body,
**kwargs,
)
else:
response: ChatCompletion = await client.chat.completions.create(
messages=messages,
tools=tools,
tools=request.get("tools", None),
extra_body=cfg.extra_body,
**kwargs,
)
Expand Down
19 changes: 18 additions & 1 deletion axlearn/open_api/openai_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,29 @@


# pylint: disable=wrong-import-position
from axlearn.open_api.common import ClientRateLimitError, Generator
from axlearn.open_api.common import ClientRateLimitError, Generator, ValidationError
from axlearn.open_api.openai import OpenAIClient

# pylint: enable=wrong-import-position


class TestOpenAIClient(unittest.IsolatedAsyncioTestCase):
"""Unit tests for class OpenAIClient."""

def setUp(self):
self.client: OpenAIClient = (
OpenAIClient.default_config().set(model="gpt-3.5-turbo").instantiate()
)
self.client._client = AsyncMock()

async def test_async_generate_raises_validation_error(self):
request = {}
with self.assertRaises(ValidationError) as context:
await self.client.async_generate(request=request)

self.assertEqual(str(context.exception), "Both prompt and messages are None.")


class TestOpenAIAsyncGenerateFromRequests(unittest.IsolatedAsyncioTestCase):
"""Unit test for async_generate_from_requests."""

Expand Down

0 comments on commit e13d41a

Please sign in to comment.