diff --git a/src/app.rs b/src/app.rs index c6d95c1c..67248f37 100644 --- a/src/app.rs +++ b/src/app.rs @@ -6,6 +6,7 @@ use crate::{ api::{Api, ApiError, ApiInner, ApiVersion}, + dispatch::{self, DispatchError, Trie}, healthcheck::{HealthCheck, HealthStatus}, http, method::Method, @@ -16,9 +17,9 @@ use crate::{ Html, StatusCode, }; use async_std::sync::Arc; +use derive_more::From; use futures::future::{BoxFuture, FutureExt}; use include_dir::{include_dir, Dir}; -use itertools::Itertools; use lazy_static::lazy_static; use maud::{html, PreEscaped}; use semver::Version; @@ -26,10 +27,7 @@ use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use snafu::{ResultExt, Snafu}; use std::{ - collections::{ - btree_map::{BTreeMap, Entry as BTreeEntry}, - hash_map::{Entry as HashEntry, HashMap}, - }, + collections::btree_map::BTreeMap, convert::Infallible, env, fmt::Display, @@ -58,24 +56,23 @@ pub use tide::listener::{Listener, ToListener}; /// use by any given API module may differ, depending on the supported version of the API. #[derive(Debug)] pub struct App { - // Map from base URL, major version to API. - pub(crate) apis: HashMap>>, + pub(crate) modules: Trie>, pub(crate) state: Arc, app_version: Option, } /// An error encountered while building an [App]. -#[derive(Clone, Debug, Snafu, PartialEq, Eq)] +#[derive(Clone, Debug, From, Snafu, PartialEq, Eq)] pub enum AppError { Api { source: ApiError }, - ModuleAlreadyExists, + Dispatch { source: DispatchError }, } impl App { /// Create a new [App] with a given state. pub fn with_state(state: State) -> Self { Self { - apis: HashMap::new(), + modules: Default::default(), state: Arc::new(state), app_version: None, } @@ -158,20 +155,8 @@ impl App { } }; - match self.apis.entry(base_url.to_string()) { - HashEntry::Occupied(mut e) => match e.get_mut().entry(major_version) { - BTreeEntry::Occupied(_) => { - return Err(AppError::ModuleAlreadyExists); - } - BTreeEntry::Vacant(e) => { - e.insert(api); - } - }, - HashEntry::Vacant(e) => { - e.insert([(major_version, api)].into()); - } - } - + self.modules + .insert(dispatch::split(base_url), major_version, api)?; Ok(self) } @@ -212,12 +197,17 @@ impl App { app_version: self.app_version.clone(), disco_version: env!("CARGO_PKG_VERSION").parse().unwrap(), modules: self - .apis + .modules .iter() - .map(|(name, versions)| { + .map(|module| { ( - name.clone(), - versions.values().rev().map(|api| api.version()).collect(), + module.path(), + module + .versions + .values() + .rev() + .map(|api| api.version()) + .collect(), ) }) .collect(), @@ -231,19 +221,22 @@ impl App { /// (due to type erasure) but can be queried using [module_health](Self::module_health) or by /// hitting the endpoint `GET /:module/healthcheck`. pub async fn health(&self, req: RequestParams, state: &State) -> AppHealth { - let mut modules = BTreeMap::>::new(); + let mut modules_health = BTreeMap::>::new(); let mut status = HealthStatus::Available; - for (name, versions) in &self.apis { - let module = modules.entry(name.clone()).or_default(); - for (version, api) in versions { + for module in &self.modules { + let versions_health = modules_health.entry(module.path()).or_default(); + for (version, api) in &module.versions { let health = StatusCode::from(api.health(req.clone(), state).await.status()); if health != StatusCode::Ok { status = HealthStatus::Unhealthy; } - module.insert(*version, health); + versions_health.insert(*version, health); } } - AppHealth { status, modules } + AppHealth { + status, + modules: modules_health, + } } /// Check the health of the named module. @@ -264,10 +257,10 @@ impl App { module: &str, major_version: Option, ) -> Option { - let versions = self.apis.get(module)?; + let module = self.modules.get(dispatch::split(module))?; let api = match major_version { - Some(v) => versions.get(&v)?, - None => versions.last_key_value()?.1, + Some(v) => module.versions.get(&v)?, + None => module.versions.last_key_value()?.1, }; Some(api.health(req, state).await) } @@ -317,35 +310,41 @@ where .allow_credentials(true), ); - for (name, versions) in &state.apis { - Self::register_api(&mut server, name.clone(), versions)?; + for module in &state.modules { + Self::register_api(&mut server, module.prefix.clone(), &module.versions)?; } - // Register app-level automatic routes: `healthcheck` and `version`. - server - .at("healthcheck") - .get(move |req: tide::Request>| async move { - let state = req.state().clone(); - let app_state = &*state.state; - let req = request_params(req, &[]).await?; - let accept = req.accept()?; - let res = state.health(req, app_state).await; - Ok(health_check_response::<_, VER>(&accept, res)) - }); - server - .at("version") - .get(move |req: tide::Request>| async move { - let accept = RequestParams::accept_from_headers(&req)?; - respond_with(&accept, req.state().version(), bind_version) - .map_err(|err| Error::from_route_error::(err).into_tide_error()) - }); - - // Serve documentation at the root URL for discoverability - server - .at("/") - .all(move |req: tide::Request>| async move { - Ok(tide::Response::from(Self::top_level_docs(req))) - }); + // Register app-level routes summarizing the status and documentation of all the registered + // modules. We skip this step if this is a singleton app with only one module registered at + // the root URL, as these app-level endpoints would conflict with the (probably more + // specific) API-level status endpoints. + if !state.modules.is_singleton() { + // Register app-level automatic routes: `healthcheck` and `version`. + server + .at("healthcheck") + .get(move |req: tide::Request>| async move { + let state = req.state().clone(); + let app_state = &*state.state; + let req = request_params(req, &[]).await?; + let accept = req.accept()?; + let res = state.health(req, app_state).await; + Ok(health_check_response::<_, VER>(&accept, res)) + }); + server + .at("version") + .get(move |req: tide::Request>| async move { + let accept = RequestParams::accept_from_headers(&req)?; + respond_with(&accept, req.state().version(), bind_version) + .map_err(|err| Error::from_route_error::(err).into_tide_error()) + }); + + // Serve documentation at the root URL for discoverability + server + .at("/") + .all(move |req: tide::Request>| async move { + Ok(tide::Response::from(Self::top_level_docs(req))) + }); + } server.listen(listener).await } @@ -353,22 +352,22 @@ where fn list_apis(&self) -> Html { html! { ul { - @for (name, versions) in &self.apis { + @for module in &self.modules { li { // Link to the alias for the latest version as the primary link. - a href=(format!("/{}", name)) {(name)} + a href=(format!("/{}", module.path())) {(module.path())} // Add a superscript link (link a footnote) for each specific supported // version, linking to documentation for that specific version. - @for version in versions.keys().rev() { + @for version in module.versions.keys().rev() { sup { - a href=(format!("/v{version}/{name}")) { + a href=(format!("/v{version}/{}", module.path())) { (format!("[v{version}]")) } } } " " // Take the description of the latest supported version. - (PreEscaped(versions.last_key_value().unwrap().1.short_description())) + (PreEscaped(module.versions.last_key_value().unwrap().1.short_description())) } } } @@ -377,7 +376,7 @@ where fn register_api( server: &mut tide::Server>, - prefix: String, + prefix: Vec, versions: &BTreeMap>, ) -> io::Result<()> { for (version, api) in versions { @@ -388,7 +387,7 @@ where fn register_api_version( server: &mut tide::Server>, - prefix: &String, + prefix: &[String], version: u64, api: &ApiInner, ) -> io::Result<()> { @@ -400,11 +399,16 @@ where server .at("/public") .at(&format!("v{version}")) - .at(prefix) + .at(&prefix.join("/")) .serve_dir(api.public().unwrap_or_else(|| &DEFAULT_PUBLIC_PATH))?; // Register routes for this API. - let mut api_endpoint = server.at(&format!("/v{version}/{prefix}")); + let mut version_endpoint = server.at(&format!("/v{version}")); + let mut api_endpoint = if prefix.is_empty() { + version_endpoint + } else { + version_endpoint.at(&prefix.join("/")) + }; api_endpoint.with(AddErrorBody::new(api.error_handler())); for (path, routes) in api.routes_by_path() { let mut endpoint = api_endpoint.at(path); @@ -418,7 +422,7 @@ where // If there is a socket route with this pattern, add the socket middleware to // all endpoints registered under this pattern, so that any request with any // method that has the socket upgrade headers will trigger a WebSockets upgrade. - Self::register_socket(prefix.to_owned(), version, &mut endpoint, socket_route); + Self::register_socket(prefix.to_vec(), version, &mut endpoint, socket_route); } if let Some(metrics_route) = routes .iter() @@ -428,13 +432,13 @@ where // all endpoints registered under this pattern, so that a request to this path // with the right headers will return metrics instead of going through the // normal method-based dispatching. - Self::register_metrics(prefix.to_owned(), version, &mut endpoint, metrics_route); + Self::register_metrics(prefix.to_vec(), version, &mut endpoint, metrics_route); } // Register the HTTP routes. for route in routes { if let Method::Http(method) = route.method() { - Self::register_route(prefix.to_owned(), version, &mut endpoint, route, method); + Self::register_route(prefix.to_vec(), version, &mut endpoint, route, method); } } } @@ -442,26 +446,26 @@ where // Register automatic routes for this API: documentation, `healthcheck` and `version`. Serve // documentation at the root of the API (with or without a trailing slash). for path in ["", "/"] { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at(path) .all(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; Ok(api.documentation()) } }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("*path") .all(move |req: tide::Request>| { let prefix = prefix.clone(); async move { // The request did not match any route. Serve documentation for the API. - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; let docs = html! { "No route matches /" (req.param("path")?) br{} @@ -474,13 +478,13 @@ where }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("healthcheck") .get(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().clone().apis[&prefix][&version]; + let api = &req.state().clone().modules[&prefix].versions[&version]; let state = req.state().clone(); Ok(api .health(request_params(req, &[]).await?, &state.state) @@ -489,13 +493,13 @@ where }); } { - let prefix = prefix.clone(); + let prefix = prefix.to_vec(); api_endpoint .at("version") .get(move |req: tide::Request>| { let prefix = prefix.clone(); async move { - let api = &req.state().apis[&prefix][&version]; + let api = &req.state().modules[&prefix].versions[&version]; let accept = RequestParams::accept_from_headers(&req)?; api.version_handler()(&accept, api.version()) .map_err(|err| Error::from_route_error(err).into_tide_error()) @@ -507,7 +511,7 @@ where } fn register_route( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -518,7 +522,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route @@ -534,7 +538,7 @@ where } fn register_metrics( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -560,7 +564,7 @@ where } fn register_socket( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -576,7 +580,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route @@ -608,7 +612,7 @@ where } fn register_fallback( - api: String, + api: Vec, version: u64, endpoint: &mut tide::Route>, route: &Route, @@ -618,7 +622,7 @@ where let name = name.clone(); let api = api.clone(); async move { - let route = &req.state().clone().apis[&api][&version][&name]; + let route = &req.state().clone().modules[&api].versions[&version][&name]; route .default_handler() .map_err(|err| match err { @@ -637,12 +641,13 @@ where next: tide::Next>, ) -> BoxFuture { async move { - let Some(mut path) = req.url().path_segments() else { + let Some(path) = req.url().path_segments() else { // If we can't parse the path, we can't run this middleware. Do our best by // continuing the request processing lifecycle. return Ok(next.run(req).await); }; - let Some(seg1) = path.next() else { + let path = path.collect::>(); + let Some(seg1) = path.first() else { // This is the root URL, with no path segments. Nothing for this middleware to do. return Ok(next.run(req).await); }; @@ -651,32 +656,25 @@ where return Ok(next.run(req).await); } - // The first segment is either a version identifier or an API identifier (implicitly - // requesting the latest version of the API). We handle these cases differently. + // The first segment is either a version identifier or (part of) an API identifier + // (implicitly requesting the latest version of the API). We handle these cases + // differently. if let Some(version) = seg1.strip_prefix('v').and_then(|n| n.parse().ok()) { // If the version identifier is present, we probably don't need a redirect. However, // we still check if this is a valid version for the request API. If not, we will // serve documentation listing the available versions. - let Some(api) = path.next() else { - // A version identifier with no API is an error, serve documentation. - return Ok(Self::top_level_error( - req, - StatusCode::BadRequest, - "illegal version prefix without API specifier", - )); - }; - let Some(versions) = req.state().apis.get(api) else { - let message = format!("No API matches /{api}"); + let Some(module) = req.state().modules.search(&path[1..]) else { + let message = format!("No API matches /{}", path[1..].join("/")); return Ok(Self::top_level_error(req, StatusCode::NotFound, message)); }; - if versions.get(&version).is_none() { + if module.versions.get(&version).is_none() { // This version is not supported, list suported versions. return Ok(html! { "Unsupported version v" (version) ". Supported versions are:" ul { - @for v in versions.keys().rev() { + @for v in module.versions.keys().rev() { li { - a href=(format!("/v{v}/{api}")) { "v" (v) } + a href=(format!("/v{v}/{}", module.path())) { "v" (v) } } } } @@ -688,20 +686,21 @@ where // successfully by the route handlers for this API. Ok(next.run(req).await) } else { - // If the first path segment is not a version prefix, it is either the name of an - // API or one of the magic top-level endpoints (version, healthcheck), implicitly - // requesting the latest version. Validate the API and then redirect. - if ["version", "healthcheck"].contains(&seg1) { + // If the first path segment is not a version prefix, then the path is either the + // name of an API (implicitly requesting the latest version) or one of the magic + // top-level endpoints (version, healthcheck). Validate the API and then redirect. + if !req.state().modules.is_singleton() && ["version", "healthcheck"].contains(seg1) + { return Ok(next.run(req).await); } - let Some(versions) = req.state().apis.get(seg1) else { - let message = format!("No API matches /{seg1}"); + let Some(module) = req.state().modules.search(&path) else { + let message = format!("No API matches /{}", path.join("/")); return Ok(Self::top_level_error(req, StatusCode::NotFound, message)); }; - let latest_version = *versions.last_key_value().unwrap().0; + let latest_version = *module.versions.last_key_value().unwrap().0; let path = path.join("/"); - Ok(tide::Redirect::permanent(format!("/v{latest_version}/{seg1}/{path}")).into()) + Ok(tide::Redirect::permanent(format!("/v{latest_version}/{path}")).into()) } } .boxed() @@ -779,6 +778,7 @@ pub struct AppVersion { /// Note that if anything goes wrong during module registration (for example, there is already an /// incompatible module registered with the same name), the drop implementation may panic. To handle /// errors without panicking, call [`register`](Self::register) explicitly. +#[derive(Debug)] pub struct Module<'a, State, Error, ModuleError, ModuleVersion> where State: Send + Sync + 'static, @@ -1118,7 +1118,14 @@ mod test { .module::("mod", v1_toml) .unwrap(); api.with_version("1.1.1".parse().unwrap()); - assert_eq!(api.register().unwrap_err(), AppError::ModuleAlreadyExists); + assert_eq!( + api.register().unwrap_err(), + DispatchError::ModuleAlreadyExists { + prefix: "mod".into(), + version: 1, + } + .into() + ); } { let mut v3 = app @@ -1396,10 +1403,7 @@ mod test { .text() .await .unwrap(); - assert!( - docs.contains("illegal version prefix without API specifier"), - "{docs}" - ); + assert!(docs.contains("No API matches /"), "{docs}"); assert!(docs.contains(&expected_list_item), "{docs}"); } @@ -1455,6 +1459,8 @@ mod test { #[async_std::test] async fn test_format_versions() { + setup_test(); + // Register two modules with different binary format versions, each in turn different from // the app-level version. Each module has two endpoints, one which always succeeds and one // which always fails, so we can test error serialization. @@ -1610,4 +1616,191 @@ mod test { check_err::(&client, "mod02/err").await; check_err::(&client, "mod03/err").await; } + + #[async_std::test] + async fn test_api_prefix() { + setup_test(); + + // It is illegal to register two API modules where one is a prefix (in terms of route + // segments) of another. + for (api1, api2) in [ + ("", "api"), + ("api", ""), + ("path", "path/sub"), + ("path/sub", "path"), + ] { + tracing::info!(api1, api2, "test case"); + let (prefix, conflict) = if api1.len() < api2.len() { + (api1.to_string(), api2.to_string()) + } else { + (api2.to_string(), api1.to_string()) + }; + + let mut app = App::<_, ServerError>::with_state(()); + let toml = toml! { + route = {} + }; + app.module::(api1, toml.clone()) + .unwrap() + .register() + .unwrap(); + assert_eq!( + app.module::(api2, toml) + .unwrap() + .register() + .unwrap_err(), + DispatchError::ConflictingModules { prefix, conflict }.into() + ); + } + } + + #[async_std::test] + async fn test_singleton_api() { + setup_test(); + + // If there is only one API, it should be possible to register it with an empty prefix. + let toml = toml! { + [route.test] + PATH = ["/test"] + }; + let mut app = App::<_, ServerError>::with_state(()); + let mut api = app.module::("", toml).unwrap(); + api.with_version("0.1.0".parse().unwrap()) + .get("test", |_, _| async move { Ok("response") }.boxed()) + .unwrap(); + api.register().unwrap(); + + let port = pick_unused_port().unwrap(); + spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance())); + let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await; + + // Test an endpoint. + let res = client.get("/test").send().await.unwrap(); + assert_eq!( + res.status(), + StatusCode::Ok, + "{}", + res.text().await.unwrap() + ); + assert_eq!(res.json::().await.unwrap(), "response"); + + // Test healthcheck and version endpoints. Since these would ordinarily conflict with the + // app-level healthcheck and version endpoints for an API with no prefix, we only get the + // API-level endpoints, so that a singleton API behaves like a normal API, while app-level + // stuff is reserved for non-trivial applications with more than one API. + let res = client.get("/healthcheck").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + HealthStatus::Available + ); + + let res = client.get("/version").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }, + ); + } + + #[async_std::test] + async fn test_multi_segment() { + setup_test(); + + let toml = toml! { + [route.test] + PATH = ["/test"] + }; + let mut app = App::<_, ServerError>::with_state(()); + + for name in ["a", "b"] { + let path = format!("api/{name}"); + let mut api = app + .module::(&path, toml.clone()) + .unwrap(); + api.with_version("0.1.0".parse().unwrap()) + .get("test", move |_, _| async move { Ok(name) }.boxed()) + .unwrap(); + api.register().unwrap(); + } + + let port = pick_unused_port().unwrap(); + spawn(app.serve(format!("0.0.0.0:{port}"), StaticVer01::instance())); + let client = Client::new(format!("http://localhost:{port}").parse().unwrap()).await; + + for api in ["a", "b"] { + tracing::info!(api, "testing api"); + + // Test an endpoint. + let res = client.get(&format!("api/{api}/test")).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!(res.json::().await.unwrap(), api); + + // Test healthcheck. + let res = client + .get(&format!("api/{api}/healthcheck")) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + HealthStatus::Available + ); + + // Test version. + let res = client + .get(&format!("api/{api}/version")) + .send() + .await + .unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap().api_version.unwrap(), + "0.1.0".parse().unwrap() + ); + } + + // Test app-level healthcheck. + let res = client.get("healthcheck").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap(), + AppHealth { + status: HealthStatus::Available, + modules: [ + ("api/a".into(), [(0, StatusCode::Ok)].into()), + ("api/b".into(), [(0, StatusCode::Ok)].into()), + ] + .into() + } + ); + + // Test app-level version. + let res = client.get("version").send().await.unwrap(); + assert_eq!(res.status(), StatusCode::Ok); + assert_eq!( + res.json::().await.unwrap().modules, + [ + ( + "api/a".into(), + vec![ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }] + ), + ( + "api/b".into(), + vec![ApiVersion { + api_version: Some("0.1.0".parse().unwrap()), + spec_version: "0.1.0".parse().unwrap(), + }] + ), + ] + .into() + ); + } } diff --git a/src/dispatch.rs b/src/dispatch.rs new file mode 100644 index 00000000..49cac178 --- /dev/null +++ b/src/dispatch.rs @@ -0,0 +1,353 @@ +use itertools::Itertools; +use snafu::Snafu; +use std::{ + collections::{btree_map::Entry, BTreeMap}, + ops::Index, +}; + +pub use crate::join; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct Module { + pub(crate) prefix: Vec, + pub(crate) versions: BTreeMap, +} + +impl Module { + fn new(prefix: Vec) -> Self { + Self { + prefix, + versions: Default::default(), + } + } + + pub(crate) fn path(&self) -> String { + self.prefix.join("/") + } +} + +#[derive(Clone, Debug, Snafu, PartialEq, Eq)] +pub enum DispatchError { + #[snafu(display("duplicate module {prefix} v{version}"))] + ModuleAlreadyExists { prefix: String, version: u64 }, + #[snafu(display("module {prefix} cannot be a prefix of module {conflict}"))] + ConflictingModules { prefix: String, conflict: String }, +} + +/// Mapping from route prefixes to APIs. +#[derive(Debug)] +pub(crate) enum Trie { + Branch { + /// The route prefix represented by this node. + prefix: Vec, + /// APIs with this prefix, indexed by the next route segment. + children: BTreeMap>, + }, + Leaf { + /// APIs available at this prefix, sorted by version. + module: Module, + }, +} + +impl Default for Trie { + fn default() -> Self { + Self::Branch { + prefix: vec![], + children: Default::default(), + } + } +} + +impl Trie { + /// Whether this is a singleton [`Trie`]. + /// + /// A singleton [`Trie`] is one with only one module, registered under the empty prefix. Note + /// that any [`Trie`] with a module with an empty prefix must be singleton, because no other + /// modules would be permitted: the empty prefix is a prefix of every other module path. + pub(crate) fn is_singleton(&self) -> bool { + matches!(self, Self::Leaf { .. }) + } + + /// Insert a new API with a certain version under the given prefix. + pub(crate) fn insert( + &mut self, + prefix: I, + version: u64, + api: Api, + ) -> Result<(), DispatchError> + where + I: IntoIterator, + I::Item: Into, + { + let mut prefix = prefix.into_iter().map(|segment| segment.into()); + + // Traverse to a leaf matching `prefix`. + let mut curr = self; + while let Some(segment) = prefix.next() { + // If there are more segments in the prefix, we must be at a branch. + match curr { + Self::Branch { prefix, children } => { + // Move to the child associated with the next path segment, inserting an empty + // child if this is the first module we've seen that has this path as a prefix. + curr = children.entry(segment.clone()).or_insert_with(|| { + let mut prefix = prefix.clone(); + prefix.push(segment); + Box::new(Trie::Branch { + prefix, + children: Default::default(), + }) + }); + } + Self::Leaf { module } => { + // If there is a leaf here, then there is already a module registered which is a + // prefix of the new module. This is not allowed. + return Err(DispatchError::ConflictingModules { + prefix: module.path(), + conflict: join!(&module.path(), &segment, &prefix.join("/")), + }); + } + } + } + + // If we have reached the end of the prefix, we must be at either a leaf or a temporary + // empty branch that we can turn into a leaf. + if let Self::Branch { prefix, children } = curr { + if children.is_empty() { + *curr = Self::Leaf { + module: Module::new(prefix.clone()), + }; + } else { + // If we have a non-trival branch at the end of the desired prefix, there is already + // a module registered for which `prefix` is a strict prefix of the registered path. + // This is not allowed. To give a useful error message, follow the existing trie + // down to a leaf so we can give an example of a module which conflicts with this + // prefix. + let prefix = prefix.join("/"); + let conflict = loop { + match curr { + Self::Branch { children, .. } => { + curr = children + .values_mut() + .next() + .expect("malformed dispatch trie: empty branch"); + } + Self::Leaf { module } => { + break module.path(); + } + } + }; + return Err(DispatchError::ConflictingModules { prefix, conflict }); + } + } + let Self::Leaf { module } = curr else { + unreachable!(); + }; + + // Insert the new API, as long as there isn't already an API with the same version in this + // module. + let Entry::Vacant(e) = module.versions.entry(version) else { + return Err(DispatchError::ModuleAlreadyExists { + prefix: module.path(), + version, + }); + }; + e.insert(api); + Ok(()) + } + + /// Get the module named by `prefix`. + /// + /// This function is similar to [`search`](Self::search), except the given `prefix` must exactly + /// match the prefix under which a module is registered. + pub(crate) fn get(&self, prefix: I) -> Option<&Module> + where + I: IntoIterator, + I::Item: AsRef, + { + let mut iter = prefix.into_iter(); + let module = self.traverse(&mut iter)?; + // Check for exact match. + if iter.next().is_some() { + None + } else { + Some(module) + } + } + + /// Get the supported versions of the API identified by the given request path. + /// + /// If a prefix of `path` uniquely identifies a registered module, the module (with all + /// supported versions) is returned. + pub(crate) fn search(&self, path: I) -> Option<&Module> + where + I: IntoIterator, + I::Item: AsRef, + { + self.traverse(&mut path.into_iter()) + } + + /// Iterate over registered modules and their supported versions. + pub(crate) fn iter(&self) -> Iter { + Iter { stack: vec![self] } + } + + /// Internal implementation of `get` and `search`. + /// + /// Returns the matching module and advances the iterator past all the segments used in the + /// match. + fn traverse(&self, iter: &mut I) -> Option<&Module> + where + I: Iterator, + I::Item: AsRef, + { + let mut curr = self; + loop { + match curr { + Self::Branch { children, .. } => { + // Traverse to the next child based on the next segment in the path. + let segment = iter.next()?; + curr = children.get(segment.as_ref())?; + } + Self::Leaf { module } => return Some(module), + } + } + } +} + +pub(crate) struct Iter<'a, Api> { + stack: Vec<&'a Trie>, +} + +impl<'a, Api> Iterator for Iter<'a, Api> { + type Item = &'a Module; + + fn next(&mut self) -> Option { + loop { + match self.stack.pop()? { + Trie::Branch { children, .. } => { + // Push children onto the stack and start visiting them. We add them in reverse + // order so that we will visit the lexicographically first children first. + self.stack + .extend(children.values().rev().map(|boxed| &**boxed)); + } + Trie::Leaf { module } => return Some(module), + } + } + } +} + +impl<'a, Api> IntoIterator for &'a Trie { + type IntoIter = Iter<'a, Api>; + type Item = &'a Module; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl Index for Trie +where + I: IntoIterator, + I::Item: AsRef, +{ + type Output = Module; + + fn index(&self, index: I) -> &Self::Output { + self.get(index).unwrap() + } +} + +/// Split a path prefix into its segments. +/// +/// Leading and trailing slashes are ignored. That is, `/prefix/` yields only the single segment +/// `prefix`, with no preceding or following empty segments. +pub(crate) fn split(s: &str) -> impl '_ + Iterator { + s.split('/').filter(|seg| !seg.is_empty()) +} + +/// Join two path strings, ensuring there are no leading or trailing slashes. +pub(crate) fn join(s1: &str, s2: &str) -> String { + let s1 = s1.strip_prefix('/').unwrap_or(s1); + let s1 = s1.strip_suffix('/').unwrap_or(s1); + let s2 = s2.strip_prefix('/').unwrap_or(s2); + let s2 = s2.strip_suffix('/').unwrap_or(s2); + if s1.is_empty() { + s2.to_string() + } else if s2.is_empty() { + s1.to_string() + } else { + format!("{s1}/{s2}") + } +} + +#[macro_export] +macro_rules! join { + () => { String::new() }; + ($s:expr) => { $s }; + ($head:expr$(, $($tail:expr),*)?) => { + $crate::dispatch::join($head, &$crate::join!($($($tail),*)?)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_empty_trie() { + let t = Trie::<()>::default(); + assert_eq!(t.iter().next(), None); + assert_eq!(t.get(["mod"]), None); + } + + #[test] + fn test_branch_trie() { + let mut t = Trie::default(); + + let mod_a = Module { + prefix: vec!["mod".into(), "a".into()], + versions: [(0, 0)].into(), + }; + let mod_b = Module { + prefix: vec!["mod".into(), "b".into()], + versions: [(1, 1)].into(), + }; + + t.insert(["mod", "a"], 0, 0).unwrap(); + t.insert(["mod", "b"], 1, 1).unwrap(); + + assert_eq!(t.iter().collect::>(), [&mod_a, &mod_b]); + + assert_eq!(t.search(["mod", "a", "route"]), Some(&mod_a)); + assert_eq!(t.get(["mod", "a"]), Some(&mod_a)); + assert_eq!(t.get(["mod", "a", "route"]), None); + + assert_eq!(t.search(["mod", "b", "route"]), Some(&mod_b)); + assert_eq!(t.get(["mod", "b"]), Some(&mod_b)); + assert_eq!(t.get(["mod", "b", "route"]), None); + + // Cannot register a module which is a prefix or suffix of the already registered modules. + t.insert(["mod"], 0, 0).unwrap_err(); + t.insert(Vec::::new(), 0, 0).unwrap_err(); + t.insert(["mod", "a", "b"], 0, 0).unwrap_err(); + } + + #[test] + fn test_null_prefix() { + let mut t = Trie::default(); + + let module = Module { + prefix: vec![], + versions: [(0, 0)].into(), + }; + t.insert(Vec::::new(), 0, 0).unwrap(); + + assert_eq!(t.iter().collect::>(), [&module]); + assert_eq!(t.search(["anything"]), Some(&module)); + assert_eq!(t.get(Vec::::new()), Some(&module)); + assert_eq!(t.get(["anything"]), None); + + // Any other module has the null module as a prefix and is thus not allowed. + t.insert(["anything"], 1, 1).unwrap_err(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 92f7aa73..b9d562fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -289,6 +289,7 @@ pub mod socket; pub mod status; pub mod testing; +mod dispatch; mod middleware; mod route; diff --git a/src/middleware.rs b/src/middleware.rs index 2040169e..1a306efa 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -99,12 +99,12 @@ where pub(crate) struct MetricsMiddleware { route: String, - api: String, + api: Vec, api_version: u64, } impl MetricsMiddleware { - pub(crate) fn new(route: String, api: String, api_version: u64) -> Self { + pub(crate) fn new(route: String, api: Vec, api_version: u64) -> Self { Self { route, api, @@ -148,7 +148,7 @@ where } // This is a metrics request, abort the rest of the dispatching chain and run the // metrics handler. - let route = &req.state().clone().apis[&api][&version][&route]; + let route = &req.state().clone().modules[&api].versions[&version][&route]; let state = &*req.state().clone().state; let req = request_params(req, route.params()).await?; route