From 5fd0b7e28a3f5987ea60b006c25e37907af7ef76 Mon Sep 17 00:00:00 2001 From: Skylar Payne Date: Sun, 29 Sep 2024 10:59:19 -0700 Subject: [PATCH 1/4] Bump mistralai to > 1.0.0 in preparation for latest models such as Pixtral Related to #521 --- .../basic_usage/mistral/official_sdk_call.py | 5 +- .../mistral/base_message_param.py | 17 +++++- .../calls/custom_client/mistral/messages.py | 17 +++++- .../calls/custom_client/mistral/shorthand.py | 17 +++++- .../custom_client/mistral/string_template.py | 17 +++++- .../basic_usage/mistral/official_sdk.py | 8 ++- mirascope/core/mistral/__init__.py | 11 +++- .../mistral/_utils/_convert_message_params.py | 24 ++++++-- .../core/mistral/_utils/_handle_stream.py | 17 +++--- mirascope/core/mistral/_utils/_setup_call.py | 61 ++++++++++--------- mirascope/core/mistral/call_kwargs.py | 9 ++- mirascope/core/mistral/call_params.py | 2 +- mirascope/core/mistral/call_response.py | 23 ++++--- mirascope/core/mistral/call_response_chunk.py | 7 +-- mirascope/core/mistral/dynamic_config.py | 10 ++- mirascope/core/mistral/stream.py | 27 ++++---- mirascope/core/mistral/tool.py | 2 +- pyproject.toml | 2 +- .../_utils/test_convert_message_params.py | 19 ++++-- .../mistral/_utils/test_get_json_output.py | 30 +++++---- .../core/mistral/_utils/test_handle_stream.py | 44 +++++++------ tests/core/mistral/_utils/test_setup_call.py | 50 ++++++++------- tests/core/mistral/test_call_response.py | 33 +++++----- .../core/mistral/test_call_response_chunk.py | 20 +++--- tests/core/mistral/test_stream.py | 55 ++++++++--------- tests/core/mistral/test_tool.py | 4 +- 26 files changed, 307 insertions(+), 224 deletions(-) diff --git a/examples/learn/calls/basic_usage/mistral/official_sdk_call.py b/examples/learn/calls/basic_usage/mistral/official_sdk_call.py index 15a965fdb..b14fb8227 100644 --- a/examples/learn/calls/basic_usage/mistral/official_sdk_call.py +++ b/examples/learn/calls/basic_usage/mistral/official_sdk_call.py @@ -1,6 +1,7 @@ -from mistralai.client import MistralClient +from mistralai import Mistral +import os -client = MistralClient() +client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) def recommend_book(genre: str) -> str: diff --git a/examples/learn/calls/custom_client/mistral/base_message_param.py b/examples/learn/calls/custom_client/mistral/base_message_param.py index 3094488ab..01fa5181c 100644 --- a/examples/learn/calls/custom_client/mistral/base_message_param.py +++ b/examples/learn/calls/custom_client/mistral/base_message_param.py @@ -1,7 +1,20 @@ +import os + from mirascope.core import BaseMessageParam, mistral -from mistralai.client import MistralClient +from mistralai import Mistral -@mistral.call("mistral-large-latest", client=MistralClient()) +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) def recommend_book(genre: str) -> list[BaseMessageParam]: return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")] + + +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) +async def recommend_book_async(genre: str) -> list[BaseMessageParam]: + return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")] diff --git a/examples/learn/calls/custom_client/mistral/messages.py b/examples/learn/calls/custom_client/mistral/messages.py index d305324ac..569ca229a 100644 --- a/examples/learn/calls/custom_client/mistral/messages.py +++ b/examples/learn/calls/custom_client/mistral/messages.py @@ -1,7 +1,20 @@ +import os + from mirascope.core import Messages, mistral -from mistralai.client import MistralClient +from mistralai import Mistral -@mistral.call("mistral-large-latest", client=MistralClient()) +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) def recommend_book(genre: str) -> Messages.Type: return Messages.User(f"Recommend a {genre} book") + + +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) +async def recommend_book_async(genre: str) -> Messages.Type: + return Messages.User(f"Recommend a {genre} book") diff --git a/examples/learn/calls/custom_client/mistral/shorthand.py b/examples/learn/calls/custom_client/mistral/shorthand.py index 276702ba3..38571b32b 100644 --- a/examples/learn/calls/custom_client/mistral/shorthand.py +++ b/examples/learn/calls/custom_client/mistral/shorthand.py @@ -1,7 +1,20 @@ +import os + from mirascope.core import mistral -from mistralai.client import MistralClient +from mistralai import Mistral -@mistral.call("mistral-large-latest", client=MistralClient()) +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) def recommend_book(genre: str) -> str: return f"Recommend a {genre} book" + + +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) +async def recommend_book_async(genre: str) -> str: + return f"Recommend a {genre} book" diff --git a/examples/learn/calls/custom_client/mistral/string_template.py b/examples/learn/calls/custom_client/mistral/string_template.py index f63c6427c..8743a1a60 100644 --- a/examples/learn/calls/custom_client/mistral/string_template.py +++ b/examples/learn/calls/custom_client/mistral/string_template.py @@ -1,7 +1,20 @@ +import os + from mirascope.core import mistral, prompt_template -from mistralai.client import MistralClient +from mistralai import Mistral -@mistral.call("mistral-large-latest", client=MistralClient()) +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) @prompt_template("Recommend a {genre} book") def recommend_book(genre: str): ... + + +@mistral.call( + "mistral-large-latest", + client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), +) +@prompt_template("Recommend a {genre} book") +async def recommend_book_async(genre: str): ... diff --git a/examples/learn/response_models/basic_usage/mistral/official_sdk.py b/examples/learn/response_models/basic_usage/mistral/official_sdk.py index aa6f717d5..7131d46dd 100644 --- a/examples/learn/response_models/basic_usage/mistral/official_sdk.py +++ b/examples/learn/response_models/basic_usage/mistral/official_sdk.py @@ -1,8 +1,10 @@ -from mistralai.client import MistralClient -from mistralai.models.chat_completion import ToolChoice +import os + +from mistralai.client import Mistral +from mistralai.models import ToolChoice from pydantic import BaseModel -client = MistralClient() +client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) class Book(BaseModel): diff --git a/mirascope/core/mistral/__init__.py b/mirascope/core/mistral/__init__.py index 99cf81737..39ce1b4ed 100644 --- a/mirascope/core/mistral/__init__.py +++ b/mirascope/core/mistral/__init__.py @@ -2,7 +2,12 @@ from typing import TypeAlias -from mistralai.models.chat_completion import ChatMessage +from mistralai.models import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) from ..base import BaseMessageParam from ._call import mistral_call @@ -14,7 +19,9 @@ from .stream import MistralStream from .tool import MistralTool -MistralMessageParam: TypeAlias = ChatMessage | BaseMessageParam +MistralMessageParam: TypeAlias = ( + AssistantMessage | SystemMessage | ToolMessage | UserMessage | BaseMessageParam +) __all__ = [ "call", diff --git a/mirascope/core/mistral/_utils/_convert_message_params.py b/mirascope/core/mistral/_utils/_convert_message_params.py index 2d102b6c6..2b3065850 100644 --- a/mirascope/core/mistral/_utils/_convert_message_params.py +++ b/mirascope/core/mistral/_utils/_convert_message_params.py @@ -1,23 +1,35 @@ """Utility for converting `BaseMessageParam` to `ChatMessage`.""" -from mistralai.models.chat_completion import ChatMessage +from mistralai.models import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) from ...base import BaseMessageParam def convert_message_params( - message_params: list[BaseMessageParam | ChatMessage], -) -> list[ChatMessage]: + message_params: list[ + BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage + ], +) -> list[BaseMessageParam]: converted_message_params = [] for message_param in message_params: - if isinstance(message_param, ChatMessage): + if not isinstance( + message_param, + BaseMessageParam, + ): converted_message_params.append(message_param) elif isinstance(content := message_param.content, str): - converted_message_params.append(ChatMessage(**message_param.model_dump())) + converted_message_params.append( + BaseMessageParam(**message_param.model_dump()) + ) else: if len(content) != 1 or content[0].type != "text": raise ValueError("Mistral currently only supports text parts.") converted_message_params.append( - ChatMessage(role=message_param.role, content=content[0].text) + BaseMessageParam(role=message_param.role, content=content[0].text) ) return converted_message_params diff --git a/mirascope/core/mistral/_utils/_handle_stream.py b/mirascope/core/mistral/_utils/_handle_stream.py index 60994ab37..9a8f643f7 100644 --- a/mirascope/core/mistral/_utils/_handle_stream.py +++ b/mirascope/core/mistral/_utils/_handle_stream.py @@ -2,11 +2,10 @@ from collections.abc import AsyncGenerator, Generator -from mistralai.models.chat_completion import ( - ChatCompletionStreamResponse, +from mistralai.models import ( + CompletionEvent, FunctionCall, ToolCall, - ToolType, ) from ..call_response_chunk import MistralCallResponseChunk @@ -14,7 +13,7 @@ def _handle_chunk( - chunk: ChatCompletionStreamResponse, + chunk: CompletionEvent, current_tool_call: ToolCall, current_tool_type: type[MistralTool] | None, tool_types: list[type[MistralTool]] | None, @@ -38,7 +37,7 @@ def _handle_chunk( arguments="", name=tool_call.function.name if tool_call.function.name else "", ), - type=ToolType.function, + type="function", ) current_tool_type = None for tool_type in tool_types: @@ -64,12 +63,12 @@ def _handle_chunk( def handle_stream( - stream: Generator[ChatCompletionStreamResponse, None, None], + stream: Generator[CompletionEvent, None, None], tool_types: list[type[MistralTool]] | None, ) -> Generator[tuple[MistralCallResponseChunk, MistralTool | None], None, None]: """Iterator over the stream and constructs tools as they are streamed.""" current_tool_call = ToolCall( - id="", function=FunctionCall(arguments="", name=""), type=ToolType.function + id="", function=FunctionCall(arguments="", name=""), type="function" ) current_tool_type = None for chunk in stream: @@ -93,12 +92,12 @@ def handle_stream( async def handle_stream_async( - stream: AsyncGenerator[ChatCompletionStreamResponse, None], + stream: AsyncGenerator[CompletionEvent, None], tool_types: list[type[MistralTool]] | None, ) -> AsyncGenerator[tuple[MistralCallResponseChunk, MistralTool | None], None]: """Async iterator over the stream and constructs tools as they are streamed.""" current_tool_call = ToolCall( - id="", function=FunctionCall(arguments="", name=""), type=ToolType.function + id="", function=FunctionCall(arguments="", name=""), type="function" ) current_tool_type = None async for chunk in stream: diff --git a/mirascope/core/mistral/_utils/_setup_call.py b/mirascope/core/mistral/_utils/_setup_call.py index 289b6d0df..10c071fe5 100644 --- a/mirascope/core/mistral/_utils/_setup_call.py +++ b/mirascope/core/mistral/_utils/_setup_call.py @@ -1,23 +1,26 @@ """This module contains the setup_call function for Mistral tools.""" -import inspect +import os from collections.abc import ( Awaitable, Callable, ) from typing import Any, cast, overload -from mistralai.async_client import MistralAsyncClient -from mistralai.client import MistralClient -from mistralai.models.chat_completion import ( +from mistralai import Mistral +from mistralai.models import ( + AssistantMessage, ChatCompletionResponse, - ChatCompletionStreamResponse, - ChatMessage, + CompletionEvent, ResponseFormat, - ResponseFormats, - ToolChoice, + SystemMessage, + ToolChoiceEnum, + ToolMessage, + UserMessage, ) +from mirascope.core.base._utils._protocols import fn_is_async + from ...base import BaseMessageParam, BaseTool, _utils from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn from ..call_kwargs import MistralCallKwargs @@ -31,7 +34,7 @@ def setup_call( *, model: str, - client: MistralAsyncClient | None, + client: Mistral | None, fn: Callable[..., Awaitable[MistralDynamicConfig]], fn_args: dict[str, Any], dynamic_config: MistralDynamicConfig, @@ -40,9 +43,9 @@ def setup_call( call_params: MistralCallParams, extract: bool, ) -> tuple[ - AsyncCreateFn[ChatCompletionResponse, ChatCompletionStreamResponse], + AsyncCreateFn[ChatCompletionResponse, CompletionEvent], str | None, - list[ChatMessage], + list[AssistantMessage | SystemMessage | ToolMessage | UserMessage], list[type[MistralTool]] | None, MistralCallKwargs, ]: ... @@ -52,7 +55,7 @@ def setup_call( def setup_call( *, model: str, - client: MistralClient | None, + client: Mistral | None, fn: Callable[..., MistralDynamicConfig], fn_args: dict[str, Any], dynamic_config: MistralDynamicConfig, @@ -61,9 +64,9 @@ def setup_call( call_params: MistralCallParams, extract: bool, ) -> tuple[ - CreateFn[ChatCompletionResponse, ChatCompletionStreamResponse], + CreateFn[ChatCompletionResponse, CompletionEvent], str | None, - list[ChatMessage], + list[AssistantMessage | SystemMessage | ToolMessage | UserMessage], list[type[MistralTool]] | None, MistralCallKwargs, ]: ... @@ -72,7 +75,7 @@ def setup_call( def setup_call( *, model: str, - client: MistralClient | MistralAsyncClient | None, + client: Mistral | None, fn: Callable[..., MistralDynamicConfig | Awaitable[MistralDynamicConfig]], fn_args: dict[str, Any], dynamic_config: MistralDynamicConfig, @@ -81,10 +84,10 @@ def setup_call( call_params: MistralCallParams, extract: bool, ) -> tuple[ - CreateFn[ChatCompletionResponse, ChatCompletionStreamResponse] - | AsyncCreateFn[ChatCompletionResponse, ChatCompletionStreamResponse], + CreateFn[ChatCompletionResponse, CompletionEvent] + | AsyncCreateFn[ChatCompletionResponse, CompletionEvent], str | None, - list[ChatMessage], + list[AssistantMessage | SystemMessage | ToolMessage | UserMessage], list[type[MistralTool]] | None, MistralCallKwargs, ]: @@ -92,31 +95,31 @@ def setup_call( fn, fn_args, dynamic_config, tools, MistralTool, call_params ) call_kwargs = cast(MistralCallKwargs, base_call_kwargs) - messages = cast(list[BaseMessageParam | ChatMessage], messages) + messages = cast( + list[AssistantMessage | SystemMessage | ToolMessage | UserMessage], messages + ) messages = convert_message_params(messages) if json_mode: - call_kwargs["response_format"] = ResponseFormat( - type=ResponseFormats("json_object") - ) + call_kwargs["response_format"] = ResponseFormat(type="json_object") json_mode_content = _utils.json_mode_content( tool_types[0] if tool_types else None ) if messages[-1].role == "user": messages[-1].content += json_mode_content else: - messages.append(ChatMessage(role="user", content=json_mode_content.strip())) + messages.append(UserMessage(content=json_mode_content.strip())) call_kwargs.pop("tools", None) elif extract: assert tool_types, "At least one tool must be provided for extraction." - call_kwargs["tool_choice"] = cast(ToolChoice, ToolChoice.any) + call_kwargs["tool_choice"] = cast(ToolChoiceEnum, "any") call_kwargs |= {"model": model, "messages": messages} if client is None: - client = ( - MistralAsyncClient() if inspect.iscoroutinefunction(fn) else MistralClient() + client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) + if fn_is_async(fn): + create_or_stream = get_async_create_fn( + client.chat.complete_async, client.chat.stream_async ) - if isinstance(client, MistralAsyncClient): - create_or_stream = get_async_create_fn(client.chat, client.chat_stream) else: - create_or_stream = get_create_fn(client.chat, client.chat_stream) + create_or_stream = get_create_fn(client.chat.complete, client.chat.stream) return create_or_stream, prompt_template, messages, tool_types, call_kwargs diff --git a/mirascope/core/mistral/call_kwargs.py b/mirascope/core/mistral/call_kwargs.py index 69fd61374..a235c0e83 100644 --- a/mirascope/core/mistral/call_kwargs.py +++ b/mirascope/core/mistral/call_kwargs.py @@ -2,7 +2,12 @@ from typing import Any -from mistralai.models.chat_completion import ChatMessage +from mistralai.models import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) from ..base import BaseCallKwargs from .call_params import MistralCallParams @@ -10,4 +15,4 @@ class MistralCallKwargs(MistralCallParams, BaseCallKwargs[dict[str, Any]]): model: str - messages: list[ChatMessage] + messages: list[AssistantMessage | SystemMessage | ToolMessage | UserMessage] diff --git a/mirascope/core/mistral/call_params.py b/mirascope/core/mistral/call_params.py index 4481c0017..1fd0d3422 100644 --- a/mirascope/core/mistral/call_params.py +++ b/mirascope/core/mistral/call_params.py @@ -2,7 +2,7 @@ from __future__ import annotations -from mistralai.models.chat_completion import ResponseFormat, ToolChoice +from mistralai.models import ResponseFormat, ToolChoice from typing_extensions import NotRequired from ..base import BaseCallParams diff --git a/mirascope/core/mistral/call_response.py b/mirascope/core/mistral/call_response.py index 745b30e24..01218544f 100644 --- a/mirascope/core/mistral/call_response.py +++ b/mirascope/core/mistral/call_response.py @@ -5,8 +5,14 @@ from typing import Any -from mistralai.models.chat_completion import ChatCompletionResponse, ChatMessage -from mistralai.models.common import UsageInfo +from mistralai.models import ( + AssistantMessage, + ChatCompletionResponse, + SystemMessage, + ToolMessage, + UsageInfo, + UserMessage, +) from pydantic import computed_field from ..base import BaseCallResponse @@ -22,9 +28,9 @@ class MistralCallResponse( MistralTool, dict[str, Any], MistralDynamicConfig, - ChatMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralCallParams, - ChatMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, ] ): """A convenience wrapper around the Mistral `ChatCompletion` response. @@ -96,7 +102,9 @@ def cost(self) -> float | None: @computed_field @property - def message_param(self) -> ChatMessage: + def message_param( + self, + ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: """Returns the assistants's response as a message parameter.""" return self.response.choices[0].message @@ -136,7 +144,7 @@ def tool(self) -> MistralTool | None: @classmethod def tool_message_params( cls, tools_and_outputs: list[tuple[MistralTool, str]] - ) -> list[ChatMessage]: + ) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]: """Returns the tool message parameters for tool call results. Args: @@ -147,8 +155,7 @@ def tool_message_params( The list of constructed `ChatMessage` parameters. """ return [ - ChatMessage( - role="tool", + ToolMessage( content=output, tool_call_id=tool.tool_call.id, name=tool._name(), diff --git a/mirascope/core/mistral/call_response_chunk.py b/mirascope/core/mistral/call_response_chunk.py index a496a5f95..075110583 100644 --- a/mirascope/core/mistral/call_response_chunk.py +++ b/mirascope/core/mistral/call_response_chunk.py @@ -3,15 +3,12 @@ usage docs: learn/streams.md#handling-streamed-responses """ -from mistralai.models.chat_completion import ChatCompletionStreamResponse, FinishReason -from mistralai.models.common import UsageInfo +from mistralai.models import CompletionChunk, FinishReason, UsageInfo from ..base import BaseCallResponseChunk -class MistralCallResponseChunk( - BaseCallResponseChunk[ChatCompletionStreamResponse, FinishReason] -): +class MistralCallResponseChunk(BaseCallResponseChunk[CompletionChunk, FinishReason]): """A convenience wrapper around the Mistral `ChatCompletionChunk` streamed chunks. When calling the Mistral API using a function decorated with `mistral_call` and diff --git a/mirascope/core/mistral/dynamic_config.py b/mirascope/core/mistral/dynamic_config.py index 3f7f55e07..672c2c15a 100644 --- a/mirascope/core/mistral/dynamic_config.py +++ b/mirascope/core/mistral/dynamic_config.py @@ -1,12 +1,18 @@ """This module defines the function return type for functions as LLM calls.""" -from mistralai.models.chat_completion import ChatMessage +from mistralai.models import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) from ..base import BaseDynamicConfig, BaseMessageParam from .call_params import MistralCallParams MistralDynamicConfig = BaseDynamicConfig[ - ChatMessage | BaseMessageParam, MistralCallParams + AssistantMessage | SystemMessage | ToolMessage | UserMessage | BaseMessageParam, + MistralCallParams, ] """The function return type for functions wrapped with the `mistral_call` decorator. diff --git a/mirascope/core/mistral/stream.py b/mirascope/core/mistral/stream.py index 85cc8e17c..7fe483766 100644 --- a/mirascope/core/mistral/stream.py +++ b/mirascope/core/mistral/stream.py @@ -5,13 +5,16 @@ from typing import Any -from mistralai.models.chat_completion import ( +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatMessage, FinishReason, + SystemMessage, + ToolMessage, + UsageInfo, + UserMessage, ) -from mistralai.models.common import UsageInfo from ..base.stream import BaseStream from ._utils import calculate_cost @@ -26,10 +29,10 @@ class MistralStream( BaseStream[ MistralCallResponse, MistralCallResponseChunk, - ChatMessage, - ChatMessage, - ChatMessage, - ChatMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, + AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralTool, dict[str, Any], MistralDynamicConfig, @@ -66,9 +69,9 @@ def cost(self) -> float | None: def _construct_message_param( self, tool_calls: list | None = None, content: str | None = None - ) -> ChatMessage: - message_param = ChatMessage( - role="assistant", content=content if content else "", tool_calls=tool_calls + ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + message_param = AssistantMessage( + content=content if content else "", tool_calls=tool_calls ) return message_param @@ -90,7 +93,7 @@ def construct_call_response(self) -> MistralCallResponse: completion = ChatCompletionResponse( id=self.id if self.id else "", choices=[ - ChatCompletionResponseChoice( + ChatCompletionChoice( finish_reason=self.finish_reasons[0] if self.finish_reasons else None, diff --git a/mirascope/core/mistral/tool.py b/mirascope/core/mistral/tool.py index 15d9473b7..a789c4724 100644 --- a/mirascope/core/mistral/tool.py +++ b/mirascope/core/mistral/tool.py @@ -8,7 +8,7 @@ from typing import Any import jiter -from mistralai.models.chat_completion import ToolCall +from mistralai.models import ToolCall from pydantic.json_schema import SkipJsonSchema from ..base import BaseTool diff --git a/pyproject.toml b/pyproject.toml index 786490feb..fe3865d0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ hyperdx = ["hyperdx-opentelemetry>=0.1.0,<1"] langfuse = ["langfuse>=2.30.0,<3"] litellm = ["litellm>=1.41.4,<2"] logfire = ["logfire>=0.41.0,<2"] -mistral = ["mistralai>=0.4.2,<1"] +mistral = ["mistralai>=1.0.0,<2"] openai = ["openai>=1.6.0,<2"] opentelemetry = ["opentelemetry-api>=1.22.0,<2", "opentelemetry-sdk>=1.22.0,<2"] vertex = ["google-cloud-aiplatform>=1.38.0,<2"] diff --git a/tests/core/mistral/_utils/test_convert_message_params.py b/tests/core/mistral/_utils/test_convert_message_params.py index 297836eeb..bf35f6c2b 100644 --- a/tests/core/mistral/_utils/test_convert_message_params.py +++ b/tests/core/mistral/_utils/test_convert_message_params.py @@ -1,7 +1,12 @@ """Tests the `mistral._utils.convert_message_params` function.""" import pytest -from mistralai.models.chat_completion import ChatMessage +from mistralai.models import ( + AssistantMessage, + SystemMessage, + ToolMessage, + UserMessage, +) from mirascope.core.base import AudioPart, BaseMessageParam, ImagePart, TextPart from mirascope.core.mistral._utils._convert_message_params import convert_message_params @@ -10,16 +15,18 @@ def test_convert_message_params() -> None: """Tests the `convert_message_params` function.""" - message_params: list[BaseMessageParam | ChatMessage] = [ - ChatMessage(role="user", content="Hello"), + message_params: list[ + BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage + ] = [ + UserMessage(content="Hello"), BaseMessageParam(role="user", content="Hello"), BaseMessageParam(role="user", content=[TextPart(type="text", text="Hello")]), ] converted_message_params = convert_message_params(message_params) assert converted_message_params == [ - ChatMessage(role="user", content="Hello"), - ChatMessage(role="user", content="Hello"), - ChatMessage(role="user", content="Hello"), + UserMessage(content="Hello"), + BaseMessageParam(role="user", content="Hello"), + BaseMessageParam(role="user", content="Hello"), ] with pytest.raises( diff --git a/tests/core/mistral/_utils/test_get_json_output.py b/tests/core/mistral/_utils/test_get_json_output.py index 9b7739fac..9b1ca8c35 100644 --- a/tests/core/mistral/_utils/test_get_json_output.py +++ b/tests/core/mistral/_utils/test_get_json_output.py @@ -1,18 +1,17 @@ """Tests the `mistral._utils.get_json_output` module.""" import pytest -from mistralai.models.chat_completion import ( +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, + CompletionChunk, + CompletionResponseStreamChoice, DeltaMessage, FunctionCall, ToolCall, - ToolType, + UsageInfo, ) -from mistralai.models.common import UsageInfo from mirascope.core.mistral._utils._get_json_output import get_json_output from mirascope.core.mistral.call_response import MistralCallResponse @@ -21,21 +20,20 @@ def test_get_json_output_call_response() -> None: """Tests the `get_json_output` function with a call response.""" + tool_call = ToolCall( id="id", function=FunctionCall( name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) choices = [ - ChatCompletionResponseChoice( + ChatCompletionChoice( index=0, - message=ChatMessage( - role="assistant", content="json_output", tool_calls=[tool_call] - ), - finish_reason=None, + message=AssistantMessage(content="json_output", tool_calls=[tool_call]), + finish_reason="stop", ) ] completion = ChatCompletionResponse( @@ -82,16 +80,16 @@ def test_get_json_output_call_response_chunk() -> None: arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', name="function", ), - type=ToolType.function, + type="function", ) choices = [ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage(content="json_output", tool_calls=[tool_call]), finish_reason=None, ) ] - chunk = ChatCompletionStreamResponse( + chunk = CompletionChunk( id="id", model="mistral-large-latest", choices=choices, diff --git a/tests/core/mistral/_utils/test_handle_stream.py b/tests/core/mistral/_utils/test_handle_stream.py index 8f18e8594..46d4500c4 100644 --- a/tests/core/mistral/_utils/test_handle_stream.py +++ b/tests/core/mistral/_utils/test_handle_stream.py @@ -1,14 +1,12 @@ """Tests the `mistral._utils.handle_stream` module.""" import pytest -from mistralai.models.chat_completion import ( - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, +from mistralai.models import ( + CompletionChunk, + CompletionResponseStreamChoice, DeltaMessage, - FinishReason, FunctionCall, ToolCall, - ToolType, ) from mirascope.core.mistral._utils._handle_stream import ( @@ -29,7 +27,7 @@ def call(self) -> None: @pytest.fixture() -def mock_chunks() -> list[ChatCompletionStreamResponse]: +def mock_chunks() -> list[CompletionChunk]: """Returns a list of mock `ChatCompletionStreamResponse` instances.""" new_tool_call = ToolCall( @@ -38,7 +36,7 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: arguments="", name="FormatBook", ), - type=ToolType.function, + type="function", ) tool_call = ToolCall( id="null", @@ -46,13 +44,13 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', name="FormatBook", ), - type=ToolType.function, + type="function", ) return [ - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage(content="content", tool_calls=None), finish_reason=None, @@ -62,10 +60,10 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, @@ -78,10 +76,10 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, @@ -94,10 +92,10 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, @@ -110,10 +108,10 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, @@ -126,13 +124,13 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=None, tool_calls=None), - finish_reason=FinishReason.tool_calls, + finish_reason="tool_calls", ) ], created=0, @@ -142,7 +140,7 @@ def mock_chunks() -> list[ChatCompletionStreamResponse]: ] -def test_handle_stream(mock_chunks: list[ChatCompletionStreamResponse]) -> None: +def test_handle_stream(mock_chunks: list[CompletionChunk]) -> None: """Tests the `handle_stream` function.""" result = list(handle_stream((c for c in mock_chunks), tool_types=[FormatBook])) @@ -166,7 +164,7 @@ def test_handle_stream(mock_chunks: list[ChatCompletionStreamResponse]) -> None: @pytest.mark.asyncio async def test_handle_stream_async( - mock_chunks: list[ChatCompletionStreamResponse], + mock_chunks: list[CompletionChunk], ) -> None: """Tests the `handle_stream_async` function.""" diff --git a/tests/core/mistral/_utils/test_setup_call.py b/tests/core/mistral/_utils/test_setup_call.py index 3be9fb7ad..a4ab0f232 100644 --- a/tests/core/mistral/_utils/test_setup_call.py +++ b/tests/core/mistral/_utils/test_setup_call.py @@ -3,12 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from mistralai.async_client import MistralAsyncClient -from mistralai.models.chat_completion import ( +from mistralai import Chat, Mistral +from mistralai.models import ( + AssistantMessage, ChatCompletionResponse, - ChatCompletionStreamResponse, - ChatMessage, - ToolChoice, + CompletionChunk, + UserMessage, ) from mirascope.core.mistral._utils._setup_call import setup_call @@ -24,11 +24,7 @@ def mock_base_setup_call() -> MagicMock: @patch( - "mirascope.core.mistral._utils._setup_call.MistralClient.chat_stream", - return_value=MagicMock(), -) -@patch( - "mirascope.core.mistral._utils._setup_call.MistralClient.chat", + "mirascope.core.mistral._utils._setup_call.Mistral", return_value=MagicMock(), ) @patch( @@ -39,19 +35,19 @@ def mock_base_setup_call() -> MagicMock: def test_setup_call( mock_utils: MagicMock, mock_convert_message_params: MagicMock, - mock_mistral_chat: MagicMock, - mock_mistral_chat_stream: MagicMock, + mock_mistral: MagicMock, mock_base_setup_call: MagicMock, ) -> None: """Tests the `setup_call` function.""" mock_utils.setup_call = mock_base_setup_call mock_chat_iterator = MagicMock() mock_chat_iterator.__iter__.return_value = ["chat"] - mock_mistral_chat_stream.return_value = mock_chat_iterator + mock_mistral.chat = MagicMock() + mock_mistral.chat.stream.return_value = mock_chat_iterator fn = MagicMock() create, prompt_template, messages, tool_types, call_kwargs = setup_call( model="mistral-large-latest", - client=None, + client=mock_mistral, fn=fn, fn_args={}, dynamic_config=None, @@ -70,9 +66,9 @@ def test_setup_call( ) assert messages == mock_convert_message_params.return_value create(stream=False, **call_kwargs) - mock_mistral_chat.assert_called_once_with(**call_kwargs) + mock_mistral.chat.complete.assert_called_once_with(**call_kwargs) stream = create(stream=True, **call_kwargs) - mock_mistral_chat_stream.assert_called_once_with(**call_kwargs) + mock_mistral.chat.stream.assert_called_once_with(**call_kwargs) assert next(stream) == "chat" # pyright: ignore [reportArgumentType] @@ -91,7 +87,7 @@ async def test_async_setup_call( mock_mistral_chat = AsyncMock(spec=ChatCompletionResponse) mock_mistral_chat.__name__ = "chat" - mock_stream_response = AsyncMock(spec=ChatCompletionStreamResponse) + mock_stream_response = AsyncMock(spec=CompletionChunk) mock_stream_response.text = "chat" class AsyncMockIterator: @@ -109,13 +105,16 @@ async def __anext__(self): mock_iterator = AsyncMockIterator([mock_stream_response]) - mock_client = AsyncMock(spec=MistralAsyncClient, name="mock_client") - mock_client.chat_stream.return_value = mock_iterator - mock_client.chat.return_value = mock_mistral_chat + mock_client = MagicMock(spec=Mistral, name="mock_client") + mock_client.chat = MagicMock(spec=Chat) + mock_client.chat.stream_async = AsyncMock() + mock_client.chat.stream_async.return_value = mock_iterator + mock_client.chat.complete_async = AsyncMock() + mock_client.chat.complete_async.return_value = mock_mistral_chat mock_utils.setup_call = mock_base_setup_call - fn = MagicMock() + fn = AsyncMock() create, prompt_template, messages, tool_types, call_kwargs = setup_call( model="mistral-large-latest", client=mock_client, @@ -136,7 +135,6 @@ async def __anext__(self): mock_base_setup_call.return_value[1] ) - mock_mistral_chat.return_value = MagicMock(spec=ChatCompletionResponse) chat = await create(stream=False, **call_kwargs) stream = await create(stream=True, **call_kwargs) result = [] @@ -162,7 +160,7 @@ def test_setup_call_json_mode( mock_json_mode_content = MagicMock() mock_json_mode_content.return_value = "\n\njson_mode_content" mock_utils.json_mode_content = mock_json_mode_content - mock_base_setup_call.return_value[1] = [ChatMessage(role="user", content="test")] + mock_base_setup_call.return_value[1] = [UserMessage(content="test")] mock_base_setup_call.return_value[-1]["tools"] = MagicMock() mock_convert_message_params.side_effect = lambda x: x _, _, messages, _, call_kwargs = setup_call( @@ -180,7 +178,7 @@ def test_setup_call_json_mode( assert "tools" not in call_kwargs mock_base_setup_call.return_value[1] = [ - ChatMessage(role="assistant", content="test"), + AssistantMessage(content="test"), ] _, _, messages, _, call_kwargs = setup_call( model="mistral-large-latest", @@ -193,7 +191,7 @@ def test_setup_call_json_mode( call_params={}, extract=False, ) - assert messages[-1] == ChatMessage(role="user", content="json_mode_content") + assert messages[-1] == UserMessage(content="json_mode_content") @patch( @@ -219,4 +217,4 @@ def test_setup_call_extract( call_params={}, extract=True, ) - assert "tool_choice" in call_kwargs and call_kwargs["tool_choice"] == ToolChoice.any + assert "tool_choice" in call_kwargs and call_kwargs["tool_choice"] == "any" diff --git a/tests/core/mistral/test_call_response.py b/tests/core/mistral/test_call_response.py index 03e904dad..9157f8535 100644 --- a/tests/core/mistral/test_call_response.py +++ b/tests/core/mistral/test_call_response.py @@ -1,15 +1,14 @@ """Tests the `mistral.call_response` module.""" -from mistralai.models.chat_completion import ( +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatMessage, - FinishReason, FunctionCall, ToolCall, - ToolType, + ToolMessage, + UsageInfo, ) -from mistralai.models.common import UsageInfo from mirascope.core.mistral.call_response import MistralCallResponse from mirascope.core.mistral.tool import MistralTool @@ -18,10 +17,10 @@ def test_mistral_call_response() -> None: """Tests the `MistralCallResponse` class.""" choices = [ - ChatCompletionResponseChoice( + ChatCompletionChoice( index=0, - message=ChatMessage(role="assistant", content="content"), - finish_reason=FinishReason.stop, + message=AssistantMessage(content="content"), + finish_reason="stop", ) ] usage = UsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=2) @@ -56,9 +55,7 @@ def test_mistral_call_response() -> None: assert call_response.input_tokens == 1 assert call_response.output_tokens == 1 assert call_response.cost == 1.2e-5 - assert call_response.message_param == ChatMessage( - role="assistant", content="content" - ) + assert call_response.message_param == AssistantMessage(content="content") assert call_response.tools is None assert call_response.tool is None @@ -79,17 +76,15 @@ def call(self) -> str: name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) completion = ChatCompletionResponse( id="id", choices=[ - ChatCompletionResponseChoice( - finish_reason=FinishReason.stop, + ChatCompletionChoice( + finish_reason="stop", index=0, - message=ChatMessage( - role="assistant", content="content", tool_calls=[tool_call] - ), + message=AssistantMessage(content="content", tool_calls=[tool_call]), ) ], created=0, @@ -120,7 +115,7 @@ def call(self) -> str: output = tool.call() assert output == "The Name of the Wind by Patrick Rothfuss" assert call_response.tool_message_params([(tool, output)]) == [ - ChatMessage( + ToolMessage( role="tool", content=output, tool_call_id=tool_call.id, diff --git a/tests/core/mistral/test_call_response_chunk.py b/tests/core/mistral/test_call_response_chunk.py index 586bbba94..933e3928e 100644 --- a/tests/core/mistral/test_call_response_chunk.py +++ b/tests/core/mistral/test_call_response_chunk.py @@ -1,15 +1,13 @@ """Tests the `mistral.call_response_chunk` module.""" -from mistralai.models.chat_completion import ( - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, +from mistralai.models import ( + CompletionChunk, + CompletionResponseStreamChoice, DeltaMessage, - FinishReason, FunctionCall, ToolCall, - ToolType, + UsageInfo, ) -from mistralai.models.common import UsageInfo from mirascope.core.mistral.call_response_chunk import MistralCallResponseChunk @@ -19,17 +17,17 @@ def test_mistral_call_response_chunk() -> None: tool_call = ToolCall( id="id", function=FunctionCall(name="function", arguments='{"key": "value"}'), - type=ToolType.function, + type="function", ) choices = [ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage(content="content", tool_calls=[tool_call]), - finish_reason=FinishReason.stop, + finish_reason="stop", ) ] usage = UsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=2) - chunk = ChatCompletionStreamResponse( + chunk = CompletionChunk( id="id", choices=choices, created=0, @@ -49,7 +47,7 @@ def test_mistral_call_response_chunk() -> None: def test_mistral_call_response_chunk_no_choices_or_usage() -> None: """Tests the `MistralCallResponseChunk` class with None values.""" - chunk = ChatCompletionStreamResponse( + chunk = CompletionChunk( id="id", choices=[], created=0, diff --git a/tests/core/mistral/test_stream.py b/tests/core/mistral/test_stream.py index ad6e2cef1..dbfa0854c 100644 --- a/tests/core/mistral/test_stream.py +++ b/tests/core/mistral/test_stream.py @@ -1,19 +1,17 @@ """Tests the `mistral.stream` module.""" import pytest -from mistralai.models.chat_completion import ( +from mistralai import AssistantMessage +from mistralai.models import ( + ChatCompletionChoice, ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, + CompletionChunk, + CompletionResponseStreamChoice, DeltaMessage, - FinishReason, FunctionCall, ToolCall, - ToolType, + UsageInfo, ) -from mistralai.models.common import UsageInfo from mirascope.core.mistral.call_response import MistralCallResponse from mirascope.core.mistral.call_response_chunk import MistralCallResponseChunk @@ -40,14 +38,14 @@ def call(self) -> None: name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) usage = UsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=2) chunks = [ - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( delta=DeltaMessage(content="content", tool_calls=None), index=0, finish_reason=None, @@ -57,10 +55,10 @@ def call(self) -> None: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, @@ -84,7 +82,7 @@ def generator(): tool_call = ToolCall( id="id", function=FunctionCall(**tool_calls[0].function.model_dump()), - type=ToolType.function, + type="function", ) yield ( call_response_chunk, @@ -124,12 +122,11 @@ def generator(): name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) ) assert format_book.tool_call is not None - assert stream.message_param == ChatMessage( - role="assistant", + assert stream.message_param == AssistantMessage( content="content", tool_calls=[format_book.tool_call], ) @@ -151,14 +148,14 @@ def call(self) -> None: name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) usage = UsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=2) chunks = [ - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( delta=DeltaMessage(content="content", tool_calls=None), index=0, finish_reason=None, @@ -168,16 +165,16 @@ def call(self) -> None: model="mistral-large-latest", object="chat.completion.chunk", ), - ChatCompletionStreamResponse( + CompletionChunk( id="id", choices=[ - ChatCompletionResponseStreamChoice( + CompletionResponseStreamChoice( index=0, delta=DeltaMessage( content=None, tool_calls=[tool_call_delta], ), - finish_reason=FinishReason.stop, + finish_reason="stop", ) ], created=0, @@ -195,7 +192,7 @@ def generator(): tool_call = ToolCall( id="id", function=FunctionCall(**tool_calls[0].function.model_dump()), - type=ToolType.function, + type="function", ) yield ( call_response_chunk, @@ -227,17 +224,15 @@ def generator(): name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) completion = ChatCompletionResponse( id="id", choices=[ - ChatCompletionResponseChoice( - finish_reason=FinishReason.stop, + ChatCompletionChoice( + finish_reason="stop", index=0, - message=ChatMessage( - role="assistant", content="content", tool_calls=[tool_call] - ), + message=AssistantMessage(content="content", tool_calls=[tool_call]), ) ], created=0, diff --git a/tests/core/mistral/test_tool.py b/tests/core/mistral/test_tool.py index 26719a6f4..e9a547cf7 100644 --- a/tests/core/mistral/test_tool.py +++ b/tests/core/mistral/test_tool.py @@ -1,6 +1,6 @@ """Tests the `mistral.tool` module.""" -from mistralai.models.chat_completion import FunctionCall, ToolCall, ToolType +from mistralai.models import FunctionCall, ToolCall from mirascope.core.base.tool import BaseTool from mirascope.core.mistral.tool import MistralTool @@ -24,7 +24,7 @@ def call(self) -> str: name="FormatBook", arguments='{"title": "The Name of the Wind", "author": "Patrick Rothfuss"}', ), - type=ToolType.function, + type="function", ) tool = FormatBook.from_tool_call(tool_call) From 66023291323b87dceafc129221e84219c0882722 Mon Sep 17 00:00:00 2001 From: Skylar Payne Date: Wed, 9 Oct 2024 15:21:25 -0700 Subject: [PATCH 2/4] Address PR feedback. Largely about consistent / correct type annotations --- .../mistral/_utils/_convert_message_params.py | 30 +++++++++++++------ mirascope/core/mistral/_utils/_setup_call.py | 5 ++-- mirascope/core/mistral/call_response.py | 10 +++---- mirascope/core/mistral/stream.py | 8 ++--- .../_utils/test_convert_message_params.py | 4 +-- 5 files changed, 34 insertions(+), 23 deletions(-) diff --git a/mirascope/core/mistral/_utils/_convert_message_params.py b/mirascope/core/mistral/_utils/_convert_message_params.py index 2b3065850..00e2740cf 100644 --- a/mirascope/core/mistral/_utils/_convert_message_params.py +++ b/mirascope/core/mistral/_utils/_convert_message_params.py @@ -1,5 +1,7 @@ """Utility for converting `BaseMessageParam` to `ChatMessage`.""" +from typing import Any + from mistralai.models import ( AssistantMessage, SystemMessage, @@ -10,26 +12,36 @@ from ...base import BaseMessageParam +def make_message( + role: str, + **kwargs, # noqa: ANN003 +) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + if role == "assistant": + return AssistantMessage(**kwargs) + elif role == "system": + return SystemMessage(**kwargs) + elif role == "tool": + return ToolMessage(**kwargs) + elif role == "user": + return UserMessage(**kwargs) + raise ValueError(f"Invalid role: {role}") + + def convert_message_params( message_params: list[ BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage ], -) -> list[BaseMessageParam]: +) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]: converted_message_params = [] for message_param in message_params: - if not isinstance( - message_param, - BaseMessageParam, - ): + if not isinstance(message_param, BaseMessageParam): converted_message_params.append(message_param) elif isinstance(content := message_param.content, str): - converted_message_params.append( - BaseMessageParam(**message_param.model_dump()) - ) + converted_message_params.append(make_message(**message_param.model_dump())) else: if len(content) != 1 or content[0].type != "text": raise ValueError("Mistral currently only supports text parts.") converted_message_params.append( - BaseMessageParam(role=message_param.role, content=content[0].text) + make_message(role=message_param.role, content=content[0].text) ) return converted_message_params diff --git a/mirascope/core/mistral/_utils/_setup_call.py b/mirascope/core/mistral/_utils/_setup_call.py index 10c071fe5..c8bb09689 100644 --- a/mirascope/core/mistral/_utils/_setup_call.py +++ b/mirascope/core/mistral/_utils/_setup_call.py @@ -19,10 +19,9 @@ UserMessage, ) -from mirascope.core.base._utils._protocols import fn_is_async - -from ...base import BaseMessageParam, BaseTool, _utils +from ...base import BaseTool, _utils from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn +from ...base._utils._protocols import fn_is_async from ..call_kwargs import MistralCallKwargs from ..call_params import MistralCallParams from ..dynamic_config import MistralDynamicConfig diff --git a/mirascope/core/mistral/call_response.py b/mirascope/core/mistral/call_response.py index 01218544f..dfe70dbec 100644 --- a/mirascope/core/mistral/call_response.py +++ b/mirascope/core/mistral/call_response.py @@ -3,7 +3,7 @@ usage docs: learn/calls.md#handling-responses """ -from typing import Any +from typing import Any, cast from mistralai.models import ( AssistantMessage, @@ -30,7 +30,7 @@ class MistralCallResponse( MistralDynamicConfig, AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralCallParams, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, + UserMessage, ] ): """A convenience wrapper around the Mistral `ChatCompletion` response. @@ -104,9 +104,9 @@ def cost(self) -> float | None: @property def message_param( self, - ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + ) -> AssistantMessage: """Returns the assistants's response as a message parameter.""" - return self.response.choices[0].message + return cast(AssistantMessage, self.response.choices[0].message) @computed_field @property @@ -144,7 +144,7 @@ def tool(self) -> MistralTool | None: @classmethod def tool_message_params( cls, tools_and_outputs: list[tuple[MistralTool, str]] - ) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]: + ) -> list[ToolMessage]: """Returns the tool message parameters for tool call results. Args: diff --git a/mirascope/core/mistral/stream.py b/mirascope/core/mistral/stream.py index 7fe483766..d2f0af846 100644 --- a/mirascope/core/mistral/stream.py +++ b/mirascope/core/mistral/stream.py @@ -29,9 +29,9 @@ class MistralStream( BaseStream[ MistralCallResponse, MistralCallResponseChunk, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, - AssistantMessage | SystemMessage | ToolMessage | UserMessage, + UserMessage, + AssistantMessage, + ToolMessage, AssistantMessage | SystemMessage | ToolMessage | UserMessage, MistralTool, dict[str, Any], @@ -69,7 +69,7 @@ def cost(self) -> float | None: def _construct_message_param( self, tool_calls: list | None = None, content: str | None = None - ) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage: + ) -> AssistantMessage: message_param = AssistantMessage( content=content if content else "", tool_calls=tool_calls ) diff --git a/tests/core/mistral/_utils/test_convert_message_params.py b/tests/core/mistral/_utils/test_convert_message_params.py index bf35f6c2b..eefbd0f49 100644 --- a/tests/core/mistral/_utils/test_convert_message_params.py +++ b/tests/core/mistral/_utils/test_convert_message_params.py @@ -25,8 +25,8 @@ def test_convert_message_params() -> None: converted_message_params = convert_message_params(message_params) assert converted_message_params == [ UserMessage(content="Hello"), - BaseMessageParam(role="user", content="Hello"), - BaseMessageParam(role="user", content="Hello"), + UserMessage(role="user", content="Hello"), + UserMessage(role="user", content="Hello"), ] with pytest.raises( From 63bcfb872b42ee2612c812bf2ed448a1a00cf3e1 Mon Sep 17 00:00:00 2001 From: Skylar Payne Date: Wed, 9 Oct 2024 15:31:37 -0700 Subject: [PATCH 3/4] Add a util for loading Mistral API key --- .../learn/calls/basic_usage/mistral/official_sdk_call.py | 4 ++-- .../calls/custom_client/mistral/base_message_param.py | 6 ++---- examples/learn/calls/custom_client/mistral/messages.py | 6 ++---- examples/learn/calls/custom_client/mistral/shorthand.py | 6 ++---- .../learn/calls/custom_client/mistral/string_template.py | 6 ++---- .../response_models/basic_usage/mistral/official_sdk.py | 7 +++---- mirascope/core/mistral/__init__.py | 2 ++ mirascope/core/mistral/_utils/_load_api_key.py | 6 ++++++ mirascope/core/mistral/_utils/_setup_call.py | 4 ++-- 9 files changed, 23 insertions(+), 24 deletions(-) create mode 100644 mirascope/core/mistral/_utils/_load_api_key.py diff --git a/examples/learn/calls/basic_usage/mistral/official_sdk_call.py b/examples/learn/calls/basic_usage/mistral/official_sdk_call.py index b14fb8227..bf6495a17 100644 --- a/examples/learn/calls/basic_usage/mistral/official_sdk_call.py +++ b/examples/learn/calls/basic_usage/mistral/official_sdk_call.py @@ -1,7 +1,7 @@ +from mirascope.core import mistral from mistralai import Mistral -import os -client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) +client = Mistral(api_key=mistral.load_api_key()) def recommend_book(genre: str) -> str: diff --git a/examples/learn/calls/custom_client/mistral/base_message_param.py b/examples/learn/calls/custom_client/mistral/base_message_param.py index 01fa5181c..8d9c5b5d0 100644 --- a/examples/learn/calls/custom_client/mistral/base_message_param.py +++ b/examples/learn/calls/custom_client/mistral/base_message_param.py @@ -1,12 +1,10 @@ -import os - from mirascope.core import BaseMessageParam, mistral from mistralai import Mistral @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) def recommend_book(genre: str) -> list[BaseMessageParam]: return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")] @@ -14,7 +12,7 @@ def recommend_book(genre: str) -> list[BaseMessageParam]: @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) async def recommend_book_async(genre: str) -> list[BaseMessageParam]: return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")] diff --git a/examples/learn/calls/custom_client/mistral/messages.py b/examples/learn/calls/custom_client/mistral/messages.py index 569ca229a..818ed277d 100644 --- a/examples/learn/calls/custom_client/mistral/messages.py +++ b/examples/learn/calls/custom_client/mistral/messages.py @@ -1,12 +1,10 @@ -import os - from mirascope.core import Messages, mistral from mistralai import Mistral @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) def recommend_book(genre: str) -> Messages.Type: return Messages.User(f"Recommend a {genre} book") @@ -14,7 +12,7 @@ def recommend_book(genre: str) -> Messages.Type: @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) async def recommend_book_async(genre: str) -> Messages.Type: return Messages.User(f"Recommend a {genre} book") diff --git a/examples/learn/calls/custom_client/mistral/shorthand.py b/examples/learn/calls/custom_client/mistral/shorthand.py index 38571b32b..f61346ab2 100644 --- a/examples/learn/calls/custom_client/mistral/shorthand.py +++ b/examples/learn/calls/custom_client/mistral/shorthand.py @@ -1,12 +1,10 @@ -import os - from mirascope.core import mistral from mistralai import Mistral @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) def recommend_book(genre: str) -> str: return f"Recommend a {genre} book" @@ -14,7 +12,7 @@ def recommend_book(genre: str) -> str: @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) async def recommend_book_async(genre: str) -> str: return f"Recommend a {genre} book" diff --git a/examples/learn/calls/custom_client/mistral/string_template.py b/examples/learn/calls/custom_client/mistral/string_template.py index 8743a1a60..0fcb2904b 100644 --- a/examples/learn/calls/custom_client/mistral/string_template.py +++ b/examples/learn/calls/custom_client/mistral/string_template.py @@ -1,12 +1,10 @@ -import os - from mirascope.core import mistral, prompt_template from mistralai import Mistral @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) @prompt_template("Recommend a {genre} book") def recommend_book(genre: str): ... @@ -14,7 +12,7 @@ def recommend_book(genre: str): ... @mistral.call( "mistral-large-latest", - client=Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")), + client=Mistral(api_key=mistral.load_api_key()), ) @prompt_template("Recommend a {genre} book") async def recommend_book_async(genre: str): ... diff --git a/examples/learn/response_models/basic_usage/mistral/official_sdk.py b/examples/learn/response_models/basic_usage/mistral/official_sdk.py index 7131d46dd..102c95139 100644 --- a/examples/learn/response_models/basic_usage/mistral/official_sdk.py +++ b/examples/learn/response_models/basic_usage/mistral/official_sdk.py @@ -1,10 +1,9 @@ -import os - -from mistralai.client import Mistral +from mirascope.core import mistral +from mistralai import Mistral from mistralai.models import ToolChoice from pydantic import BaseModel -client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) +client = Mistral(api_key=mistral.load_api_key()) class Book(BaseModel): diff --git a/mirascope/core/mistral/__init__.py b/mirascope/core/mistral/__init__.py index 39ce1b4ed..b08cd7293 100644 --- a/mirascope/core/mistral/__init__.py +++ b/mirascope/core/mistral/__init__.py @@ -12,6 +12,7 @@ from ..base import BaseMessageParam from ._call import mistral_call from ._call import mistral_call as call +from ._utils._load_api_key import load_api_key from .call_params import MistralCallParams from .call_response import MistralCallResponse from .call_response_chunk import MistralCallResponseChunk @@ -25,6 +26,7 @@ __all__ = [ "call", + "load_api_key", "MistralDynamicConfig", "MistralCallParams", "MistralCallResponse", diff --git a/mirascope/core/mistral/_utils/_load_api_key.py b/mirascope/core/mistral/_utils/_load_api_key.py new file mode 100644 index 000000000..f06f3bcc4 --- /dev/null +++ b/mirascope/core/mistral/_utils/_load_api_key.py @@ -0,0 +1,6 @@ +import os + + +def load_api_key() -> str: + """Load the API key from the standard environment variable.""" + return os.environ.get("MISTRAL_API_KEY", "") diff --git a/mirascope/core/mistral/_utils/_setup_call.py b/mirascope/core/mistral/_utils/_setup_call.py index c8bb09689..b8d96cf84 100644 --- a/mirascope/core/mistral/_utils/_setup_call.py +++ b/mirascope/core/mistral/_utils/_setup_call.py @@ -1,6 +1,5 @@ """This module contains the setup_call function for Mistral tools.""" -import os from collections.abc import ( Awaitable, Callable, @@ -19,6 +18,7 @@ UserMessage, ) +from ... import mistral from ...base import BaseTool, _utils from ...base._utils import AsyncCreateFn, CreateFn, get_async_create_fn, get_create_fn from ...base._utils._protocols import fn_is_async @@ -114,7 +114,7 @@ def setup_call( call_kwargs |= {"model": model, "messages": messages} if client is None: - client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY", "")) + client = Mistral(api_key=mistral.load_api_key()) if fn_is_async(fn): create_or_stream = get_async_create_fn( client.chat.complete_async, client.chat.stream_async From c3d61d03677d05191a0eb29820bd95c0543379f7 Mon Sep 17 00:00:00 2001 From: Skylar Payne Date: Tue, 29 Oct 2024 13:48:51 -0700 Subject: [PATCH 4/4] fix type issues --- mirascope/core/mistral/call_response.py | 13 +++++++++---- mirascope/core/mistral/call_response_chunk.py | 7 ++++++- mirascope/core/mistral/stream.py | 7 +++---- mirascope/core/mistral/tool.py | 4 +++- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mirascope/core/mistral/call_response.py b/mirascope/core/mistral/call_response.py index dfe70dbec..5f3b2fd83 100644 --- a/mirascope/core/mistral/call_response.py +++ b/mirascope/core/mistral/call_response.py @@ -5,6 +5,7 @@ from typing import Any, cast +from mistralai import ChatCompletionChoice from mistralai.models import ( AssistantMessage, ChatCompletionResponse, @@ -57,17 +58,21 @@ def recommend_book(genre: str) -> str: _provider = "mistral" + @property + def _response_choices(self) -> list[ChatCompletionChoice]: + return self.response.choices or [] + @property def content(self) -> str: """The content of the chat completion for the 0th choice.""" - return self.response.choices[0].message.content + return self._response_choices[0].message.content or "" @property def finish_reasons(self) -> list[str]: """Returns the finish reasons of the response.""" return [ choice.finish_reason if choice.finish_reason else "" - for choice in self.response.choices + for choice in self._response_choices ] @property @@ -106,7 +111,7 @@ def message_param( self, ) -> AssistantMessage: """Returns the assistants's response as a message parameter.""" - return cast(AssistantMessage, self.response.choices[0].message) + return cast(AssistantMessage, self._response_choices[0].message) @computed_field @property @@ -116,7 +121,7 @@ def tools(self) -> list[MistralTool] | None: Raises: ValidationError: if the tool call doesn't match the tool's schema. """ - tool_calls = self.response.choices[0].message.tool_calls + tool_calls = self._response_choices[0].message.tool_calls if not self.tool_types or not tool_calls: return None diff --git a/mirascope/core/mistral/call_response_chunk.py b/mirascope/core/mistral/call_response_chunk.py index 075110583..2c7b6bff5 100644 --- a/mirascope/core/mistral/call_response_chunk.py +++ b/mirascope/core/mistral/call_response_chunk.py @@ -3,6 +3,8 @@ usage docs: learn/streams.md#handling-streamed-responses """ +from typing import cast + from mistralai.models import CompletionChunk, FinishReason, UsageInfo from ..base import BaseCallResponseChunk @@ -39,7 +41,10 @@ def content(self) -> str: delta = None if self.chunk.choices: delta = self.chunk.choices[0].delta - return delta.content if delta is not None and delta.content is not None else "" + + if delta is not None and delta.content is not None: + return cast(str, delta.content) + return "" @property def finish_reasons(self) -> list[FinishReason]: diff --git a/mirascope/core/mistral/stream.py b/mirascope/core/mistral/stream.py index d2f0af846..ff3407b86 100644 --- a/mirascope/core/mistral/stream.py +++ b/mirascope/core/mistral/stream.py @@ -3,7 +3,7 @@ usage docs: learn/streams.md """ -from typing import Any +from typing import Any, cast from mistralai.models import ( AssistantMessage, @@ -90,13 +90,12 @@ def construct_call_response(self) -> MistralCallResponse: completion_tokens=int(self.output_tokens or 0), total_tokens=int(self.input_tokens or 0) + int(self.output_tokens or 0), ) + finish_reason = cast(FinishReason, (self.finish_reasons or [])[0]) completion = ChatCompletionResponse( id=self.id if self.id else "", choices=[ ChatCompletionChoice( - finish_reason=self.finish_reasons[0] - if self.finish_reasons - else None, + finish_reason=finish_reason, index=0, message=self.message_param, ) diff --git a/mirascope/core/mistral/tool.py b/mirascope/core/mistral/tool.py index a789c4724..dec031008 100644 --- a/mirascope/core/mistral/tool.py +++ b/mirascope/core/mistral/tool.py @@ -72,6 +72,8 @@ def from_tool_call(cls, tool_call: ToolCall) -> MistralTool: Args: tool_call: The Mistral tool call from which to construct this tool instance. """ - model_json = jiter.from_json(tool_call.function.arguments.encode()) + model_json = tool_call.function.arguments + if isinstance(model_json, str): + model_json = jiter.from_json(model_json.encode()) model_json["tool_call"] = tool_call.model_dump() return cls.model_validate(model_json)