Skip to content

Commit

Permalink
ollama: allow specifying a bearer token (#19491)
Browse files Browse the repository at this point in the history
  • Loading branch information
kov committed Oct 20, 2024
1 parent 92c29be commit 1a39f6f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 29 deletions.
26 changes: 18 additions & 8 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub enum AssistantProviderContentV1 {
Ollama {
default_model: Option<OllamaModel>,
api_url: Option<String>,
bearer_token: Option<String>,
low_speed_timeout_in_seconds: Option<u64>,
},
}
Expand Down Expand Up @@ -138,6 +139,7 @@ impl AssistantSettingsContent {
),
AssistantProviderContentV1::Ollama {
api_url,
bearer_token,
low_speed_timeout_in_seconds,
..
} => update_settings_file::<AllLanguageModelSettings>(
Expand All @@ -147,6 +149,7 @@ impl AssistantSettingsContent {
if content.ollama.is_none() {
content.ollama = Some(OllamaSettingsContent {
api_url,
bearer_token,
low_speed_timeout_in_seconds,
available_models: None,
});
Expand Down Expand Up @@ -313,17 +316,24 @@ impl AssistantSettingsContent {
});
}
"ollama" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
low_speed_timeout_in_seconds,
..
}) => (api_url.clone(), *low_speed_timeout_in_seconds),
_ => (None, None),
};
let (api_url, bearer_token, low_speed_timeout_in_seconds) =
match &settings.provider {
Some(AssistantProviderContentV1::Ollama {
api_url,
bearer_token,
low_speed_timeout_in_seconds,
..
}) => (
api_url.clone(),
bearer_token.clone(),
*low_speed_timeout_in_seconds,
),
_ => (None, None, None),
};
settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model, None, None)),
api_url,
bearer_token,
low_speed_timeout_in_seconds,
});
}
Expand Down
33 changes: 23 additions & 10 deletions crates/language_model/src/provider/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub bearer_token: Option<String>,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
Expand Down Expand Up @@ -66,10 +67,11 @@ impl State {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let bearer_token = settings.bearer_token.clone();

// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let models = get_models(http_client.as_ref(), &api_url, bearer_token, None).await?;

let mut models: Vec<ollama::Model> = models
.into_iter()
Expand Down Expand Up @@ -180,8 +182,9 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let bearer_token = settings.bearer_token.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
cx.spawn(|_| async move { preload_model(http_client, &api_url, bearer_token, &id).await })
.detach_and_log_err(cx);
}

Expand Down Expand Up @@ -248,14 +251,15 @@ impl OllamaLanguageModel {
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
let http_client = self.http_client.clone();

let Ok(api_url) = cx.update(|cx| {
let Ok((api_url, bearer_token)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
(settings.api_url.clone(), settings.bearer_token.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
async move { ollama::complete(http_client.as_ref(), &api_url, bearer_token, request).await }
.boxed()
}
}

Expand Down Expand Up @@ -309,17 +313,26 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);

let http_client = self.http_client.clone();
let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
let Ok((api_url, bearer_token, low_speed_timeout)) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
(settings.api_url.clone(), settings.low_speed_timeout)
(
settings.api_url.clone(),
settings.bearer_token.clone(),
settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};

let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let response = stream_chat_completion(
http_client.as_ref(),
&api_url,
bearer_token,
request,
low_speed_timeout,
)
.await?;
let stream = response
.filter_map(|response| async move {
match response {
Expand Down
4 changes: 4 additions & 0 deletions crates/language_model/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub struct AnthropicSettingsContentV1 {
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct OllamaSettingsContent {
pub api_url: Option<String>,
pub bearer_token: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::ollama::AvailableModel>>,
}
Expand Down Expand Up @@ -291,6 +292,9 @@ impl settings::Settings for AllLanguageModelSettings {
&mut settings.ollama.api_url,
value.ollama.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(bearer_token) = value.ollama.as_ref().and_then(|s| s.bearer_token.clone()) {
settings.ollama.bearer_token = Some(bearer_token);
}
if let Some(low_speed_timeout_in_seconds) = value
.ollama
.as_ref()
Expand Down
47 changes: 36 additions & 11 deletions crates/ollama/src/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,19 @@ pub struct ModelDetails {
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
bearer_token: Option<String>,
request: ChatRequest,
) -> Result<ChatResponseDelta> {
let uri = format!("{api_url}/api/chat");
let request_builder = HttpRequest::builder()
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");

if let Some(bearer_token) = bearer_token {
request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}"))
}

let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;

Expand All @@ -261,6 +266,7 @@ pub async fn complete(
pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
bearer_token: Option<String>,
request: ChatRequest,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
Expand All @@ -270,6 +276,10 @@ pub async fn stream_chat_completion(
.uri(uri)
.header("Content-Type", "application/json");

if let Some(bearer_token) = bearer_token {
request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}"))
}

if let Some(low_speed_timeout) = low_speed_timeout {
request_builder = request_builder.read_timeout(low_speed_timeout);
}
Expand Down Expand Up @@ -305,14 +315,19 @@ pub async fn stream_chat_completion(
pub async fn get_models(
client: &dyn HttpClient,
api_url: &str,
bearer_token: Option<String>,
_: Option<Duration>,
) -> Result<Vec<LocalModelListing>> {
let uri = format!("{api_url}/api/tags");
let request_builder = HttpRequest::builder()
let mut request_builder = HttpRequest::builder()
.method(Method::GET)
.uri(uri)
.header("Accept", "application/json");

if let Some(bearer_token) = bearer_token {
request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}"))
}

let request = request_builder.body(AsyncBody::default())?;

let mut response = client.send(request).await?;
Expand All @@ -335,18 +350,28 @@ pub async fn get_models(
}

/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
pub async fn preload_model(
client: Arc<dyn HttpClient>,
api_url: &str,
bearer_token: Option<String>,
model: &str,
) -> Result<()> {
let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder()
let mut request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(
&serde_json::json!({
"model": model,
"keep_alive": "15m",
}),
)?))?;
.header("Content-Type", "application/json");

if let Some(bearer_token) = bearer_token {
request_builder = request_builder.header("Authorization", format!("Bearer {bearer_token}"))
}

let request = request_builder.body(AsyncBody::from(serde_json::to_string(
&serde_json::json!({
"model": model,
"keep_alive": "15m",
}),
)?))?;

let mut response = client.send(request).await?;

Expand Down

0 comments on commit 1a39f6f

Please sign in to comment.