Skip to content

Commit

Permalink
Merge pull request #143 from muzarski/typechecks
Browse files Browse the repository at this point in the history
binding: typechecks
  • Loading branch information
wprzytula authored Oct 2, 2024
2 parents 9ac038f + 8dbe9d1 commit af43106
Show file tree
Hide file tree
Showing 14 changed files with 1,438 additions and 120 deletions.
1 change: 1 addition & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
- [ ] I have split my patch into logically separate commits.
- [ ] All commit messages clearly explain what they change and why.
- [ ] PR description sums up the changes and reasons why they should be introduced.
- [ ] I have implemented Rust unit tests for the features/changes introduced.
- [ ] I have enabled appropriate tests in `.github/workflows/build.yml` in `gtest_filter`.
- [ ] I have enabled appropriate tests in `.github/workflows/cassandra.yml` in `gtest_filter`.
7 changes: 0 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,6 @@ The driver inherits almost all the features of C/C++ and Rust drivers, such as:
<tr>
<td colspan=2 align="center" style="font-weight:bold">Collection</td>
</tr>
<tr>
<td>cass_collection_new_from_data_type</td>
<td rowspan="2">Unimplemented</td>
</tr>
<tr>
<td>cass_collection_data_type</td>
</tr>
<tr>
<td>cass_collection_append_custom[_n]</td>
<td>Unimplemented because of the same reasons as binding for statements.<br> <b>Note</b>: The driver does not check whether the type of the appended value is compatible with the type of the collection items.</td>
Expand Down
2 changes: 1 addition & 1 deletion scylla-rust-wrapper/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ pub unsafe extern "C" fn cass_batch_add_statement(

match &statement.statement {
Statement::Simple(q) => state.batch.append_statement(q.query.clone()),
Statement::Prepared(p) => state.batch.append_statement((**p).clone()),
Statement::Prepared(p) => state.batch.append_statement(p.statement.clone()),
};

state.bound_values.push(statement.bound_values.clone());
Expand Down
6 changes: 0 additions & 6 deletions scylla-rust-wrapper/src/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@
//! It can be used for binding named parameter in CassStatement or field by name in CassUserType.
//! * Functions from make_appender don't take any extra argument, as they are for use by CassCollection
//! functions - values are appended to collection.
use crate::{cass_types::CassDataType, value::CassCqlValue};

pub fn is_compatible_type(_data_type: &CassDataType, _value: &Option<CassCqlValue>) -> bool {
// TODO: cppdriver actually checks types.
true
}

macro_rules! make_index_binder {
($this:ty, $consume_v:expr, $fn_by_idx:ident, $e:expr, [$($arg:ident @ $t:ty), *]) => {
Expand Down
208 changes: 174 additions & 34 deletions scylla-rust-wrapper/src/cass_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include!(concat!(env!("OUT_DIR"), "/cppdriver_data_types.rs"));
include!(concat!(env!("OUT_DIR"), "/cppdriver_data_query_error.rs"));
include!(concat!(env!("OUT_DIR"), "/cppdriver_batch_types.rs"));

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct UDTDataType {
// Vec to preserve the order of types
pub field_types: Vec<(String, Arc<CassDataType>)>,
Expand Down Expand Up @@ -87,6 +87,42 @@ impl UDTDataType {
pub fn get_field_by_index(&self, index: usize) -> Option<&Arc<CassDataType>> {
self.field_types.get(index).map(|(_, b)| b)
}

fn typecheck_equals(&self, other: &UDTDataType) -> bool {
// See: https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L354-L386

if !any_string_empty_or_both_equal(&self.keyspace, &other.keyspace) {
return false;
}
if !any_string_empty_or_both_equal(&self.name, &other.name) {
return false;
}

// A comment from cpp-driver:
//// UDT's can be considered equal as long as the mutual first fields shared
//// between them are equal. UDT's are append only as far as fields go, so a
//// newer 'version' of the UDT data type after a schema change event should be
//// treated as equivalent in this scenario, by simply looking at the first N
//// mutual fields they should share.
//
// Iterator returned from zip() is perfect for checking the first mutual fields.
for (field, other_field) in self.field_types.iter().zip(other.field_types.iter()) {
// Compare field names.
if field.0 != other_field.0 {
return false;
}
// Compare field types.
if !field.1.typecheck_equals(&other_field.1) {
return false;
}
}

true
}
}

fn any_string_empty_or_both_equal(s1: &str, s2: &str) -> bool {
s1.is_empty() || s2.is_empty() || s1 == s2
}

impl Default for UDTDataType {
Expand All @@ -95,27 +131,106 @@ impl Default for UDTDataType {
}
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum MapDataType {
Untyped,
Key(Arc<CassDataType>),
KeyAndValue(Arc<CassDataType>, Arc<CassDataType>),
}

#[derive(Clone, Debug, PartialEq)]
pub enum CassDataType {
Value(CassValueType),
UDT(UDTDataType),
List {
// None stands for untyped list.
typ: Option<Arc<CassDataType>>,
frozen: bool,
},
Set {
// None stands for untyped set.
typ: Option<Arc<CassDataType>>,
frozen: bool,
},
Map {
key_type: Option<Arc<CassDataType>>,
val_type: Option<Arc<CassDataType>>,
typ: MapDataType,
frozen: bool,
},
// Empty vector stands for untyped tuple.
Tuple(Vec<Arc<CassDataType>>),
Custom(String),
}

impl CassDataType {
/// Checks for equality during typechecks.
///
/// This takes into account the fact that tuples/collections may be untyped.
pub fn typecheck_equals(&self, other: &CassDataType) -> bool {
match self {
CassDataType::Value(t) => *t == other.get_value_type(),
CassDataType::UDT(udt) => match other {
CassDataType::UDT(other_udt) => udt.typecheck_equals(other_udt),
_ => false,
},
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => match other {
CassDataType::List { typ: other_typ, .. }
| CassDataType::Set { typ: other_typ, .. } => {
// If one of them is list, and the other is set, fail the typecheck.
if self.get_value_type() != other.get_value_type() {
return false;
}
match (typ, other_typ) {
// One of them is untyped, skip the typecheck for subtype.
(None, _) | (_, None) => true,
(Some(typ), Some(other_typ)) => typ.typecheck_equals(other_typ),
}
}
_ => false,
},
CassDataType::Map { typ: t, .. } => match other {
CassDataType::Map { typ: t_other, .. } => match (t, t_other) {
// See https://github.com/scylladb/cpp-driver/blob/master/src/data_type.hpp#L218
// In cpp-driver the types are held in a vector.
// The logic is following:

// If either of vectors is empty, skip the typecheck.
(MapDataType::Untyped, _) => true,
(_, MapDataType::Untyped) => true,

// Otherwise, the vectors should have equal length and we perform the typecheck for subtypes.
(MapDataType::Key(k), MapDataType::Key(k_other)) => k.typecheck_equals(k_other),
(
MapDataType::KeyAndValue(k, v),
MapDataType::KeyAndValue(k_other, v_other),
) => k.typecheck_equals(k_other) && v.typecheck_equals(v_other),
_ => false,
},
_ => false,
},
CassDataType::Tuple(sub) => match other {
CassDataType::Tuple(other_sub) => {
// If either of tuples is untyped, skip the typecheck for subtypes.
if sub.is_empty() || other_sub.is_empty() {
return true;
}

// If both are non-empty, check for subtypes equality.
if sub.len() != other_sub.len() {
return false;
}
sub.iter()
.zip(other_sub.iter())
.all(|(typ, other_typ)| typ.typecheck_equals(other_typ))
}
_ => false,
},
CassDataType::Custom(_) => {
unimplemented!("Cpp-rust-driver does not support custom types!")
}
}
}
}

impl From<NativeType> for CassValueType {
fn from(native_type: NativeType) -> CassValueType {
match native_type {
Expand Down Expand Up @@ -160,16 +275,18 @@ pub fn get_column_type_from_cql_type(
frozen: *frozen,
},
CollectionType::Map(key, value) => CassDataType::Map {
key_type: Some(Arc::new(get_column_type_from_cql_type(
key,
user_defined_types,
keyspace_name,
))),
val_type: Some(Arc::new(get_column_type_from_cql_type(
value,
user_defined_types,
keyspace_name,
))),
typ: MapDataType::KeyAndValue(
Arc::new(get_column_type_from_cql_type(
key,
user_defined_types,
keyspace_name,
)),
Arc::new(get_column_type_from_cql_type(
value,
user_defined_types,
keyspace_name,
)),
),
frozen: *frozen,
},
CollectionType::Set(set) => CassDataType::Set {
Expand Down Expand Up @@ -222,10 +339,19 @@ impl CassDataType {
}
}
CassDataType::Map {
key_type, val_type, ..
typ: MapDataType::Untyped,
..
} => None,
CassDataType::Map {
typ: MapDataType::Key(k),
..
} => (index == 0).then_some(k),
CassDataType::Map {
typ: MapDataType::KeyAndValue(k, v),
..
} => match index {
0 => key_type.as_ref(),
1 => val_type.as_ref(),
0 => Some(k),
1 => Some(v),
_ => None,
},
CassDataType::Tuple(v) => v.get(index),
Expand All @@ -243,17 +369,28 @@ impl CassDataType {
}
},
CassDataType::Map {
key_type, val_type, ..
typ: MapDataType::KeyAndValue(_, _),
..
} => Err(CassError::CASS_ERROR_LIB_BAD_PARAMS),
CassDataType::Map {
typ: MapDataType::Key(k),
frozen,
} => {
if key_type.is_some() && val_type.is_some() {
Err(CassError::CASS_ERROR_LIB_BAD_PARAMS)
} else if key_type.is_none() {
*key_type = Some(sub_type);
Ok(())
} else {
*val_type = Some(sub_type);
Ok(())
}
*self = CassDataType::Map {
typ: MapDataType::KeyAndValue(k.clone(), sub_type),
frozen: *frozen,
};
Ok(())
}
CassDataType::Map {
typ: MapDataType::Untyped,
frozen,
} => {
*self = CassDataType::Map {
typ: MapDataType::Key(sub_type),
frozen: *frozen,
};
Ok(())
}
CassDataType::Tuple(types) => {
types.push(sub_type);
Expand Down Expand Up @@ -305,8 +442,10 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType {
frozen: false,
},
ColumnType::Map(key, value) => CassDataType::Map {
key_type: Some(Arc::new(get_column_type(key.as_ref()))),
val_type: Some(Arc::new(get_column_type(value.as_ref()))),
typ: MapDataType::KeyAndValue(
Arc::new(get_column_type(key.as_ref())),
Arc::new(get_column_type(value.as_ref())),
),
frozen: false,
},
ColumnType::Set(boxed_type) => CassDataType::Set {
Expand Down Expand Up @@ -357,8 +496,7 @@ pub unsafe extern "C" fn cass_data_type_new(value_type: CassValueType) -> *const
},
CassValueType::CASS_VALUE_TYPE_TUPLE => CassDataType::Tuple(Vec::new()),
CassValueType::CASS_VALUE_TYPE_MAP => CassDataType::Map {
key_type: None,
val_type: None,
typ: MapDataType::Untyped,
frozen: false,
},
CassValueType::CASS_VALUE_TYPE_UDT => CassDataType::UDT(UDTDataType::new()),
Expand Down Expand Up @@ -555,9 +693,11 @@ pub unsafe extern "C" fn cass_data_type_sub_type_count(data_type: *const CassDat
CassDataType::Value(..) => 0,
CassDataType::UDT(udt_data_type) => udt_data_type.field_types.len() as size_t,
CassDataType::List { typ, .. } | CassDataType::Set { typ, .. } => typ.is_some() as size_t,
CassDataType::Map {
key_type, val_type, ..
} => key_type.is_some() as size_t + val_type.is_some() as size_t,
CassDataType::Map { typ, .. } => match typ {
MapDataType::Untyped => 0,
MapDataType::Key(_) => 1,
MapDataType::KeyAndValue(_, _) => 2,
},
CassDataType::Tuple(v) => v.len() as size_t,
CassDataType::Custom(..) => 0,
}
Expand Down
Loading

0 comments on commit af43106

Please sign in to comment.