Skip to content

Commit

Permalink
chore: Add new_momento_cli to use in Momento CLI and add agent me…
Browse files Browse the repository at this point in the history
…tadata for SDK (#39)
  • Loading branch information
poppoerika authored Apr 12, 2022
1 parent 97c379d commit fa6fe5b
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 56 deletions.
17 changes: 0 additions & 17 deletions src/grpc/auth_header_interceptor.rs

This file was deleted.

23 changes: 0 additions & 23 deletions src/grpc/cache_header_interceptor.rs

This file was deleted.

35 changes: 35 additions & 0 deletions src/grpc/header_interceptor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};

const AUTHORIZATION: &str = "authorization";
const AGENT: &str = "agent";

#[derive(Clone)]
pub struct HeaderInterceptor {
pub header: HashMap<String, String>,
}

impl tonic::service::Interceptor for HeaderInterceptor {
fn call(
&mut self,
mut request: tonic::Request<()>,
) -> Result<tonic::Request<()>, tonic::Status> {
static ARE_ONLY_ONCE_HEADER_SENT: AtomicBool = AtomicBool::new(false);
for (key, value) in self.header.iter() {
if *key == *AUTHORIZATION {
request.metadata_mut().insert(
tonic::metadata::AsciiMetadataKey::from_static(AUTHORIZATION),
tonic::metadata::AsciiMetadataValue::from_str(value).unwrap(),
);
}
if !ARE_ONLY_ONCE_HEADER_SENT.load(Ordering::Relaxed) && *key == *AGENT {
request.metadata_mut().insert(
tonic::metadata::AsciiMetadataKey::from_static(AGENT),
tonic::metadata::AsciiMetadataValue::from_str(value).unwrap(),
);
ARE_ONLY_ONCE_HEADER_SENT.store(true, Ordering::Relaxed);
}
}
Ok(request)
}
}
3 changes: 1 addition & 2 deletions src/grpc/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod auth_header_interceptor;
pub mod cache_header_interceptor;
pub mod header_interceptor;
98 changes: 84 additions & 14 deletions src/simple_cache_client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde_json::Value;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::num::NonZeroU64;
use tonic::{
Expand All @@ -8,13 +9,12 @@ use tonic::{
};

use crate::endpoint_resolver::MomentoEndpointsResolver;
use crate::grpc::cache_header_interceptor::CacheHeaderInterceptor;
use crate::grpc::header_interceptor::HeaderInterceptor;
use crate::{
generated::control_client::{
scs_control_client::ScsControlClient, CreateCacheRequest, CreateSigningKeyRequest,
DeleteCacheRequest, ListCachesRequest, RevokeSigningKeyRequest,
},
grpc::auth_header_interceptor::AuthHeaderInterceptor,
response::{
create_signing_key_response::MomentoCreateSigningKeyResponse,
error::MomentoError,
Expand All @@ -31,6 +31,11 @@ use crate::{
generated::cache_client::{scs_client::ScsClient, ECacheResult, GetRequest, SetRequest},
response::cache_get_response::MomentoGetResponse,
};

const AUTHORIZATION_NAME: &str = "authorization";
const AGENT_NAME: &str = "agent";
const VERSION: &str = env!("CARGO_PKG_VERSION");

pub trait MomentoRequest {
fn into_bytes(self) -> Vec<u8>;
}
Expand Down Expand Up @@ -104,18 +109,25 @@ impl SimpleCacheClientBuilder {
}

pub fn build(self) -> SimpleCacheClient {
let agent_value = format!("rust:{}", VERSION);
let mut control_interceptor_hashmap = HashMap::new();
control_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), self.auth_token.clone());
control_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value.clone());
let control_interceptor = InterceptedService::new(
self.control_channel,
AuthHeaderInterceptor {
auth_key: self.auth_token.clone(),
HeaderInterceptor {
header: control_interceptor_hashmap,
},
);
let control_client = ScsControlClient::new(control_interceptor);

let mut data_interceptor_hashmap = HashMap::new();
data_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), self.auth_token.clone());
data_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value);
let data_interceptor = InterceptedService::new(
self.data_channel,
CacheHeaderInterceptor {
auth_key: self.auth_token.clone(),
HeaderInterceptor {
header: data_interceptor_hashmap,
},
);
let data_client = ScsClient::new(data_interceptor);
Expand All @@ -131,8 +143,8 @@ impl SimpleCacheClientBuilder {

pub struct SimpleCacheClient {
data_endpoint: String,
control_client: ScsControlClient<InterceptedService<Channel, AuthHeaderInterceptor>>,
data_client: ScsClient<InterceptedService<Channel, CacheHeaderInterceptor>>,
control_client: ScsControlClient<InterceptedService<Channel, HeaderInterceptor>>,
data_client: ScsClient<InterceptedService<Channel, HeaderInterceptor>>,
item_default_ttl_seconds: NonZeroU64,
}

Expand Down Expand Up @@ -160,16 +172,66 @@ impl SimpleCacheClient {
default_ttl_seconds: NonZeroU64,
) -> Result<Self, MomentoError> {
let data_endpoint = utils::get_claims(&auth_token).c;
let agent_value = format!("rust:{}", VERSION);
let momento_endpoints = MomentoEndpointsResolver::resolve(&auth_token, &None);
let control_client = SimpleCacheClient::build_control_client(
auth_token.clone(),
momento_endpoints.control_endpoint,
agent_value.clone(),
)
.await;
let data_client = SimpleCacheClient::build_data_client(
auth_token.clone(),
momento_endpoints.data_endpoint,
agent_value,
)
.await;

let simple_cache_client = Self {
data_endpoint,
control_client: control_client.unwrap(),
data_client: data_client.unwrap(),
item_default_ttl_seconds: default_ttl_seconds,
};
Ok(simple_cache_client)
}

/// Returns an instance of a Momento client.
/// This is specifically used for Momento CLI to initialize a Momento client.
///
/// # Arguments
///
/// * `auth_token` - Momento Token
/// * `item_default_ttl_seconds` - Default TTL for items put into a cache.
/// # Examples
///
/// ```
/// # tokio_test::block_on(async {
/// use momento::simple_cache_client::SimpleCacheClient;
/// use std::env;
/// use std::num::NonZeroU64;
/// let auth_token = env::var("TEST_AUTH_TOKEN").expect("TEST_AUTH_TOKEN must be set");
/// let default_ttl = 30;
/// let momento = SimpleCacheClient::new_momento_cli(auth_token, NonZeroU64::new(default_ttl).unwrap()).await;
/// # })
/// ```
pub async fn new_momento_cli(
auth_token: String,
default_ttl_seconds: NonZeroU64,
) -> Result<Self, MomentoError> {
let data_endpoint = utils::get_claims(&auth_token).c;
let agent_value = format!("momento-cli:{}", VERSION);
let momento_endpoints = MomentoEndpointsResolver::resolve(&auth_token, &None);
let control_client = SimpleCacheClient::build_control_client(
auth_token.clone(),
momento_endpoints.control_endpoint,
agent_value.clone(),
)
.await;
let data_client = SimpleCacheClient::build_data_client(
auth_token.clone(),
momento_endpoints.data_endpoint,
agent_value,
)
.await;

Expand All @@ -185,18 +247,22 @@ impl SimpleCacheClient {
async fn build_control_client(
auth_token: String,
endpoint: String,
) -> Result<ScsControlClient<InterceptedService<Channel, AuthHeaderInterceptor>>, MomentoError>
agent_value: String,
) -> Result<ScsControlClient<InterceptedService<Channel, HeaderInterceptor>>, MomentoError>
{
let uri = Uri::try_from(endpoint)?;
let channel = Channel::builder(uri)
.tls_config(ClientTlsConfig::default())
.unwrap()
.connect_lazy();

let mut control_interceptor_hashmap = HashMap::new();
control_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), auth_token);
control_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value);
let interceptor = InterceptedService::new(
channel,
AuthHeaderInterceptor {
auth_key: auth_token,
HeaderInterceptor {
header: control_interceptor_hashmap,
},
);
let client = ScsControlClient::new(interceptor);
Expand All @@ -206,17 +272,21 @@ impl SimpleCacheClient {
async fn build_data_client(
auth_token: String,
endpoint: String,
) -> Result<ScsClient<InterceptedService<Channel, CacheHeaderInterceptor>>, MomentoError> {
agent_value: String,
) -> Result<ScsClient<InterceptedService<Channel, HeaderInterceptor>>, MomentoError> {
let uri = Uri::try_from(endpoint)?;
let channel = Channel::builder(uri)
.tls_config(ClientTlsConfig::default())
.unwrap()
.connect_lazy();

let mut data_interceptor_hashmap = HashMap::new();
data_interceptor_hashmap.insert(AUTHORIZATION_NAME.to_string(), auth_token);
data_interceptor_hashmap.insert(AGENT_NAME.to_string(), agent_value);
let interceptor = InterceptedService::new(
channel,
CacheHeaderInterceptor {
auth_key: auth_token,
HeaderInterceptor {
header: data_interceptor_hashmap,
},
);
let client = ScsClient::new(interceptor);
Expand Down

0 comments on commit fa6fe5b

Please sign in to comment.