Skip to content

Commit

Permalink
Merge pull request #200 from muzarski/rust-driver-07553fee
Browse files Browse the repository at this point in the history
  • Loading branch information
muzarski authored Oct 30, 2024
2 parents 94b6d16 + 0e3251d commit 39dd81e
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 51 deletions.
8 changes: 4 additions & 4 deletions scylla-rust-wrapper/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions scylla-rust-wrapper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ categories = ["database"]
license = "MIT OR Apache-2.0"

[dependencies]
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0", features = [
scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "64b4afcd", features = [
"ssl",
] }
tokio = { version = "1.27.0", features = ["full"] }
Expand All @@ -32,7 +32,7 @@ bindgen = "0.65"
chrono = "0.4.20"

[dev-dependencies]
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.14.0" }
scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "64b4afcd" }

assert_matches = "1.5.0"
ntest = "0.9.3"
Expand Down
49 changes: 38 additions & 11 deletions scylla-rust-wrapper/src/cass_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,26 @@ impl From<&QueryError> for CassError {
match error {
QueryError::DbError(db_error, _string) => CassError::from(db_error),
QueryError::BadQuery(bad_query) => CassError::from(bad_query),
QueryError::IoError(_io_error) => CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
QueryError::ProtocolError(_str) => CassError::CASS_ERROR_SERVER_PROTOCOL_ERROR,
QueryError::InvalidMessage(_string) => CassError::CASS_ERROR_SERVER_INVALID_QUERY,
QueryError::TimeoutError => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT, // This may be either read or write timeout error
QueryError::TooManyOrphanedStreamIds(_) => CassError::CASS_ERROR_LIB_INVALID_STATE,
QueryError::UnableToAllocStreamId => CassError::CASS_ERROR_LIB_NO_STREAMS,
QueryError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT,
QueryError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION,
QueryError::CqlResponseParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE,
QueryError::CqlRequestSerialization(_) => CassError::CASS_ERROR_LIB_MESSAGE_ENCODE,
QueryError::BodyExtensionsParseError(_) => {
CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE
}
QueryError::EmptyPlan => CassError::CASS_ERROR_LIB_INVALID_STATE,
QueryError::CqlResultParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE,
QueryError::CqlErrorParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE,
QueryError::MetadataError(_) => CassError::CASS_ERROR_LIB_INVALID_STATE,
// I know that TranslationError (corresponding to CASS_ERROR_LIB_HOST_RESOLUTION)
// is hidden under the ConnectionPoolError.
// However, we still have a lot work to do when it comes to error conversion.
// I will address it, once we start resolving all issues related to error conversion.
QueryError::ConnectionPoolError(_) => CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
QueryError::BrokenConnection(_) => CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
// QueryError is non_exhaustive
_ => CassError::CASS_ERROR_LAST_ENTRY,
}
}
}
Expand Down Expand Up @@ -62,6 +73,10 @@ impl From<&BadQuery> for CassError {
BadQuery::Other(_other_query) => CassError::CASS_ERROR_LAST_ENTRY,
BadQuery::SerializationError(_) => CassError::CASS_ERROR_LAST_ENTRY,
BadQuery::TooManyQueriesInBatchStatement(_) => CassError::CASS_ERROR_LAST_ENTRY,
// BadQuery is non_exhaustive
// For now, since all other variants return LAST_ENTRY,
// let's do it here as well.
_ => CassError::CASS_ERROR_LAST_ENTRY,
}
}
}
Expand All @@ -75,19 +90,29 @@ impl From<&NewSessionError> for CassError {
NewSessionError::EmptyKnownNodesList => CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE,
NewSessionError::DbError(_db_error, _string) => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::BadQuery(_bad_query) => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::IoError(_io_error) => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::ProtocolError(_str) => {
CassError::CASS_ERROR_LIB_UNABLE_TO_DETERMINE_PROTOCOL
}
NewSessionError::InvalidMessage(_string) => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::TimeoutError => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::TooManyOrphanedStreamIds(_) => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::UnableToAllocStreamId => CassError::CASS_ERROR_LAST_ENTRY,
NewSessionError::RequestTimeout(_) => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT,
NewSessionError::TranslationError(_) => CassError::CASS_ERROR_LIB_HOST_RESOLUTION,
NewSessionError::CqlResponseParseError(_) => {
NewSessionError::CqlRequestSerialization(_) => CassError::CASS_ERROR_LIB_MESSAGE_ENCODE,
NewSessionError::BodyExtensionsParseError(_) => {
CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE
}
NewSessionError::EmptyPlan => CassError::CASS_ERROR_LIB_INVALID_STATE,
NewSessionError::CqlResultParseError(_) => {
CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE
}
NewSessionError::CqlErrorParseError(_) => CassError::CASS_ERROR_LIB_UNEXPECTED_RESPONSE,
NewSessionError::MetadataError(_) => CassError::CASS_ERROR_LIB_INVALID_STATE,
// I know that TranslationError (corresponding to CASS_ERROR_LIB_HOST_RESOLUTION)
// is hidden under the ConnectionPoolError.
// However, we still have a lot work to do when it comes to error conversion.
// I will address it, once we start resolving all issues related to error conversion.
NewSessionError::ConnectionPoolError(_) => CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
NewSessionError::BrokenConnection(_) => CassError::CASS_ERROR_LIB_UNABLE_TO_CONNECT,
// NS error is non_exhaustive
_ => CassError::CASS_ERROR_LAST_ENTRY,
}
}
}
Expand All @@ -98,6 +123,8 @@ impl From<&BadKeyspaceName> for CassError {
BadKeyspaceName::Empty => CassError::CASS_ERROR_LAST_ENTRY,
BadKeyspaceName::TooLong(_string, _usize) => CassError::CASS_ERROR_LAST_ENTRY,
BadKeyspaceName::IllegalCharacter(_string, _char) => CassError::CASS_ERROR_LAST_ENTRY,
// non_exhaustive
_ => CassError::CASS_ERROR_LAST_ENTRY,
}
}
}
Expand Down
13 changes: 9 additions & 4 deletions scylla-rust-wrapper/src/cass_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ impl CassDataType {

pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
match column_type {
ColumnType::Custom(s) => CassDataType::Custom((*s).clone()),
ColumnType::Custom(s) => CassDataType::Custom(s.clone().into_owned()),
ColumnType::Ascii => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_ASCII),
ColumnType::Boolean => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BOOLEAN),
ColumnType::Blob => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BLOB),
Expand Down Expand Up @@ -459,10 +459,15 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
} => CassDataType::UDT(UDTDataType {
field_types: field_types
.iter()
.map(|(name, col_type)| ((*name).clone(), Arc::new(get_column_type(col_type))))
.map(|(name, col_type)| {
(
name.clone().into_owned(),
Arc::new(get_column_type(col_type)),
)
})
.collect(),
keyspace: (*keyspace).clone(),
name: (*type_name).clone(),
keyspace: keyspace.clone().into_owned(),
name: type_name.clone().into_owned(),
frozen: false,
}),
ColumnType::SmallInt => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_SMALL_INT),
Expand Down
8 changes: 4 additions & 4 deletions scylla-rust-wrapper/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,13 +689,13 @@ pub unsafe extern "C" fn cass_cluster_set_retry_policy(
let cluster = ptr_to_ref_mut(cluster_raw);

let retry_policy: Arc<dyn RetryPolicy> = match ptr_to_ref(retry_policy) {
DefaultRetryPolicy(default) => default.clone(),
FallthroughRetryPolicy(fallthrough) => fallthrough.clone(),
DowngradingConsistencyRetryPolicy(downgrading) => downgrading.clone(),
DefaultRetryPolicy(default) => Arc::clone(default) as _,
FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _,
DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _,
};

exec_profile_builder_modify(&mut cluster.default_execution_profile_builder, |builder| {
builder.retry_policy(retry_policy.clone_boxed())
builder.retry_policy(retry_policy)
});
}

Expand Down
10 changes: 5 additions & 5 deletions scylla-rust-wrapper/src/exec_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,13 +376,13 @@ pub unsafe extern "C" fn cass_execution_profile_set_retry_policy(
profile: *mut CassExecProfile,
retry_policy: *const CassRetryPolicy,
) -> CassError {
let retry_policy: &dyn RetryPolicy = match ptr_to_ref(retry_policy) {
DefaultRetryPolicy(default) => default.as_ref(),
FallthroughRetryPolicy(fallthrough) => fallthrough.as_ref(),
DowngradingConsistencyRetryPolicy(downgrading) => downgrading.as_ref(),
let retry_policy: Arc<dyn RetryPolicy> = match ptr_to_ref(retry_policy) {
DefaultRetryPolicy(default) => Arc::clone(default) as _,
FallthroughRetryPolicy(fallthrough) => Arc::clone(fallthrough) as _,
DowngradingConsistencyRetryPolicy(downgrading) => Arc::clone(downgrading) as _,
};
let profile_builder = ptr_to_ref_mut(profile);
profile_builder.modify_in_place(|builder| builder.retry_policy(retry_policy.clone_boxed()));
profile_builder.modify_in_place(|builder| builder.retry_policy(retry_policy));

CassError::CASS_OK
}
Expand Down
8 changes: 4 additions & 4 deletions scylla-rust-wrapper/src/prepared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ impl CassPrepared {
let variable_col_data_types = statement
.get_variable_col_specs()
.iter()
.map(|col_spec| Arc::new(get_column_type(&col_spec.typ)))
.map(|col_spec| Arc::new(get_column_type(col_spec.typ())))
.collect();

let result_col_data_types: Arc<Vec<Arc<CassDataType>>> = Arc::new(
statement
.get_result_set_col_specs()
.iter()
.map(|col_spec| Arc::new(get_column_type(&col_spec.typ)))
.map(|col_spec| Arc::new(get_column_type(col_spec.typ())))
.collect(),
);

Expand All @@ -50,7 +50,7 @@ impl CassPrepared {
.statement
.get_variable_col_specs()
.iter()
.position(|col_spec| col_spec.name == name)?;
.position(|col_spec| col_spec.name() == name)?;

match self.variable_col_data_types.get(index) {
Some(dt) => Some(dt),
Expand Down Expand Up @@ -108,7 +108,7 @@ pub unsafe extern "C" fn cass_prepared_parameter_name(
.get(index as usize)
{
Some(col_spec) => {
write_str_to_c(&col_spec.name, name, name_length);
write_str_to_c(col_spec.name(), name, name_length);
CassError::CASS_OK
}
None => CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS,
Expand Down
20 changes: 8 additions & 12 deletions scylla-rust-wrapper/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ pub struct CassResult {

pub struct CassResultData {
pub paging_state_response: PagingStateResponse,
pub col_specs: Vec<ColumnSpec>,
pub col_specs: Vec<ColumnSpec<'static>>,
pub col_data_types: Arc<Vec<Arc<CassDataType>>>,
pub tracing_id: Option<Uuid>,
}

impl CassResultData {
pub fn from_result_payload(
paging_state_response: PagingStateResponse,
col_specs: Vec<ColumnSpec>,
col_specs: Vec<ColumnSpec<'static>>,
maybe_col_data_types: Option<Arc<Vec<Arc<CassDataType>>>>,
tracing_id: Option<Uuid>,
) -> CassResultData {
Expand All @@ -43,7 +43,7 @@ impl CassResultData {
Arc::new(
col_specs
.iter()
.map(|col_spec| Arc::new(get_column_type(&col_spec.typ)))
.map(|col_spec| Arc::new(get_column_type(col_spec.typ())))
.collect(),
)
});
Expand Down Expand Up @@ -903,8 +903,8 @@ pub unsafe extern "C" fn cass_row_get_column_by_name_n(
.iter()
.enumerate()
.find(|(_, spec)| {
is_case_sensitive && spec.name == name_str
|| !is_case_sensitive && spec.name.eq_ignore_ascii_case(name_str)
is_case_sensitive && spec.name() == name_str
|| !is_case_sensitive && spec.name().eq_ignore_ascii_case(name_str)
})
.map(|(index, _)| {
return match row_from_raw.columns.get(index) {
Expand All @@ -930,7 +930,7 @@ pub unsafe extern "C" fn cass_result_column_name(
}

let column_spec: &ColumnSpec = result_from_raw.metadata.col_specs.get(index_usize).unwrap();
let column_name = column_spec.name.as_str();
let column_name = column_spec.name();

write_str_to_c(column_name, name, name_length);

Expand Down Expand Up @@ -1406,12 +1406,8 @@ mod tests {

use super::{cass_result_column_count, cass_result_column_type, CassResult, CassResultData};

fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec {
ColumnSpec {
table_spec: TableSpec::borrowed("ks", "tbl"),
name: name.to_owned(),
typ,
}
fn col_spec(name: &'static str, typ: ColumnType<'static>) -> ColumnSpec<'static> {
ColumnSpec::borrowed(name, typ, TableSpec::borrowed("ks", "tbl"))
}

const FIRST_COLUMN_NAME: &str = "bigint_col";
Expand Down
4 changes: 2 additions & 2 deletions scylla-rust-wrapper/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ impl CassStatement {
.iter()
.enumerate()
.filter(|(_, col)| {
is_case_sensitive && col.name == name_str
|| !is_case_sensitive && col.name.eq_ignore_ascii_case(name_str)
is_case_sensitive && col.name() == name_str
|| !is_case_sensitive && col.name().eq_ignore_ascii_case(name_str)
})
.map(|(i, _)| i)
.collect();
Expand Down
4 changes: 2 additions & 2 deletions scylla-rust-wrapper/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use scylla::{
SetOrListSerializationErrorKind, TupleSerializationErrorKind,
UdtSerializationErrorKind,
},
writers::WrittenCellProof,
CellWriter, SerializationError,
writers::{CellWriter, WrittenCellProof},
SerializationError,
},
};
use uuid::Uuid;
Expand Down
2 changes: 1 addition & 1 deletion tests/src/integration/tests/test_prepared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ CASSANDRA_INTEGRATION_TEST_F(PreparedTests, FailFastWhenPreparedIDChangesDuringR
insert_statement.bind<Integer>(0, Integer(0));
insert_statement.bind<Integer>(1, Integer(1));
Result result = session_.execute(insert_statement, false);
EXPECT_TRUE(contains(result.error_message(), "Prepared statement Id changed"));
EXPECT_TRUE(contains(result.error_message(), "Prepared statement id changed after repreparation"));
}

/**
Expand Down

0 comments on commit 39dd81e

Please sign in to comment.