diff --git a/Cargo.lock b/Cargo.lock index e24ef56..840a3e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,12 @@ version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74f37166d7d48a0284b99dd824694c26119c700b53bf0d1540cdb147dbdaaf13" +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayvec" version = "0.7.6" @@ -2370,6 +2376,7 @@ version = "0.1.0" dependencies = [ "aformat", "anyhow", + "arc-swap", "aws-config", "aws-sdk-polly", "axum", diff --git a/Cargo.toml b/Cargo.toml index 7e4b87b..4b32152 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ small-fixed-array = { version = "0.4.0", features = ["serde"] } memchr = "2.7.4" aformat = "0.1.4" mini-moka = { version = "0.10.3", features = ["sync"] } +arc-swap = "1.7.1" [dependencies.tracing-subscriber] version = "0.3" diff --git a/src/main.rs b/src/main.rs index c2f41a3..2158d19 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,14 +11,22 @@ use std::{ fmt::Display, str::FromStr, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, OnceLock, }, time::{Duration, Instant}, }; -use axum::{http::header::HeaderValue, response::Response, routing::get, Json}; +use arc_swap::ArcSwap; +use axum::{ + http::header::HeaderValue, + response::Response, + routing::{get, post}, + Json, +}; use bytes::Bytes; +use mini_moka::sync::Cache; +use reqwest::StatusCode; use serde_json::to_value; use sha2::{ digest::{consts::U32, generic_array::GenericArray}, @@ -117,6 +125,44 @@ async fn get_translation_languages() -> ResponseResult Json { + let cache = STATE.get().unwrap().cache.load(); + let hits = cache.hits.load(Ordering::Relaxed); + let misses = cache.misses.load(Ordering::Relaxed); + + Json(CacheInfo { + hits, + misses, + total: hits + misses, + }) +} + +#[derive(serde::Deserialize)] +struct RefreshCache { + new_capacity: u64, +} + +async fn refresh_cache( + Json(RefreshCache { new_capacity }): Json, +) -> reqwest::StatusCode { + let state = STATE.get().unwrap(); + + state.cache.store(Arc::new(AudioCache { + inner: Cache::new(new_capacity), + misses: AtomicU64::new(0), + hits: AtomicU64::new(0), + })); + + StatusCode::OK +} + #[derive(serde::Deserialize, Debug)] struct GetTTS { text: FixedString, @@ -195,13 +241,17 @@ async fn get_tts( ); let cache_hash = sha2::Sha256::digest(&cache_key); - if let Some(cached_audio) = state.cache.get(&cache_hash) { + let audio_cache = state.cache.load(); + if let Some(cached_audio) = audio_cache.inner.get(&cache_hash) { + audio_cache.hits.fetch_add(1, Ordering::Relaxed); + mode.check_length(&cached_audio, payload.max_length)?; tracing::debug!("Used cached TTS for {cache_key}"); return Ok(mode.into_response(cached_audio, None)); } + audio_cache.misses.fetch_add(1, Ordering::Relaxed); cache_hash }; @@ -263,7 +313,7 @@ async fn get_tts( ); tracing::debug!("Cached {} kb of audio", (audio.len() as f64) / 1024.0); - state.cache.insert(cache_hash, audio.clone()); + state.cache.load().inner.insert(cache_hash, audio.clone()); }; mode.check_length(&audio, payload.max_length)?; @@ -374,12 +424,18 @@ impl serde::Serialize for TTSMode { } } +struct AudioCache { + inner: Cache, + misses: AtomicU64, + hits: AtomicU64, +} + struct State { auth_key: Option>, translation_key: Option>, reqwest: reqwest::Client, - cache: mini_moka::sync::Cache, + cache: ArcSwap, polly: polly::State, gtts: tokio::sync::RwLock, @@ -423,12 +479,14 @@ async fn main() -> Result<()> { .and_then(|c| c.parse().ok()) .unwrap_or(1000); - let cache = mini_moka::sync::Cache::builder() - .max_capacity(max_cap) - .build(); + let cache = Cache::builder().max_capacity(max_cap).build(); tracing::info!("Initialised audio cache with max capacity: {max_cap}"); - cache + ArcSwap::from_pointee(AudioCache { + inner: cache, + hits: AtomicU64::new(0), + misses: AtomicU64::new(0), + }) }, auth_key: std::env::var("AUTH_KEY").ok().map(str_to_fixedstring), @@ -442,6 +500,8 @@ async fn main() -> Result<()> { let app = axum::Router::new() .route("/tts", get(get_tts)) .route("/voices", get(get_voices)) + .route("/cache", get(get_cache_info)) + .route("/cache", post(refresh_cache)) .route("/translation_languages", get(get_translation_languages)) .route( "/modes",