Skip to content

Commit

Permalink
introduce AllowAnonymous trait
Browse files Browse the repository at this point in the history
  • Loading branch information
calebbourg committed Nov 7, 2024
1 parent c9dd4ff commit 2b28e8b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
37 changes: 36 additions & 1 deletion rama-http/src/layer/auth/require_authorization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ use base64::Engine as _;
use std::{fmt, marker::PhantomData};

use crate::layer::validate_request::{
ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
AllowAnonymous, ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
};
use crate::{
header::{self, HeaderValue},
Request, Response, StatusCode,
};
use rama_core::Context;

use rama_net::user::UserId;

const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;

impl<S, ResBody> ValidateRequestHeader<S, Basic<ResBody>> {
Expand Down Expand Up @@ -135,6 +137,7 @@ impl<ResBody> ValidateRequestHeaderLayer<Bearer<ResBody>> {
/// See [`ValidateRequestHeader::bearer`] for more details.
pub struct Bearer<ResBody> {
header_value: HeaderValue,
allow_anonymous: bool,
_ty: PhantomData<fn() -> ResBody>,
}

Expand All @@ -147,6 +150,7 @@ impl<ResBody> Bearer<ResBody> {
header_value: format!("Bearer {}", token)
.parse()
.expect("token is not a valid header value"),
allow_anonymous: false,
_ty: PhantomData,
}
}
Expand All @@ -156,6 +160,7 @@ impl<ResBody> Clone for Bearer<ResBody> {
fn clone(&self) -> Self {
Self {
header_value: self.header_value.clone(),
allow_anonymous: self.allow_anonymous,
_ty: PhantomData,
}
}
Expand All @@ -165,10 +170,17 @@ impl<ResBody> fmt::Debug for Bearer<ResBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Bearer")
.field("header_value", &self.header_value)
.field("allow_anonymous", &self.allow_anonymous)
.finish()
}
}

impl AllowAnonymous for Bearer<()> {
fn allow_anonymous(&mut self, allow_anonymous: bool) {
self.allow_anonymous = allow_anonymous;
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Bearer<ResBody>
where
ResBody: Default + Send + 'static,
Expand All @@ -185,6 +197,12 @@ where
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
_ => {
if self.allow_anonymous {
let mut ctx = ctx.clone();
ctx.insert(UserId::Anonymous);

return Ok((ctx, request));
}
let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
Err(res)
Expand All @@ -198,6 +216,7 @@ where
/// See [`ValidateRequestHeader::basic`] for more details.
pub struct Basic<ResBody> {
header_value: HeaderValue,
allow_anonymous: bool,
_ty: PhantomData<fn() -> ResBody>,
}

Expand All @@ -210,6 +229,7 @@ impl<ResBody> Basic<ResBody> {
let header_value = format!("Basic {}", encoded).parse().unwrap();
Self {
header_value,
allow_anonymous: false,
_ty: PhantomData,
}
}
Expand All @@ -219,6 +239,7 @@ impl<ResBody> Clone for Basic<ResBody> {
fn clone(&self) -> Self {
Self {
header_value: self.header_value.clone(),
allow_anonymous: self.allow_anonymous,
_ty: PhantomData,
}
}
Expand All @@ -228,10 +249,17 @@ impl<ResBody> fmt::Debug for Basic<ResBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Basic")
.field("header_value", &self.header_value)
.field("allow_anonymous", &self.allow_anonymous)
.finish()
}
}

impl AllowAnonymous for Basic<()> {
fn allow_anonymous(&mut self, allow_anonymous: bool) {
self.allow_anonymous = allow_anonymous;
}
}

impl<S, B, ResBody> ValidateRequest<S, B> for Basic<ResBody>
where
ResBody: Default + Send + 'static,
Expand All @@ -248,6 +276,13 @@ where
match request.headers().get(header::AUTHORIZATION) {
Some(actual) if actual == self.header_value => Ok((ctx, request)),
_ => {
if self.allow_anonymous {
let mut ctx = ctx.clone();
ctx.insert(UserId::Anonymous);

return Ok((ctx, request));
}

let mut res = Response::new(ResBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res.headers_mut()
Expand Down
4 changes: 3 additions & 1 deletion rama-http/src/layer/validate_request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,6 @@ pub use validate::ValidateRequest;
#[doc(inline)]
pub use validate_fn::{BoxValidateRequestFn, ValidateRequestFn};
#[doc(inline)]
pub use validate_request_header::{ValidateRequestHeader, ValidateRequestHeaderLayer};
pub use validate_request_header::{
AllowAnonymous, ValidateRequestHeader, ValidateRequestHeaderLayer,
};
36 changes: 36 additions & 0 deletions rama-http/src/layer/validate_request/validate_request_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use rama_core::{Context, Layer, Service};
use rama_utils::macros::define_inner_service_accessors;
use std::fmt;

pub trait AllowAnonymous {
fn allow_anonymous(&mut self, allow_anonymous: bool);
}

/// Layer that applies [`ValidateRequestHeader`] which validates all requests.
///
/// See the [module docs](crate::layer::validate_request) for an example.
Expand Down Expand Up @@ -85,6 +89,22 @@ where
}
}

impl<T> ValidateRequestHeaderLayer<T>
where
T: AllowAnonymous,
{
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous(allow_anonymous);
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous(allow_anonymous);
self
}
}

/// Middleware that validates requests.
///
/// See the [module docs](crate::layer::validate_request) for an example.
Expand Down Expand Up @@ -157,6 +177,22 @@ impl<S, F, A> ValidateRequestHeader<S, BoxValidateRequestFn<F, A>> {
}
}

impl<S, T> ValidateRequestHeader<S, T>
where
T: AllowAnonymous,
{
pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
self.validate.allow_anonymous(allow_anonymous);
self
}

/// Allow anonymous requests.
pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.validate.allow_anonymous(allow_anonymous);
self
}
}

impl<ReqBody, ResBody, State, S, V> Service<State, Request<ReqBody>> for ValidateRequestHeader<S, V>
where
ReqBody: Send + 'static,
Expand Down

0 comments on commit 2b28e8b

Please sign in to comment.