diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 2bab6a9624f314..6b3f38baa24288 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -48,6 +48,7 @@ pub enum AssistantProviderContentV1 { Ollama { default_model: Option, api_url: Option, + bearer_token: Option, low_speed_timeout_in_seconds: Option, }, } @@ -138,6 +139,7 @@ impl AssistantSettingsContent { ), AssistantProviderContentV1::Ollama { api_url, + bearer_token, low_speed_timeout_in_seconds, .. } => update_settings_file::( @@ -147,6 +149,7 @@ impl AssistantSettingsContent { if content.ollama.is_none() { content.ollama = Some(OllamaSettingsContent { api_url, + bearer_token, low_speed_timeout_in_seconds, available_models: None, }); @@ -313,17 +316,24 @@ impl AssistantSettingsContent { }); } "ollama" => { - let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { - Some(AssistantProviderContentV1::Ollama { - api_url, - low_speed_timeout_in_seconds, - .. - }) => (api_url.clone(), *low_speed_timeout_in_seconds), - _ => (None, None), - }; + let (api_url, bearer_token, low_speed_timeout_in_seconds) = + match &settings.provider { + Some(AssistantProviderContentV1::Ollama { + api_url, + bearer_token, + low_speed_timeout_in_seconds, + .. + }) => ( + api_url.clone(), + bearer_token.clone(), + *low_speed_timeout_in_seconds, + ), + _ => (None, None, None), + }; settings.provider = Some(AssistantProviderContentV1::Ollama { default_model: Some(ollama::Model::new(&model, None, None)), api_url, + bearer_token, low_speed_timeout_in_seconds, }); } diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index a29ff3cf6a7a1a..00a1f4bb2c41cb 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -30,6 +30,7 @@ const PROVIDER_NAME: &str = "Ollama"; #[derive(Default, Debug, Clone, PartialEq)] pub struct OllamaSettings { pub api_url: String, + pub bearer_token: Option, pub low_speed_timeout: Option, pub available_models: Vec, } @@ -66,10 +67,11 @@ impl State { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); + let bearer_token = settings.bearer_token.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(|this, mut cx| async move { - let models = get_models(http_client.as_ref(), &api_url, None).await?; + let models = get_models(http_client.as_ref(), &api_url, bearer_token, None).await?; let mut models: Vec = models .into_iter() @@ -180,8 +182,9 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); + let bearer_token = settings.bearer_token.clone(); let id = model.id().0.to_string(); - cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await }) + cx.spawn(|_| async move { preload_model(http_client, &api_url, bearer_token, &id).await }) .detach_and_log_err(cx); } @@ -248,14 +251,15 @@ impl OllamaLanguageModel { ) -> BoxFuture<'static, Result> { let http_client = self.http_client.clone(); - let Ok(api_url) = cx.update(|cx| { + let Ok((api_url, bearer_token)) = cx.update(|cx| { let settings = &AllLanguageModelSettings::get_global(cx).ollama; - settings.api_url.clone() + (settings.api_url.clone(), settings.bearer_token.clone()) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; - async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed() + async move { ollama::complete(http_client.as_ref(), &api_url, bearer_token, request).await } + .boxed() } } @@ -309,17 +313,26 @@ impl LanguageModel for OllamaLanguageModel { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); - let Ok((api_url, low_speed_timeout)) = cx.update(|cx| { + let Ok((api_url, bearer_token, low_speed_timeout)) = cx.update(|cx| { let settings = &AllLanguageModelSettings::get_global(cx).ollama; - (settings.api_url.clone(), settings.low_speed_timeout) + ( + settings.api_url.clone(), + settings.bearer_token.clone(), + settings.low_speed_timeout, + ) }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { - let response = - stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout) - .await?; + let response = stream_chat_completion( + http_client.as_ref(), + &api_url, + bearer_token, + request, + low_speed_timeout, + ) + .await?; let stream = response .filter_map(|response| async move { match response { diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 2bf8deb04238c2..a0e8010e70a795 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -153,6 +153,7 @@ pub struct AnthropicSettingsContentV1 { #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct OllamaSettingsContent { pub api_url: Option, + pub bearer_token: Option, pub low_speed_timeout_in_seconds: Option, pub available_models: Option>, } @@ -291,6 +292,9 @@ impl settings::Settings for AllLanguageModelSettings { &mut settings.ollama.api_url, value.ollama.as_ref().and_then(|s| s.api_url.clone()), ); + if let Some(bearer_token) = value.ollama.as_ref().and_then(|s| s.bearer_token.clone()) { + settings.ollama.bearer_token = Some(bearer_token); + } if let Some(low_speed_timeout_in_seconds) = value .ollama .as_ref() diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index a38b9e7a564512..b5da7430f75469 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -229,14 +229,19 @@ pub struct ModelDetails { pub async fn complete( client: &dyn HttpClient, api_url: &str, + bearer_token: Option, request: ChatRequest, ) -> Result { let uri = format!("{api_url}/api/chat"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); + if let Some(bearer_token) = bearer_token { + request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}")) + } + let serialized_request = serde_json::to_string(&request)?; let request = request_builder.body(AsyncBody::from(serialized_request))?; @@ -261,6 +266,7 @@ pub async fn complete( pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, + bearer_token: Option, request: ChatRequest, low_speed_timeout: Option, ) -> Result>> { @@ -270,6 +276,10 @@ pub async fn stream_chat_completion( .uri(uri) .header("Content-Type", "application/json"); + if let Some(bearer_token) = bearer_token { + request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}")) + } + if let Some(low_speed_timeout) = low_speed_timeout { request_builder = request_builder.read_timeout(low_speed_timeout); } @@ -305,14 +315,19 @@ pub async fn stream_chat_completion( pub async fn get_models( client: &dyn HttpClient, api_url: &str, + bearer_token: Option, _: Option, ) -> Result> { let uri = format!("{api_url}/api/tags"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); + if let Some(bearer_token) = bearer_token { + request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}")) + } + let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; @@ -335,18 +350,28 @@ pub async fn get_models( } /// Sends an empty request to Ollama to trigger loading the model -pub async fn preload_model(client: Arc, api_url: &str, model: &str) -> Result<()> { +pub async fn preload_model( + client: Arc, + api_url: &str, + bearer_token: Option, + model: &str, +) -> Result<()> { let uri = format!("{api_url}/api/generate"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json") - .body(AsyncBody::from(serde_json::to_string( - &serde_json::json!({ - "model": model, - "keep_alive": "15m", - }), - )?))?; + .header("Content-Type", "application/json"); + + if let Some(bearer_token) = bearer_token { + request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}")) + } + + let request = request_builder.body(AsyncBody::from(serde_json::to_string( + &serde_json::json!({ + "model": model, + "keep_alive": "15m", + }), + )?))?; let mut response = client.send(request).await?;