Skip to content

Commit

Permalink
ollama: allow specifying a bearer token as API key (zed-industries#19491
Browse files Browse the repository at this point in the history
)

Release Notes:

- Added OLLAMA_API_KEY and UI to set up key for the ollama provider to allow
  for authenticated API endpoints
  • Loading branch information
kov committed Oct 21, 2024
1 parent 92c29be commit a429f96
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 27 deletions.
257 changes: 241 additions & 16 deletions crates/language_model/src/provider/ollama.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use anyhow::{anyhow, bail, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
View, WhiteSpace,
};
use http_client::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Expand All @@ -10,7 +14,8 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, Indicator, Tooltip};
use util::ResultExt;

use crate::LanguageModelCompletionEvent;
Expand All @@ -30,6 +35,7 @@ const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub bearer_token: Option<String>,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
Expand All @@ -54,22 +60,63 @@ pub struct OllamaLanguageModelProvider {
pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
api_key: Option<String>,
api_key_from_env: bool,
_subscription: Subscription,
}

const OLLAMA_API_KEY_VAR: &str = "OLLAMA_API_KEY";

impl State {
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}

fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let delete_credentials = cx.delete_credentials(&settings.api_url);
cx.spawn(|this, mut cx| async move {
delete_credentials.await.log_err();
this.update(&mut cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
this.available_models.clear();
cx.notify();
})
})
}

fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let write_credentials =
cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());

cx.spawn(|this, mut cx| async move {
write_credentials.await?;
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
cx.notify();
})
})
}

fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();

// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OLLAMA_API_KEY_VAR) {
(Some(api_key), true)
} else {
if let Some((_, bytes)) = cx.update(|cx| cx.read_credentials(&api_url))?.await? {
(Some(String::from_utf8(bytes)?), false)
} else {
(None, false)
}
};

let models = get_models(http_client.as_ref(), &api_url, api_key.clone(), None).await?;

let mut models: Vec<ollama::Model> = models
.into_iter()
Expand All @@ -83,6 +130,8 @@ impl State {
models.sort_by(|a, b| a.name.cmp(&b.name));

this.update(&mut cx, |this, cx| {
this.api_key = api_key;
this.api_key_from_env = from_env;
this.available_models = models;
cx.notify();
})
Expand All @@ -105,6 +154,8 @@ impl OllamaLanguageModelProvider {
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
api_key: None,
api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
this.fetch_models(cx).detach();
cx.notify();
Expand Down Expand Up @@ -169,6 +220,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
Expand All @@ -180,8 +232,9 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let bearer_token = settings.bearer_token.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
cx.spawn(|_| async move { preload_model(http_client, &api_url, bearer_token, &id).await })
.detach_and_log_err(cx);
}

Expand All @@ -199,13 +252,25 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
}

fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).ollama.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
this.api_key = None;
this.api_key_from_env = false;
this.available_models.clear();
cx.notify();
})
})
}
}

pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
Expand Down Expand Up @@ -248,14 +313,15 @@ impl OllamaLanguageModel {
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
let http_client = self.http_client.clone();

let Ok(api_url) = cx.update(|cx| {
let Ok((api_url, api_key)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
(settings.api_url.clone(), state.api_key.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
async move { ollama::complete(http_client.as_ref(), &api_url, api_key, request).await }
.boxed()
}
}

Expand Down Expand Up @@ -309,17 +375,26 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);

let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let Ok((api_url, bearer_token, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
(
settings.api_url.clone(),
settings.bearer_token.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let response = stream_chat_completion(
http_client.as_ref(),
&api_url,
bearer_token,
request,
low_speed_timeout,
)
.await?;
let stream = response
.filter_map(|response| async move {
match response {
Expand Down Expand Up @@ -391,8 +466,10 @@ impl LanguageModel for OllamaLanguageModel {
}

struct ConfigurationView {
api_key_editor: View<Editor>,
state: gpui::Model<State>,
loading_models_task: Option<Task<()>>,
load_credentials_task: Option<Task<()>>,
}

impl ConfigurationView {
Expand All @@ -414,12 +491,100 @@ impl ConfigurationView {
}
}));

let load_credentials_task = Some(cx.spawn({
let state = state.clone();
|this, mut cx| async move {
if let Some(task) = state
.update(&mut cx, |state, cx| state.authenticate(cx))
.log_err()
{
// We don't log an error, because "not signed in" is also an error.
let _ = task.await;
}
this.update(&mut cx, |this, cx| {
this.load_credentials_task = None;
cx.notify();
})
.log_err();
}
}));

Self {
api_key_editor: cx.new_view(|cx| {
let mut editor = Editor::single_line(cx);
editor.set_placeholder_text("ollama bearer...", cx);
editor
}),
state,
loading_models_task,
load_credentials_task,
}
}

fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
let api_key = self.api_key_editor.read(cx).text(cx);
if api_key.is_empty() {
return;
}

let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
state
.update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
.await
})
.detach_and_log_err(cx);

cx.notify();
}

fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
self.api_key_editor
.update(cx, |editor, cx| editor.set_text("", cx));

let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
state
.update(&mut cx, |state, cx| state.reset_api_key(cx))?
.await
})
.detach_and_log_err(cx);

cx.notify();
}

fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let settings = ThemeSettings::get_global(cx);
let text_style = TextStyle {
color: cx.theme().colors().text,
font_family: settings.ui_font.family.clone(),
font_features: settings.ui_font.features.clone(),
font_fallbacks: settings.ui_font.fallbacks.clone(),
font_size: rems(0.875).into(),
font_weight: settings.ui_font.weight,
font_style: FontStyle::Normal,
line_height: relative(1.3),
background_color: None,
underline: None,
strikethrough: None,
white_space: WhiteSpace::Normal,
truncate: None,
};
EditorElement::new(
&self.api_key_editor,
EditorStyle {
background: cx.theme().colors().editor_background,
local_player: cx.theme().players().local(),
text: text_style,
..Default::default()
},
)
}

fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
!self.state.read(cx).is_authenticated()
}

fn retry_connection(&self, cx: &mut WindowContext) {
self.state
.update(cx, |state, cx| state.fetch_models(cx))
Expand All @@ -438,7 +603,9 @@ impl Render for ConfigurationView {
let mut inline_code_bg = cx.theme().colors().editor_background;
inline_code_bg.fade_out(0.5);

if self.loading_models_task.is_some() {
let mut container = div();

let child = if self.loading_models_task.is_some() {
div().child(Label::new("Loading models...")).into_any()
} else {
v_flex()
Expand Down Expand Up @@ -533,6 +700,64 @@ impl Render for ConfigurationView {
}),
)
.into_any()
}
};

container = container.child(child);

let env_var_set = self.state.read(cx).api_key_from_env;

let child = if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
v_flex()
.size_full()
.on_action(cx.listener(Self::save_api_key))
.child(Label::new(format!("If you host your ollama instance behind a bearer token, you can set it here")))
.child(
h_flex()
.w_full()
.my_2()
.px_2()
.py_1()
.bg(cx.theme().colors().editor_background)
.rounded_md()
.child(self.render_api_key_editor(cx)),
)
.child(
Label::new(
format!("You can also assign the {OLLAMA_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small),
)
.into_any()
} else {
h_flex()
.size_full()
.justify_between()
.child(
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
.child(Label::new(if env_var_set {
format!("API key set in {OLLAMA_API_KEY_VAR} environment variable.")
} else {
"API key configured.".to_string()
})),
)
.child(
Button::new("reset-key", "Reset key")
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
.disabled(env_var_set)
.when(env_var_set, |this| {
this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {OLLAMA_API_KEY_VAR} environment variable."), cx))
})
.on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
)
.into_any()
};

container.child(child).into_any()
}
}
Loading

0 comments on commit a429f96

Please sign in to comment.