Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump mistralai to > 1.0.0 in preparation for latest models such as Pixtral #531

Open
wants to merge 4 commits into
base: release/v1.10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mistralai.client import MistralClient
from mirascope.core import mistral
from mistralai import Mistral

client = MistralClient()
client = Mistral(api_key=mistral.load_api_key())


def recommend_book(genre: str) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from mirascope.core import BaseMessageParam, mistral
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
def recommend_book(genre: str) -> list[BaseMessageParam]:
return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")]


@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
async def recommend_book_async(genre: str) -> list[BaseMessageParam]:
return [BaseMessageParam(role="user", content=f"Recommend a {genre} book")]
15 changes: 13 additions & 2 deletions examples/learn/calls/custom_client/mistral/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from mirascope.core import Messages, mistral
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
def recommend_book(genre: str) -> Messages.Type:
return Messages.User(f"Recommend a {genre} book")


@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
async def recommend_book_async(genre: str) -> Messages.Type:
return Messages.User(f"Recommend a {genre} book")
15 changes: 13 additions & 2 deletions examples/learn/calls/custom_client/mistral/shorthand.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from mirascope.core import mistral
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
def recommend_book(genre: str) -> str:
return f"Recommend a {genre} book"


@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
async def recommend_book_async(genre: str) -> str:
return f"Recommend a {genre} book"
15 changes: 13 additions & 2 deletions examples/learn/calls/custom_client/mistral/string_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from mirascope.core import mistral, prompt_template
from mistralai.client import MistralClient
from mistralai import Mistral


@mistral.call("mistral-large-latest", client=MistralClient())
@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
@prompt_template("Recommend a {genre} book")
def recommend_book(genre: str): ...


@mistral.call(
"mistral-large-latest",
client=Mistral(api_key=mistral.load_api_key()),
)
@prompt_template("Recommend a {genre} book")
async def recommend_book_async(genre: str): ...
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ToolChoice
from mirascope.core import mistral
from mistralai import Mistral
from mistralai.models import ToolChoice
from pydantic import BaseModel

client = MistralClient()
client = Mistral(api_key=mistral.load_api_key())


class Book(BaseModel):
Expand Down
13 changes: 11 additions & 2 deletions mirascope/core/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,31 @@

from typing import TypeAlias

from mistralai.models.chat_completion import ChatMessage
from mistralai.models import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)

from ..base import BaseMessageParam
from ._call import mistral_call
from ._call import mistral_call as call
from ._utils._load_api_key import load_api_key
from .call_params import MistralCallParams
from .call_response import MistralCallResponse
from .call_response_chunk import MistralCallResponseChunk
from .dynamic_config import MistralDynamicConfig
from .stream import MistralStream
from .tool import MistralTool

MistralMessageParam: TypeAlias = ChatMessage | BaseMessageParam
MistralMessageParam: TypeAlias = (
AssistantMessage | SystemMessage | ToolMessage | UserMessage | BaseMessageParam
)

__all__ = [
"call",
"load_api_key",
"MistralDynamicConfig",
"MistralCallParams",
"MistralCallResponse",
Expand Down
36 changes: 30 additions & 6 deletions mirascope/core/mistral/_utils/_convert_message_params.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,47 @@
"""Utility for converting `BaseMessageParam` to `ChatMessage`."""

from mistralai.models.chat_completion import ChatMessage
from typing import Any

from mistralai.models import (
AssistantMessage,
SystemMessage,
ToolMessage,
UserMessage,
)

from ...base import BaseMessageParam


def make_message(
role: str,
**kwargs, # noqa: ANN003
) -> AssistantMessage | SystemMessage | ToolMessage | UserMessage:
if role == "assistant":
return AssistantMessage(**kwargs)
elif role == "system":
return SystemMessage(**kwargs)
elif role == "tool":
return ToolMessage(**kwargs)
elif role == "user":
return UserMessage(**kwargs)
raise ValueError(f"Invalid role: {role}")


def convert_message_params(
message_params: list[BaseMessageParam | ChatMessage],
) -> list[ChatMessage]:
message_params: list[
BaseMessageParam | AssistantMessage | SystemMessage | ToolMessage | UserMessage
],
) -> list[AssistantMessage | SystemMessage | ToolMessage | UserMessage]:
converted_message_params = []
for message_param in message_params:
if isinstance(message_param, ChatMessage):
if not isinstance(message_param, BaseMessageParam):
converted_message_params.append(message_param)
elif isinstance(content := message_param.content, str):
converted_message_params.append(ChatMessage(**message_param.model_dump()))
converted_message_params.append(make_message(**message_param.model_dump()))
else:
if len(content) != 1 or content[0].type != "text":
raise ValueError("Mistral currently only supports text parts.")
converted_message_params.append(
ChatMessage(role=message_param.role, content=content[0].text)
make_message(role=message_param.role, content=content[0].text)
)
return converted_message_params
17 changes: 8 additions & 9 deletions mirascope/core/mistral/_utils/_handle_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@

from collections.abc import AsyncGenerator, Generator

from mistralai.models.chat_completion import (
ChatCompletionStreamResponse,
from mistralai.models import (
CompletionEvent,
FunctionCall,
ToolCall,
ToolType,
)

from ..call_response_chunk import MistralCallResponseChunk
from ..tool import MistralTool


def _handle_chunk(
chunk: ChatCompletionStreamResponse,
chunk: CompletionEvent,
current_tool_call: ToolCall,
current_tool_type: type[MistralTool] | None,
tool_types: list[type[MistralTool]] | None,
Expand All @@ -38,7 +37,7 @@ def _handle_chunk(
arguments="",
name=tool_call.function.name if tool_call.function.name else "",
),
type=ToolType.function,
type="function",
)
current_tool_type = None
for tool_type in tool_types:
Expand All @@ -64,12 +63,12 @@ def _handle_chunk(


def handle_stream(
stream: Generator[ChatCompletionStreamResponse, None, None],
stream: Generator[CompletionEvent, None, None],
tool_types: list[type[MistralTool]] | None,
) -> Generator[tuple[MistralCallResponseChunk, MistralTool | None], None, None]:
"""Iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ToolCall(
id="", function=FunctionCall(arguments="", name=""), type=ToolType.function
id="", function=FunctionCall(arguments="", name=""), type="function"
)
current_tool_type = None
for chunk in stream:
Expand All @@ -93,12 +92,12 @@ def handle_stream(


async def handle_stream_async(
stream: AsyncGenerator[ChatCompletionStreamResponse, None],
stream: AsyncGenerator[CompletionEvent, None],
tool_types: list[type[MistralTool]] | None,
) -> AsyncGenerator[tuple[MistralCallResponseChunk, MistralTool | None], None]:
"""Async iterator over the stream and constructs tools as they are streamed."""
current_tool_call = ToolCall(
id="", function=FunctionCall(arguments="", name=""), type=ToolType.function
id="", function=FunctionCall(arguments="", name=""), type="function"
)
current_tool_type = None
async for chunk in stream:
Expand Down
6 changes: 6 additions & 0 deletions mirascope/core/mistral/_utils/_load_api_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os


def load_api_key() -> str:
"""Load the API key from the standard environment variable."""
return os.environ.get("MISTRAL_API_KEY", "")
Loading