Skip to content

Commit

Permalink
Add endpoints to configure cache
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Nov 7, 2024
1 parent 53be351 commit 7066b68
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 9 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ small-fixed-array = { version = "0.4.0", features = ["serde"] }
memchr = "2.7.4"
aformat = "0.1.4"
mini-moka = { version = "0.10.3", features = ["sync"] }
arc-swap = "1.7.1"

[dependencies.tracing-subscriber]
version = "0.3"
Expand Down
78 changes: 69 additions & 9 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@ use std::{
fmt::Display,
str::FromStr,
sync::{
atomic::{AtomicBool, Ordering},
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, OnceLock,
},
time::{Duration, Instant},
};

use axum::{http::header::HeaderValue, response::Response, routing::get, Json};
use arc_swap::ArcSwap;
use axum::{
http::header::HeaderValue,
response::Response,
routing::{get, post},
Json,
};
use bytes::Bytes;
use mini_moka::sync::Cache;
use reqwest::StatusCode;
use serde_json::to_value;
use sha2::{
digest::{consts::U32, generic_array::GenericArray},
Expand Down Expand Up @@ -117,6 +125,44 @@ async fn get_translation_languages() -> ResponseResult<Json<Vec<(FixedString, Fi
}
}

#[derive(serde::Serialize)]
struct CacheInfo {
hits: u64,
misses: u64,
total: u64,
}

async fn get_cache_info() -> Json<CacheInfo> {
let cache = STATE.get().unwrap().cache.load();
let hits = cache.hits.load(Ordering::Relaxed);
let misses = cache.misses.load(Ordering::Relaxed);

Json(CacheInfo {
hits,
misses,
total: hits + misses,
})
}

#[derive(serde::Deserialize)]
struct RefreshCache {
new_capacity: u64,
}

async fn refresh_cache(
Json(RefreshCache { new_capacity }): Json<RefreshCache>,
) -> reqwest::StatusCode {
let state = STATE.get().unwrap();

state.cache.store(Arc::new(AudioCache {
inner: Cache::new(new_capacity),
misses: AtomicU64::new(0),
hits: AtomicU64::new(0),
}));

StatusCode::OK
}

#[derive(serde::Deserialize, Debug)]
struct GetTTS {
text: FixedString,
Expand Down Expand Up @@ -195,13 +241,17 @@ async fn get_tts(
);

let cache_hash = sha2::Sha256::digest(&cache_key);
if let Some(cached_audio) = state.cache.get(&cache_hash) {
let audio_cache = state.cache.load();
if let Some(cached_audio) = audio_cache.inner.get(&cache_hash) {
audio_cache.hits.fetch_add(1, Ordering::Relaxed);

mode.check_length(&cached_audio, payload.max_length)?;

tracing::debug!("Used cached TTS for {cache_key}");
return Ok(mode.into_response(cached_audio, None));
}

audio_cache.misses.fetch_add(1, Ordering::Relaxed);
cache_hash
};

Expand Down Expand Up @@ -263,7 +313,7 @@ async fn get_tts(
);

tracing::debug!("Cached {} kb of audio", (audio.len() as f64) / 1024.0);
state.cache.insert(cache_hash, audio.clone());
state.cache.load().inner.insert(cache_hash, audio.clone());
};

mode.check_length(&audio, payload.max_length)?;
Expand Down Expand Up @@ -374,12 +424,18 @@ impl serde::Serialize for TTSMode {
}
}

struct AudioCache {
inner: Cache<AudioCacheDigest, Bytes>,
misses: AtomicU64,
hits: AtomicU64,
}

struct State {
auth_key: Option<FixedString<u8>>,
translation_key: Option<FixedString<u8>>,
reqwest: reqwest::Client,

cache: mini_moka::sync::Cache<AudioCacheDigest, Bytes>,
cache: ArcSwap<AudioCache>,

polly: polly::State,
gtts: tokio::sync::RwLock<gtts::State>,
Expand Down Expand Up @@ -423,12 +479,14 @@ async fn main() -> Result<()> {
.and_then(|c| c.parse().ok())
.unwrap_or(1000);

let cache = mini_moka::sync::Cache::builder()
.max_capacity(max_cap)
.build();
let cache = Cache::builder().max_capacity(max_cap).build();

tracing::info!("Initialised audio cache with max capacity: {max_cap}");
cache
ArcSwap::from_pointee(AudioCache {
inner: cache,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
})
},

auth_key: std::env::var("AUTH_KEY").ok().map(str_to_fixedstring),
Expand All @@ -442,6 +500,8 @@ async fn main() -> Result<()> {
let app = axum::Router::new()
.route("/tts", get(get_tts))
.route("/voices", get(get_voices))
.route("/cache", get(get_cache_info))
.route("/cache", post(refresh_cache))
.route("/translation_languages", get(get_translation_languages))
.route(
"/modes",
Expand Down

0 comments on commit 7066b68

Please sign in to comment.