Skip to content
This repository has been archived by the owner on Feb 11, 2024. It is now read-only.

Commit

Permalink
feat: add update method to register endpoint (#61)
Browse files Browse the repository at this point in the history
* feat: add update method to register endpoint

* fix: use `HashSet` in tags updates
  • Loading branch information
Xavier Basty authored Jul 13, 2023
1 parent 02b668f commit ac132bc
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 14 deletions.
11 changes: 11 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ pub enum Error {
#[error("invalid options provided for {0}")]
InvalidOptionsProvided(String),

#[error("invalid update request")]
InvalidUpdateRequest,

#[error(transparent)]
FromUtf8Error(#[from] std::string::FromUtf8Error),

Expand Down Expand Up @@ -189,6 +192,14 @@ impl IntoResponse for Error {
}],
vec![],
),
Error::InvalidUpdateRequest => crate::handlers::Response::new_failure(
StatusCode::BAD_REQUEST,
vec![ResponseError {
name: "appendTags/removeTags".to_string(),
message: "cannot append and remove the same tag".to_string(),
}],
vec![],
),
e => crate::handlers::Response::new_failure(StatusCode::INTERNAL_SERVER_ERROR, vec![
ResponseError {
name: "unknown_error".to_string(),
Expand Down
4 changes: 3 additions & 1 deletion src/handlers/get_registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ pub async fn handler(
.await;

Ok(Json(RegisterPayload {
tags: registration.tags,
tags: Some(registration.tags),
append_tags: None,
remove_tags: None,
relay_url: registration.relay_url,
}))
}
87 changes: 78 additions & 9 deletions src/handlers/register.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use {
crate::{
auth::AuthBearer,
error,
error::{self, Error},
handlers::Response,
increment_counter,
log::prelude::*,
Expand All @@ -13,13 +13,15 @@ use {
jwt::{JwtBasicClaims, VerifyableClaims},
},
serde::{Deserialize, Serialize},
std::sync::Arc,
std::{collections::HashSet, sync::Arc},
};

#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPayload {
pub tags: Vec<Arc<str>>,
pub tags: Option<Vec<Arc<str>>>,
pub append_tags: Option<Vec<Arc<str>>>,
pub remove_tags: Option<Vec<Arc<str>>>,
pub relay_url: Arc<str>,
}

Expand All @@ -32,24 +34,91 @@ pub async fn handler(
claims.verify_basic(&state.auth_aud, None)?;
let client_id = ClientId::from(claims.iss);

increment_counter!(state.metrics, register);

if let Some(tags) = body.tags {
increment_counter!(state.metrics, registration_overwrite);

let tags = tags.into_iter().collect::<HashSet<_>>();
overwrite_registration(&state, client_id.clone(), tags, body.relay_url).await?;
} else {
increment_counter!(state.metrics, registration_update);

let append_tags = body
.append_tags
.map(|tags| tags.into_iter().collect::<HashSet<_>>());
let remove_tags = body
.remove_tags
.map(|tags| tags.into_iter().collect::<HashSet<_>>());

update_registration(
&state,
client_id.clone(),
append_tags,
remove_tags,
body.relay_url,
)
.await?;
}

Ok(Response::default())
}

async fn overwrite_registration(
state: &Arc<AppState>,
client_id: ClientId,
tags: HashSet<Arc<str>>,
relay_url: Arc<str>,
) -> error::Result<Response> {
state
.registration_store
.upsert_registration(
client_id.value(),
body.tags.iter().map(AsRef::as_ref).collect(),
body.relay_url.as_ref(),
tags.iter().map(AsRef::as_ref).collect(),
relay_url.as_ref(),
)
.await?;

state
.registration_cache
.insert(client_id.into_value(), CachedRegistration {
tags: body.tags.clone(),
relay_url: body.relay_url.clone(),
tags: tags.into_iter().collect::<Vec<_>>(),
relay_url,
})
.await;

increment_counter!(state.metrics, register);

Ok(Response::default())
}

async fn update_registration(
state: &Arc<AppState>,
client_id: ClientId,
append_tags: Option<HashSet<Arc<str>>>,
remove_tags: Option<HashSet<Arc<str>>>,
relay_url: Arc<str>,
) -> error::Result<Response> {
let append_tags = append_tags.unwrap_or_default();
let remove_tags = remove_tags.unwrap_or_default();

if remove_tags.intersection(&append_tags).count() > 0 {
return Err(Error::InvalidUpdateRequest);
}

let registration = state
.registration_store
.get_registration(client_id.as_ref())
.await?;

let tags = registration
.tags
.into_iter()
.collect::<HashSet<_>>()
.difference(&remove_tags)
.cloned()
.collect::<HashSet<_>>()
.union(&append_tags)
.cloned()
.collect();

overwrite_registration(state, client_id, tags, relay_url).await
}
16 changes: 15 additions & 1 deletion src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub struct Metrics {
pub served_items: Counter<u64>,

pub register: Counter<u64>,
pub registration_overwrite: Counter<u64>,
pub registration_update: Counter<u64>,
pub cached_registrations: Counter<u64>,
pub fetched_registrations: Counter<u64>,
pub registration_cache_invalidation: Counter<u64>,
Expand Down Expand Up @@ -71,7 +73,17 @@ impl Metrics {

let register = meter
.u64_counter("register")
.with_description("The number of calls to the register method")
.with_description("The total number of calls to the register method")
.init();

let registration_overwrite = meter
.u64_counter("register")
.with_description("The number of calls to the register method in overwrite mode")
.init();

let registration_update = meter
.u64_counter("register")
.with_description("The number of calls to the register method in update mode")
.init();

let cached_registrations = meter
Expand All @@ -96,6 +108,8 @@ impl Metrics {
get_queries,
served_items,
register,
registration_overwrite,
registration_update,
cached_registrations,
fetched_registrations,
registration_cache_invalidation,
Expand Down
Loading

0 comments on commit ac132bc

Please sign in to comment.