Skip to content

Commit

Permalink
caching_session: make generic over session APIs
Browse files Browse the repository at this point in the history
In a similar fashion to Session, CachingSession was also made generic
over the session kind.

Co-authored-by: Wojciech Przytuła <[email protected]>
  • Loading branch information
piodul and wprzytula committed Oct 15, 2024
1 parent 8d03975 commit 611e48c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 25 deletions.
2 changes: 1 addition & 1 deletion scylla/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ pub use statement::query;
pub use frame::response::cql_to_rust;
pub use frame::response::cql_to_rust::FromRow;

pub use transport::caching_session::CachingSession;
pub use transport::caching_session::{CachingSession, LegacyCachingSession};
pub use transport::execution_profile::ExecutionProfile;
pub use transport::legacy_query_result::LegacyQueryResult;
pub use transport::query_result::QueryResult;
Expand Down
128 changes: 107 additions & 21 deletions scylla/src/transport/caching_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::statement::{PagingState, PagingStateResponse};
use crate::transport::errors::QueryError;
use crate::transport::iterator::LegacyRowIterator;
use crate::transport::partitioner::PartitionerName;
use crate::{LegacyQueryResult, LegacySession};
use crate::{LegacyQueryResult, QueryResult};
use bytes::Bytes;
use dashmap::DashMap;
use futures::future::try_join_all;
Expand All @@ -16,6 +16,11 @@ use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
use std::sync::Arc;

use super::iterator::RawIterator;
use super::session::{
CurrentDeserializationApi, DeserializationApiKind, GenericSession, LegacyDeserializationApi,
};

/// Contains just the parts of a prepared statement that were returned
/// from the database. All remaining parts (query string, page size,
/// consistency, etc.) are taken from the Query passed
Expand All @@ -31,23 +36,28 @@ struct RawPreparedStatementData {

/// Provides auto caching while executing queries
#[derive(Debug)]
pub struct CachingSession<S = RandomState>
pub struct GenericCachingSession<DeserializationApi, S = RandomState>
where
S: Clone + BuildHasher,
DeserializationApi: DeserializationApiKind,
{
session: LegacySession,
session: GenericSession<DeserializationApi>,
/// The prepared statement cache size
/// If a prepared statement is added while the limit is reached, the oldest prepared statement
/// is removed from the cache
max_capacity: usize,
cache: DashMap<String, RawPreparedStatementData, S>,
}

impl<S> CachingSession<S>
pub type CachingSession<S = RandomState> = GenericCachingSession<CurrentDeserializationApi, S>;
pub type LegacyCachingSession<S = RandomState> = GenericCachingSession<LegacyDeserializationApi, S>;

impl<DeserApi, S> GenericCachingSession<DeserApi, S>
where
S: Default + BuildHasher + Clone,
DeserApi: DeserializationApiKind,
{
pub fn from(session: LegacySession, cache_size: usize) -> Self {
pub fn from(session: GenericSession<DeserApi>, cache_size: usize) -> Self {
Self {
session,
max_capacity: cache_size,
Expand All @@ -56,20 +66,88 @@ where
}
}

impl<S> CachingSession<S>
impl<DeserApi, S> GenericCachingSession<DeserApi, S>
where
S: BuildHasher + Clone,
DeserApi: DeserializationApiKind,
{
/// Builds a [`CachingSession`] from a [`Session`], a cache size, and a [`BuildHasher`].,
/// using a customer hasher.
pub fn with_hasher(session: LegacySession, cache_size: usize, hasher: S) -> Self {
pub fn with_hasher(session: GenericSession<DeserApi>, cache_size: usize, hasher: S) -> Self {
Self {
session,
max_capacity: cache_size,
cache: DashMap::with_hasher(hasher),
}
}
}

impl<S> GenericCachingSession<CurrentDeserializationApi, S>
where
S: BuildHasher + Clone,
{
/// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
pub async fn execute_unpaged(
&self,
query: impl Into<Query>,
values: impl SerializeRow,
) -> Result<QueryResult, QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement_owned(query).await?;
self.session.execute_unpaged(&prepared, values).await
}

/// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache
pub async fn execute_iter(
&self,
query: impl Into<Query>,
values: impl SerializeRow,
) -> Result<RawIterator, QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement_owned(query).await?;
self.session.execute_iter(prepared, values).await
}

/// Does the same thing as [`Session::execute_single_page`] but uses the prepared statement cache
pub async fn execute_single_page(
&self,
query: impl Into<Query>,
values: impl SerializeRow,
paging_state: PagingState,
) -> Result<(QueryResult, PagingStateResponse), QueryError> {
let query = query.into();
let prepared = self.add_prepared_statement_owned(query).await?;
self.session
.execute_single_page(&prepared, values, paging_state)
.await
}

/// Does the same thing as [`Session::batch`] but uses the prepared statement cache\
/// Prepares batch using CachingSession::prepare_batch if needed and then executes it
pub async fn batch(
&self,
batch: &Batch,
values: impl BatchValues,
) -> Result<QueryResult, QueryError> {
let all_prepared: bool = batch
.statements
.iter()
.all(|stmt| matches!(stmt, BatchStatement::PreparedStatement(_)));

if all_prepared {
self.session.batch(batch, &values).await
} else {
let prepared_batch: Batch = self.prepare_batch(batch).await?;

self.session.batch(&prepared_batch, &values).await
}
}
}

impl<S> GenericCachingSession<LegacyDeserializationApi, S>
where
S: BuildHasher + Clone,
{
/// Does the same thing as [`Session::execute_unpaged`] but uses the prepared statement cache
pub async fn execute_unpaged(
&self,
Expand Down Expand Up @@ -126,7 +204,13 @@ where
self.session.batch(&prepared_batch, &values).await
}
}
}

impl<DeserApi, S> GenericCachingSession<DeserApi, S>
where
S: BuildHasher + Clone,
DeserApi: DeserializationApiKind,
{
/// Prepares all statements within the batch and returns a new batch where every
/// statement is prepared.
/// Uses the prepared statements cache.
Expand Down Expand Up @@ -212,7 +296,7 @@ where
self.max_capacity
}

pub fn get_session(&self) -> &LegacySession {
pub fn get_session(&self) -> &GenericSession<DeserApi> {
&self.session
}
}
Expand All @@ -229,7 +313,7 @@ mod tests {
use crate::{
batch::{Batch, BatchStatement},
prepared_statement::PreparedStatement,
CachingSession, LegacySession,
LegacyCachingSession, LegacySession,
};
use futures::TryStreamExt;
use std::collections::BTreeSet;
Expand Down Expand Up @@ -273,8 +357,8 @@ mod tests {
session
}

async fn create_caching_session() -> CachingSession {
let session = CachingSession::from(new_for_test(true).await, 2);
async fn create_caching_session() -> LegacyCachingSession {
let session = LegacyCachingSession::from(new_for_test(true).await, 2);

// Add a row, this makes it easier to check if the caching works combined with the regular execute fn on Session
session
Expand Down Expand Up @@ -385,7 +469,7 @@ mod tests {
}

async fn assert_test_batch_table_rows_contain(
sess: &CachingSession,
sess: &LegacyCachingSession,
expected_rows: &[(i32, i32)],
) {
let selected_rows: BTreeSet<(i32, i32)> = sess
Expand Down Expand Up @@ -431,18 +515,18 @@ mod tests {
}
}

let _session: CachingSession<std::collections::hash_map::RandomState> =
CachingSession::from(new_for_test(true).await, 2);
let _session: CachingSession<CustomBuildHasher> =
CachingSession::from(new_for_test(true).await, 2);
let _session: CachingSession<CustomBuildHasher> =
CachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
let _session: LegacyCachingSession<std::collections::hash_map::RandomState> =
LegacyCachingSession::from(new_for_test(true).await, 2);
let _session: LegacyCachingSession<CustomBuildHasher> =
LegacyCachingSession::from(new_for_test(true).await, 2);
let _session: LegacyCachingSession<CustomBuildHasher> =
LegacyCachingSession::with_hasher(new_for_test(true).await, 2, Default::default());
}

#[tokio::test]
async fn test_batch() {
setup_tracing();
let session: CachingSession = create_caching_session().await;
let session: LegacyCachingSession = create_caching_session().await;

session
.execute_unpaged(
Expand Down Expand Up @@ -565,7 +649,8 @@ mod tests {
#[tokio::test]
async fn test_parameters_caching() {
setup_tracing();
let session: CachingSession = CachingSession::from(new_for_test(true).await, 100);
let session: LegacyCachingSession =
LegacyCachingSession::from(new_for_test(true).await, 100);

session
.execute_unpaged("CREATE TABLE tbl (a int PRIMARY KEY, b int)", ())
Expand Down Expand Up @@ -618,7 +703,8 @@ mod tests {
}

// This test uses CDC which is not yet compatible with Scylla's tablets.
let session: CachingSession = CachingSession::from(new_for_test(false).await, 100);
let session: LegacyCachingSession =
LegacyCachingSession::from(new_for_test(false).await, 100);

session
.execute_unpaged(
Expand Down
6 changes: 3 additions & 3 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::transport::topology::{
use crate::utils::test_utils::{
create_new_session_builder, supports_feature, unique_keyspace_name,
};
use crate::CachingSession;
use crate::ExecutionProfile;
use crate::LegacyCachingSession;
use crate::LegacyQueryResult;
use crate::{LegacySession, SessionBuilder};
use assert_matches::assert_matches;
Expand Down Expand Up @@ -2012,7 +2012,7 @@ async fn rename(session: &LegacySession, rename_str: &str) {
.unwrap();
}

async fn rename_caching(session: &CachingSession, rename_str: &str) {
async fn rename_caching(session: &LegacyCachingSession, rename_str: &str) {
session
.execute_unpaged(format!("ALTER TABLE tab RENAME {}", rename_str), &())
.await
Expand Down Expand Up @@ -2230,7 +2230,7 @@ async fn test_unprepared_reprepare_in_caching_session_execute() {
session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session.use_keyspace(ks, false).await.unwrap();

let caching_session: CachingSession = CachingSession::from(session, 64);
let caching_session: LegacyCachingSession = LegacyCachingSession::from(session, 64);

caching_session
.execute_unpaged(
Expand Down

0 comments on commit 611e48c

Please sign in to comment.