diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index c27327fb2..b60f75264 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -561,6 +561,7 @@ mod tests { use crate::MessageContent; use either::Either; use indexmap::IndexMap; + use serde_json::Value; macro_rules! hashmap { (@single $($x:tt)*) => (()); @@ -572,7 +573,7 @@ mod tests { let _cap = hashmap!(@count $($key),*); let mut _map = ::indexmap::IndexMap::with_capacity(_cap); $( - let _ = _map.insert($key, $value); + let _ = _map.insert($key, Value::String($value)); )* _map } @@ -677,7 +678,7 @@ mod tests { ]; let mut inputs = Vec::new(); for [role, content] in messages { - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left(role.to_string())); message.insert("content".to_string(), Either::Left(content.to_string())); @@ -711,7 +712,7 @@ mod tests { let mut inputs = Vec::new(); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("system".to_string())); message.insert( @@ -723,7 +724,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( @@ -740,7 +741,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("assistant".to_string())); message.insert( @@ -752,7 +753,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( @@ -769,7 +770,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("assistant".to_string())); message.insert( @@ -781,7 +782,7 @@ mod tests { ); inputs.push(message); - let mut message: IndexMap>>> = + let mut message: IndexMap>>> = IndexMap::new(); message.insert("role".to_string(), Either::Left("user".to_string())); message.insert( diff --git a/mistralrs-core/src/pipeline/processing.rs b/mistralrs-core/src/pipeline/processing.rs index e1034b84a..bb19e3f46 100644 --- a/mistralrs-core/src/pipeline/processing.rs +++ b/mistralrs-core/src/pipeline/processing.rs @@ -84,8 +84,13 @@ pub(crate) fn apply_chat_template( 'outer: for content_row in rv { for (content_k, content_v) in content_row { if content_k == "text" { - new_message.insert(k, Either::Left(content_v)); - break 'outer; + if let Some(content_str) = content_v.as_str() { + new_message.insert( + k, + Either::Left(content_str.to_string()), + ); + break 'outer; + } } } } @@ -149,6 +154,6 @@ impl Processor for BasicProcessor { &[] } fn template_action(&self) -> MessagesAction { - MessagesAction::FlattenOnlyText + MessagesAction::Keep } } diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index bbd75dd7c..716a93147 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -2,6 +2,7 @@ use either::Either; use indexmap::IndexMap; use mistralrs_quant::IsqType; use serde::{Deserialize, Serialize}; +use serde_json::Value; use crate::{ response::Response, @@ -28,7 +29,7 @@ pub enum ImageGenerationResponseFormat { B64Json, } -pub type MessageContent = Either>>; +pub type MessageContent = Either>>; #[derive(Clone, Debug)] /// Message or messages for a [`Request`]. diff --git a/mistralrs-core/src/tools/response.rs b/mistralrs-core/src/tools/response.rs index 4eecd5caf..1e082a900 100644 --- a/mistralrs-core/src/tools/response.rs +++ b/mistralrs-core/src/tools/response.rs @@ -6,6 +6,14 @@ pub enum ToolCallType { Function, } +impl std::fmt::Display for ToolCallType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ToolCallType::Function => write!(f, "function"), + } + } +} + #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)] #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 4d86714be..aea3f8ec8 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -5,6 +5,7 @@ use anymoe::{AnyMoeConfig, AnyMoeExpertType}; use either::Either; use indexmap::IndexMap; use requests::{ChatCompletionRequest, CompletionRequest, ToolChoice}; +use serde_json::Value; use std::{ cell::RefCell, collections::HashMap, @@ -683,7 +684,7 @@ impl Runner { Either::Left(content) => { let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert( "role".to_string(), @@ -760,7 +761,7 @@ impl Runner { } let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert( "role".to_string(), @@ -774,11 +775,13 @@ impl Runner { let mut content_map = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), "image".to_string()); + content_image_map + .insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), "text".to_string()); - content_text_map.insert("text".to_string(), content); + content_text_map + .insert("type".to_string(), Value::String("text".to_string())); + content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map); message_map @@ -808,7 +811,7 @@ impl Runner { let mut messages = Vec::new(); let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left("user".to_string())); message_map.insert("content".to_string(), Either::Left(prompt.to_string())); diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 5ba1745de..e10092e68 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -1,3 +1,4 @@ +use serde_json::Value; use std::{ collections::HashMap, env, @@ -173,7 +174,7 @@ async fn parse_request( Either::Left(content) => { let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left(message.role)); message_map @@ -234,7 +235,7 @@ async fn parse_request( } let mut message_map: IndexMap< String, - Either>>, + Either>>, > = IndexMap::new(); message_map.insert("role".to_string(), Either::Left(message.role)); let (content, url) = if items[0] == "text" { @@ -243,13 +244,15 @@ async fn parse_request( get_content_and_url(1, 0, image_messages)? }; - let mut content_map = Vec::new(); + let mut content_map: Vec> = Vec::new(); let mut content_image_map = IndexMap::new(); - content_image_map.insert("type".to_string(), "image".to_string()); + content_image_map + .insert("type".to_string(), Value::String("image".to_string())); content_map.push(content_image_map); let mut content_text_map = IndexMap::new(); - content_text_map.insert("type".to_string(), "text".to_string()); - content_text_map.insert("text".to_string(), content); + content_text_map + .insert("type".to_string(), Value::String("text".to_string())); + content_text_map.insert("text".to_string(), Value::String(content)); content_map.push(content_text_map); message_map.insert("content".to_string(), Either::Right(content_map)); @@ -276,7 +279,7 @@ async fn parse_request( } Either::Right(prompt) => { let mut messages = Vec::new(); - let mut message_map: IndexMap>>> = + let mut message_map: IndexMap>>> = IndexMap::new(); message_map.insert("role".to_string(), Either::Left("user".to_string())); message_map.insert("content".to_string(), Either::Left(prompt)); diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 43d8e364e..2e6b5300d 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -7,6 +7,7 @@ use mistralrs_core::{ }; use once_cell::sync::Lazy; use regex::Regex; +use serde_json::Value; use std::{ io::{self, Write}, sync::{atomic::Ordering, Arc, Mutex}, @@ -221,7 +222,7 @@ async fn text_interactive_mode(mistralrs: Arc, throughput: bool) { println!(); info!("Average T/s: {}", toks as f64 / time); } - let mut assistant_message: IndexMap>>> = + let mut assistant_message: IndexMap>>> = IndexMap::new(); assistant_message.insert("role".to_string(), Either::Left("assistant".to_string())); assistant_message.insert("content".to_string(), Either::Left(assistant_output)); @@ -431,7 +432,7 @@ async fn vision_interactive_mode(mistralrs: Arc, throughput: bool) { println!(); info!("Average T/s: {}", toks as f64 / time); } - let mut assistant_message: IndexMap>>> = + let mut assistant_message: IndexMap>>> = IndexMap::new(); assistant_message.insert("role".to_string(), Either::Left("assistant".to_string())); assistant_message.insert("content".to_string(), Either::Left(assistant_output)); diff --git a/mistralrs/examples/lower_level/tools/main.rs b/mistralrs/examples/lower_level/tools/main.rs index 6e532b3fc..4fc8b1b08 100644 --- a/mistralrs/examples/lower_level/tools/main.rs +++ b/mistralrs/examples/lower_level/tools/main.rs @@ -126,20 +126,26 @@ fn main() -> anyhow::Result<()> { // Add tool call message from assistant so it knows what it called messages.push(IndexMap::from([ ("role".to_string(), Either::Left("assistant".to_string())), + ("content".to_string(), Either::Left("".to_string())), ( - "content".to_string(), - Either::Left( - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), - ), + "tool_calls".to_string(), + Either::Right(Vec![IndexMap::from([ + ("id".to_string(), tool_call.id), + ( + "function".to_string(), + json!({ + "name": called.function.name, + "arguments": called.function.arguments, + }) + ), + ("type".to_string(), "function".to_string()), + ])]), ), ])); // Add message from the tool messages.push(IndexMap::from([ ("role".to_string(), Either::Left("tool".to_string())), + ("tool_call_id".to_string(), Either::Left(tool_call.id)), ("content".to_string(), Either::Left(result)), ])); diff --git a/mistralrs/examples/tools/main.rs b/mistralrs/examples/tools/main.rs index 10b102047..02c002ca0 100644 --- a/mistralrs/examples/tools/main.rs +++ b/mistralrs/examples/tools/main.rs @@ -66,15 +66,12 @@ async fn main() -> Result<()> { // Add tool call message from assistant so it knows what it called // Then, add message from the tool messages = messages - .add_message( + .add_message_with_tool_call( TextMessageRole::Assistant, - json!({ - "name": called.function.name, - "parameters": called.function.arguments, - }) - .to_string(), + String::new(), + vec![called.clone()], ) - .add_message(TextMessageRole::Tool, result) + .add_tool_message(result, called.id.clone()) .set_tool_choice(ToolChoice::None); let response = model.send_chat_request(messages.clone()).await?; diff --git a/mistralrs/src/messages.rs b/mistralrs/src/messages.rs index e141c03a4..7aa512211 100644 --- a/mistralrs/src/messages.rs +++ b/mistralrs/src/messages.rs @@ -4,6 +4,7 @@ use super::*; use either::Either; use image::DynamicImage; use indexmap::IndexMap; +use serde_json::{json, Value}; /// A type which can be used as a request. pub trait RequestLike { @@ -207,10 +208,10 @@ impl VisionMessages { ( "content".to_string(), Either::Right(vec![ - IndexMap::from([("type".to_string(), "image".to_string())]), + IndexMap::from([("type".to_string(), Value::String("image".to_string()))]), IndexMap::from([ - ("type".to_string(), "text".to_string()), - ("content".to_string(), text.to_string()), + ("type".to_string(), Value::String("text".to_string())), + ("content".to_string(), Value::String(text.to_string())), ]), ]), ), @@ -361,6 +362,10 @@ impl RequestBuilder { } } + /// Add a message to the request. + /// + /// For messages with tool calls, use [`Self::add_message_with_tool_call`]. + /// For messages with tool outputs, use [`Self::add_tool_message`]. pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self { self.messages.push(IndexMap::from([ ("role".to_string(), Either::Left(role.to_string())), @@ -369,6 +374,55 @@ impl RequestBuilder { self } + /// Add a message with the output of a tool call. + pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self { + self.messages.push(IndexMap::from([ + ( + "role".to_string(), + Either::Left(TextMessageRole::Tool.to_string()), + ), + ( + "content".to_string(), + Either::Left(tool_content.to_string()), + ), + ( + "tool_call_id".to_string(), + Either::Left(tool_id.to_string()), + ), + ])); + self + } + + pub fn add_message_with_tool_call( + mut self, + role: TextMessageRole, + text: impl ToString, + tool_calls: Vec, + ) -> Self { + let tool_messages = tool_calls + .iter() + .map(|t| { + IndexMap::from([ + ("id".to_string(), Value::String(t.id.clone())), + ("type".to_string(), Value::String(t.tp.to_string())), + ( + "function".to_string(), + json!({ + "name": t.function.name, + "arguments": t.function.arguments, + }), + ), + ]) + }) + .collect(); + self.messages.push(IndexMap::from([ + ("role".to_string(), Either::Left(role.to_string())), + ("content".to_string(), Either::Left(text.to_string())), + ("function".to_string(), Either::Right(tool_messages)), + ])); + self + } + pub fn add_image_message( mut self, role: TextMessageRole,