diff --git a/src/grpc/auth_header_interceptor.rs b/src/grpc/auth_header_interceptor.rs deleted file mode 100644 index 16ca6bf0..00000000 --- a/src/grpc/auth_header_interceptor.rs +++ /dev/null @@ -1,17 +0,0 @@ -#[derive(Clone)] -pub struct AuthHeaderInterceptor { - pub auth_key: String, -} - -impl tonic::service::Interceptor for AuthHeaderInterceptor { - fn call( - &mut self, - mut request: tonic::Request<()>, - ) -> Result, tonic::Status> { - request.metadata_mut().insert( - "authorization", - tonic::metadata::AsciiMetadataValue::from_str(self.auth_key.as_str()).unwrap(), - ); - Ok(request) - } -} diff --git a/src/grpc/cache_header_interceptor.rs b/src/grpc/cache_header_interceptor.rs deleted file mode 100644 index 545972eb..00000000 --- a/src/grpc/cache_header_interceptor.rs +++ /dev/null @@ -1,23 +0,0 @@ -#[derive(Clone)] -pub struct CacheHeaderInterceptor { - pub auth_key: String, -} - -impl tonic::service::Interceptor for CacheHeaderInterceptor { - fn call( - &mut self, - mut request: tonic::Request<()>, - ) -> Result, tonic::Status> { - request.metadata_mut().insert( - "authorization", - tonic::metadata::AsciiMetadataValue::from_str(self.auth_key.as_str()).unwrap(), - ); - // for reasons unknown, tonic seems to be stripping out the content-type. So we need to add this as - // a workaround so that the requests are successful - request.metadata_mut().insert( - "content-type", - tonic::metadata::AsciiMetadataValue::from_str("application/grpc").unwrap(), - ); - Ok(request) - } -} diff --git a/src/grpc/header_interceptor.rs b/src/grpc/header_interceptor.rs new file mode 100644 index 00000000..7b306e6a --- /dev/null +++ b/src/grpc/header_interceptor.rs @@ -0,0 +1,35 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; + +const AUTHORIZATION: &str = "authorization"; +const AGENT: &str = "agent"; + +#[derive(Clone)] +pub struct HeaderInterceptor { + pub header: HashMap, +} + +impl tonic::service::Interceptor for HeaderInterceptor { + fn call( + &mut self, + mut request: tonic::Request<()>, + ) -> Result, tonic::Status> { + static ARE_ONLY_ONCE_HEADER_SENT: AtomicBool = AtomicBool::new(false); + for (key, value) in self.header.iter() { + if *key == *AUTHORIZATION { + request.metadata_mut().insert( + tonic::metadata::AsciiMetadataKey::from_static(AUTHORIZATION), + tonic::metadata::AsciiMetadataValue::from_str(value).unwrap(), + ); + } + if !ARE_ONLY_ONCE_HEADER_SENT.load(Ordering::Relaxed) && *key == *AGENT { + request.metadata_mut().insert( + tonic::metadata::AsciiMetadataKey::from_static(AGENT), + tonic::metadata::AsciiMetadataValue::from_str(value).unwrap(), + ); + ARE_ONLY_ONCE_HEADER_SENT.store(true, Ordering::Relaxed); + } + } + Ok(request) + } +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 6604ac30..7efaa74e 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -1,2 +1 @@ -pub mod auth_header_interceptor; -pub mod cache_header_interceptor; +pub mod header_interceptor; diff --git a/src/simple_cache_client.rs b/src/simple_cache_client.rs index 75e35131..4700736c 100644 --- a/src/simple_cache_client.rs +++ b/src/simple_cache_client.rs @@ -1,4 +1,5 @@ use serde_json::Value; +use std::collections::HashMap; use std::convert::TryFrom; use std::num::NonZeroU64; use tonic::{ @@ -8,13 +9,12 @@ use tonic::{ }; use crate::endpoint_resolver::MomentoEndpointsResolver; -use crate::grpc::cache_header_interceptor::CacheHeaderInterceptor; +use crate::grpc::header_interceptor::HeaderInterceptor; use crate::{ generated::control_client::{ scs_control_client::ScsControlClient, CreateCacheRequest, CreateSigningKeyRequest, DeleteCacheRequest, ListCachesRequest, RevokeSigningKeyRequest, }, - grpc::auth_header_interceptor::AuthHeaderInterceptor, response::{ create_signing_key_response::MomentoCreateSigningKeyResponse, error::MomentoError, @@ -31,6 +31,11 @@ use crate::{ generated::cache_client::{scs_client::ScsClient, ECacheResult, GetRequest, SetRequest}, response::cache_get_response::MomentoGetResponse, }; + +const AUTHORIZATION_NAME: &str = "authorization"; +const AGENT_NAME: &str = "agent"; +const VERSION: &str = env!("CARGO_PKG_VERSION"); + pub trait MomentoRequest { fn into_bytes(self) -> Vec; } @@ -104,18 +109,25 @@ impl SimpleCacheClientBuilder { } pub fn build(self) -> SimpleCacheClient { + let agent_value = format!("rust:{}", VERSION); + let mut control_interceptor_hashmap = HashMap::new(); + control_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), self.auth_token.clone()); + control_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value.clone()); let control_interceptor = InterceptedService::new( self.control_channel, - AuthHeaderInterceptor { - auth_key: self.auth_token.clone(), + HeaderInterceptor { + header: control_interceptor_hashmap, }, ); let control_client = ScsControlClient::new(control_interceptor); + let mut data_interceptor_hashmap = HashMap::new(); + data_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), self.auth_token.clone()); + data_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value); let data_interceptor = InterceptedService::new( self.data_channel, - CacheHeaderInterceptor { - auth_key: self.auth_token.clone(), + HeaderInterceptor { + header: data_interceptor_hashmap, }, ); let data_client = ScsClient::new(data_interceptor); @@ -131,8 +143,8 @@ impl SimpleCacheClientBuilder { pub struct SimpleCacheClient { data_endpoint: String, - control_client: ScsControlClient>, - data_client: ScsClient>, + control_client: ScsControlClient>, + data_client: ScsClient>, item_default_ttl_seconds: NonZeroU64, } @@ -160,16 +172,66 @@ impl SimpleCacheClient { default_ttl_seconds: NonZeroU64, ) -> Result { let data_endpoint = utils::get_claims(&auth_token).c; + let agent_value = format!("rust:{}", VERSION); + let momento_endpoints = MomentoEndpointsResolver::resolve(&auth_token, &None); + let control_client = SimpleCacheClient::build_control_client( + auth_token.clone(), + momento_endpoints.control_endpoint, + agent_value.clone(), + ) + .await; + let data_client = SimpleCacheClient::build_data_client( + auth_token.clone(), + momento_endpoints.data_endpoint, + agent_value, + ) + .await; + + let simple_cache_client = Self { + data_endpoint, + control_client: control_client.unwrap(), + data_client: data_client.unwrap(), + item_default_ttl_seconds: default_ttl_seconds, + }; + Ok(simple_cache_client) + } + /// Returns an instance of a Momento client. + /// This is specifically used for Momento CLI to initialize a Momento client. + /// + /// # Arguments + /// + /// * `auth_token` - Momento Token + /// * `item_default_ttl_seconds` - Default TTL for items put into a cache. + /// # Examples + /// + /// ``` + /// # tokio_test::block_on(async { + /// use momento::simple_cache_client::SimpleCacheClient; + /// use std::env; + /// use std::num::NonZeroU64; + /// let auth_token = env::var("TEST_AUTH_TOKEN").expect("TEST_AUTH_TOKEN must be set"); + /// let default_ttl = 30; + /// let momento = SimpleCacheClient::new_momento_cli(auth_token, NonZeroU64::new(default_ttl).unwrap()).await; + /// # }) + /// ``` + pub async fn new_momento_cli( + auth_token: String, + default_ttl_seconds: NonZeroU64, + ) -> Result { + let data_endpoint = utils::get_claims(&auth_token).c; + let agent_value = format!("momento-cli:{}", VERSION); let momento_endpoints = MomentoEndpointsResolver::resolve(&auth_token, &None); let control_client = SimpleCacheClient::build_control_client( auth_token.clone(), momento_endpoints.control_endpoint, + agent_value.clone(), ) .await; let data_client = SimpleCacheClient::build_data_client( auth_token.clone(), momento_endpoints.data_endpoint, + agent_value, ) .await; @@ -185,7 +247,8 @@ impl SimpleCacheClient { async fn build_control_client( auth_token: String, endpoint: String, - ) -> Result>, MomentoError> + agent_value: String, + ) -> Result>, MomentoError> { let uri = Uri::try_from(endpoint)?; let channel = Channel::builder(uri) @@ -193,10 +256,13 @@ impl SimpleCacheClient { .unwrap() .connect_lazy(); + let mut control_interceptor_hashmap = HashMap::new(); + control_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), auth_token); + control_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value); let interceptor = InterceptedService::new( channel, - AuthHeaderInterceptor { - auth_key: auth_token, + HeaderInterceptor { + header: control_interceptor_hashmap, }, ); let client = ScsControlClient::new(interceptor); @@ -206,17 +272,21 @@ impl SimpleCacheClient { async fn build_data_client( auth_token: String, endpoint: String, - ) -> Result>, MomentoError> { + agent_value: String, + ) -> Result>, MomentoError> { let uri = Uri::try_from(endpoint)?; let channel = Channel::builder(uri) .tls_config(ClientTlsConfig::default()) .unwrap() .connect_lazy(); + let mut data_interceptor_hashmap = HashMap::new(); + data_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), auth_token); + data_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value); let interceptor = InterceptedService::new( channel, - CacheHeaderInterceptor { - auth_key: auth_token, + HeaderInterceptor { + header: data_interceptor_hashmap, }, ); let client = ScsClient::new(interceptor);