diff --git a/scylla-rust-wrapper/src/exec_profile.rs b/scylla-rust-wrapper/src/exec_profile.rs index f3003dc7..098bc6a3 100644 --- a/scylla-rust-wrapper/src/exec_profile.rs +++ b/scylla-rust-wrapper/src/exec_profile.rs @@ -8,7 +8,7 @@ use std::time::Duration; use scylla::execution_profile::{ ExecutionProfile, ExecutionProfileBuilder, ExecutionProfileHandle, }; -use scylla::load_balancing::LatencyAwarenessBuilder; +use scylla::load_balancing::{LatencyAwarenessBuilder, LoadBalancingPolicy}; use scylla::retry_policy::RetryPolicy; use scylla::speculative_execution::SimpleSpeculativeExecutionPolicy; use scylla::statement::Consistency; @@ -31,7 +31,7 @@ use crate::types::{ #[derive(Clone, Debug)] pub struct CassExecProfile { inner: ExecutionProfileBuilder, - load_balancing_kind: LoadBalancingKind, + load_balancing_kind: Option, load_balancing_config: LoadBalancingConfig, } @@ -39,19 +39,22 @@ impl CassExecProfile { fn new() -> Self { Self { inner: ExecutionProfile::builder(), - load_balancing_kind: LoadBalancingKind::RoundRobin, + load_balancing_kind: None, load_balancing_config: Default::default(), } } - pub(crate) async fn build(self) -> ExecutionProfile { - self.inner - .load_balancing_policy( - self.load_balancing_config - .build(self.load_balancing_kind) - .await, - ) - .build() + pub(crate) async fn build( + self, + cluster_default_lbp: Arc, + ) -> ExecutionProfile { + let load_balacing = if let Some(load_balancing_kind) = self.load_balancing_kind { + self.load_balancing_config.build(load_balancing_kind).await + } else { + cluster_default_lbp + }; + + self.inner.load_balancing_policy(load_balacing).build() } } @@ -346,7 +349,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_dc_aware_n( let profile_builder = ptr_to_ref_mut(profile); set_load_balance_dc_aware_n( - |load_balancing_kind| profile_builder.load_balancing_kind = load_balancing_kind, + |load_balancing_kind| profile_builder.load_balancing_kind = Some(load_balancing_kind), local_dc, local_dc_length, used_hosts_per_remote_dc, @@ -359,7 +362,7 @@ pub unsafe extern "C" fn cass_execution_profile_set_load_balance_round_robin( profile: *mut CassExecProfile, ) -> CassError { let profile_builder = ptr_to_ref_mut(profile); - profile_builder.load_balancing_kind = LoadBalancingKind::RoundRobin; + profile_builder.load_balancing_kind = Some(LoadBalancingKind::RoundRobin); CassError::CASS_OK } @@ -479,7 +482,7 @@ mod tests { /* Test valid configurations */ let profile = ptr_to_ref(profile_raw); { - assert_matches!(profile.load_balancing_kind, LoadBalancingKind::RoundRobin); + assert_matches!(profile.load_balancing_kind, None); assert!(profile.load_balancing_config.token_awareness_enabled); assert!(!profile.load_balancing_config.latency_awareness_enabled); } @@ -508,7 +511,7 @@ mod tests { let load_balancing_kind = &profile.load_balancing_kind; match load_balancing_kind { - LoadBalancingKind::DcAware { local_dc } => { + Some(LoadBalancingKind::DcAware { local_dc }) => { assert_eq!(local_dc, "eu") } _ => panic!("Expected preferred dc"), diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 5d4f214f..14c6b483 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -112,13 +112,14 @@ impl CassSessionInner { "Already connecting, closing, or connected".msg(), )); } + + let (mut session_builder, default_lbp) = session_builder_and_default_lbp_fut.await; + let mut exec_profile_map = HashMap::with_capacity(exec_profile_builder_map.len()); for (name, builder) in exec_profile_builder_map { - exec_profile_map.insert(name, builder.build().await.into_handle()); + exec_profile_map.insert(name, builder.build(default_lbp.clone()).await.into_handle()); } - // TODO: pass default_lbp to exec profiles above. - let (mut session_builder, _default_lbp) = session_builder_and_default_lbp_fut.await; if let Some(keyspace) = keyspace { session_builder = session_builder.use_keyspace(keyspace, false); }