diff --git a/examples/http_connect_proxy.rs b/examples/http_connect_proxy.rs index 4d0964a6..2c41a6e5 100644 --- a/examples/http_connect_proxy.rs +++ b/examples/http_connect_proxy.rs @@ -71,7 +71,7 @@ use rama::{ service::web::{extract::Path, match_service}, Body, IntoResponse, Request, RequestContext, Response, StatusCode, }, - net::stream::layer::http::BodyLimitLayer, + net::{stream::layer::http::BodyLimitLayer, user::Basic}, rt::Executor, service::{ context::Extensions, layer::HijackLayer, service_fn, Context, Service, ServiceBuilder, @@ -116,7 +116,7 @@ async fn main() { .layer(TraceLayer::new_for_http()) // See [`ProxyAuthLayer::with_labels`] for more information, // e.g. can also be used to extract upstream proxy filters - .layer(ProxyAuthLayer::basic(("john", "secret")).with_labels::<(PriorityUsernameLabelParser, UsernameOpaqueLabelParser)>()) + .layer(ProxyAuthLayer::new(Basic::new("john", "secret")).with_labels::<(PriorityUsernameLabelParser, UsernameOpaqueLabelParser)>()) // example of how one might insert an API layer into their proxy .layer(HijackLayer::new( DomainMatcher::new("echo.example.internal"), diff --git a/examples/https_connect_proxy.rs b/examples/https_connect_proxy.rs index 23d38bad..00f8f09d 100644 --- a/examples/https_connect_proxy.rs +++ b/examples/https_connect_proxy.rs @@ -35,7 +35,7 @@ use rama::{ server::HttpServer, Body, IntoResponse, Request, RequestContext, Response, StatusCode, }, - net::stream::layer::http::BodyLimitLayer, + net::{stream::layer::http::BodyLimitLayer, user::Basic}, rt::Executor, service::{service_fn, Context, Service, ServiceBuilder}, tcp::{server::TcpListener, utils::is_connection_error}, @@ -133,7 +133,7 @@ async fn main() { .layer(TraceLayer::new_for_http()) // See [`ProxyAuthLayer::with_labels`] for more information, // e.g. can also be used to extract upstream proxy filters - .layer(ProxyAuthLayer::basic(("john", "secret"))) + .layer(ProxyAuthLayer::new(Basic::new("john", "secret"))) .layer(UpgradeLayer::new( MethodMatcher::CONNECT, service_fn(http_connect_accept), diff --git a/src/http/layer/proxy_auth/auth.rs b/src/http/layer/proxy_auth/auth.rs index e863ace9..a750bd73 100644 --- a/src/http/layer/proxy_auth/auth.rs +++ b/src/http/layer/proxy_auth/auth.rs @@ -1,8 +1,6 @@ use crate::{ - http::headers::{ - authorization::{Basic, Credentials}, - Authorization, - }, + http::headers::authorization::Credentials, + net::user::{Basic, UserId}, service::context::Extensions, utils::username::{parse_username, UsernameLabelParser, DEFAULT_USERNAME_LABEL_SEPARATOR}, }; @@ -47,7 +45,7 @@ where impl ProxyAuthoritySync for Basic { fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool { if self == credentials { - ext.insert(self.clone()); + ext.insert(UserId::Username(self.username().to_owned())); true } else { false @@ -75,7 +73,7 @@ impl ProxyAuthoritySync for Basic { Err(err) => { tracing::trace!("failed to parse username: {:?}", err); return if self == credentials { - ext.insert(self.clone()); + ext.insert(UserId::Username(username.to_owned())); true } else { false @@ -88,105 +86,7 @@ impl ProxyAuthoritySync for Basic { } ext.extend(parser_ext); - ext.insert(Authorization::basic(username.as_str(), password).0); - true - } -} - -impl ProxyAuthoritySync for (&'static str, &'static str) { - fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool { - if self.0 == credentials.username() && self.1 == credentials.password() { - ext.insert(Authorization::basic(self.0, self.1).0); - true - } else { - false - } - } -} - -impl ProxyAuthoritySync for (&'static str, &'static str) { - fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool { - let username = credentials.username(); - let password = credentials.password(); - - if password != self.1 { - return false; - } - - let mut parser_ext = Extensions::new(); - let username = match parse_username( - &mut parser_ext, - T::default(), - username, - DEFAULT_USERNAME_LABEL_SEPARATOR, - ) { - Ok(t) => t, - Err(err) => { - tracing::trace!("failed to parse username: {:?}", err); - return if self.0 == credentials.username() && self.1 == credentials.password() { - ext.insert(Authorization::basic(self.0, self.1).0); - true - } else { - false - }; - } - }; - - if username != self.0 { - return false; - } - - ext.extend(parser_ext); - ext.insert(Authorization::basic(username.as_str(), password).0); - true - } -} - -impl ProxyAuthoritySync for (String, String) { - fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool { - if self.0 == credentials.username() && self.1 == credentials.password() { - ext.insert(Authorization::basic(self.0.as_str(), self.1.as_str()).0); - true - } else { - false - } - } -} - -impl ProxyAuthoritySync for (String, String) { - fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool { - let username = credentials.username(); - let password = credentials.password(); - - if password != self.1 { - return false; - } - - let mut parser_ext = Extensions::new(); - let username = match parse_username( - &mut parser_ext, - T::default(), - username, - DEFAULT_USERNAME_LABEL_SEPARATOR, - ) { - Ok(t) => t, - Err(err) => { - tracing::trace!("failed to parse username: {:?}", err); - return if self.0 == credentials.username() && self.1 == credentials.password() { - ext.insert(Authorization::basic(self.0.as_str(), self.1.as_str()).0); - true - } else { - false - }; - } - }; - - if username != self.0 { - return false; - } - - ext.extend(parser_ext); - ext.insert(Authorization::basic(username.as_str(), password).0); + ext.insert(UserId::Username(username)); true } } @@ -215,38 +115,35 @@ where mod test { use super::*; use crate::{ - http::headers::{authorization::Basic, Authorization}, + net::user::Basic, proxy::{ProxyFilter, ProxyFilterUsernameParser}, utils::username::{UsernameLabels, UsernameOpaqueLabelParser}, }; #[tokio::test] async fn basic_authorization() { - let Authorization(auth) = Authorization::basic("Aladdin", "open sesame"); - let auths = vec![Authorization::basic("foo", "bar").0, auth.clone()]; - let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone()) + let auth = Basic::new("Aladdin", "open sesame"); + let auths = vec![Basic::new("foo", "bar"), auth.clone()]; + let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth) .await .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); + let user: &UserId = ext.get().unwrap(); + assert_eq!(user, "Aladdin"); } #[tokio::test] async fn basic_authorization_with_filter_found() { - let Authorization(auth) = Authorization::basic("john", "secret"); - let auths = vec![ - Authorization::basic("foo", "bar").0, - Authorization::basic("john", "secret").0, - ]; + let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")]; let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized( &auths, - Authorization::basic("john-country-us", "secret").0, + Basic::new("john-country-us", "secret"), ) .await .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); + + let c: &UserId = ext.get().unwrap(); + assert_eq!(c, "john"); let filter: &ProxyFilter = ext.get().unwrap(); assert_eq!(filter.country, Some(vec!["us".into()])); @@ -254,20 +151,17 @@ mod test { #[tokio::test] async fn basic_authorization_with_labels_found() { - let Authorization(auth) = Authorization::basic("john", "secret"); - let auths = vec![ - Authorization::basic("foo", "bar").0, - Authorization::basic("john", "secret").0, - ]; + let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")]; let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized( &auths, - Authorization::basic("john-green-red", "secret").0, + Basic::new("john-green-red", "secret"), ) .await .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); + + let c: &UserId = ext.get().unwrap(); + assert_eq!(c, "john"); let labels: &UsernameLabels = ext.get().unwrap(); assert_eq!(&labels.0, &vec!["green".to_owned(), "red".to_owned()]); @@ -275,79 +169,31 @@ mod test { #[tokio::test] async fn basic_authorization_with_filter_not_found() { - let Authorization(auth) = Authorization::basic("john", "secret"); - let auths = vec![ - Authorization::basic("foo", "bar").0, - Authorization::basic("john", "secret").0, - ]; + let auth = Basic::new("john", "secret"); + let auths = vec![Basic::new("foo", "bar"), auth.clone()]; - let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized(&auths, auth.clone()) + let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized(&auths, auth) .await .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); + + let c: &UserId = ext.get().unwrap(); + assert_eq!(c, "john"); assert!(ext.get::().is_none()); } #[tokio::test] async fn basic_authorization_with_labels_not_found() { - let Authorization(auth) = Authorization::basic("john", "secret"); - let auths = vec![ - Authorization::basic("foo", "bar").0, - Authorization::basic("john", "secret").0, - ]; + let auth = Basic::new("john", "secret"); + let auths = vec![Basic::new("foo", "bar"), auth.clone()]; - let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth.clone()) + let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth) .await .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); - assert!(ext.get::().is_none()); - } - - #[tokio::test] - async fn basic_authorization_tuple() { - let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")]; - let Authorization(auth) = Authorization::basic("Aladdin", "open sesame"); - let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone()) - .await - .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); - } + let c: &UserId = ext.get().unwrap(); + assert_eq!(c, "john"); - #[tokio::test] - async fn basic_authorization_tuple_no_auth_username() { - let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")]; - let Authorization(auth) = Authorization::basic("bax", "qux"); - assert!(ProxyAuthority::<_, ()>::authorized(&auths, auth.clone()) - .await - .is_none()); - } - - #[tokio::test] - async fn basic_authorization_tuple_no_auth_password() { - let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")]; - let Authorization(auth) = Authorization::basic("baz", "quc"); - assert!(ProxyAuthority::<_, ()>::authorized(&auths, auth.clone()) - .await - .is_none()) - } - - #[tokio::test] - async fn basic_authorization_tuple_string() { - let auths = vec![ - ("foo".to_owned(), "bar".to_owned()), - ("Aladdin".to_owned(), "open sesame".to_owned()), - ("baz".to_owned(), "qux".to_owned()), - ]; - let Authorization(auth) = Authorization::basic("Aladdin", "open sesame"); - let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone()) - .await - .unwrap(); - let c: &Basic = ext.get().unwrap(); - assert_eq!(&auth, c); + assert!(ext.get::().is_none()); } } diff --git a/src/http/layer/proxy_auth/mod.rs b/src/http/layer/proxy_auth/mod.rs index 77968add..0a0ede6a 100644 --- a/src/http/layer/proxy_auth/mod.rs +++ b/src/http/layer/proxy_auth/mod.rs @@ -53,16 +53,6 @@ impl ProxyAuthLayer { } } -impl ProxyAuthLayer { - /// Creates a new [`ProxyAuthLayer`] with the default [`Basic`] credentials. - pub fn basic(proxy_auth: A) -> Self { - ProxyAuthLayer { - proxy_auth, - _phantom: PhantomData, - } - } -} - impl Layer for ProxyAuthLayer where A: ProxyAuthority + Clone, diff --git a/src/net/user/credentials.rs b/src/net/user/credentials.rs new file mode 100644 index 00000000..4ed50c4b --- /dev/null +++ b/src/net/user/credentials.rs @@ -0,0 +1,179 @@ +use std::borrow::Cow; + +use crate::http; +use crate::http::headers::authorization; +use base64::engine::general_purpose::STANDARD as ENGINE; +use base64::Engine; + +#[derive(Debug, Clone)] +/// Basic credentials. +pub struct Basic { + data: BasicData, +} + +#[derive(Debug, Clone)] +enum BasicData { + Username(Cow<'static, str>), + Pair { + username: Cow<'static, str>, + password: Cow<'static, str>, + }, + Decoded { + decoded: String, + colon_pos: usize, + }, +} + +impl Basic { + /// Creates a new [`Basic`] credential. + pub fn new( + username: impl Into>, + password: impl Into>, + ) -> Self { + let data = BasicData::Pair { + username: username.into(), + password: password.into(), + }; + Basic { data } + } + + /// Creates a new [`Basic`] credential with only a username. + pub fn unprotected(username: impl Into>) -> Self { + let data: BasicData = BasicData::Username(username.into()); + Basic { data } + } + + /// View the decoded username. + pub fn username(&self) -> &str { + match &self.data { + BasicData::Username(username) => username, + BasicData::Pair { username, .. } => username, + BasicData::Decoded { decoded, colon_pos } => &decoded[..*colon_pos], + } + } + + /// View the decoded password. + pub fn password(&self) -> &str { + match &self.data { + BasicData::Username(_) => "", + BasicData::Pair { password, .. } => password, + BasicData::Decoded { decoded, colon_pos } => &decoded[*colon_pos + 1..], + } + } +} + +impl PartialEq for Basic { + fn eq(&self, other: &Basic) -> bool { + self.username() == other.username() && self.password() == other.password() + } +} + +impl Eq for Basic {} + +impl authorization::Credentials for Basic { + const SCHEME: &'static str = "Basic"; + + fn decode(value: &http::HeaderValue) -> Option { + debug_assert!( + value.as_bytes()[..Self::SCHEME.len()].eq_ignore_ascii_case(Self::SCHEME.as_bytes()), + "HeaderValue to decode should start with \"Basic ..\", received = {:?}", + value, + ); + + let bytes = &value.as_bytes()["Basic ".len()..]; + let non_space_pos = bytes.iter().position(|b| *b != b' ')?; + let bytes = &bytes[non_space_pos..]; + + let bytes = ENGINE.decode(bytes).ok()?; + + let decoded = String::from_utf8(bytes).ok()?; + + let colon_pos = decoded.find(':')?; + + let data = BasicData::Decoded { decoded, colon_pos }; + Some(Basic { data }) + } + + fn encode(&self) -> http::HeaderValue { + let mut encoded = String::from("Basic "); + + match &self.data { + BasicData::Username(username) => { + let decoded = format!("{username}:"); + ENGINE.encode_string(&decoded, &mut encoded); + } + BasicData::Pair { username, password } => { + let decoded = format!("{username}:{password}"); + ENGINE.encode_string(&decoded, &mut encoded); + } + BasicData::Decoded { decoded, .. } => { + ENGINE.encode_string(decoded, &mut encoded); + } + } + + let bytes = bytes::Bytes::from(encoded); + http::HeaderValue::from_maybe_shared(bytes) + .expect("base64 encoding is always a valid HeaderValue") + } +} + +#[cfg(test)] +mod tests { + use ::http::HeaderValue; + use headers::authorization::Credentials; + + use super::*; + + #[test] + fn basic_encode() { + let auth = Basic::new("Aladdin", "open sesame"); + let value = auth.encode(); + + assert_eq!(value, "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",); + } + + #[test] + fn basic_encode_no_password() { + let auth = Basic::unprotected("Aladdin"); + let value = auth.encode(); + + assert_eq!(value, "Basic QWxhZGRpbjo=",); + } + + #[test] + fn basic_decode() { + let auth = Basic::decode(&HeaderValue::from_static( + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + )) + .unwrap(); + assert_eq!(auth.username(), "Aladdin"); + assert_eq!(auth.password(), "open sesame"); + } + + #[test] + fn basic_decode_case_insensitive() { + let auth = Basic::decode(&HeaderValue::from_static( + "basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + )) + .unwrap(); + assert_eq!(auth.username(), "Aladdin"); + assert_eq!(auth.password(), "open sesame"); + } + + #[test] + fn basic_decode_extra_whitespaces() { + let auth = Basic::decode(&HeaderValue::from_static( + "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==", + )) + .unwrap(); + assert_eq!(auth.username(), "Aladdin"); + assert_eq!(auth.password(), "open sesame"); + } + + #[test] + fn basic_decode_no_password() { + let auth = Basic::decode(&HeaderValue::from_static("Basic QWxhZGRpbjo=")).unwrap(); + assert_eq!(auth.username(), "Aladdin"); + assert_eq!(auth.password(), ""); + } +} diff --git a/src/net/user/mod.rs b/src/net/user/mod.rs index 13ede5ac..dcfe85a9 100644 --- a/src/net/user/mod.rs +++ b/src/net/user/mod.rs @@ -5,3 +5,7 @@ mod id; #[doc(inline)] pub use id::UserId; + +mod credentials; +#[doc(inline)] +pub use credentials::Basic;