Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shard aware batching - add Session::shard_for_statement & Batch::enforce_target_node #738

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Cargo.lock.msrv

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 37 additions & 2 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ pub trait SerializeRow {
/// the bind marker types and names so that the values can be properly
/// type checked and serialized.
fn is_empty(&self) -> bool;

/// Specialization that allows the driver to not re-serialize the row if it's already
/// a `SerializedValues`
///
/// Note that if using this, it's the user's responsibility to ensure that this
/// `SerializedValues` has been generated with the same prepared statement as the query
/// is going to be made with.
#[inline]
fn already_serialized(&self) -> Option<&SerializedValues> {
None
}
}

macro_rules! fallback_impl_contents {
Expand Down Expand Up @@ -255,12 +266,36 @@ impl<T: SerializeRow + ?Sized> SerializeRow for &T {
ctx: &RowSerializationContext<'_>,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
<T as SerializeRow>::serialize(self, ctx, writer)
<T as SerializeRow>::serialize(*self, ctx, writer)
}

#[inline]
fn is_empty(&self) -> bool {
<T as SerializeRow>::is_empty(self)
<T as SerializeRow>::is_empty(*self)
}

#[inline]
fn already_serialized(&self) -> Option<&SerializedValues> {
<T as SerializeRow>::already_serialized(*self)
}
}

impl SerializeRow for SerializedValues {
fn serialize(
&self,
_ctx: &RowSerializationContext<'_>,
writer: &mut RowWriter,
) -> Result<(), SerializationError> {
writer.append_serialize_row(self);
Ok(())
}

fn is_empty(&self) -> bool {
self.is_empty()
}

fn already_serialized(&self) -> Option<&SerializedValues> {
Some(self)
}
}

Expand Down
2 changes: 2 additions & 0 deletions scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ tracing-subscriber = { version = "0.3.14", features = ["env-filter"] }
assert_matches = "1.5.0"
rand_chacha = "0.3.1"
time = "0.3"
futures-batch = "0.6.1"
tokio-stream = "0.1.14"

[[bench]]
name = "benchmark"
Expand Down
80 changes: 80 additions & 0 deletions scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ use std::borrow::Cow;
use std::sync::Arc;

use crate::history::HistoryListener;
use crate::load_balancing;
use crate::retry_policy::RetryPolicy;
use crate::routing::Shard;
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
use crate::transport::execution_profile::ExecutionProfileHandle;
use crate::transport::NodeRef;
use crate::Session;

use super::StatementConfig;
use super::{Consistency, SerialConsistency};
Expand Down Expand Up @@ -142,6 +146,82 @@ impl Batch {
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
self.config.execution_profile_handle.as_ref()
}

/// Associates the batch with a new execution profile that will have a load
/// balancing policy that will enforce the use of the provided [`NodeRef`]
/// to the extent possible.
///
/// This should typically be used in conjunction with
/// [`Session::shard_for_statement`], where you would constitute a batch
/// by assigning to the same batch all the statements that would be executed
/// in the same shard.
///
/// Since it is not guaranteed that subsequent calls to the load balancer
/// would re-assign the statement to the same node, you should use this
/// method to enforce the use of the original node that was envisioned by
/// `shard_for_statement` for the batch:
///
/// ```rust
/// # use scylla::Session;
/// # use std::error::Error;
/// # async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
/// use scylla::{batch::Batch, serialize::row::SerializedValues};
///
/// let prepared_statement = session
/// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)")
/// .await?;
///
/// let serialized_values: SerializedValues = prepared_statement.serialize_values(&(1, 2))?;
/// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?;
///
/// // Send that to a task that will handle statements targeted to the same shard
///
/// // On that task:
/// // Constitute a batch with all the statements that would be executed in the same shard
///
/// let mut batch: Batch = Default::default();
/// if let Some((node, shard_idx)) = shard {
/// batch.enforce_target_node(&node, shard_idx, &session);
/// }
/// let mut batch_values = Vec::new();
///
/// // As the task handling statements targeted to this shard receives them,
/// // it appends them to the batch
/// batch.append_statement(prepared_statement);
/// batch_values.push(serialized_values);
///
/// // Run the batch
/// session.batch(&batch, batch_values).await?;
/// # Ok(())
/// # }
/// ```
///
///
/// If the target node is not available anymore at the time of executing the
/// statement, it will fallback to the original load balancing policy:
/// - Either that currently set on the [`Batch`], if any
/// - Or that of the [`Session`] if there isn't one on the `Batch`
pub fn enforce_target_node(
&mut self,
node: NodeRef<'_>,
shard: Shard,
base_execution_profile_from_session: &Session,
) {
let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| {
base_execution_profile_from_session.get_default_execution_profile_handle()
});
self.set_execution_profile_handle(Some(
execution_profile_handle
.pointee_to_builder()
.load_balancing_policy(Arc::new(load_balancing::EnforceTargetShardPolicy::new(
node,
shard,
execution_profile_handle.load_balancing_policy(),
)))
.build()
.into_handle(),
))
}
}

impl Default for Batch {
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/statement/prepared_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ impl PreparedStatement {
self.config.execution_profile_handle.as_ref()
}

pub(crate) fn serialize_values(
pub fn serialize_values(
&self,
values: &impl SerializeRow,
) -> Result<SerializedValues, SerializationError> {
Expand Down
15 changes: 15 additions & 0 deletions scylla/src/transport/execution_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,19 @@ impl ExecutionProfileHandle {
pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) {
self.0 .0.store(profile.0)
}

/// Get the load balancing policy associated with this execution profile.
///
/// This may be useful if one wants to construct a new load balancing policy
/// that is based on the one associated with this execution profile.
pub fn load_balancing_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
// Exposed as a building block of `Batch::enforce_target_node` in case a user
// wants more control than what that method does.
// Since the fact that the load balancing policy is accessible from the
// ExecutionProfileHandle is already public API through the fact it's documented
// that it would be preserved by pointee_to_builder, having this as pblic API
// doesn't prevent any more non-breaking evolution than would already be
// blocked anyway
self.0 .0.load().load_balancing_policy.clone()
}
}
2 changes: 1 addition & 1 deletion scylla/src/transport/load_balancing/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ impl DefaultPolicy {
vec.into_iter()
}

fn is_alive(node: NodeRef, _shard: Option<Shard>) -> bool {
pub(crate) fn is_alive(node: NodeRef, _shard: Option<Shard>) -> bool {
// For now, we leave this as stub, until we have time to improve node events.
// node.is_enabled() && !node.is_down()
node.is_enabled()
Expand Down
62 changes: 62 additions & 0 deletions scylla/src/transport/load_balancing/enforce_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
use crate::{
routing::Shard,
transport::{cluster::ClusterData, Node},
};
use std::sync::Arc;
use uuid::Uuid;

/// This policy will always return the same node, unless it is not available anymore, in which case it will
/// fallback to the provided policy.
///
/// This is meant to be used for shard-aware batching.
#[derive(Debug)]
pub struct EnforceTargetShardPolicy {
target_node: Uuid,
shard: Shard,
fallback: Arc<dyn LoadBalancingPolicy>,
}

impl EnforceTargetShardPolicy {
pub fn new(
target_node: &Arc<Node>,
shard: Shard,
fallback: Arc<dyn LoadBalancingPolicy>,
) -> Self {
Self {
target_node: target_node.host_id,
shard,
fallback,
}
}
}
impl LoadBalancingPolicy for EnforceTargetShardPolicy {
fn pick<'a>(
&'a self,
query: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> Option<(NodeRef<'a>, Option<Shard>)> {
cluster
.known_peers
.get(&self.target_node)
.map(|node| (node, Some(self.shard)))
.filter(|&(node, shard)| DefaultPolicy::is_alive(node, shard))
.or_else(|| self.fallback.pick(query, cluster))
}

fn fallback<'a>(
&'a self,
query: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> FallbackPlan<'a> {
self.fallback.fallback(query, cluster)
}

fn name(&self) -> String {
format!(
"Enforce target shard Load balancing policy - Node: {} - fallback: {}",
self.target_node,
self.fallback.name()
)
}
}
6 changes: 5 additions & 1 deletion scylla/src/transport/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ use scylla_cql::{
use std::time::Duration;

mod default;
mod enforce_node;
mod plan;
pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder};
pub use plan::Plan;
pub use {
default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder},
enforce_node::EnforceTargetShardPolicy,
};

/// Represents info about statement that can be used by load balancing policies.
#[derive(Default, Clone, Debug)]
Expand Down
Loading
Loading