diff --git a/src/error.rs b/src/error.rs index 065b428..99cd506 100644 --- a/src/error.rs +++ b/src/error.rs @@ -75,6 +75,9 @@ pub enum Error { #[error("invalid options provided for {0}")] InvalidOptionsProvided(String), + #[error("invalid update request")] + InvalidUpdateRequest, + #[error(transparent)] FromUtf8Error(#[from] std::string::FromUtf8Error), @@ -189,6 +192,14 @@ impl IntoResponse for Error { }], vec![], ), + Error::InvalidUpdateRequest => crate::handlers::Response::new_failure( + StatusCode::BAD_REQUEST, + vec![ResponseError { + name: "appendTags/removeTags".to_string(), + message: "cannot append and remove the same tag".to_string(), + }], + vec![], + ), e => crate::handlers::Response::new_failure(StatusCode::INTERNAL_SERVER_ERROR, vec![ ResponseError { name: "unknown_error".to_string(), diff --git a/src/handlers/get_registration.rs b/src/handlers/get_registration.rs index 06957f1..9fba58b 100644 --- a/src/handlers/get_registration.rs +++ b/src/handlers/get_registration.rs @@ -42,7 +42,9 @@ pub async fn handler( .await; Ok(Json(RegisterPayload { - tags: registration.tags, + tags: Some(registration.tags), + append_tags: None, + remove_tags: None, relay_url: registration.relay_url, })) } diff --git a/src/handlers/register.rs b/src/handlers/register.rs index eb0f4ef..6a6f719 100644 --- a/src/handlers/register.rs +++ b/src/handlers/register.rs @@ -1,7 +1,7 @@ use { crate::{ auth::AuthBearer, - error, + error::{self, Error}, handlers::Response, increment_counter, log::prelude::*, @@ -13,13 +13,15 @@ use { jwt::{JwtBasicClaims, VerifyableClaims}, }, serde::{Deserialize, Serialize}, - std::sync::Arc, + std::{collections::HashSet, sync::Arc}, }; #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] #[serde(rename_all = "camelCase")] pub struct RegisterPayload { - pub tags: Vec>, + pub tags: Option>>, + pub append_tags: Option>>, + pub remove_tags: Option>>, pub relay_url: Arc, } @@ -32,24 +34,91 @@ pub async fn handler( claims.verify_basic(&state.auth_aud, None)?; let client_id = ClientId::from(claims.iss); + increment_counter!(state.metrics, register); + + if let Some(tags) = body.tags { + increment_counter!(state.metrics, registration_overwrite); + + let tags = tags.into_iter().collect::>(); + overwrite_registration(&state, client_id.clone(), tags, body.relay_url).await?; + } else { + increment_counter!(state.metrics, registration_update); + + let append_tags = body + .append_tags + .map(|tags| tags.into_iter().collect::>()); + let remove_tags = body + .remove_tags + .map(|tags| tags.into_iter().collect::>()); + + update_registration( + &state, + client_id.clone(), + append_tags, + remove_tags, + body.relay_url, + ) + .await?; + } + + Ok(Response::default()) +} + +async fn overwrite_registration( + state: &Arc, + client_id: ClientId, + tags: HashSet>, + relay_url: Arc, +) -> error::Result { state .registration_store .upsert_registration( client_id.value(), - body.tags.iter().map(AsRef::as_ref).collect(), - body.relay_url.as_ref(), + tags.iter().map(AsRef::as_ref).collect(), + relay_url.as_ref(), ) .await?; state .registration_cache .insert(client_id.into_value(), CachedRegistration { - tags: body.tags.clone(), - relay_url: body.relay_url.clone(), + tags: tags.into_iter().collect::>(), + relay_url, }) .await; - increment_counter!(state.metrics, register); - Ok(Response::default()) } + +async fn update_registration( + state: &Arc, + client_id: ClientId, + append_tags: Option>>, + remove_tags: Option>>, + relay_url: Arc, +) -> error::Result { + let append_tags = append_tags.unwrap_or_default(); + let remove_tags = remove_tags.unwrap_or_default(); + + if remove_tags.intersection(&append_tags).count() > 0 { + return Err(Error::InvalidUpdateRequest); + } + + let registration = state + .registration_store + .get_registration(client_id.as_ref()) + .await?; + + let tags = registration + .tags + .into_iter() + .collect::>() + .difference(&remove_tags) + .cloned() + .collect::>() + .union(&append_tags) + .cloned() + .collect(); + + overwrite_registration(state, client_id, tags, relay_url).await +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 3035d03..9b38b70 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -24,6 +24,8 @@ pub struct Metrics { pub served_items: Counter, pub register: Counter, + pub registration_overwrite: Counter, + pub registration_update: Counter, pub cached_registrations: Counter, pub fetched_registrations: Counter, pub registration_cache_invalidation: Counter, @@ -71,7 +73,17 @@ impl Metrics { let register = meter .u64_counter("register") - .with_description("The number of calls to the register method") + .with_description("The total number of calls to the register method") + .init(); + + let registration_overwrite = meter + .u64_counter("register") + .with_description("The number of calls to the register method in overwrite mode") + .init(); + + let registration_update = meter + .u64_counter("register") + .with_description("The number of calls to the register method in update mode") .init(); let cached_registrations = meter @@ -96,6 +108,8 @@ impl Metrics { get_queries, served_items, register, + registration_overwrite, + registration_update, cached_registrations, fetched_registrations, registration_cache_invalidation, diff --git a/tests/registration/mod.rs b/tests/registration/mod.rs index 127c278..8c16133 100644 --- a/tests/registration/mod.rs +++ b/tests/registration/mod.rs @@ -8,11 +8,13 @@ use { #[test_context(ServerContext)] #[tokio::test] -async fn test_register(ctx: &mut ServerContext) { +async fn test_register_new(ctx: &mut ServerContext) { let (jwt, client_id) = get_client_jwt(); let payload = RegisterPayload { - tags: vec![Arc::from("4000"), Arc::from("5***")], + tags: Some(vec![Arc::from("4000"), Arc::from("5***")]), + append_tags: None, + remove_tags: None, relay_url: Arc::from(TEST_RELAY_URL), }; @@ -40,6 +42,233 @@ async fn test_register(ctx: &mut ServerContext) { .is_some()) } +#[test_context(ServerContext)] +#[tokio::test] +async fn test_register(ctx: &mut ServerContext) { + let (jwt, client_id) = get_client_jwt(); + let relay_url: Arc = Arc::from(TEST_RELAY_URL); + + struct TestCase { + name: &'static str, + start: Vec>, + overwrite: Option>>, + append: Option>>, + remove: Option>>, + expected: Vec>, + } + + let tests = vec![ + TestCase { + name: "Overwrite", + start: vec![Arc::from("4000")], + overwrite: Some(vec![Arc::from("4001"), Arc::from("4002")]), + append: None, + remove: None, + expected: vec![Arc::from("4001"), Arc::from("4002")], + }, + TestCase { + name: "Update, Add tags", + start: vec![Arc::from("4000")], + overwrite: None, + append: Some(vec![Arc::from("4001"), Arc::from("4002")]), + remove: None, + expected: vec![Arc::from("4000"), Arc::from("4001"), Arc::from("4002")], + }, + TestCase { + name: "Update, Add existing tags", + start: vec![Arc::from("4000")], + overwrite: None, + append: Some(vec![Arc::from("4000"), Arc::from("4001")]), + remove: None, + expected: vec![Arc::from("4000"), Arc::from("4001")], + }, + TestCase { + name: "Update, Remove tags", + start: vec![Arc::from("4000"), Arc::from("4001"), Arc::from("4002")], + overwrite: None, + append: None, + remove: Some(vec![Arc::from("4001"), Arc::from("4002")]), + expected: vec![Arc::from("4000")], + }, + TestCase { + name: "Update, Remove missing tags", + start: vec![Arc::from("4000"), Arc::from("4001"), Arc::from("4002")], + overwrite: None, + append: None, + remove: Some(vec![Arc::from("5000"), Arc::from("4001")]), + expected: vec![Arc::from("4000"), Arc::from("4002")], + }, + TestCase { + name: "Overwrite + Update, Update has remove tag from overwrite", + start: vec![Arc::from("4000")], + overwrite: Some(vec![Arc::from("5000")]), + append: Some(vec![Arc::from("4001"), Arc::from("4002")]), + remove: Some(vec![Arc::from("5000")]), + expected: vec![Arc::from("5000")], + }, + TestCase { + name: "Empty", + start: vec![Arc::from("4000")], + overwrite: None, + append: None, + remove: None, + expected: vec![Arc::from("4000")], + }, + ]; + + for test in tests.iter() { + ctx.server + .registration_store + .registrations + .insert(client_id.to_string(), Registration { + id: None, + client_id: client_id.clone().into_value(), + tags: test.start.clone(), + relay_url: relay_url.clone(), + }) + .await; + + let payload = RegisterPayload { + tags: test.overwrite.clone(), + append_tags: test.append.clone(), + remove_tags: test.remove.clone(), + relay_url: relay_url.clone(), + }; + + let client = reqwest::Client::new(); + let response = client + .post(format!("http://{}/register", ctx.server.public_addr)) + .json(&payload) + .header(http::header::AUTHORIZATION, format!("Bearer {jwt}")) + .send() + .await + .expect("Call failed"); + + assert!( + response.status().is_success(), + "{:?} - Response was not successful: {:?} - {:?}", + test.name, + response.status(), + response.text().await + ); + + let registration = ctx + .server + .registration_store + .registrations + .get(client_id.value().as_ref()); + + assert!( + registration.is_some(), + "{:?} - Registration was not found in store", + test.name + ); + + let mut registration = registration.unwrap(); + registration.tags.sort(); + assert_eq!( + registration.tags, + test.expected.clone(), + "{:?} - Tags did not match expected", + test.name + ); + } +} + +#[test_context(ServerContext)] +#[tokio::test] +async fn test_register_update_bad_update(ctx: &mut ServerContext) { + let (jwt, client_id) = get_client_jwt(); + + let payload = RegisterPayload { + tags: None, + append_tags: Some(vec![Arc::from("5000")]), + remove_tags: Some(vec![Arc::from("5000")]), + relay_url: Arc::from(TEST_RELAY_URL), + }; + + let client = reqwest::Client::new(); + let response = client + .post(format!("http://{}/register", ctx.server.public_addr)) + .json(&payload) + .header(http::header::AUTHORIZATION, format!("Bearer {jwt}")) + .send() + .await + .expect("Call failed"); + + assert!( + response.status().is_client_error(), + "Response status was invalid: {:?} - {:?}", + response.status(), + response.text().await + ); + + let registration = ctx + .server + .registration_store + .registrations + .get(client_id.value().as_ref()); + + assert!( + registration.is_none(), + "Registration was found in store when it should not" + ); +} + +#[test_context(ServerContext)] +#[tokio::test] +async fn test_register_update_bad_update_with_overwrite(ctx: &mut ServerContext) { + let (jwt, client_id) = get_client_jwt(); + + let tags = vec![Arc::from("4000")]; + let payload = RegisterPayload { + tags: Some(tags.clone()), + append_tags: Some(vec![Arc::from("5000")]), + remove_tags: Some(vec![Arc::from("5000")]), + relay_url: Arc::from(TEST_RELAY_URL), + }; + + let client = reqwest::Client::new(); + let response = client + .post(format!("http://{}/register", ctx.server.public_addr)) + .json(&payload) + .header(http::header::AUTHORIZATION, format!("Bearer {jwt}")) + .send() + .await + .expect("Call failed"); + + assert!( + response.status().is_success(), + "Response was unsuccessful: {:?} - {:?}", + response.status(), + response.text().await + ); + + assert!(ctx + .server + .registration_store + .registrations + .get(client_id.value().as_ref()) + .is_some()); + + let registration = ctx + .server + .registration_store + .registrations + .get(client_id.value().as_ref()); + + assert!( + registration.is_some(), + "Registration was not found in store" + ); + + assert_eq!( + registration.unwrap().tags, + tags.clone(), + "Tags did not match expected" + ); +} + #[test_context(ServerContext)] #[tokio::test] async fn test_get_registration(ctx: &mut ServerContext) { @@ -84,6 +313,6 @@ async fn test_get_registration(ctx: &mut ServerContext) { assert_eq!(allowed_origins.to_str().unwrap(), "*"); let payload: RegisterPayload = response.json().await.unwrap(); - assert_eq!(payload.tags, tags); + assert_eq!(payload.tags.unwrap(), tags); assert_eq!(payload.relay_url.as_ref(), TEST_RELAY_URL); }