diff --git a/tower-async-http/CHANGELOG.md b/tower-async-http/CHANGELOG.md index f89d8e9..88ce132 100644 --- a/tower-async-http/CHANGELOG.md +++ b/tower-async-http/CHANGELOG.md @@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.1.2 (July 20, 2023) + +Sync with original `tower-http` codebase from [`0.4.1`](https://github.com/tower-rs/tower-http/releases/tag/tower-http-0.4.1) +to [`0.4.3`](https://github.com/tower-rs/tower-http/releases/tag/tower-http-0.4.3). + +## Added + +- **cors:** Add support for private network preflights ([tower-rs/tower-http#373]) +- **compression:** Implement `Default` for `DecompressionBody` ([tower-rs/tower-http#370]) + +## Changed + +- **compression:** Update to async-compression 0.4 ([tower-rs/tower-http#371]) + +## Fixed + +- **compression:** Override default brotli compression level 11 -> 4 ([tower-rs/tower-http#356]) +- **trace:** Simplify dynamic tracing level application ([tower-rs/tower-http#380]) +- **normalize_path:** Fix path normalization for preceding slashes ([tower-rs/tower-http#359]) + +[tower-rs/tower-http#356]: https://github.com/tower-rs/tower-http/pull/356 +[tower-rs/tower-http#359]: https://github.com/tower-rs/tower-http/pull/359 +[tower-rs/tower-http#370]: https://github.com/tower-rs/tower-http/pull/370 +[tower-rs/tower-http#371]: https://github.com/tower-rs/tower-http/pull/371 +[tower-rs/tower-http#373]: https://github.com/tower-rs/tower-http/pull/373 +[tower-rs/tower-http#380]: https://github.com/tower-rs/tower-http/pull/380 + ## 0.1.1 (July 18, 2023) - Improve, expand and fix documentation; diff --git a/tower-async-http/Cargo.toml b/tower-async-http/Cargo.toml index e26edf4..5ec5a14 100644 --- a/tower-async-http/Cargo.toml +++ b/tower-async-http/Cargo.toml @@ -4,7 +4,7 @@ description = """ Tower Async middleware and utilities for HTTP clients and servers. An "Async Trait" fork from the original Tower Library. """ -version = "0.1.1" +version = "0.1.2" authors = ["Glen De Cauwsemaecker "] edition = "2021" license = "MIT" @@ -25,7 +25,7 @@ tower-async-layer = { version = "0.1", path = "../tower-async-layer" } tower-async-service = { version = "0.1", path = "../tower-async-service" } # optional dependencies -async-compression = { version = "0.3", optional = true, features = ["tokio"] } +async-compression = { version = "0.4", optional = true, features = ["tokio"] } base64 = { version = "0.21", optional = true } http-range-header = "0.3.0" iri-string = { version = "0.7.0", optional = true } diff --git a/tower-async-http/src/auth/add_authorization.rs b/tower-async-http/src/auth/add_authorization.rs index 892d214..8a09c19 100644 --- a/tower-async-http/src/auth/add_authorization.rs +++ b/tower-async-http/src/auth/add_authorization.rs @@ -37,12 +37,14 @@ //! # } //! ``` -use base64::{engine::general_purpose::STANDARD as base64, Engine}; +use base64::Engine as _; use http::{HeaderValue, Request, Response}; use std::convert::TryFrom; use tower_async_layer::Layer; use tower_async_service::Service; +const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; + /// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the /// [`Authorization`] header. /// @@ -67,7 +69,7 @@ impl AddAuthorizationLayer { /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS /// with this method. However use of HTTPS/TLS is not enforced by this middleware. pub fn basic(username: &str, password: &str) -> Self { - let encoded = base64.encode(format!("{}:{}", username, password)); + let encoded = BASE64.encode(format!("{}:{}", username, password)); let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap(); Self { value } } diff --git a/tower-async-http/src/auth/require_authorization.rs b/tower-async-http/src/auth/require_authorization.rs index 333bf75..3cef52b 100644 --- a/tower-async-http/src/auth/require_authorization.rs +++ b/tower-async-http/src/auth/require_authorization.rs @@ -52,7 +52,7 @@ //! Custom validation can be made by implementing [`ValidateRequest`]. use crate::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer}; -use base64::{engine::general_purpose::STANDARD as base64, Engine}; +use base64::Engine as _; use http::{ header::{self, HeaderValue}, Request, Response, StatusCode, @@ -60,6 +60,8 @@ use http::{ use http_body::Body; use std::{fmt, marker::PhantomData}; +const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD; + impl ValidateRequestHeader> { /// Authorize requests using a username and password pair. /// @@ -194,7 +196,7 @@ impl Basic { where ResBody: Body + Default, { - let encoded = base64.encode(format!("{}:{}", username, password)); + let encoded = BASE64.encode(format!("{}:{}", username, password)); let header_value = format!("Basic {}", encoded).parse().unwrap(); Self { header_value, @@ -260,7 +262,7 @@ mod tests { let request = Request::get("/") .header( header::AUTHORIZATION, - format!("Basic {}", base64.encode("foo:bar")), + format!("Basic {}", BASE64.encode("foo:bar")), ) .body(Body::empty()) .unwrap(); @@ -279,7 +281,7 @@ mod tests { let request = Request::get("/") .header( header::AUTHORIZATION, - format!("Basic {}", base64.encode("wrong:credentials")), + format!("Basic {}", BASE64.encode("wrong:credentials")), ) .body(Body::empty()) .unwrap(); @@ -317,7 +319,7 @@ mod tests { let request = Request::get("/") .header( header::AUTHORIZATION, - format!("basic {}", base64.encode("foo:bar")), + format!("basic {}", BASE64.encode("foo:bar")), ) .body(Body::empty()) .unwrap(); @@ -336,7 +338,7 @@ mod tests { let request = Request::get("/") .header( header::AUTHORIZATION, - format!("Basic {}", base64.encode("Foo:bar")), + format!("Basic {}", BASE64.encode("Foo:bar")), ) .body(Body::empty()) .unwrap(); diff --git a/tower-async-http/src/compression/body.rs b/tower-async-http/src/compression/body.rs index b09df53..e2bc1ec 100644 --- a/tower-async-http/src/compression/body.rs +++ b/tower-async-http/src/compression/body.rs @@ -228,7 +228,16 @@ where type Output = BrotliEncoder; fn apply(input: Self::Input, quality: CompressionLevel) -> Self::Output { - BrotliEncoder::with_quality(input, quality.into_async_compression()) + // The brotli crate used under the hood here has a default compression level of 11, + // which is the max for brotli. This causes extremely slow compression times, so we + // manually set a default of 4 here. + // + // This is the same default used by NGINX for on-the-fly brotli compression. + let level = match quality { + CompressionLevel::Default => async_compression::Level::Precise(4), + other => other.into_async_compression(), + }; + BrotliEncoder::with_quality(input, level) } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { diff --git a/tower-async-http/src/compression/layer.rs b/tower-async-http/src/compression/layer.rs index 410ba30..62cf254 100644 --- a/tower-async-http/src/compression/layer.rs +++ b/tower-async-http/src/compression/layer.rs @@ -146,7 +146,10 @@ mod tests { #[tokio::test] async fn accept_encoding_configuration_works() -> Result<(), crate::BoxError> { - let deflate_only_layer = CompressionLayer::new().no_br().no_gzip(); + let deflate_only_layer = CompressionLayer::new() + .quality(CompressionLevel::Best) + .no_br() + .no_gzip(); let mut service = ServiceBuilder::new() // Compress responses based on the `Accept-Encoding` header. @@ -173,7 +176,10 @@ mod tests { let deflate_bytes_len = bytes.len(); - let br_only_layer = CompressionLayer::new().no_gzip().no_deflate(); + let br_only_layer = CompressionLayer::new() + .quality(CompressionLevel::Best) + .no_gzip() + .no_deflate(); let mut service = ServiceBuilder::new() // Compress responses based on the `Accept-Encoding` header. diff --git a/tower-async-http/src/compression_utils.rs b/tower-async-http/src/compression_utils.rs index 5535a67..5ca8d98 100644 --- a/tower-async-http/src/compression_utils.rs +++ b/tower-async-http/src/compression_utils.rs @@ -359,7 +359,9 @@ impl CompressionLevel { CompressionLevel::Fastest => AsyncCompressionLevel::Fastest, CompressionLevel::Best => AsyncCompressionLevel::Best, CompressionLevel::Default => AsyncCompressionLevel::Default, - CompressionLevel::Precise(quality) => AsyncCompressionLevel::Precise(quality), + CompressionLevel::Precise(quality) => { + AsyncCompressionLevel::Precise(quality.try_into().unwrap_or(i32::MAX)) + } } } } diff --git a/tower-async-http/src/cors/allow_credentials.rs b/tower-async-http/src/cors/allow_credentials.rs index cbd4f55..de53ffe 100644 --- a/tower-async-http/src/cors/allow_credentials.rs +++ b/tower-async-http/src/cors/allow_credentials.rs @@ -27,6 +27,8 @@ impl AllowCredentials { /// Allow credentials for some requests, based on a given predicate /// + /// The first argument to the predicate is the request origin. + /// /// See [`CorsLayer::allow_credentials`] for more details. /// /// [`CorsLayer::allow_credentials`]: super::CorsLayer::allow_credentials diff --git a/tower-async-http/src/cors/allow_private_network.rs b/tower-async-http/src/cors/allow_private_network.rs new file mode 100644 index 0000000..b9847bd --- /dev/null +++ b/tower-async-http/src/cors/allow_private_network.rs @@ -0,0 +1,196 @@ +use std::{fmt, sync::Arc}; + +use http::{ + header::{HeaderName, HeaderValue}, + request::Parts as RequestParts, +}; + +/// Holds configuration for how to set the [`Access-Control-Allow-Private-Network`][wicg] header. +/// +/// See [`CorsLayer::allow_private_network`] for more details. +/// +/// [wicg]: https://wicg.github.io/private-network-access/ +/// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network +#[derive(Clone, Default)] +#[must_use] +pub struct AllowPrivateNetwork(AllowPrivateNetworkInner); + +static TRUE: HeaderValue = HeaderValue::from_static("true"); + +impl AllowPrivateNetwork { + /// Allow requests via a more private network than the one used to access the origin + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn yes() -> Self { + Self(AllowPrivateNetworkInner::Yes) + } + + /// Allow requests via private network for some requests, based on a given predicate + /// + /// The first argument to the predicate is the request origin. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [`CorsLayer::allow_private_network`]: super::CorsLayer::allow_private_network + pub fn predicate(f: F) -> Self + where + F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static, + { + Self(AllowPrivateNetworkInner::Predicate(Arc::new(f))) + } + + pub(super) fn to_header( + &self, + origin: Option<&HeaderValue>, + parts: &RequestParts, + ) -> Option<(HeaderName, HeaderValue)> { + #[allow(clippy::declare_interior_mutable_const)] + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + #[allow(clippy::declare_interior_mutable_const)] + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + // Cheapest fallback: allow_private_network hasn't been set + if let AllowPrivateNetworkInner::No = &self.0 { + return None; + } + + // Access-Control-Allow-Private-Network is only relevant if the request + // has the Access-Control-Request-Private-Network header set, else skip + if parts.headers.get(REQUEST_PRIVATE_NETWORK) != Some(&TRUE) { + return None; + } + + let allow_private_network = match &self.0 { + AllowPrivateNetworkInner::Yes => true, + AllowPrivateNetworkInner::No => false, // unreachable, but not harmful + AllowPrivateNetworkInner::Predicate(c) => c(origin?, parts), + }; + + allow_private_network.then(|| (ALLOW_PRIVATE_NETWORK, TRUE.clone())) + } +} + +impl From for AllowPrivateNetwork { + fn from(v: bool) -> Self { + match v { + true => Self(AllowPrivateNetworkInner::Yes), + false => Self(AllowPrivateNetworkInner::No), + } + } +} + +impl fmt::Debug for AllowPrivateNetwork { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.0 { + AllowPrivateNetworkInner::Yes => f.debug_tuple("Yes").finish(), + AllowPrivateNetworkInner::No => f.debug_tuple("No").finish(), + AllowPrivateNetworkInner::Predicate(_) => f.debug_tuple("Predicate").finish(), + } + } +} + +#[derive(Clone)] +enum AllowPrivateNetworkInner { + Yes, + No, + Predicate( + Arc Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>, + ), +} + +impl Default for AllowPrivateNetworkInner { + fn default() -> Self { + Self::No + } +} + +#[cfg(test)] +mod tests { + use super::AllowPrivateNetwork; + use crate::cors::CorsLayer; + + use http::{header::ORIGIN, request::Parts, HeaderName, HeaderValue, Request, Response}; + use hyper::Body; + use tower_async::{BoxError, ServiceBuilder}; + use tower_async_service::Service; + + const REQUEST_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-request-private-network"); + + const ALLOW_PRIVATE_NETWORK: HeaderName = + HeaderName::from_static("access-control-allow-private-network"); + + const TRUE: HeaderValue = HeaderValue::from_static("true"); + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly() { + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(true)) + .service_fn(echo); + + let req = Request::builder() + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .body(Body::empty()) + .unwrap(); + let res = service.call(req).await.unwrap(); + + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder().body(Body::empty()).unwrap(); + let res = service.call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + #[tokio::test] + async fn cors_private_network_header_is_added_correctly_with_predicate() { + let allow_private_network = + AllowPrivateNetwork::predicate(|origin: &HeaderValue, parts: &Parts| { + parts.uri.path() == "/allow-private" && origin == "localhost" + }); + let mut service = ServiceBuilder::new() + .layer(CorsLayer::new().allow_private_network(allow_private_network)) + .service_fn(echo); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.call(req).await.unwrap(); + assert_eq!(res.headers().get(ALLOW_PRIVATE_NETWORK).unwrap(), TRUE); + + let req = Request::builder() + .header(ORIGIN, "localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/other") + .body(Body::empty()) + .unwrap(); + + let res = service.call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + + let req = Request::builder() + .header(ORIGIN, "not-localhost") + .header(REQUEST_PRIVATE_NETWORK, TRUE) + .uri("/allow-private") + .body(Body::empty()) + .unwrap(); + + let res = service.call(req).await.unwrap(); + + assert!(res.headers().get(ALLOW_PRIVATE_NETWORK).is_none()); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/tower-async-http/src/cors/mod.rs b/tower-async-http/src/cors/mod.rs index 769db7c..1ccb7a0 100644 --- a/tower-async-http/src/cors/mod.rs +++ b/tower-async-http/src/cors/mod.rs @@ -60,13 +60,15 @@ mod allow_credentials; mod allow_headers; mod allow_methods; mod allow_origin; +mod allow_private_network; mod expose_headers; mod max_age; mod vary; pub use self::{ allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods, - allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, + allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork, + expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary, }; /// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn]. @@ -81,6 +83,7 @@ pub struct CorsLayer { allow_headers: AllowHeaders, allow_methods: AllowMethods, allow_origin: AllowOrigin, + allow_private_network: AllowPrivateNetwork, expose_headers: ExposeHeaders, max_age: MaxAge, vary: Vary, @@ -103,6 +106,7 @@ impl CorsLayer { allow_headers: Default::default(), allow_methods: Default::default(), allow_origin: Default::default(), + allow_private_network: Default::default(), expose_headers: Default::default(), max_age: Default::default(), vary: Default::default(), @@ -351,6 +355,23 @@ impl CorsLayer { self } + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// ``` + /// use tower_async_http::cors::CorsLayer; + /// + /// let layer = CorsLayer::new().allow_private_network(true); + /// ``` + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network(mut self, allow_private_network: T) -> Self + where + T: Into, + { + self.allow_private_network = allow_private_network.into(); + self + } + /// Set the value(s) of the [`Vary`][mdn] header. /// /// In contrast to the other headers, this one has a non-empty default of @@ -545,6 +566,18 @@ impl Cors { self.map_layer(|layer| layer.expose_headers(headers)) } + /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header. + /// + /// See [`CorsLayer::allow_private_network`] for more details. + /// + /// [wicg]: https://wicg.github.io/private-network-access/ + pub fn allow_private_network(self, allow_private_network: T) -> Self + where + T: Into, + { + self.map_layer(|layer| layer.allow_private_network(allow_private_network)) + } + fn map_layer(mut self, f: F) -> Self where F: FnOnce(CorsLayer) -> CorsLayer, @@ -573,6 +606,7 @@ where headers.extend(self.layer.allow_origin.to_header(origin, &parts)); headers.extend(self.layer.allow_credentials.to_header(origin, &parts)); + headers.extend(self.layer.allow_private_network.to_header(origin, &parts)); let mut vary_headers = self.layer.vary.values(); if let Some(first) = vary_headers.next() { diff --git a/tower-async-http/src/decompression/body.rs b/tower-async-http/src/decompression/body.rs index 5cd7f28..1fa4248 100644 --- a/tower-async-http/src/decompression/body.rs +++ b/tower-async-http/src/decompression/body.rs @@ -36,6 +36,19 @@ pin_project! { } } +impl Default for DecompressionBody +where + B: Body + Default, +{ + fn default() -> Self { + Self { + inner: BodyInner::Identity { + inner: B::default(), + }, + } + } +} + impl DecompressionBody where B: Body, diff --git a/tower-async-http/src/normalize_path.rs b/tower-async-http/src/normalize_path.rs index d96630e..b6927fd 100644 --- a/tower-async-http/src/normalize_path.rs +++ b/tower-async-http/src/normalize_path.rs @@ -99,16 +99,20 @@ where } fn normalize_trailing_slash(uri: &mut Uri) { - if !uri.path().ends_with('/') { + if !uri.path().ends_with('/') && !uri.path().starts_with("//") { return; } - let new_path = uri.path().trim_end_matches('/'); + let new_path = format!("/{}", uri.path().trim_matches('/')); let mut parts = uri.clone().into_parts(); let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { - let new_path = if new_path.is_empty() { "/" } else { new_path }; + let new_path = if new_path.is_empty() { + "/" + } else { + new_path.as_str() + }; let new_path_and_query = if let Some(query) = path_and_query.query() { Cow::Owned(format!("{}?{}", new_path, query)) @@ -202,4 +206,18 @@ mod tests { normalize_trailing_slash(&mut uri); assert_eq!(uri, "/?a=a"); } + + #[test] + fn removes_multiple_preceding_slashes_even_with_query() { + let mut uri = "///foo//?a=a".parse::().unwrap(); + normalize_trailing_slash(&mut uri); + assert_eq!(uri, "/foo?a=a"); + } + + #[test] + fn removes_multiple_preceding_slashes() { + let mut uri = "///foo".parse::().unwrap(); + normalize_trailing_slash(&mut uri); + assert_eq!(uri, "/foo"); + } } diff --git a/tower-async-http/src/timeout/mod.rs b/tower-async-http/src/timeout/mod.rs index 82fa9ea..940958c 100644 --- a/tower-async-http/src/timeout/mod.rs +++ b/tower-async-http/src/timeout/mod.rs @@ -1,4 +1,43 @@ -//! Middleware for setting timeouts on requests and responses. +//! Middleware that applies a timeout to requests. +//! +//! If the request does not complete within the specified timeout it will be aborted and a `408 +//! Request Timeout` response will be sent. +//! +//! # Differences from `tower::timeout` +//! +//! tower's [`Timeout`](tower_async::timeout::Timeout) middleware uses an error to signal timeout, i.e. +//! it changes the error type to [`BoxError`](tower_async::BoxError). For HTTP services that is rarely +//! what you want as returning errors will terminate the connection without sending a response. +//! +//! This middleware won't change the error type and instead return a `408 Request Timeout` +//! response. That means if your service's error type is [`Infallible`] it will still be +//! [`Infallible`] after applying this middleware. +//! +//! # Example +//! +//! ``` +//! use http::{Request, Response}; +//! use hyper::Body; +//! use std::{convert::Infallible, time::Duration}; +//! use tower_async::ServiceBuilder; +//! use tower_async_http::timeout::TimeoutLayer; +//! +//! async fn handle(_: Request) -> Result, Infallible> { +//! // ... +//! # Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let svc = ServiceBuilder::new() +//! // Timeout requests after 30 seconds +//! .layer(TimeoutLayer::new(Duration::from_secs(30))) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! [`Infallible`]: std::convert::Infallible mod service; diff --git a/tower-async-http/src/timeout/service.rs b/tower-async-http/src/timeout/service.rs index 3f41287..0101248 100644 --- a/tower-async-http/src/timeout/service.rs +++ b/tower-async-http/src/timeout/service.rs @@ -1,50 +1,11 @@ -//! Middleware that applies a timeout to requests. -//! -//! If the request does not complete within the specified timeout it will be aborted and a `408 -//! Request Timeout` response will be sent. -//! -//! # Differences from `tower_async::timeout` -//! -//! tower's [`Timeout`](tower_async::timeout::Timeout) middleware uses an error to signal timeout, i.e. -//! it changes the error type to [`BoxError`](tower_async::BoxError). For HTTP services that is rarely -//! what you want as returning errors will terminate the connection without sending a response. -//! -//! This middleware won't change the error type and instead return a `408 Request Timeout` -//! response. That means if your service's error type is [`Infallible`] it will still be -//! [`Infallible`] after applying this middleware. -//! -//! # Example -//! -//! ``` -//! use http::{Request, Response}; -//! use hyper::Body; -//! use std::{convert::Infallible, time::Duration}; -//! use tower_async::ServiceBuilder; -//! use tower_async_http::timeout::TimeoutLayer; -//! -//! async fn handle(_: Request) -> Result, Infallible> { -//! // ... -//! # Ok(Response::new(Body::empty())) -//! } -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! let svc = ServiceBuilder::new() -//! // Timeout requests after 30 seconds -//! .layer(TimeoutLayer::new(Duration::from_secs(30))) -//! .service_fn(handle); -//! # Ok(()) -//! # } -//! ``` -//! -//! [`Infallible`]: std::convert::Infallible - use http::{Request, Response, StatusCode}; use std::time::Duration; use tower_async_layer::Layer; use tower_async_service::Service; /// Layer that applies the [`Timeout`] middleware which apply a timeout to requests. +/// +/// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct TimeoutLayer { timeout: Duration, @@ -69,6 +30,8 @@ impl Layer for TimeoutLayer { /// /// If the request does not complete within the specified timeout it will be aborted and a `408 /// Request Timeout` response will be sent. +/// +/// See the [module docs](super) for an example. #[derive(Debug, Clone, Copy)] pub struct Timeout { inner: S, diff --git a/tower-async-http/src/trace/make_span.rs b/tower-async-http/src/trace/make_span.rs index 904d941..bf558d3 100644 --- a/tower-async-http/src/trace/make_span.rs +++ b/tower-async-http/src/trace/make_span.rs @@ -103,21 +103,11 @@ impl MakeSpan for DefaultMakeSpan { } match self.level { - Level::ERROR => { - make_span!(Level::ERROR) - } - Level::WARN => { - make_span!(Level::WARN) - } - Level::INFO => { - make_span!(Level::INFO) - } - Level::DEBUG => { - make_span!(Level::DEBUG) - } - Level::TRACE => { - make_span!(Level::TRACE) - } + Level::ERROR => make_span!(Level::ERROR), + Level::WARN => make_span!(Level::WARN), + Level::INFO => make_span!(Level::INFO), + Level::DEBUG => make_span!(Level::DEBUG), + Level::TRACE => make_span!(Level::TRACE), } } } diff --git a/tower-async-http/src/trace/mod.rs b/tower-async-http/src/trace/mod.rs index f4e0bd7..29dccba 100644 --- a/tower-async-http/src/trace/mod.rs +++ b/tower-async-http/src/trace/mod.rs @@ -375,6 +375,8 @@ //! [`Body::poll_trailers`]: http_body::Body::poll_trailers //! [`Body::poll_data`]: http_body::Body::poll_data +use std::{fmt, time::Duration}; + use tracing::Level; pub use self::{ @@ -389,6 +391,55 @@ pub use self::{ service::Trace, }; +use crate::LatencyUnit; + +macro_rules! event_dynamic_lvl { + ( $(target: $target:expr,)? $(parent: $parent:expr,)? $lvl:expr, $($tt:tt)* ) => { + match $lvl { + tracing::Level::ERROR => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::ERROR, + $($tt)* + ); + } + tracing::Level::WARN => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::WARN, + $($tt)* + ); + } + tracing::Level::INFO => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::INFO, + $($tt)* + ); + } + tracing::Level::DEBUG => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::DEBUG, + $($tt)* + ); + } + tracing::Level::TRACE => { + tracing::event!( + $(target: $target,)? + $(parent: $parent,)? + tracing::Level::TRACE, + $($tt)* + ); + } + } + }; +} + mod body; mod layer; mod make_span; @@ -402,6 +453,22 @@ mod service; const DEFAULT_MESSAGE_LEVEL: Level = Level::DEBUG; const DEFAULT_ERROR_LEVEL: Level = Level::ERROR; +struct Latency { + unit: LatencyUnit, + duration: Duration, +} + +impl fmt::Display for Latency { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.unit { + LatencyUnit::Seconds => write!(f, "{} s", self.duration.as_secs_f64()), + LatencyUnit::Millis => write!(f, "{} ms", self.duration.as_millis()), + LatencyUnit::Micros => write!(f, "{} μs", self.duration.as_micros()), + LatencyUnit::Nanos => write!(f, "{} ns", self.duration.as_nanos()), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/tower-async-http/src/trace/on_eos.rs b/tower-async-http/src/trace/on_eos.rs index bebe6db..ab90fc9 100644 --- a/tower-async-http/src/trace/on_eos.rs +++ b/tower-async-http/src/trace/on_eos.rs @@ -1,4 +1,4 @@ -use super::DEFAULT_MESSAGE_LEVEL; +use super::{Latency, DEFAULT_MESSAGE_LEVEL}; use crate::{classify::grpc_errors_as_failures::ParsedGrpcStatus, LatencyUnit}; use http::header::HeaderMap; use std::time::Duration; @@ -83,88 +83,12 @@ impl DefaultOnEos { } } -// Repeating this pattern match for each case is tedious. So we do it with a quick and -// dirty macro. -// -// Tracing requires all these parts to be declared statically. You cannot easily build -// events dynamically. -#[allow(unused_macros)] -macro_rules! log_pattern_match { - ( - $this:expr, $stream_duration:expr, $status:expr, [$($level:ident),*] - ) => { - match ($this.level, $this.latency_unit, $status) { - $( - (Level::$level, LatencyUnit::Seconds, None) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} s", $stream_duration.as_secs_f64()), - "end of stream" - ); - } - (Level::$level, LatencyUnit::Seconds, Some(status)) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} s", $stream_duration.as_secs_f64()), - status = status, - "end of stream" - ); - } - - (Level::$level, LatencyUnit::Millis, None) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} ms", $stream_duration.as_millis()), - "end of stream" - ); - } - (Level::$level, LatencyUnit::Millis, Some(status)) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} ms", $stream_duration.as_millis()), - status = status, - "end of stream" - ); - } - - (Level::$level, LatencyUnit::Micros, None) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} μs", $stream_duration.as_micros()), - "end of stream" - ); - } - (Level::$level, LatencyUnit::Micros, Some(status)) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} μs", $stream_duration.as_micros()), - status = status, - "end of stream" - ); - } - - (Level::$level, LatencyUnit::Nanos, None) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} ns", $stream_duration.as_nanos()), - "end of stream" - ); - } - (Level::$level, LatencyUnit::Nanos, Some(status)) => { - tracing::event!( - Level::$level, - stream_duration = format_args!("{} ns", $stream_duration.as_nanos()), - status = status, - "end of stream" - ); - } - )* - } - }; -} - impl OnEos for DefaultOnEos { fn on_eos(self, trailers: Option<&HeaderMap>, stream_duration: Duration, _span: &Span) { + let stream_duration = Latency { + unit: self.latency_unit, + duration: stream_duration, + }; let status = trailers.and_then(|trailers| { match crate::classify::grpc_errors_as_failures::classify_grpc_metadata( trailers, @@ -178,11 +102,6 @@ impl OnEos for DefaultOnEos { } }); - log_pattern_match!( - self, - stream_duration, - status, - [ERROR, WARN, INFO, DEBUG, TRACE] - ); + event_dynamic_lvl!(self.level, %stream_duration, status, "end of stream"); } } diff --git a/tower-async-http/src/trace/on_failure.rs b/tower-async-http/src/trace/on_failure.rs index 42fb100..7dfa186 100644 --- a/tower-async-http/src/trace/on_failure.rs +++ b/tower-async-http/src/trace/on_failure.rs @@ -1,4 +1,4 @@ -use super::DEFAULT_ERROR_LEVEL; +use super::{Latency, DEFAULT_ERROR_LEVEL}; use crate::LatencyUnit; use std::{fmt, time::Duration}; use tracing::{Level, Span}; @@ -81,64 +81,20 @@ impl DefaultOnFailure { } } -// Repeating this pattern match for each case is tedious. So we do it with a quick and -// dirty macro. -// -// Tracing requires all these parts to be declared statically. You cannot easily build -// events dynamically. -macro_rules! log_pattern_match { - ( - $this:expr, $failure_classification:expr, $latency:expr, [$($level:ident),*] - ) => { - match ($this.level, $this.latency_unit) { - $( - (Level::$level, LatencyUnit::Seconds) => { - tracing::event!( - Level::$level, - classification = tracing::field::display($failure_classification), - latency = format_args!("{} s", $latency.as_secs_f64()), - "response failed" - ); - } - (Level::$level, LatencyUnit::Millis) => { - tracing::event!( - Level::$level, - classification = tracing::field::display($failure_classification), - latency = format_args!("{} ms", $latency.as_millis()), - "response failed" - ); - } - (Level::$level, LatencyUnit::Micros) => { - tracing::event!( - Level::$level, - classification = tracing::field::display($failure_classification), - latency = format_args!("{} μs", $latency.as_micros()), - "response failed" - ); - } - (Level::$level, LatencyUnit::Nanos) => { - tracing::event!( - Level::$level, - classification = tracing::field::display($failure_classification), - latency = format_args!("{} ns", $latency.as_nanos()), - "response failed" - ); - } - )* - } - }; -} - impl OnFailure for DefaultOnFailure where FailureClass: fmt::Display, { fn on_failure(&mut self, failure_classification: FailureClass, latency: Duration, _: &Span) { - log_pattern_match!( - self, - &failure_classification, - latency, - [ERROR, WARN, INFO, DEBUG, TRACE] + let latency = Latency { + unit: self.latency_unit, + duration: latency, + }; + event_dynamic_lvl!( + self.level, + classification = %failure_classification, + %latency, + "response failed" ); } } diff --git a/tower-async-http/src/trace/on_request.rs b/tower-async-http/src/trace/on_request.rs index 0e343ae..07de189 100644 --- a/tower-async-http/src/trace/on_request.rs +++ b/tower-async-http/src/trace/on_request.rs @@ -77,22 +77,6 @@ impl DefaultOnRequest { impl OnRequest for DefaultOnRequest { fn on_request(&mut self, _: &Request, _: &Span) { - match self.level { - Level::ERROR => { - tracing::event!(Level::ERROR, "started processing request"); - } - Level::WARN => { - tracing::event!(Level::WARN, "started processing request"); - } - Level::INFO => { - tracing::event!(Level::INFO, "started processing request"); - } - Level::DEBUG => { - tracing::event!(Level::DEBUG, "started processing request"); - } - Level::TRACE => { - tracing::event!(Level::TRACE, "started processing request"); - } - } + event_dynamic_lvl!(self.level, "started processing request"); } } diff --git a/tower-async-http/src/trace/on_response.rs b/tower-async-http/src/trace/on_response.rs index 573de2e..c6ece84 100644 --- a/tower-async-http/src/trace/on_response.rs +++ b/tower-async-http/src/trace/on_response.rs @@ -1,4 +1,4 @@ -use super::DEFAULT_MESSAGE_LEVEL; +use super::{Latency, DEFAULT_MESSAGE_LEVEL}; use crate::LatencyUnit; use http::Response; use std::time::Duration; @@ -101,162 +101,22 @@ impl DefaultOnResponse { } } -// Repeating this pattern match for each case is tedious. So we do it with a quick and -// dirty macro. -// -// Tracing requires all these parts to be declared statically. You cannot easily build -// events dynamically. -#[allow(unused_macros)] -macro_rules! log_pattern_match { - ( - $this:expr, $res:expr, $latency:expr, $include_headers:expr, [$($level:ident),*] - ) => { - match ($this.level, $include_headers, $this.latency_unit, status($res)) { - $( - (Level::$level, true, LatencyUnit::Seconds, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} s", $latency.as_secs_f64()), - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Seconds, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} s", $latency.as_secs_f64()), - "finished processing request" - ); - } - (Level::$level, true, LatencyUnit::Seconds, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} s", $latency.as_secs_f64()), - status = status, - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Seconds, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} s", $latency.as_secs_f64()), - status = status, - "finished processing request" - ); - } - - (Level::$level, true, LatencyUnit::Millis, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ms", $latency.as_millis()), - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Millis, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ms", $latency.as_millis()), - "finished processing request" - ); - } - (Level::$level, true, LatencyUnit::Millis, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ms", $latency.as_millis()), - status = status, - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Millis, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ms", $latency.as_millis()), - status = status, - "finished processing request" - ); - } - - (Level::$level, true, LatencyUnit::Micros, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} μs", $latency.as_micros()), - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Micros, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} μs", $latency.as_micros()), - "finished processing request" - ); - } - (Level::$level, true, LatencyUnit::Micros, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} μs", $latency.as_micros()), - status = status, - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Micros, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} μs", $latency.as_micros()), - status = status, - "finished processing request" - ); - } - - (Level::$level, true, LatencyUnit::Nanos, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ns", $latency.as_nanos()), - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Nanos, None) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ns", $latency.as_nanos()), - "finished processing request" - ); - } - (Level::$level, true, LatencyUnit::Nanos, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ns", $latency.as_nanos()), - status = status, - response_headers = ?$res.headers(), - "finished processing request" - ); - } - (Level::$level, false, LatencyUnit::Nanos, Some(status)) => { - tracing::event!( - Level::$level, - latency = format_args!("{} ns", $latency.as_nanos()), - status = status, - "finished processing request" - ); - } - )* - } - }; -} - impl OnResponse for DefaultOnResponse { fn on_response(self, response: &Response, latency: Duration, _: &Span) { - log_pattern_match!( - self, - response, - latency, - self.include_headers, - [ERROR, WARN, INFO, DEBUG, TRACE] + let latency = Latency { + unit: self.latency_unit, + duration: latency, + }; + let response_headers = self + .include_headers + .then(|| tracing::field::debug(response.headers())); + + event_dynamic_lvl!( + self.level, + %latency, + status = status(response), + response_headers, + "finished processing request" ); } }