Skip to content

Commit

Permalink
Address PR feedback. Largely about consistent / correct type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
skylarbpayne committed Oct 29, 2024
1 parent 5fd0b7e commit 6602329
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 23 deletions.
30 changes: 21 additions & 9 deletions mirascope/core/mistral/_utils/_convert_message_params.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utility for converting `BaseMessageParam` to `ChatMessage`."""

from typing import Any

from mistralai.models import (
AssistantMessage,
SystemMessage,
Expand All @@ -10,26 +12,36 @@
from ...base import BaseMessageParam


def make_message(
role: str,
**kwargs, # noqa: ANN003
) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage:
if role == "assistant":
return AssistantMessage(**kwargs)
elif role == "system":
return SystemMessage(**kwargs)
elif role == "tool":
return ToolMessage(**kwargs)
elif role == "user":
return UserMessage(**kwargs)
raise ValueError(f"Invalid role: {role}")


def convert_message_params(
message_params: list[
BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage
],
) -> list[BaseMessageParam]:
) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]:
converted_message_params = []
for message_param in message_params:
if not isinstance(
message_param,
BaseMessageParam,
):
if not isinstance(message_param, BaseMessageParam):
converted_message_params.append(message_param)
elif isinstance(content := message_param.content, str):
converted_message_params.append(
BaseMessageParam(**message_param.model_dump())
)
converted_message_params.append(make_message(**message_param.model_dump()))
else:
if len(content) != 1 or content[0].type != "text":
raise ValueError("Mistral currently only supports text parts.")
converted_message_params.append(
BaseMessageParam(role=message_param.role, content=content[0].text)
make_message(role=message_param.role, content=content[0].text)
)
return converted_message_params
5 changes: 2 additions & 3 deletions mirascope/core/mistral/_utils/_setup_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
UserMessage,
)

from mirascope.core.base._utils._protocols import fn_is_async

from ...base import BaseMessageParam, BaseTool, _utils
from ...base import BaseTool, _utils
from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn
from ...base._utils._protocols import fn_is_async
from ..call_kwargs import MistralCallKwargs
from ..call_params import MistralCallParams
from ..dynamic_config import MistralDynamicConfig
Expand Down
10 changes: 5 additions & 5 deletions mirascope/core/mistral/call_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
usage docs: learn/calls.md#handling-responses
"""

from typing import Any
from typing import Any, cast

from mistralai.models import (
AssistantMessage,
Expand All @@ -30,7 +30,7 @@ class MistralCallResponse(
MistralDynamicConfig,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
MistralCallParams,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
UserMessage,
]
):
"""A convenience wrapper around the Mistral `ChatCompletion` response.
Expand Down Expand Up @@ -104,9 +104,9 @@ def cost(self) -> float | None:
@property
def message_param(
self,
) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage:
) -> AssistantMessage:
"""Returns the assistants's response as a message parameter."""
return self.response.choices[0].message
return cast(AssistantMessage, self.response.choices[0].message)

@computed_field
@property
Expand Down Expand Up @@ -144,7 +144,7 @@ def tool(self) -> MistralTool | None:
@classmethod
def tool_message_params(
cls, tools_and_outputs: list[tuple[MistralTool, str]]
) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]:
) -> list[ToolMessage]:
"""Returns the tool message parameters for tool call results.
Args:
Expand Down
8 changes: 4 additions & 4 deletions mirascope/core/mistral/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class MistralStream(
BaseStream[
MistralCallResponse,
MistralCallResponseChunk,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
UserMessage,
AssistantMessage,
ToolMessage,
AssistantMessage | SystemMessage | ToolMessage | UserMessage,
MistralTool,
dict[str, Any],
Expand Down Expand Up @@ -69,7 +69,7 @@ def cost(self) -> float | None:

def _construct_message_param(
self, tool_calls: list | None = None, content: str | None = None
) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage:
) -> AssistantMessage:
message_param = AssistantMessage(
content=content if content else "", tool_calls=tool_calls
)
Expand Down
4 changes: 2 additions & 2 deletions tests/core/mistral/_utils/test_convert_message_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def test_convert_message_params() -> None:
converted_message_params = convert_message_params(message_params)
assert converted_message_params == [
UserMessage(content="Hello"),
BaseMessageParam(role="user", content="Hello"),
BaseMessageParam(role="user", content="Hello"),
UserMessage(role="user", content="Hello"),
UserMessage(role="user", content="Hello"),
]

with pytest.raises(
Expand Down

0 comments on commit 6602329

Please sign in to comment.