diff --git a/mirascope/core/mistral/_utils/_convert_message_params.py b/mirascope/core/mistral/_utils/_convert_message_params.py index 2b3065850..00e2740cf 100644 --- a/mirascope/core/mistral/_utils/_convert_message_params.py +++ b/mirascope/core/mistral/_utils/_convert_message_params.py @@ -1,5 +1,7 @@ """Utility for converting `BaseMessageParam` to `ChatMessage`.""" +from typing import Any + from mistralai.models import ( AssistantMessage, SystemMessage, @@ -10,26 +12,36 @@ from ...base import BaseMessageParam +def make_message( + role: str, + **kwargs, # noqa: ANN003 +) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + if role == "assistant": + return AssistantMessage(**kwargs) + elif role == "system": + return SystemMessage(**kwargs) + elif role == "tool": + return ToolMessage(**kwargs) + elif role == "user": + return UserMessage(**kwargs) + raise ValueError(f"Invalid role: {role}") + + def convert_message_params( message_params: list[ BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage ], -) -> list[BaseMessageParam]: +) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]: converted_message_params = [] for message_param in message_params: - if not isinstance( - message_param, - BaseMessageParam, - ): + if not isinstance(message_param, BaseMessageParam): converted_message_params.append(message_param) elif isinstance(content := message_param.content, str): - converted_message_params.append( - BaseMessageParam(**message_param.model_dump()) - ) + converted_message_params.append(make_message(**message_param.model_dump())) else: if len(content) != 1 or content[0].type != "text": raise ValueError("Mistral currently only supports text parts.") converted_message_params.append( - BaseMessageParam(role=message_param.role, content=content[0].text) + make_message(role=message_param.role, content=content[0].text) ) return converted_message_params diff --git a/mirascope/core/mistral/_utils/_setup_call.py b/mirascope/core/mistral/_utils/_setup_call.py index 10c071fe5..c8bb09689 100644 --- a/mirascope/core/mistral/_utils/_setup_call.py +++ b/mirascope/core/mistral/_utils/_setup_call.py @@ -19,10 +19,9 @@ UserMessage, ) -from mirascope.core.base._utils._protocols import fn_is_async - -from ...base import BaseMessageParam, BaseTool, _utils +from ...base import BaseTool, _utils from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn +from ...base._utils._protocols import fn_is_async from ..call_kwargs import MistralCallKwargs from ..call_params import MistralCallParams from ..dynamic_config import MistralDynamicConfig diff --git a/mirascope/core/mistral/call_response.py b/mirascope/core/mistral/call_response.py index 01218544f..dfe70dbec 100644 --- a/mirascope/core/mistral/call_response.py +++ b/mirascope/core/mistral/call_response.py @@ -3,7 +3,7 @@ usage docs: learn/calls.md#handling-responses """ -from typing import Any +from typing import Any, cast from mistralai.models import ( AssistantMessage, @@ -30,7 +30,7 @@ class MistralCallResponse( MistralDynamicConfig, AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralCallParams, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, + UserMessage, ] ): """A convenience wrapper around the Mistral `ChatCompletion` response. @@ -104,9 +104,9 @@ def cost(self) -> float | None: @property def message_param( self, - ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + ) -> AssistantMessage: """Returns the assistants's response as a message parameter.""" - return self.response.choices[0].message + return cast(AssistantMessage, self.response.choices[0].message) @computed_field @property @@ -144,7 +144,7 @@ def tool(self) -> MistralTool | None: @classmethod def tool_message_params( cls, tools_and_outputs: list[tuple[MistralTool, str]] - ) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]: + ) -> list[ToolMessage]: """Returns the tool message parameters for tool call results. Args: diff --git a/mirascope/core/mistral/stream.py b/mirascope/core/mistral/stream.py index 7fe483766..d2f0af846 100644 --- a/mirascope/core/mistral/stream.py +++ b/mirascope/core/mistral/stream.py @@ -29,9 +29,9 @@ class MistralStream( BaseStream[ MistralCallResponse, MistralCallResponseChunk, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, + UserMessage, + AssistantMessage, + ToolMessage, AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralTool, dict[str, Any], @@ -69,7 +69,7 @@ def cost(self) -> float | None: def _construct_message_param( self, tool_calls: list | None = None, content: str | None = None - ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + ) -> AssistantMessage: message_param = AssistantMessage( content=content if content else "", tool_calls=tool_calls ) diff --git a/tests/core/mistral/_utils/test_convert_message_params.py b/tests/core/mistral/_utils/test_convert_message_params.py index bf35f6c2b..eefbd0f49 100644 --- a/tests/core/mistral/_utils/test_convert_message_params.py +++ b/tests/core/mistral/_utils/test_convert_message_params.py @@ -25,8 +25,8 @@ def test_convert_message_params() -> None: converted_message_params = convert_message_params(message_params) assert converted_message_params == [ UserMessage(content="Hello"), - BaseMessageParam(role="user", content="Hello"), - BaseMessageParam(role="user", content="Hello"), + UserMessage(role="user", content="Hello"), + UserMessage(role="user", content="Hello"), ] with pytest.raises(