Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shard aware batching - add Session::shard_for_statement & Batch::enforce_target_node #738

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::sync::Arc;
use crate::history::HistoryListener;
use crate::retry_policy::RetryPolicy;
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
use crate::transport::execution_profile::ExecutionProfileHandle;
use crate::transport::{execution_profile::ExecutionProfileHandle, Node};
use crate::Session;

use super::StatementConfig;
pub use super::{Consistency, SerialConsistency};
Expand Down Expand Up @@ -144,6 +145,82 @@ impl Batch {
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
self.config.execution_profile_handle.as_ref()
}

/// Associates the batch with a new execution profile that will have a load balancing policy
/// that will enforce the use of the provided [`Node`] to the extent possible.
///
/// This should typically be used in conjunction with [`Session::shard_for_statement`], where
/// you would constitute a batch by assigning to the same batch all the statements that would be executed in
/// the same shard.
///
/// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement
/// to the same node, you should use this method to enforce the use of the original node that was envisioned by
/// `shard_for_statement` for the batch:
///
/// ```rust
/// # use scylla::Session;
/// # use std::error::Error;
/// # async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
/// use scylla::{
/// batch::Batch,
/// frame::value::{SerializedValues, ValueList},
/// };
///
/// let prepared_statement = session
/// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)")
/// .await?;
///
/// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned();
/// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?;
///
/// // Send that to a task that will handle statements targeted to the same shard
///
/// // On that task:
/// // Constitute a batch with all the statements that would be executed in the same shard
///
/// let mut batch: Batch = Default::default();
/// if let Some((node, _shard_idx)) = shard {
/// batch.enforce_target_node(&node, &session);
/// }
/// let mut batch_values = Vec::new();
///
/// // As the task handling statements targeted to this shard receives them,
/// // it appends them to the batch
/// batch.append_statement(prepared_statement);
/// batch_values.push(serialized_values);
///
/// // Run the batch
/// session.batch(&batch, batch_values).await?;
/// # Ok(())
/// # }
/// ```
///
///
/// If the target node is not available anymore at the time of executing the statement, it will fallback to the
/// original load balancing policy:
/// - Either that currently set on the [`Batch`], if any
/// - Or that of the [`Session`] if there isn't one on the `Batch`
pub fn enforce_target_node(
&mut self,
node: &Arc<Node>,
base_execution_profile_from_session: &Session,
) {
let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| {
base_execution_profile_from_session.get_default_execution_profile_handle()
});
self.set_execution_profile_handle(Some(
execution_profile_handle
.pointee_to_builder()
.load_balancing_policy(Arc::new(
crate::load_balancing::EnforceTargetNodePolicy::new(
node,
execution_profile_handle.load_balancing_policy(),
),
wprzytula marked this conversation as resolved.
Show resolved Hide resolved
))
.build()
.into_handle(),
))
}
}

impl Default for Batch {
Expand Down
4 changes: 4 additions & 0 deletions scylla/src/transport/execution_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,8 @@ impl ExecutionProfileHandle {
pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) {
self.0 .0.store(profile.0)
}

pub fn load_balancing_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
self.0 .0.load().load_balancing_policy.clone()
}
Ten0 marked this conversation as resolved.
Show resolved Hide resolved
}
2 changes: 1 addition & 1 deletion scylla/src/transport/load_balancing/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ impl DefaultPolicy {
vec.into_iter()
}

fn is_alive(node: &NodeRef<'_>) -> bool {
pub(crate) fn is_alive(node: &NodeRef<'_>) -> bool {
// For now, we leave this as stub, until we have time to improve node events.
// node.is_enabled() && !node.is_down()
node.is_enabled()
Expand Down
47 changes: 47 additions & 0 deletions scylla/src/transport/load_balancing/enforce_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
use crate::transport::{cluster::ClusterData, Node};
use std::sync::Arc;

/// This policy will always return the same node, unless it is not available anymore, in which case it will
/// fallback to the provided policy.
///
/// This is meant to be used for shard-aware batching.
#[derive(Debug)]
pub struct EnforceTargetNodePolicy {
target_node: uuid::Uuid,
fallback: Arc<dyn LoadBalancingPolicy>,
Ten0 marked this conversation as resolved.
Show resolved Hide resolved
}

impl EnforceTargetNodePolicy {
pub fn new(target_node: &Arc<Node>, fallback: Arc<dyn LoadBalancingPolicy>) -> Self {
Self {
target_node: target_node.host_id,
fallback,
}
}
}
impl LoadBalancingPolicy for EnforceTargetNodePolicy {
fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option<NodeRef<'a>> {
cluster
.known_peers
.get(&self.target_node)
.filter(DefaultPolicy::is_alive)
.or_else(|| self.fallback.pick(query, cluster))
}

fn fallback<'a>(
&'a self,
query: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> FallbackPlan<'a> {
self.fallback.fallback(query, cluster)
}

fn name(&self) -> String {
format!(
"Enforce target node Load balancing policy - Node: {} - fallback: {}",
self.target_node,
self.fallback.name()
)
}
}
6 changes: 5 additions & 1 deletion scylla/src/transport/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ use scylla_cql::{errors::QueryError, frame::types};
use std::time::Duration;

mod default;
mod enforce_node;
mod plan;
pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder};
pub use plan::Plan;
pub use {
default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder},
enforce_node::EnforceTargetNodePolicy,
};

/// Represents info about statement that can be used by load balancing policies.
#[derive(Default, Clone, Debug)]
Expand Down
76 changes: 60 additions & 16 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::frame::types::LegacyConsistency;
use crate::history;
use crate::history::HistoryListener;
use crate::retry_policy::RetryPolicy;
use crate::routing;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use bytes::Bytes;
Expand Down Expand Up @@ -923,15 +924,7 @@ impl Session {
.as_ref()
.map(|pk| prepared.get_partitioner_name().hash(pk));

let statement_info = RoutingInfo {
consistency: prepared
.get_consistency()
.unwrap_or(self.default_execution_profile_handle.access().consistency),
serial_consistency: prepared.get_serial_consistency(),
token,
keyspace: prepared.get_keyspace_name(),
is_confirmed_lwt: prepared.is_confirmed_lwt(),
};
let statement_info = self.routing_info_from_prepared_statement(prepared, token);

let span =
RequestSpan::new_prepared(partition_key.as_ref(), token, serialized_values.size());
Expand Down Expand Up @@ -1839,13 +1832,64 @@ impl Session {
prepared: &PreparedStatement,
serialized_values: &SerializedValues,
) -> Result<Option<Token>, QueryError> {
match self.calculate_partition_key(prepared, serialized_values) {
Ok(Some(partition_key)) => {
let partitioner_name = prepared.get_partitioner_name();
Ok(Some(partitioner_name.hash(&partition_key)))
}
Ok(None) => Ok(None),
Err(err) => Err(err),
Ok(self
.calculate_partition_key(prepared, serialized_values)?
.map(|partition_key| prepared.get_partitioner_name().hash(&partition_key)))
Ten0 marked this conversation as resolved.
Show resolved Hide resolved
}

/// Get a node/shard that the load balancer would potentially target if running this query
///
/// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`])
#[allow(clippy::type_complexity)]
pub fn shard_for_statement(
&self,
prepared: &PreparedStatement,
serialized_values: &SerializedValues,
) -> Result<Option<(Arc<Node>, Option<routing::Shard>)>, QueryError> {
let token = match self.calculate_token(prepared, serialized_values)? {
Some(token) => token,
None => return Ok(None),
};
let routing_info = self.routing_info_from_prepared_statement(prepared, Some(token));
let cluster_data = self.cluster.get_data();
let execution_profile = prepared
.config
.execution_profile_handle
.as_ref()
.unwrap_or_else(|| self.get_default_execution_profile_handle())
.access();
let mut query_plan = load_balancing::Plan::new(
&*execution_profile.load_balancing_policy,
&routing_info,
&cluster_data,
);
// We can't return the full iterator here because the iterator borrows from local variables.
// In order to achieve that, two designs would be possible:
// - Construct a self-referential struct and implement iterator on it via e.g. Ouroboros
// - Take a closure as a parameter that will take the local iterator and return anything, and
// this function would return directly what the closure returns
// Most likely though, people would use this for some kind of shard-awareness optimization for batching,
// and are consequently not interested in subsequent nodes.
// Until then, let's just expose this, as it is simpler
Ok(query_plan.next().map(move |node| {
let token = node.sharder().map(|sharder| sharder.shard_of(token));
(node.clone(), token)
}))
}

fn routing_info_from_prepared_statement<'p>(
&self,
prepared: &'p PreparedStatement,
token: Option<Token>,
) -> RoutingInfo<'p> {
RoutingInfo {
consistency: prepared
.get_consistency()
.unwrap_or(self.default_execution_profile_handle.access().consistency),
serial_consistency: prepared.get_serial_consistency(),
token,
keyspace: prepared.get_keyspace_name(),
is_confirmed_lwt: prepared.is_confirmed_lwt(),
}
}

Expand Down