diff --git a/scylla-rust-wrapper/src/cluster.rs b/scylla-rust-wrapper/src/cluster.rs index 2499aa1a..27156d05 100644 --- a/scylla-rust-wrapper/src/cluster.rs +++ b/scylla-rust-wrapper/src/cluster.rs @@ -41,7 +41,7 @@ const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION"); pub(crate) struct LoadBalancingConfig { pub(crate) token_awareness_enabled: bool, pub(crate) token_aware_shuffling_replicas_enabled: bool, - pub(crate) load_balancing_kind: LoadBalancingKind, + pub(crate) load_balancing_kind: Option, pub(crate) latency_awareness_enabled: bool, pub(crate) latency_awareness_builder: LatencyAwarenessBuilder, } @@ -49,13 +49,18 @@ impl LoadBalancingConfig { // This is `async` to prevent running this function from beyond tokio context, // as it results in panic due to DefaultPolicyBuilder::build() spawning a tokio task. pub(crate) async fn build(self) -> Arc { + let load_balancing_kind = self + .load_balancing_kind + // Round robin is chosen by default for cluster wide LBP. + .unwrap_or(LoadBalancingKind::RoundRobin); + let mut builder = DefaultPolicyBuilder::new().token_aware(self.token_awareness_enabled); if self.token_awareness_enabled { // Cpp-driver enables shuffling replicas only if token aware routing is enabled. builder = builder.enable_shuffling_replicas(self.token_aware_shuffling_replicas_enabled); } - if let LoadBalancingKind::DcAware { local_dc } = self.load_balancing_kind { + if let LoadBalancingKind::DcAware { local_dc } = load_balancing_kind { builder = builder.prefer_datacenter(local_dc).permit_dc_failover(true) } if self.latency_awareness_enabled { @@ -69,7 +74,7 @@ impl Default for LoadBalancingConfig { Self { token_awareness_enabled: true, token_aware_shuffling_replicas_enabled: true, - load_balancing_kind: LoadBalancingKind::RoundRobin, + load_balancing_kind: None, latency_awareness_enabled: false, latency_awareness_builder: Default::default(), } @@ -416,7 +421,7 @@ pub unsafe extern "C" fn cass_cluster_set_credentials_n( #[no_mangle] pub unsafe extern "C" fn cass_cluster_set_load_balance_round_robin(cluster_raw: *mut CassCluster) { let cluster = ptr_to_ref_mut(cluster_raw); - cluster.load_balancing_config.load_balancing_kind = LoadBalancingKind::RoundRobin; + cluster.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); } #[no_mangle] @@ -455,7 +460,7 @@ pub(crate) unsafe fn set_load_balance_dc_aware_n( .unwrap() .to_string(); - load_balancing_config.load_balancing_kind = LoadBalancingKind::DcAware { local_dc }; + load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::DcAware { local_dc }); CassError::CASS_OK } @@ -810,10 +815,7 @@ mod tests { /* Test valid configurations */ let cluster = ptr_to_ref(cluster_raw); { - assert_matches!( - cluster.load_balancing_config.load_balancing_kind, - LoadBalancingKind::RoundRobin - ); + assert_matches!(cluster.load_balancing_config.load_balancing_kind, None); assert!(cluster.load_balancing_config.token_awareness_enabled); assert!(!cluster.load_balancing_config.latency_awareness_enabled); } @@ -842,7 +844,7 @@ mod tests { let load_balancing_kind = &cluster.load_balancing_config.load_balancing_kind; match load_balancing_kind { - LoadBalancingKind::DcAware { local_dc } => { + Some(LoadBalancingKind::DcAware { local_dc }) => { assert_eq!(local_dc, "eu") } _ => panic!("Expected preferred dc"), diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index a5b7ca8f..52995dc8 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -8,7 +8,7 @@ use std::time::Duration; use scylla::execution_profile::{ ExecutionProfile, ExecutionProfileBuilder, ExecutionProfileHandle, }; -use scylla::load_balancing::LatencyAwarenessBuilder; +use scylla::load_balancing::{LatencyAwarenessBuilder, LoadBalancingPolicy}; use scylla::retry_policy::RetryPolicy; use scylla::speculative_execution::SimpleSpeculativeExecutionPolicy; use scylla::statement::Consistency; @@ -42,10 +42,19 @@ impl CassExecProfile { } } - pub(crate) async fn build(self) -> ExecutionProfile { - self.inner - .load_balancing_policy(self.load_balancing_config.build().await) - .build() + pub(crate) async fn build( + self, + cluster_default_lbp: Arc, + ) -> ExecutionProfile { + let load_balacing = if let Some(_) = self.load_balancing_config.load_balancing_kind { + self.load_balancing_config.build().await + } else { + // If load balancing config does not have LB kind defined, + // we make use of cluster's LBP. + cluster_default_lbp + }; + + self.inner.load_balancing_policy(load_balacing).build() } } @@ -353,7 +362,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( profile: *mut CassExecProfile, ) -> CassError { let profile_builder = ptr_to_ref_mut(profile); - profile_builder.load_balancing_config.load_balancing_kind = LoadBalancingKind::RoundRobin; + profile_builder.load_balancing_config.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); CassError::CASS_OK } @@ -473,10 +482,7 @@ mod tests { /* Test valid configurations */ let profile = ptr_to_ref(profile_raw); { - assert_matches!( - profile.load_balancing_config.load_balancing_kind, - LoadBalancingKind::RoundRobin - ); + assert_matches!(profile.load_balancing_config.load_balancing_kind, None); assert!(profile.load_balancing_config.token_awareness_enabled); assert!(!profile.load_balancing_config.latency_awareness_enabled); } @@ -505,7 +511,7 @@ mod tests { let load_balancing_kind = &profile.load_balancing_config.load_balancing_kind; match load_balancing_kind { - LoadBalancingKind::DcAware { local_dc } => { + Some(LoadBalancingKind::DcAware { local_dc }) => { assert_eq!(local_dc, "eu") } _ => panic!("Expected preferred dc"), diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 5d4f214f..14c6b483 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -112,13 +112,14 @@ impl CassSessionInner { "Already connecting, closing, or connected".msg(), )); } + + let (mut session_builder, default_lbp) = session_builder_and_default_lbp_fut.await; + let mut exec_profile_map = HashMap::with_capacity(exec_profile_builder_map.len()); for (name, builder) in exec_profile_builder_map { - exec_profile_map.insert(name, builder.build().await.into_handle()); + exec_profile_map.insert(name, builder.build(default_lbp.clone()).await.into_handle()); } - // TODO: pass default_lbp to exec profiles above. - let (mut session_builder, _default_lbp) = session_builder_and_default_lbp_fut.await; if let Some(keyspace) = keyspace { session_builder = session_builder.use_keyspace(keyspace, false); }