Skip to content

Commit

Permalink
Merge pull request #167 from muzarski/self-identity
Browse files Browse the repository at this point in the history
config: set self-identity
  • Loading branch information
dkropachev authored Sep 11, 2024
2 parents b279526 + e9e3cd4 commit 9387c06
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
88 changes: 87 additions & 1 deletion scylla-rust-wrapper/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::retry_policy::CassRetryPolicy;
use crate::retry_policy::RetryPolicy::*;
use crate::ssl::CassSsl;
use crate::types::*;
use crate::uuid::CassUuid;
use openssl::ssl::SslContextBuilder;
use openssl_sys::SSL_CTX_up_ref;
use scylla::execution_profile::ExecutionProfileBuilder;
Expand All @@ -16,6 +17,7 @@ use scylla::load_balancing::{DefaultPolicyBuilder, LoadBalancingPolicy};
use scylla::retry_policy::RetryPolicy;
use scylla::speculative_execution::SimpleSpeculativeExecutionPolicy;
use scylla::statement::{Consistency, SerialConsistency};
use scylla::transport::SelfIdentity;
use scylla::{SessionBuilder, SessionConfig};
use std::collections::HashMap;
use std::convert::TryInto;
Expand All @@ -32,6 +34,9 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_compression_types.rs"));
const DEFAULT_CONSISTENCY: Consistency = Consistency::LocalOne;
const DEFAULT_REQUEST_TIMEOUT_MILLIS: u64 = 12000;

const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");

#[derive(Clone, Debug)]
pub(crate) struct LoadBalancingConfig {
pub(crate) token_awareness_enabled: bool,
Expand Down Expand Up @@ -85,6 +90,8 @@ pub struct CassCluster {
use_beta_protocol_version: bool,
auth_username: Option<String>,
auth_password: Option<String>,

client_id: Option<uuid::Uuid>,
}

impl CassCluster {
Expand All @@ -106,6 +113,11 @@ impl CassCluster {
pub(crate) fn get_contact_points(&self) -> &[String] {
&self.contact_points
}

#[inline]
pub(crate) fn get_client_id(&self) -> Option<uuid::Uuid> {
self.client_id
}
}

pub struct CassCustomPayload;
Expand Down Expand Up @@ -140,8 +152,13 @@ pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster {
.consistency(DEFAULT_CONSISTENCY)
.request_timeout(Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MILLIS)));

// Set DRIVER_NAME and DRIVER_VERSION of cpp-rust driver.
let custom_identity = SelfIdentity::new()
.with_custom_driver_name(DRIVER_NAME)
.with_custom_driver_version(DRIVER_VERSION);

Box::into_raw(Box::new(CassCluster {
session_builder: SessionBuilder::new(),
session_builder: SessionBuilder::new().custom_identity(custom_identity),
port: 9042,
contact_points: Vec::new(),
// Per DataStax documentation: Without additional configuration the C/C++ driver
Expand All @@ -152,6 +169,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> *mut CassCluster {
default_execution_profile_builder,
execution_profile_map: Default::default(),
load_balancing_config: Default::default(),
client_id: None,
}))
}

Expand Down Expand Up @@ -219,6 +237,74 @@ pub unsafe extern "C" fn cass_cluster_set_use_randomized_contact_points(
CassError::CASS_OK
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_application_name(
cluster_raw: *mut CassCluster,
app_name: *const c_char,
) {
cass_cluster_set_application_name_n(cluster_raw, app_name, strlen(app_name))
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_application_name_n(
cluster_raw: *mut CassCluster,
app_name: *const c_char,
app_name_len: size_t,
) {
let cluster = ptr_to_ref_mut(cluster_raw);
let app_name = ptr_to_cstr_n(app_name, app_name_len).unwrap().to_string();

cluster
.session_builder
.config
.identity
.set_application_name(app_name)
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_application_version(
cluster_raw: *mut CassCluster,
app_version: *const c_char,
) {
cass_cluster_set_application_version_n(cluster_raw, app_version, strlen(app_version))
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_application_version_n(
cluster_raw: *mut CassCluster,
app_version: *const c_char,
app_version_len: size_t,
) {
let cluster = ptr_to_ref_mut(cluster_raw);
let app_version = ptr_to_cstr_n(app_version, app_version_len)
.unwrap()
.to_string();

cluster
.session_builder
.config
.identity
.set_application_version(app_version);
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_client_id(
cluster_raw: *mut CassCluster,
client_id: CassUuid,
) {
let cluster = ptr_to_ref_mut(cluster_raw);

let client_uuid: uuid::Uuid = client_id.into();
let client_uuid_str = client_uuid.to_string();

cluster.client_id = Some(client_uuid);
cluster
.session_builder
.config
.identity
.set_client_id(client_uuid_str)
}

#[no_mangle]
pub unsafe extern "C" fn cass_cluster_set_use_schema(
cluster_raw: *mut CassCluster,
Expand Down
16 changes: 16 additions & 0 deletions scylla-rust-wrapper/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::query_result::{CassResult, CassResultData, CassRow, CassValue, Collec
use crate::statement::CassStatement;
use crate::statement::Statement;
use crate::types::{cass_uint64_t, size_t};
use crate::uuid::CassUuid;
use scylla::frame::response::result::{CqlValue, Row};
use scylla::frame::types::Consistency;
use scylla::query::Query;
Expand All @@ -31,6 +32,7 @@ use tokio::sync::RwLock;
pub struct CassSessionInner {
session: Session,
exec_profile_map: HashMap<ExecProfileName, ExecutionProfileHandle>,
client_id: uuid::Uuid,
}

impl CassSessionInner {
Expand Down Expand Up @@ -82,6 +84,10 @@ impl CassSessionInner {
session_opt,
session_builder,
exec_profile_map,
cluster
.get_client_id()
// If user did not set a client id, generate a random uuid v4.
.unwrap_or_else(uuid::Uuid::new_v4),
keyspace,
))
}
Expand All @@ -90,6 +96,7 @@ impl CassSessionInner {
session_opt: &RwLock<Option<CassSessionInner>>,
session_builder_fut: impl Future<Output = SessionBuilder>,
exec_profile_builder_map: HashMap<ExecProfileName, CassExecProfile>,
client_id: uuid::Uuid,
keyspace: Option<String>,
) -> CassFutureResult {
// This can sleep for a long time, but only if someone connects/closes session
Expand Down Expand Up @@ -119,6 +126,7 @@ impl CassSessionInner {
*session_guard = Some(CassSessionInner {
session,
exec_profile_map,
client_id,
});
Ok(CassResultValue::Empty)
}
Expand Down Expand Up @@ -562,6 +570,14 @@ pub unsafe extern "C" fn cass_session_close(session: *mut CassSession) -> *const
})
}

#[no_mangle]
pub unsafe extern "C" fn cass_session_get_client_id(session: *const CassSession) -> CassUuid {
let cass_session = ptr_to_ref(session);

let client_id: uuid::Uuid = cass_session.blocking_read().as_ref().unwrap().client_id;
client_id.into()
}

#[no_mangle]
pub unsafe extern "C" fn cass_session_get_schema_meta(
session: *const CassSession,
Expand Down

0 comments on commit 9387c06

Please sign in to comment.