Skip to content

Commit

Permalink
simplify authentication usage and code
Browse files Browse the repository at this point in the history
+ expose as UserId
instead of what we did before
  • Loading branch information
glendc committed Jun 10, 2024
1 parent dbac49e commit 4fe872a
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 201 deletions.
4 changes: 2 additions & 2 deletions examples/http_connect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ use rama::{
service::web::{extract::Path, match_service},
Body, IntoResponse, Request, RequestContext, Response, StatusCode,
},
net::stream::layer::http::BodyLimitLayer,
net::{stream::layer::http::BodyLimitLayer, user::Basic},
rt::Executor,
service::{
context::Extensions, layer::HijackLayer, service_fn, Context, Service, ServiceBuilder,
Expand Down Expand Up @@ -116,7 +116,7 @@ async fn main() {
.layer(TraceLayer::new_for_http())
// See [`ProxyAuthLayer::with_labels`] for more information,
// e.g. can also be used to extract upstream proxy filters
.layer(ProxyAuthLayer::basic(("john", "secret")).with_labels::<(PriorityUsernameLabelParser, UsernameOpaqueLabelParser)>())
.layer(ProxyAuthLayer::new(Basic::new("john", "secret")).with_labels::<(PriorityUsernameLabelParser, UsernameOpaqueLabelParser)>())
// example of how one might insert an API layer into their proxy
.layer(HijackLayer::new(
DomainMatcher::new("echo.example.internal"),
Expand Down
4 changes: 2 additions & 2 deletions examples/https_connect_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use rama::{
server::HttpServer,
Body, IntoResponse, Request, RequestContext, Response, StatusCode,
},
net::stream::layer::http::BodyLimitLayer,
net::{stream::layer::http::BodyLimitLayer, user::Basic},
rt::Executor,
service::{service_fn, Context, Service, ServiceBuilder},
tcp::{server::TcpListener, utils::is_connection_error},
Expand Down Expand Up @@ -133,7 +133,7 @@ async fn main() {
.layer(TraceLayer::new_for_http())
// See [`ProxyAuthLayer::with_labels`] for more information,
// e.g. can also be used to extract upstream proxy filters
.layer(ProxyAuthLayer::basic(("john", "secret")))
.layer(ProxyAuthLayer::new(Basic::new("john", "secret")))
.layer(UpgradeLayer::new(
MethodMatcher::CONNECT,
service_fn(http_connect_accept),
Expand Down
220 changes: 33 additions & 187 deletions src/http/layer/proxy_auth/auth.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::{
http::headers::{
authorization::{Basic, Credentials},
Authorization,
},
http::headers::authorization::Credentials,
net::user::{Basic, UserId},
service::context::Extensions,
utils::username::{parse_username, UsernameLabelParser, DEFAULT_USERNAME_LABEL_SEPARATOR},
};
Expand Down Expand Up @@ -47,7 +45,7 @@ where
impl ProxyAuthoritySync<Basic, ()> for Basic {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
if self == credentials {
ext.insert(self.clone());
ext.insert(UserId::Username(self.username().to_owned()));
true
} else {
false
Expand Down Expand Up @@ -75,7 +73,7 @@ impl<T: UsernameLabelParser> ProxyAuthoritySync<Basic, T> for Basic {
Err(err) => {
tracing::trace!("failed to parse username: {:?}", err);
return if self == credentials {
ext.insert(self.clone());
ext.insert(UserId::Username(username.to_owned()));
true
} else {
false
Expand All @@ -88,105 +86,7 @@ impl<T: UsernameLabelParser> ProxyAuthoritySync<Basic, T> for Basic {
}

ext.extend(parser_ext);
ext.insert(Authorization::basic(username.as_str(), password).0);
true
}
}

impl ProxyAuthoritySync<Basic, ()> for (&'static str, &'static str) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0, self.1).0);
true
} else {
false
}
}
}

impl<T: UsernameLabelParser> ProxyAuthoritySync<Basic, T> for (&'static str, &'static str) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
let username = credentials.username();
let password = credentials.password();

if password != self.1 {
return false;
}

let mut parser_ext = Extensions::new();
let username = match parse_username(
&mut parser_ext,
T::default(),
username,
DEFAULT_USERNAME_LABEL_SEPARATOR,
) {
Ok(t) => t,
Err(err) => {
tracing::trace!("failed to parse username: {:?}", err);
return if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0, self.1).0);
true
} else {
false
};
}
};

if username != self.0 {
return false;
}

ext.extend(parser_ext);
ext.insert(Authorization::basic(username.as_str(), password).0);
true
}
}

impl ProxyAuthoritySync<Basic, ()> for (String, String) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0.as_str(), self.1.as_str()).0);
true
} else {
false
}
}
}

impl<T: UsernameLabelParser> ProxyAuthoritySync<Basic, T> for (String, String) {
fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
let username = credentials.username();
let password = credentials.password();

if password != self.1 {
return false;
}

let mut parser_ext = Extensions::new();
let username = match parse_username(
&mut parser_ext,
T::default(),
username,
DEFAULT_USERNAME_LABEL_SEPARATOR,
) {
Ok(t) => t,
Err(err) => {
tracing::trace!("failed to parse username: {:?}", err);
return if self.0 == credentials.username() && self.1 == credentials.password() {
ext.insert(Authorization::basic(self.0.as_str(), self.1.as_str()).0);
true
} else {
false
};
}
};

if username != self.0 {
return false;
}

ext.extend(parser_ext);
ext.insert(Authorization::basic(username.as_str(), password).0);
ext.insert(UserId::Username(username));
true
}
}
Expand Down Expand Up @@ -215,139 +115,85 @@ where
mod test {
use super::*;
use crate::{
http::headers::{authorization::Basic, Authorization},
net::user::Basic,
proxy::{ProxyFilter, ProxyFilterUsernameParser},
utils::username::{UsernameLabels, UsernameOpaqueLabelParser},
};

#[tokio::test]
async fn basic_authorization() {
let Authorization(auth) = Authorization::basic("Aladdin", "open sesame");
let auths = vec![Authorization::basic("foo", "bar").0, auth.clone()];
let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone())
let auth = Basic::new("Aladdin", "open sesame");
let auths = vec![Basic::new("foo", "bar"), auth.clone()];
let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);
let user: &UserId = ext.get().unwrap();
assert_eq!(user, "Aladdin");
}

#[tokio::test]
async fn basic_authorization_with_filter_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];
let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")];

let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized(
&auths,
Authorization::basic("john-country-us", "secret").0,
Basic::new("john-country-us", "secret"),
)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

let c: &UserId = ext.get().unwrap();
assert_eq!(c, "john");

let filter: &ProxyFilter = ext.get().unwrap();
assert_eq!(filter.country, Some(vec!["us".into()]));
}

#[tokio::test]
async fn basic_authorization_with_labels_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];
let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")];

let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized(
&auths,
Authorization::basic("john-green-red", "secret").0,
Basic::new("john-green-red", "secret"),
)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

let c: &UserId = ext.get().unwrap();
assert_eq!(c, "john");

let labels: &UsernameLabels = ext.get().unwrap();
assert_eq!(&labels.0, &vec!["green".to_owned(), "red".to_owned()]);
}

#[tokio::test]
async fn basic_authorization_with_filter_not_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];
let auth = Basic::new("john", "secret");
let auths = vec![Basic::new("foo", "bar"), auth.clone()];

let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized(&auths, auth.clone())
let ext = ProxyAuthority::<_, ProxyFilterUsernameParser>::authorized(&auths, auth)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

let c: &UserId = ext.get().unwrap();
assert_eq!(c, "john");

assert!(ext.get::<ProxyFilter>().is_none());
}

#[tokio::test]
async fn basic_authorization_with_labels_not_found() {
let Authorization(auth) = Authorization::basic("john", "secret");
let auths = vec![
Authorization::basic("foo", "bar").0,
Authorization::basic("john", "secret").0,
];
let auth = Basic::new("john", "secret");
let auths = vec![Basic::new("foo", "bar"), auth.clone()];

let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth.clone())
let ext = ProxyAuthority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth)
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);

assert!(ext.get::<UsernameLabels>().is_none());
}

#[tokio::test]
async fn basic_authorization_tuple() {
let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")];
let Authorization(auth) = Authorization::basic("Aladdin", "open sesame");
let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone())
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);
}
let c: &UserId = ext.get().unwrap();
assert_eq!(c, "john");

#[tokio::test]
async fn basic_authorization_tuple_no_auth_username() {
let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")];
let Authorization(auth) = Authorization::basic("bax", "qux");
assert!(ProxyAuthority::<_, ()>::authorized(&auths, auth.clone())
.await
.is_none());
}

#[tokio::test]
async fn basic_authorization_tuple_no_auth_password() {
let auths = vec![("foo", "bar"), ("Aladdin", "open sesame"), ("baz", "qux")];
let Authorization(auth) = Authorization::basic("baz", "quc");
assert!(ProxyAuthority::<_, ()>::authorized(&auths, auth.clone())
.await
.is_none())
}

#[tokio::test]
async fn basic_authorization_tuple_string() {
let auths = vec![
("foo".to_owned(), "bar".to_owned()),
("Aladdin".to_owned(), "open sesame".to_owned()),
("baz".to_owned(), "qux".to_owned()),
];
let Authorization(auth) = Authorization::basic("Aladdin", "open sesame");
let ext = ProxyAuthority::<_, ()>::authorized(&auths, auth.clone())
.await
.unwrap();
let c: &Basic = ext.get().unwrap();
assert_eq!(&auth, c);
assert!(ext.get::<UsernameLabels>().is_none());
}
}
10 changes: 0 additions & 10 deletions src/http/layer/proxy_auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@ impl<A, C, L> ProxyAuthLayer<A, C, L> {
}
}

impl<A> ProxyAuthLayer<A, Basic, ()> {
/// Creates a new [`ProxyAuthLayer`] with the default [`Basic`] credentials.
pub fn basic(proxy_auth: A) -> Self {
ProxyAuthLayer {
proxy_auth,
_phantom: PhantomData,
}
}
}

impl<A, C, L, S> Layer<S> for ProxyAuthLayer<A, C, L>
where
A: ProxyAuthority<C, L> + Clone,
Expand Down
Loading

0 comments on commit 4fe872a

Please sign in to comment.