Skip to content

Commit

Permalink
Add valuesort, new control result mode and value normalizer (#237)
Browse files Browse the repository at this point in the history
* Merged in PR 123 from parent project

Signed-off-by: Bruce Ritchie <[email protected]>

* Adding resultmode control and custom value normalizer. Bumped version.

Signed-off-by: Bruce Ritchie <[email protected]>

* Added test to verify resultmode can be updated

Signed-off-by: Bruce Ritchie <[email protected]>

* Cargo fmt.

Signed-off-by: Bruce Ritchie <[email protected]>

* Cargo clippy.

Signed-off-by: Bruce Ritchie <[email protected]>

* elide lifetimes to get ci check to pass.

Signed-off-by: Bruce Ritchie <[email protected]>

* Updates after merge with upstream

Signed-off-by: Bruce Ritchie <[email protected]>

* Added valuesort test, updated changelog.

Signed-off-by: Bruce Ritchie <[email protected]>

---------

Signed-off-by: Bruce Ritchie <[email protected]>
  • Loading branch information
Omega359 authored Dec 20, 2024
1 parent e08bc06 commit ac188cb
Show file tree
Hide file tree
Showing 12 changed files with 432 additions and 229 deletions.
155 changes: 103 additions & 52 deletions CHANGELOG.md

Large diffs are not rendered by default.

305 changes: 149 additions & 156 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ resolver = "2"
members = ["sqllogictest", "sqllogictest-bin", "sqllogictest-engines", "tests"]

[workspace.package]
version = "0.23.1"
version = "0.24.0"
edition = "2021"
homepage = "https://github.com/risinglightdb/sqllogictest-rs"
keywords = ["sql", "database", "parser", "cli"]
Expand Down
4 changes: 2 additions & 2 deletions sqllogictest-bin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ glob = "0.3"
itertools = "0.13"
quick-junit = { version = "0.5" }
rand = "0.8"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
sqllogictest-engines = { path = "../sqllogictest-engines", version = "0.24" }
tokio = { version = "1", features = [
"rt",
"rt-multi-thread",
Expand Down
5 changes: 3 additions & 2 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
use rand::distributions::DistString;
use rand::seq::SliceRandom;
use sqllogictest::{
default_column_validator, default_validator, update_record_with_output, AsyncDB, Injected,
MakeConnection, Record, Runner,
default_column_validator, default_normalizer, default_validator, update_record_with_output,
AsyncDB, Injected, MakeConnection, Record, Runner,
};

#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
Expand Down Expand Up @@ -770,6 +770,7 @@ async fn update_record<M: MakeConnection>(
&record_output,
"\t",
default_validator,
default_normalizer,
default_column_validator,
) {
Some(new_record) => {
Expand Down
2 changes: 1 addition & 1 deletion sqllogictest-engines/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ postgres-types = { version = "0.2.8", features = ["derive", "with-chrono-0_4"] }
rust_decimal = { version = "1.36.0", features = ["tokio-pg"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
sqllogictest = { path = "../sqllogictest", version = "0.23" }
sqllogictest = { path = "../sqllogictest", version = "0.24" }
thiserror = "2"
tokio = { version = "1", features = [
"rt",
Expand Down
49 changes: 49 additions & 0 deletions sqllogictest/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub enum QueryExpect<T: ColumnType> {
Results {
types: Vec<T>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
label: Option<String>,
results: Vec<String>,
},
Expand All @@ -98,6 +99,7 @@ impl<T: ColumnType> QueryExpect<T> {
Self::Results {
types: Vec::new(),
sort_mode: None,
result_mode: None,
label: None,
results: Vec::new(),
}
Expand Down Expand Up @@ -287,6 +289,7 @@ impl<T: ColumnType> std::fmt::Display for Record<T> {
}
Record::Control(c) => match c {
Control::SortMode(m) => write!(f, "control sortmode {}", m.as_str()),
Control::ResultMode(m) => write!(f, "control resultmode {}", m.as_str()),
Control::Substitution(s) => write!(f, "control substitution {}", s.as_str()),
},
Record::Condition(cond) => match cond {
Expand Down Expand Up @@ -435,6 +438,8 @@ impl PartialEq for ExpectedError {
pub enum Control {
/// Control sort mode.
SortMode(SortMode),
/// control result mode.
ResultMode(ResultMode),
/// Control whether or not to substitute variables in the SQL.
Substitution(bool),
}
Expand Down Expand Up @@ -545,6 +550,38 @@ impl ControlItem for SortMode {
}
}

/// Whether the results should be parsed as value-wise or row-wise
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ResultMode {
/// Results are in a single column
ValueWise,
/// The default option where results are in columns separated by spaces
RowWise,
}

impl ControlItem for ResultMode {
fn try_from_str(s: &str) -> Result<Self, ParseErrorKind> {
match s {
"rowwise" => Ok(Self::RowWise),
"valuewise" => Ok(Self::ValueWise),
_ => Err(ParseErrorKind::InvalidSortMode(s.to_string())),
}
}

fn as_str(&self) -> &'static str {
match self {
Self::RowWise => "rowwise",
Self::ValueWise => "valuewise",
}
}
}

impl fmt::Display for ResultMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}

/// The error type for parsing sqllogictest.
#[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)]
#[error("parse error at {loc}: {kind}")]
Expand Down Expand Up @@ -754,6 +791,7 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
QueryExpect::Results {
types,
sort_mode,
result_mode: None,
label,
results: Vec::new(),
}
Expand Down Expand Up @@ -812,6 +850,12 @@ fn parse_inner<T: ColumnType>(loc: &Location, script: &str) -> Result<Vec<Record
});
}
["control", res @ ..] => match res {
["resultmode", result_mode] => match ResultMode::try_from_str(result_mode) {
Ok(result_mode) => {
records.push(Record::Control(Control::ResultMode(result_mode)))
}
Err(k) => return Err(k.at(loc)),
},
["sortmode", sort_mode] => match SortMode::try_from_str(sort_mode) {
Ok(sort_mode) => records.push(Record::Control(Control::SortMode(sort_mode))),
Err(k) => return Err(k.at(loc)),
Expand Down Expand Up @@ -988,6 +1032,11 @@ mod tests {
parse_roundtrip::<DefaultColumnType>("../tests/slt/rowsort.slt")
}

#[test]
fn test_valuesort() {
parse_roundtrip::<DefaultColumnType>("../tests/slt/valuesort.slt")
}

#[test]
fn test_substitution() {
parse_roundtrip::<DefaultColumnType>("../tests/substitution/basic.slt")
Expand Down
85 changes: 73 additions & 12 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,26 +449,39 @@ fn format_column_diff(expected: &str, actual: &str, colorize: bool) -> String {
format!("[Expected] {expected}\n[Actual ] {actual}")
}

/// Normalizer will be used by [`Runner`] to normalize the result values
///
/// # Default
///
/// By default, the ([`default_normalizer`]) will be used to normalize values.
pub type Normalizer = fn(s: &String) -> String;

/// Trim and replace multiple whitespaces with one.
#[allow(clippy::ptr_arg)]
fn normalize_string(s: &String) -> String {
pub fn default_normalizer(s: &String) -> String {
s.trim().split_ascii_whitespace().join(" ")
}

/// Validator will be used by [`Runner`] to validate the output.
///
/// # Default
///
/// By default ([`default_validator`]), we will use compare normalized results.
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;

pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
let expected_results = expected.iter().map(normalize_string).collect_vec();
/// By default, the ([`default_validator`]) will be used compare normalized results.
pub type Validator =
fn(normalizer: Normalizer, actual: &[Vec<String>], expected: &[String]) -> bool;

pub fn default_validator(
normalizer: Normalizer,
actual: &[Vec<String>],
expected: &[String],
) -> bool {
let expected_results = expected.iter().map(normalizer).collect_vec();
// Default, we compare normalized results. Whitespace characters are ignored.
let normalized_rows = actual
.iter()
.map(|strs| strs.iter().map(normalize_string).join(" "))
.map(|strs| strs.iter().map(normalizer).join(" "))
.collect_vec();

normalized_rows == expected_results
}

Expand Down Expand Up @@ -502,9 +515,12 @@ pub struct Runner<D: AsyncDB, M: MakeConnection> {
conn: Connections<D, M>,
// validator is used for validate if the result of query equals to expected.
validator: Validator,
// normalizer is used to normalize the result text
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
substitution: Option<Substitution>,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
/// 0 means never hashing
hash_threshold: usize,
/// Labels for condition `skipif` and `onlyif`.
Expand All @@ -518,9 +534,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
pub fn new(make_conn: M) -> Self {
Runner {
validator: default_validator,
normalizer: default_normalizer,
column_type_validator: default_column_validator,
substitution: None,
sort_mode: None,
result_mode: None,
hash_threshold: 0,
labels: HashSet::new(),
conn: Connections::new(make_conn),
Expand All @@ -532,6 +550,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
self.labels.insert(label.to_string());
}

pub fn with_normalizer(&mut self, normalizer: Normalizer) {
self.normalizer = normalizer;
}
pub fn with_validator(&mut self, validator: Validator) {
self.validator = validator;
}
Expand Down Expand Up @@ -769,15 +790,31 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
QueryExpect::Error(_) => None,
}
.or(self.sort_mode);

let mut value_sort = false;
match sort_mode {
None | Some(SortMode::NoSort) => {}
Some(SortMode::RowSort) => {
rows.sort_unstable();
}
Some(SortMode::ValueSort) => todo!("value sort"),
Some(SortMode::ValueSort) => {
rows = rows
.iter()
.flat_map(|row| row.iter())
.map(|s| vec![s.to_owned()])
.collect();
rows.sort_unstable();
value_sort = true;
}
};

if self.hash_threshold > 0 && rows.len() * types.len() > self.hash_threshold {
let num_values = if value_sort {
rows.len()
} else {
rows.len() * types.len()
};

if self.hash_threshold > 0 && num_values > self.hash_threshold {
let mut md5 = md5::Md5::new();
for line in &rows {
for value in line {
Expand Down Expand Up @@ -808,6 +845,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
Control::SortMode(sort_mode) => {
self.sort_mode = Some(sort_mode);
}
Control::ResultMode(result_mode) => {
self.result_mode = Some(result_mode);
}
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
(s @ None, true) => *s = Some(Substitution::default()),
(s @ Some(_), false) => *s = None,
Expand Down Expand Up @@ -996,7 +1036,17 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.at(loc));
}

if !(self.validator)(rows, &expected_results) {
let actual_results = match self.result_mode {
Some(ResultMode::ValueWise) => rows
.iter()
.flat_map(|strs| strs.iter())
.map(|str| vec![str.to_string()])
.collect_vec(),
// default to rowwise
_ => rows.clone(),
};

if !(self.validator)(self.normalizer, &actual_results, &expected_results) {
let output_rows =
rows.iter().map(|strs| strs.iter().join(" ")).collect_vec();
return Err(TestErrorKind::QueryResultMismatch {
Expand Down Expand Up @@ -1167,9 +1217,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
conn_builder(target.clone(), db_name.clone()).map(Ok)
}),
validator: self.validator,
normalizer: self.normalizer,
column_type_validator: self.column_type_validator,
substitution: self.substitution.clone(),
sort_mode: self.sort_mode,
result_mode: self.result_mode,
hash_threshold: self.hash_threshold,
labels: self.labels.clone(),
};
Expand Down Expand Up @@ -1240,6 +1292,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
filename: impl AsRef<Path>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
) -> Result<(), Box<dyn std::error::Error>> {
use std::io::{Read, Seek, SeekFrom, Write};
Expand Down Expand Up @@ -1355,6 +1408,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
&record_output,
col_separator,
validator,
normalizer,
column_type_validator,
)
.unwrap_or(record);
Expand Down Expand Up @@ -1384,6 +1438,7 @@ pub fn update_record_with_output<T: ColumnType>(
record_output: &RecordOutput<T>,
col_separator: &str,
validator: Validator,
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<T>,
) -> Option<Record<T>> {
match (record.clone(), record_output) {
Expand Down Expand Up @@ -1523,7 +1578,7 @@ pub fn update_record_with_output<T: ColumnType>(
QueryExpect::Results {
results: expected_results,
..
} if validator(rows, expected_results) => expected_results.clone(),
} if validator(normalizer, rows, expected_results) => expected_results.clone(),
_ => rows.iter().map(|cols| cols.join(col_separator)).collect(),
};
let types = match &expected {
Expand All @@ -1541,17 +1596,22 @@ pub fn update_record_with_output<T: ColumnType>(
connection,
expected: match expected {
QueryExpect::Results {
sort_mode, label, ..
sort_mode,
label,
result_mode,
..
} => QueryExpect::Results {
results,
types,
sort_mode,
result_mode,
label,
},
QueryExpect::Error(_) => QueryExpect::Results {
results,
types,
sort_mode: None,
result_mode: None,
label: None,
},
},
Expand Down Expand Up @@ -2009,6 +2069,7 @@ Caused by:
&record_output,
" ",
default_validator,
default_normalizer,
strict_column_validator,
);

Expand Down
5 changes: 4 additions & 1 deletion tests/custom_type/custom_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,8 @@ fn test() {
let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
tester.with_column_validator(strict_column_validator);

tester.run_file("./custom_type/custom_type.slt").unwrap();
let r = tester.run_file("./custom_type/custom_type.slt");
if let Err(err) = r {
eprintln!("{:?}", err);
}
}
Loading

0 comments on commit ac188cb

Please sign in to comment.