diff --git a/src/binder/copy.rs b/src/binder/copy.rs index 7b2aaf3f..91785394 100644 --- a/src/binder/copy.rs +++ b/src/binder/copy.rs @@ -87,14 +87,18 @@ impl Binder { } => { let (table, is_system, is_view) = self.bind_table_id(&table_name)?; if is_system { - return Err(BindError::CopyTo("system table".into())); + return Err( + ErrorKind::CopyTo("system table".into()).with_spanned(&table_name) + ); } else if is_view { - return Err(BindError::CopyTo("view".into())); + return Err(ErrorKind::CopyTo("view".into()).with_spanned(&table_name)); } let cols = self.bind_table_columns(&table_name, &columns)?; (table, cols) } - CopySource::Query(_) => return Err(BindError::CopyTo("query".into())), + CopySource::Query(query) => { + return Err(ErrorKind::CopyTo("query".into()).with_spanned(&*query)); + } }; let types = self.type_(cols)?; let types = self.egraph.add(Node::Type(types)); diff --git a/src/binder/create_function.rs b/src/binder/create_function.rs index c9fe3c38..4bf83b62 100644 --- a/src/binder/create_function.rs +++ b/src/binder/create_function.rs @@ -58,26 +58,28 @@ impl Binder { }: crate::parser::CreateFunction, ) -> Result { let Ok((schema_name, function_name)) = split_name(&name) else { - return Err(BindError::BindFunctionError( + return Err(ErrorKind::BindFunctionError( "failed to parse the input function name".to_string(), - )); + ) + .with_spanned(&name)); }; let schema_name = schema_name.to_string(); let name = function_name.to_string(); let Some(return_type) = return_type else { - return Err(BindError::BindFunctionError( + return Err(ErrorKind::BindFunctionError( "`return type` must be specified".to_string(), - )); + ) + .into()); }; let return_type = crate::types::DataType::from(&return_type); // TODO: language check (e.g., currently only support sql) let Some(language) = language else { - return Err(BindError::BindFunctionError( - "`language` must be specified".to_string(), - )); + return Err( + ErrorKind::BindFunctionError("`language` must be specified".to_string()).into(), + ); }; let language = language.to_string(); @@ -88,7 +90,11 @@ impl Binder { | Some(CreateFunctionBody::AsAfterOptions(expr)) => match expr { Expr::Value(Value::SingleQuotedString(s)) => s, Expr::Value(Value::DollarQuotedString(s)) => s.value, - _ => return Err(BindError::BindFunctionError("expected string".into())), + _ => { + return Err( + ErrorKind::BindFunctionError("expected string".into()).with_spanned(&expr) + ) + } }, Some(CreateFunctionBody::Return(return_expr)) => { // Note: this is a current work around, and we are assuming return sql udf @@ -96,9 +102,10 @@ impl Binder { format!("select {}", &return_expr.to_string()) } None => { - return Err(BindError::BindFunctionError( + return Err(ErrorKind::BindFunctionError( "AS or RETURN must be specified".to_string(), - )); + ) + .into()); } }; diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 003bcc14..642ee7fe 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -62,16 +62,18 @@ impl Binder { let schema = self .catalog .get_schema_by_name(schema_name) - .ok_or_else(|| BindError::InvalidSchema(schema_name.into()))?; + .ok_or_else(|| ErrorKind::InvalidSchema(schema_name.into()).with_spanned(&name))?; if schema.get_table_by_name(table_name).is_some() { - return Err(BindError::TableExists(table_name.into())); + return Err(ErrorKind::TableExists(table_name.into()).with_spanned(&name)); } // check duplicated column names let mut set = HashSet::new(); for col in &columns { if !set.insert(col.name.value.to_lowercase()) { - return Err(BindError::ColumnExists(col.name.value.to_lowercase())); + return Err( + ErrorKind::ColumnExists(col.name.value.to_lowercase()).with_spanned(col) + ); } } @@ -80,18 +82,20 @@ impl Binder { if ordered_pk_ids.len() > 1 { // multi primary key should be declared by "primary key(c1, c2...)" syntax - return Err(BindError::NotSupportedTSQL); + return Err(ErrorKind::NotSupportedTSQL.into()); } let pks_name_from_constraints = Binder::pks_name_from_constraints(&constraints); if has_pk_from_column && !pks_name_from_constraints.is_empty() { // can't get primary key both from "primary key(c1, c2...)" syntax and // column's option - return Err(BindError::NotSupportedTSQL); + return Err(ErrorKind::NotSupportedTSQL.into()); } else if !has_pk_from_column { - for name in &pks_name_from_constraints { - if !set.contains(name) { - return Err(BindError::InvalidColumn(name.clone())); + for name in pks_name_from_constraints { + if !set.contains(&name.value.to_lowercase()) { + return Err( + ErrorKind::InvalidColumn(name.value.to_lowercase()).with_span(name.span) + ); } } // We have used `pks_name_from_constraints` to get the primary keys' name sorted by @@ -102,7 +106,7 @@ impl Binder { .map(|name| { columns .iter() - .position(|c| c.name.value.eq_ignore_ascii_case(name)) + .position(|c| c.name.value.eq_ignore_ascii_case(&name.value)) .unwrap() as ColumnId }) .collect(); @@ -153,20 +157,15 @@ impl Binder { } /// get the primary keys' name sorted by declaration order in "primary key(c1, c2..)" syntax. - fn pks_name_from_constraints(constraints: &[TableConstraint]) -> Vec { + fn pks_name_from_constraints(constraints: &[TableConstraint]) -> &[Ident] { for constraint in constraints { match constraint { - TableConstraint::PrimaryKey { columns, .. } => { - return columns - .iter() - .map(|ident| ident.value.to_lowercase()) - .collect() - } + TableConstraint::PrimaryKey { columns, .. } => return columns, _ => continue, } } // no primary key - vec![] + &[] } } diff --git a/src/binder/create_view.rs b/src/binder/create_view.rs index d4e71c64..e73daaec 100644 --- a/src/binder/create_view.rs +++ b/src/binder/create_view.rs @@ -16,16 +16,18 @@ impl Binder { let schema = self .catalog .get_schema_by_name(schema_name) - .ok_or_else(|| BindError::InvalidSchema(schema_name.into()))?; + .ok_or_else(|| ErrorKind::InvalidSchema(schema_name.into()).with_spanned(&name))?; if schema.get_table_by_name(table_name).is_some() { - return Err(BindError::TableExists(table_name.into())); + return Err(ErrorKind::TableExists(table_name.into()).with_spanned(&name)); } // check duplicated column names let mut set = HashSet::new(); for col in &columns { if !set.insert(col.name.value.to_lowercase()) { - return Err(BindError::ColumnExists(col.name.value.to_lowercase())); + return Err( + ErrorKind::ColumnExists(col.name.value.to_lowercase()).with_spanned(col) + ); } } @@ -35,7 +37,7 @@ impl Binder { // TODO: support inferring column names from query if columns.len() != output_types.len() { - return Err(BindError::ViewAliasesMismatch); + return Err(ErrorKind::ViewAliasesMismatch.with_spanned(&name)); } let columns: Vec = columns diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 144ddcd2..18eb3959 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -4,19 +4,21 @@ use super::*; impl Binder { pub(super) fn bind_delete(&mut self, delete: Delete) -> Result { - let from = match delete.from { + let from = match &delete.from { FromTable::WithFromKeyword(t) => t, FromTable::WithoutKeyword(t) => t, }; if from.len() != 1 || !from[0].joins.is_empty() { - return Err(BindError::Todo(format!("delete from {from:?}"))); + return Err(ErrorKind::Todo(format!("delete from {from:?}")).with_spanned(&delete.from)); } let TableFactor::Table { name, alias, .. } = &from[0].relation else { - return Err(BindError::Todo(format!("delete from {from:?}"))); + return Err( + ErrorKind::Todo(format!("delete from {from:?}")).with_spanned(&from[0].relation) + ); }; let (table_id, is_system, is_view) = self.bind_table_id(name)?; if is_system || is_view { - return Err(BindError::CanNotDelete); + return Err(ErrorKind::CanNotDelete.with_spanned(name)); } let scan = self.bind_table_def(name, alias.clone(), true)?; let cond = self.bind_where(delete.selection)?; diff --git a/src/binder/drop.rs b/src/binder/drop.rs index 1f74e29a..bae14185 100644 --- a/src/binder/drop.rs +++ b/src/binder/drop.rs @@ -11,10 +11,10 @@ impl Binder { cascade: bool, ) -> Result { if !matches!(object_type, ObjectType::Table | ObjectType::View) { - return Err(BindError::Todo(format!("drop {object_type:?}"))); + return Err(ErrorKind::Todo(format!("drop {object_type:?}")).into()); } if cascade { - return Err(BindError::Todo("cascade drop".into())); + return Err(ErrorKind::Todo("cascade drop".into()).into()); } let mut table_ids = Vec::with_capacity(names.len()); for name in names { @@ -24,7 +24,8 @@ impl Binder { if if_exists && result.is_none() { continue; } - let table_id = result.ok_or_else(|| BindError::InvalidTable(table_name.into()))?; + let table_id = result + .ok_or_else(|| ErrorKind::InvalidTable(table_name.into()).with_spanned(&name))?; let id = self.egraph.add(Node::Table(table_id)); table_ids.push(id); } diff --git a/src/binder/error.rs b/src/binder/error.rs new file mode 100644 index 00000000..40711121 --- /dev/null +++ b/src/binder/error.rs @@ -0,0 +1,220 @@ +//! The error type of bind operations. +//! +//! To raise an error in binder, construct an `ErrorKind` and attach a span if possible: +//! +//! ```ignore +//! return Err(ErrorKind::InvalidTable("table".into()).into()); +//! return Err(ErrorKind::InvalidTable("table".into()).with_span(ident.span)); +//! return Err(ErrorKind::InvalidTable("table".into()).with_spanned(object_name)); +//! ``` + +use sqlparser::ast::{Ident, ObjectType, Spanned}; +use sqlparser::tokenizer::Span; + +use crate::planner::TypeError; + +/// The error type of bind operations. +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub struct BindError(#[from] Box); + +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +struct Inner { + #[source] + kind: ErrorKind, + span: Option, + sql: Option, +} + +impl std::fmt::Display for BindError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::fmt::Display for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.kind)?; + if let Some(sql) = &self.sql + && let Some(span) = self.span + { + write!(f, "\n\n{}", highlight_sql(sql, span))?; + } else if let Some(span) = self.span { + // " at Line: {}, Column: {}" + write!(f, "{}", span.start)?; + } + Ok(()) + } +} + +/// The error type of bind operations. +#[derive(thiserror::Error, Debug, PartialEq, Eq)] +pub enum ErrorKind { + #[error("invalid schema {0:?}")] + InvalidSchema(String), + #[error("invalid table {0:?}")] + InvalidTable(String), + #[error("invalid column {0:?}")] + InvalidColumn(String), + #[error("table {0:?} already exists")] + TableExists(String), + #[error("column {0:?} already exists")] + ColumnExists(String), + #[error("duplicated alias {0:?}")] + DuplicatedAlias(String), + #[error("duplicate CTE name {0:?}")] + DuplicatedCteName(String), + #[error("table {0:?} has {1} columns available but {2} columns specified")] + ColumnCountMismatch(String, usize, usize), + #[error("invalid expression {0}")] + InvalidExpression(String), + #[error("not nullable column {0:?}")] + NotNullableColumn(String), + #[error("ambiguous column {0:?} (use {1})")] + AmbiguousColumn(String, String), + #[error("invalid table name {0:?}")] + InvalidTableName(Vec), + #[error("SQL not supported")] + NotSupportedTSQL, + #[error("invalid SQL")] + InvalidSQL, + #[error("cannot cast {0:?} to {1:?}")] + CastError(crate::types::DataValue, crate::types::DataType), + #[error("{0}")] + BindFunctionError(String), + #[error("type error: {0}")] + TypeError(TypeError), + #[error("aggregate function calls cannot be nested")] + NestedAgg, + #[error("WHERE clause cannot contain aggregates")] + AggInWhere, + #[error("GROUP BY clause cannot contain aggregates")] + AggInGroupBy, + #[error("window function calls cannot be nested")] + NestedWindow, + #[error("WHERE clause cannot contain window functions")] + WindowInWhere, + #[error("HAVING clause cannot contain window functions")] + WindowInHaving, + #[error("column {0:?} must appear in the GROUP BY clause or be used in an aggregate function")] + ColumnNotInAgg(String), + #[error("ORDER BY items must appear in the select list if DISTINCT is specified")] + OrderKeyNotInDistinct, + #[error("{0:?} is not an aggregate function")] + NotAgg(String), + #[error("unsupported object name: {0:?}")] + UnsupportedObjectName(ObjectType), + #[error("not supported yet: {0}")] + Todo(String), + #[error("can not copy to {0}")] + CopyTo(String), + #[error("can only insert into table")] + CanNotInsert, + #[error("can only delete from table")] + CanNotDelete, + #[error("VIEW aliases mismatch query result")] + ViewAliasesMismatch, + #[error("pragma does not exist: {0}")] + NoPragma(String), +} + +impl ErrorKind { + /// Create a `BindError` with a span. + pub fn with_span(self, span: Span) -> BindError { + BindError(Box::new(Inner { + kind: self, + span: Some(span), + sql: None, + })) + } + + /// Create a `BindError` with a span from a `Spanned` object. + pub fn with_spanned(self, span: &impl Spanned) -> BindError { + self.with_span(span.span()) + } +} + +impl BindError { + /// Set the SQL string for the error. + pub fn with_sql(mut self, sql: &str) -> BindError { + self.0.sql = Some(sql.to_string()); + self + } +} + +impl From for BindError { + fn from(kind: ErrorKind) -> Self { + BindError(Box::new(Inner { + kind, + span: None, + sql: None, + })) + } +} + +impl From for BindError { + fn from(kind: TypeError) -> Self { + BindError(Box::new(Inner { + kind: ErrorKind::TypeError(kind), + span: None, + sql: None, + })) + } +} + +/// Highlight the SQL string at the given span. +fn highlight_sql(sql: &str, span: Span) -> String { + let lines: Vec<&str> = sql.lines().collect(); + if span.start.line == 0 || span.start.line as usize > lines.len() { + return String::new(); + } + + let error_line = lines[span.start.line as usize - 1]; + let prefix = format!("LINE {}: ", span.start.line); + let mut indicator = " ".repeat(prefix.len()).to_string(); + + if span.start.column > 0 && span.start.column as usize <= error_line.len() { + for _ in 1..span.start.column { + indicator.push(' '); + } + let caret_count = if span.end.column > span.start.column { + span.end.column - span.start.column + } else { + 1 + }; + for _ in 0..caret_count { + indicator.push('^'); + } + } + + format!("{}{}\n{}", prefix, error_line, indicator) +} + +#[cfg(test)] +mod tests { + use sqlparser::tokenizer::Location; + + use super::*; + + #[test] + fn test_bind_error_size() { + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::(), + "the size of BindError should be one pointer" + ); + } + + #[test] + fn test_highlight_sql() { + let sql = "SELECT * FROM table WHERE id = 1"; + let span = Span::new(Location::new(1, 15), Location::new(1, 20)); + assert_eq!( + highlight_sql(sql, span), + " +LINE 1: SELECT * FROM table WHERE id = 1 + ^^^^^ + " + .trim() + ); + } +} diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 7bf1a036..fa1627f6 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -3,6 +3,7 @@ use rust_decimal::Decimal; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; +use sqlparser::tokenizer::Span; use super::*; use crate::parser::{ @@ -20,10 +21,11 @@ impl Binder { // parameter-like (i.e., `$1`) values at present // TODO: consider formally `bind_parameter` in the future // e.g., lambda function support, etc. - if let Value::Placeholder(key) = v { + if let Value::Placeholder(key) = &v { self.udf_context - .get_expr(&key) - .map_or_else(|| Err(BindError::InvalidSQL), |&e| Ok(e)) + .get_expr(key) + .cloned() + .ok_or_else(|| ErrorKind::InvalidSQL.with_spanned(&v)) } else { Ok(self.egraph.add(Node::Constant(v.into()))) } @@ -99,21 +101,24 @@ impl Binder { fn bind_ident(&self, idents: impl IntoIterator) -> Result { let idents = idents .into_iter() - .map(|ident| Ident::new(ident.value.to_lowercase())) + .map(|ident| Ident::with_span(ident.span, ident.value.to_lowercase())) .collect_vec(); - let (_schema_name, table_name, column_name) = match idents.as_slice() { - [column] => (None, None, &column.value), - [table, column] => (None, Some(&table.value), &column.value), - [schema, table, column] => (Some(&schema.value), Some(&table.value), &column.value), - _ => return Err(BindError::InvalidTableName(idents)), + let (_schema_ident, table_ident, column_ident) = match idents.as_slice() { + [column] => (None, None, column), + [table, column] => (None, Some(table), column), + [schema, table, column] => (Some(schema), Some(table), column), + _ => { + let span = Span::union_iter(idents.iter().map(|ident| ident.span)); + return Err(ErrorKind::InvalidTableName(idents).with_span(span)); + } }; // Special check for sql udf - if let Some(id) = self.udf_context.get_expr(column_name) { + if let Some(id) = self.udf_context.get_expr(&column_ident.value) { return Ok(*id); } - self.find_alias(column_name, table_name.map(|s| s.as_str())) + self.find_alias(column_ident, table_ident) } fn bind_binary_op(&mut self, left: Expr, op: BinaryOperator, right: Expr) -> Result { @@ -174,7 +179,7 @@ impl Binder { match data_type { DataType::Date => { let date = value.parse().map_err(|_| { - BindError::CastError( + ErrorKind::CastError( DataValue::String(value.into()), crate::types::DataType::Date, ) @@ -183,7 +188,7 @@ impl Binder { } DataType::Timestamp(_, _) => { let timestamp = value.parse().map_err(|_| { - BindError::CastError( + ErrorKind::CastError( DataValue::String(value.into()), crate::types::DataType::Timestamp, ) @@ -325,8 +330,8 @@ impl Binder { let mut distinct = false; let function_args = match &func.args { FunctionArguments::None => &[], - FunctionArguments::Subquery(_) => { - return Err(BindError::Todo("subquery argument".into())) + FunctionArguments::Subquery(subquery) => { + return Err(ErrorKind::Todo("subquery argument".into()).with_spanned(&**subquery)); } FunctionArguments::List(arg_list) => { distinct = arg_list.duplicate_treatment == Some(DuplicateTreatment::Distinct); @@ -356,10 +361,11 @@ impl Binder { let catalog = self.catalog(); let Ok((schema_name, function_name)) = split_name(&func.name) else { - return Err(BindError::BindFunctionError(format!( + return Err(ErrorKind::BindFunctionError(format!( "failed to parse the function name {}", func.name - ))); + )) + .with_spanned(&func.name)); }; // See if the input function is sql udf @@ -368,18 +374,20 @@ impl Binder { // Create the brand new `udf_context` let Ok(context) = UdfContext::create_udf_context(function_args, function_catalog) else { - return Err(BindError::InvalidExpression( - "failed to create udf context".to_string(), - )); + return Err( + ErrorKind::InvalidExpression("failed to create udf context".into()) + .with_spanned(&func.name), + ); }; let mut udf_context = HashMap::new(); // Bind each expression in the newly created `udf_context` for (c, e) in context { let Ok(e) = self.bind_expr(e) else { - return Err(BindError::BindFunctionError( - "failed to bind arguments within the given sql udf".to_string(), - )); + return Err(ErrorKind::BindFunctionError( + "failed to bind arguments within the given sql udf".into(), + ) + .with_spanned(&func.name)); }; udf_context.insert(c, e); } @@ -387,14 +395,15 @@ impl Binder { // Parse the sql body using `function_catalog` let dialect = GenericDialect {}; let Ok(ast) = Parser::parse_sql(&dialect, &function_catalog.body) else { - return Err(BindError::InvalidSQL); + return Err(ErrorKind::InvalidSQL.with_spanned(&func.name)); }; // Extract the corresponding udf expression out from `ast` let Ok(expr) = UdfContext::extract_udf_expression(ast) else { - return Err(BindError::InvalidExpression( - "failed to bind the sql udf expression".to_string(), - )); + return Err(ErrorKind::InvalidExpression( + "failed to bind the sql udf expression".into(), + ) + .with_spanned(&func.name)); }; let stashed_udf_context = self.udf_context.get_context(); @@ -404,9 +413,10 @@ impl Binder { // Bind the expression in sql udf body let Ok(bind_result) = self.bind_expr(expr) else { - return Err(BindError::InvalidExpression( - "failed to bind the expression".to_string(), - )); + return Err( + ErrorKind::InvalidExpression("failed to bind the expression".into()) + .with_spanned(&func.name), + ); }; // Restore the context after binding @@ -436,21 +446,23 @@ impl Binder { }; let mut id = self.egraph.add(node); if let Some(window) = func.over { - id = self.bind_window_function(id, window)?; + id = self.bind_window_function(id, window, &func.name)?; } Ok(id) } - fn bind_window_function(&mut self, func: Id, window: WindowType) -> Result { + fn bind_window_function(&mut self, func: Id, window: WindowType, name: &ObjectName) -> Result { let window = match window { WindowType::WindowSpec(window) => window, - WindowType::NamedWindow(_) => return Err(BindError::Todo("named window".into())), + WindowType::NamedWindow(name) => { + return Err(ErrorKind::Todo("named window".into()).with_span(name.span)); + } }; if !self.node(func).is_window_function() { - return Err(BindError::NotAgg(self.node(func).to_string())); + return Err(ErrorKind::NotAgg(self.node(func).to_string()).with_spanned(name)); } if !self.overs(func).is_empty() { - return Err(BindError::NestedWindow); + return Err(ErrorKind::NestedWindow.with_spanned(name)); } let partitionby = self.bind_exprs(window.partition_by)?; let orderby = self.bind_orderby(window.order_by)?; diff --git a/src/binder/insert.rs b/src/binder/insert.rs index fe73a9b4..29ba1064 100644 --- a/src/binder/insert.rs +++ b/src/binder/insert.rs @@ -3,23 +3,15 @@ use super::*; impl Binder { - pub fn bind_insert( - &mut self, - Insert { - table_name, - columns, - source, - .. - }: Insert, - ) -> Result { - let Some(source) = source else { - return Err(BindError::InvalidSQL); + pub fn bind_insert(&mut self, insert: Insert) -> Result { + let Some(source) = insert.source else { + return Err(ErrorKind::InvalidSQL.with_spanned(&insert)); }; - let (table, is_internal, is_view) = self.bind_table_id(&table_name)?; + let (table, is_internal, is_view) = self.bind_table_id(&insert.table_name)?; if is_internal || is_view { - return Err(BindError::CanNotInsert); + return Err(ErrorKind::CanNotInsert.with_spanned(&insert.table_name)); } - let cols = self.bind_table_columns(&table_name, &columns)?; + let cols = self.bind_table_columns(&insert.table_name, &insert.columns)?; let source = self.bind_query(*source)?.0; let id = self.egraph.add(Node::Insert([table, cols, source])); Ok(id) diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 67de1eff..61fd2fe1 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -12,7 +12,8 @@ use crate::array; use crate::catalog::function::FunctionCatalog; use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId}; use crate::parser::*; -use crate::planner::{Expr as Node, RecExpr, TypeError, TypeSchemaAnalysis}; +use crate::planner::{Expr as Node, RecExpr, TypeSchemaAnalysis}; +use crate::types::DataValue; pub mod copy; mod create_function; @@ -20,87 +21,19 @@ mod create_table; mod create_view; mod delete; mod drop; +mod error; mod expr; mod insert; mod select; mod table; -pub use create_function::CreateFunction; -pub use create_table::CreateTable; +pub use self::create_function::CreateFunction; +pub use self::create_table::CreateTable; +pub use self::error::BindError; +use self::error::ErrorKind; pub type Result = std::result::Result; -/// The error type of bind operations. -#[derive(thiserror::Error, Debug, PartialEq, Eq)] -pub enum BindError { - #[error("invalid schema {0:?}")] - InvalidSchema(String), - #[error("invalid table {0:?}")] - InvalidTable(String), - #[error("invalid column {0:?}")] - InvalidColumn(String), - #[error("table {0:?} already exists")] - TableExists(String), - #[error("column {0:?} already exists")] - ColumnExists(String), - #[error("duplicated alias {0:?}")] - DuplicatedAlias(String), - #[error("duplicate CTE name {0:?}")] - DuplicatedCteName(String), - #[error("table {0:?} has {1} columns available but {2} columns specified")] - ColumnCountMismatch(String, usize, usize), - #[error("invalid expression {0}")] - InvalidExpression(String), - #[error("not nullable column {0:?}")] - NotNullableColumn(String), - #[error("ambiguous column {0:?} (use {1})")] - AmbiguousColumn(String, String), - #[error("invalid table name {0:?}")] - InvalidTableName(Vec), - #[error("SQL not supported")] - NotSupportedTSQL, - #[error("invalid SQL")] - InvalidSQL, - #[error("cannot cast {0:?} to {1:?}")] - CastError(crate::types::DataValue, crate::types::DataType), - #[error("{0}")] - BindFunctionError(String), - #[error("type error: {0}")] - TypeError(#[from] TypeError), - #[error("aggregate function calls cannot be nested")] - NestedAgg, - #[error("WHERE clause cannot contain aggregates")] - AggInWhere, - #[error("GROUP BY clause cannot contain aggregates")] - AggInGroupBy, - #[error("window function calls cannot be nested")] - NestedWindow, - #[error("WHERE clause cannot contain window functions")] - WindowInWhere, - #[error("HAVING clause cannot contain window functions")] - WindowInHaving, - #[error("column {0:?} must appear in the GROUP BY clause or be used in an aggregate function")] - ColumnNotInAgg(String), - #[error("ORDER BY items must appear in the select list if DISTINCT is specified")] - OrderKeyNotInDistinct, - #[error("{0:?} is not an aggregate function")] - NotAgg(String), - #[error("unsupported object name: {0:?}")] - UnsupportedObjectName(ObjectType), - #[error("not supported yet: {0}")] - Todo(String), - #[error("can not copy to {0}")] - CopyTo(String), - #[error("can only insert into table")] - CanNotInsert, - #[error("can only delete from table")] - CanNotDelete, - #[error("VIEW aliases mismatch query result")] - ViewAliasesMismatch, - #[error("pragma does not exist: {0}")] - NoPragma(String), -} - /// The binder resolves all expressions referring to schema objects such as /// tables or views with their column names and types. pub struct Binder { @@ -166,35 +99,40 @@ impl UdfContext { /// expression out from the input `ast` pub fn extract_udf_expression(ast: Vec) -> Result { if ast.len() != 1 { - return Err(BindError::InvalidExpression( + return Err(ErrorKind::InvalidExpression( "the query for sql udf should contain only one statement".to_string(), - )); + ) + .into()); } // Extract the expression out let Statement::Query(query) = ast[0].clone() else { - return Err(BindError::InvalidExpression( + return Err(ErrorKind::InvalidExpression( "invalid function definition, please recheck the syntax".to_string(), - )); + ) + .into()); }; let SetExpr::Select(select) = *query.body else { - return Err(BindError::InvalidExpression( + return Err(ErrorKind::InvalidExpression( "missing `select` body for sql udf expression, please recheck the syntax" .to_string(), - )); + ) + .into()); }; if select.projection.len() != 1 { - return Err(BindError::InvalidExpression( + return Err(ErrorKind::InvalidExpression( "`projection` should contain only one `SelectItem`".to_string(), - )); + ) + .into()); } let SelectItem::UnnamedExpr(expr) = select.projection[0].clone() else { - return Err(BindError::InvalidExpression( + return Err(ErrorKind::InvalidExpression( "expect `UnnamedExpr` for `projection`".to_string(), - )); + ) + .into()); }; Ok(expr) @@ -210,7 +148,9 @@ impl UdfContext { match current_arg { FunctionArg::Unnamed(arg) => { let FunctionArgExpr::Expr(e) = arg else { - return Err(BindError::InvalidExpression("invalid syntax".to_string())); + return Err( + ErrorKind::InvalidExpression("invalid syntax".into()).into() + ); }; if catalog.arg_names[i].is_empty() { ret.insert(format!("${}", i + 1), e.clone()); @@ -220,7 +160,7 @@ impl UdfContext { ret.insert(catalog.arg_names[i].clone(), e.clone()); } } - _ => return Err(BindError::InvalidExpression("invalid syntax".to_string())), + _ => return Err(ErrorKind::InvalidExpression("invalid syntax".into()).into()), } } } @@ -316,10 +256,14 @@ impl Binder { Statement::Explain { statement, analyze, .. } => self.bind_explain(*statement, analyze), + Statement::Pragma { name, value, .. } => self.bind_pragma(name, value), + Statement::SetVariable { + variables, value, .. + } => self.bind_set(variables.as_ref(), value), Statement::ShowVariable { .. } | Statement::ShowCreate { .. } - | Statement::ShowColumns { .. } => Err(BindError::NotSupportedTSQL), - _ => Err(BindError::InvalidSQL), + | Statement::ShowColumns { .. } => Err(ErrorKind::NotSupportedTSQL.into()), + _ => Err(ErrorKind::InvalidSQL.into()), } } @@ -338,7 +282,7 @@ impl Binder { fn add_table_alias(&mut self, table_name: &str) -> Result<()> { let context = self.contexts.last_mut().unwrap(); if !context.table_aliases.insert(table_name.into()) { - return Err(BindError::DuplicatedAlias(table_name.into())); + return Err(ErrorKind::DuplicatedAlias(table_name.into()).into()); } Ok(()) } @@ -350,24 +294,30 @@ impl Binder { } /// Add a CTE to the current context. - fn add_cte(&mut self, table_name: &str, query: Id, columns: HashMap) -> Result<()> { + fn add_cte( + &mut self, + table_ident: &Ident, + query: Id, + columns: HashMap, + ) -> Result<()> { let context = self.contexts.last_mut().unwrap(); + let table_name = table_ident.value.to_lowercase(); if context .ctes - .insert(table_name.into(), (query, columns)) + .insert(table_name.clone(), (query, columns)) .is_some() { - return Err(BindError::DuplicatedCteName(table_name.into())); + return Err(ErrorKind::DuplicatedCteName(table_name).with_span(table_ident.span)); } Ok(()) } /// Find an alias. - fn find_alias(&self, column_name: &str, table_name: Option<&str>) -> Result { + fn find_alias(&self, column_ident: &Ident, table_ident: Option<&Ident>) -> Result { for context in self.contexts.iter().rev() { - if let Some(map) = context.column_aliases.get(column_name) { - if let Some(table_name) = table_name { - if let Some(id) = map.get(table_name) { + if let Some(map) = context.column_aliases.get(&column_ident.value) { + if let Some(table_ident) = table_ident { + if let Some(id) = map.get(&table_ident.value) { return Ok(*id); } } else if map.len() == 1 { @@ -375,13 +325,14 @@ impl Binder { } else { let use_ = map .keys() - .map(|table_name| format!("\"{table_name}.{column_name}\"")) + .map(|table_name| format!("\"{table_name}.{column_ident}\"")) .join(" or "); - return Err(BindError::AmbiguousColumn(column_name.into(), use_)); + return Err(ErrorKind::AmbiguousColumn(column_ident.value.clone(), use_) + .with_span(column_ident.span)); } } } - Err(BindError::InvalidColumn(column_name.into())) + Err(ErrorKind::InvalidColumn(column_ident.value.clone()).with_span(column_ident.span)) } /// Find an CTE. @@ -441,6 +392,32 @@ impl Binder { }); Ok(id) } + + pub fn bind_pragma(&mut self, name: ObjectName, value: Option) -> Result { + let name_string = name.to_string().to_lowercase(); + match name_string.as_str() { + "enable_optimizer" | "disable_optimizer" => {} + name_str => return Err(ErrorKind::NoPragma(name_str.into()).with_spanned(&name)), + } + let name_id = self.egraph.add(Node::Constant(name_string.into())); + let value_id = self.egraph.add(Node::Constant( + value.map_or(DataValue::Null, DataValue::from), + )); + let id = self.egraph.add(Node::Pragma([name_id, value_id])); + Ok(id) + } + + pub fn bind_set(&mut self, variables: &[ObjectName], values: Vec) -> Result { + if variables.len() != 1 || values.len() != 1 { + return Err(ErrorKind::InvalidSQL.into()); + } + let name_id = self + .egraph + .add(Node::Constant(variables[0].to_string().into())); + let value_id = self.bind_expr(values.into_iter().next().unwrap())?; + let id = self.egraph.add(Node::Set([name_id, value_id])); + Ok(id) + } } /// Split an object name into `(schema name, table name)`. @@ -448,7 +425,7 @@ fn split_name(name: &ObjectName) -> Result<(&str, &str)> { Ok(match name.0.as_slice() { [table] => (RootCatalog::DEFAULT_SCHEMA_NAME, &table.value), [schema, table] => (&schema.value, &table.value), - _ => return Err(BindError::InvalidTableName(name.0.clone())), + _ => return Err(ErrorKind::InvalidTableName(name.0.clone()).with_spanned(name)), }) } @@ -457,7 +434,7 @@ fn lower_case_name(name: &ObjectName) -> ObjectName { ObjectName( name.0 .iter() - .map(|ident| Ident::new(ident.value.to_lowercase())) + .map(|ident| Ident::with_span(ident.span, ident.value.to_lowercase())) .collect::>(), ) } diff --git a/src/binder/select.rs b/src/binder/select.rs index 4f1e78f1..0de24eab 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -1,5 +1,7 @@ // Copyright 2024 RisingLight Project Authors. Licensed under Apache-2.0. +use sqlparser::tokenizer::Span; + use super::*; use crate::parser::{Expr, Query, SelectItem, SetExpr}; @@ -16,7 +18,7 @@ impl Binder { pub(super) fn bind_query_internal(&mut self, query: Query) -> Result { if let Some(with) = query.with { if with.recursive { - return Err(BindError::Todo("recursive CTE".into())); + return Err(ErrorKind::Todo("recursive CTE".into()).with_spanned(&with)); } for cte in with.cte_tables { self.bind_cte(cte)?; @@ -25,7 +27,7 @@ impl Binder { let child = match *query.body { SetExpr::Select(select) => self.bind_select(*select, query.order_by)?, SetExpr::Values(values) => self.bind_values(values)?, - _ => return Err(BindError::Todo("unknown set expr".into())), + body => return Err(ErrorKind::Todo("unknown set expr".into()).with_spanned(&body)), }; let limit = match query.limit { Some(expr) => self.bind_expr(expr)?, @@ -51,11 +53,12 @@ impl Binder { let expected_column_num = self.schema(query).len(); let actual_column_num = alias.columns.len(); if actual_column_num != expected_column_num { - return Err(BindError::ColumnCountMismatch( + return Err(ErrorKind::ColumnCountMismatch( table_alias.clone(), expected_column_num, actual_column_num, - )); + ) + .with_spanned(&alias)); } for (column, id) in alias.columns.iter().zip(self.schema(query)) { columns.insert(column.name.value.to_lowercase(), id); @@ -70,7 +73,7 @@ impl Binder { columns.insert(name, id); } } - self.add_cte(&table_alias, query, columns)?; + self.add_cte(&alias.name, query, columns)?; Ok(query) } @@ -79,7 +82,9 @@ impl Binder { let projection = self.bind_projection(select.projection, from)?; let mut where_ = self.bind_where(select.selection)?; let groupby = match select.group_by { - GroupByExpr::All(_) => return Err(BindError::Todo("group by all".into())), + GroupByExpr::All(_) => { + return Err(ErrorKind::Todo("group by all".into()).with_spanned(&select.group_by)) + } GroupByExpr::Expressions(exprs, _) if exprs.is_empty() => None, GroupByExpr::Expressions(exprs, _) => Some(self.bind_groupby(exprs)?), }; @@ -148,10 +153,11 @@ impl Binder { pub(super) fn bind_where(&mut self, selection: Option) -> Result { let id = self.bind_selection(selection)?; if !self.aggs(id).is_empty() { - return Err(BindError::AggInWhere); + return Err(ErrorKind::AggInWhere.into()); // TODO: raise error in `bind_selection` to + // get the correct span } if !self.overs(id).is_empty() { - return Err(BindError::WindowInWhere); + return Err(ErrorKind::WindowInWhere.into()); // TODO: ditto } Ok(id) } @@ -160,7 +166,7 @@ impl Binder { fn bind_having(&mut self, selection: Option) -> Result { let id = self.bind_selection(selection)?; if !self.overs(id).is_empty() { - return Err(BindError::WindowInHaving); + return Err(ErrorKind::WindowInHaving.into()); // TODO: ditto } Ok(id) } @@ -179,7 +185,7 @@ impl Binder { fn bind_groupby(&mut self, group_by: Vec) -> Result { let id = self.bind_exprs(group_by)?; if !self.aggs(id).is_empty() { - return Err(BindError::AggInGroupBy); + return Err(ErrorKind::AggInGroupBy.into()); // TODO: ditto } Ok(id) } @@ -209,9 +215,11 @@ impl Binder { let column_len = values[0].len(); for row in values { if row.len() != column_len { - return Err(BindError::InvalidExpression( + let span = Span::union_iter(row.iter().map(|e| e.span())); + return Err(ErrorKind::InvalidExpression( "VALUES lists must all be the same length".into(), - )); + ) + .with_span(span)); } bound_values.push(self.bind_exprs(row)?); } @@ -231,7 +239,7 @@ impl Binder { // check nested agg for child in aggs.iter().flat_map(|agg| agg.children()) { if !self.aggs(*child).is_empty() { - return Err(BindError::NestedAgg); + return Err(ErrorKind::NestedAgg.into()); // TODO: ditto } } let mut list: Vec<_> = aggs.into_iter().map(|agg| self.egraph.add(agg)).collect(); @@ -277,7 +285,7 @@ impl Binder { } if let Node::Column(cid) = &expr { let name = self.catalog.get_column(cid).unwrap().name().to_string(); - return Err(BindError::ColumnNotInAgg(name)); + return Err(ErrorKind::ColumnNotInAgg(name).into()); } for child in expr.children_mut() { *child = self.rewrite_agg_in_expr(*child, schema)?; @@ -318,7 +326,7 @@ impl Binder { _ => id, }; if !distinct_on.contains(key) { - return Err(BindError::OrderKeyNotInDistinct); + return Err(ErrorKind::OrderKeyNotInDistinct.into()); } } // for all projection items that are not in DISTINCT list, diff --git a/src/binder/table.rs b/src/binder/table.rs index b739c85d..e01cf279 100644 --- a/src/binder/table.rs +++ b/src/binder/table.rs @@ -172,7 +172,7 @@ impl Binder { let ref_id = self .catalog .get_table_id_by_name(schema_name, table_name) - .ok_or_else(|| BindError::InvalidTable(table_name.into()))?; + .ok_or_else(|| ErrorKind::InvalidTable(table_name.into()))?; let table = self.catalog.get_table(&ref_id).unwrap(); let table_occurence = { @@ -220,7 +220,7 @@ impl Binder { let table_ref_id = self .catalog .get_table_id_by_name(schema_name, table_name) - .ok_or_else(|| BindError::InvalidTable(table_name.into()))?; + .ok_or_else(|| ErrorKind::InvalidTable(table_name.into()).with_spanned(&name))?; let table = self.catalog.get_table(&table_ref_id).unwrap(); @@ -230,9 +230,9 @@ impl Binder { let mut ids = vec![]; for col in columns { let col_name = col.value.to_lowercase(); - let col = table - .get_column_by_name(&col_name) - .ok_or_else(|| BindError::InvalidColumn(col_name.clone()))?; + let col = table.get_column_by_name(&col_name).ok_or_else(|| { + ErrorKind::InvalidColumn(col_name.clone()).with_span(col.span) + })?; ids.push(col.id()); } ids @@ -259,7 +259,7 @@ impl Binder { let table_ref_id = self .catalog .get_table_id_by_name(schema_name, table_name) - .ok_or_else(|| BindError::InvalidTable(table_name.into()))?; + .ok_or_else(|| ErrorKind::InvalidTable(table_name.into()).with_spanned(&name))?; let table = self.catalog.get_table(&table_ref_id).unwrap(); let id = self.egraph.add(Node::Table(table_ref_id)); Ok(( diff --git a/src/db.rs b/src/db.rs index 8c4dc422..a2e7b442 100644 --- a/src/db.rs +++ b/src/db.rs @@ -10,8 +10,8 @@ use risinglight_proto::rowset::block_statistics::BlockStatisticsType; use crate::array::Chunk; use crate::binder::bind_header; use crate::catalog::{RootCatalog, RootCatalogRef, TableRefId}; -use crate::parser::{parse, ParserError, Statement}; -use crate::planner::Statistics; +use crate::parser::{parse, ParserError}; +use crate::planner::{Expr, RecExpr, Statistics}; use crate::storage::{ InMemoryStorage, SecondaryStorage, SecondaryStorageOptions, Storage, StorageColumnRef, StorageImpl, Table, @@ -99,12 +99,11 @@ impl Database { let stmts = parse(&sql)?; let mut outputs: Vec = vec![]; for stmt in stmts { - if self.handle_set(&stmt)? { + let mut binder = crate::binder::Binder::new(self.catalog.clone()); + let mut plan = binder.bind(stmt.clone()).map_err(|e| e.with_sql(&sql))?; + if self.handle_set(&plan)? { continue; } - - let mut binder = crate::binder::Binder::new(self.catalog.clone()); - let mut plan = binder.bind(stmt.clone())?; if !self.config.lock().unwrap().disable_optimizer { plan = optimizer.optimize(plan); } @@ -155,47 +154,42 @@ impl Database { Ok(stat) } - /// Mock the row count of a table for planner test. - fn handle_set(&self, stmt: &Statement) -> Result { - if let Statement::Pragma { name, .. } = stmt { - match name.to_string().as_str() { + /// Handle PRAGMA and SET statements. + fn handle_set(&self, plan: &RecExpr) -> Result { + let root = &plan.as_ref()[plan.as_ref().len() - 1]; + match root { + Expr::Pragma([name, _value]) => match plan[*name].as_const().as_str() { "enable_optimizer" => { self.config.lock().unwrap().disable_optimizer = false; - return Ok(true); + Ok(true) } "disable_optimizer" => { self.config.lock().unwrap().disable_optimizer = true; - return Ok(true); + Ok(true) } - name => { - return Err(crate::binder::BindError::NoPragma(name.into()).into()); + name => Err(Error::Internal(format!("no such pragma: {name}"))), + }, + Expr::Set([name, value]) => match plan[*name].as_const().as_str() { + // Mock the row count of a table for planner test. + name if name.starts_with("mock_rowcount_") => { + let table_name = name.strip_prefix("mock_rowcount_").unwrap(); + let count = plan[*value].as_const().as_usize().unwrap().unwrap() as u32; + let table_id = self + .catalog + .get_table_id_by_name("postgres", table_name) + .ok_or_else(|| Error::Internal("table not found".into()))?; + self.config + .lock() + .unwrap() + .mock_stat + .get_or_insert_with(Default::default) + .add_row_count(table_id, count); + Ok(true) } - } + _ => Ok(false), + }, + _ => Ok(false), } - let Statement::SetVariable { - variables, value, .. - } = stmt - else { - return Ok(false); - }; - let Some(table_name) = variables[0].0[0].value.strip_prefix("mock_rowcount_") else { - return Ok(false); - }; - let count = value[0] - .to_string() - .parse::() - .map_err(|_| Error::Internal("invalid count".into()))?; - let table_id = self - .catalog - .get_table_id_by_name("postgres", table_name) - .ok_or_else(|| Error::Internal("table not found".into()))?; - self.config - .lock() - .unwrap() - .mock_stat - .get_or_insert_with(Default::default) - .add_row_count(table_id, count); - Ok(true) } /// Return all available pragma options. diff --git a/src/planner/explain.rs b/src/planner/explain.rs index 0cb2a03f..efabd09f 100644 --- a/src/planner/explain.rs +++ b/src/planner/explain.rs @@ -371,6 +371,20 @@ impl<'a> Explain<'a> { with_meta(vec![]), vec![self.child(child).pretty()], ), + Pragma([name, value]) => Pretty::childless_record( + "Pragma", + with_meta(vec![ + ("name", self.expr(name).pretty()), + ("value", self.expr(value).pretty()), + ]), + ), + Set([name, value]) => Pretty::childless_record( + "Set", + with_meta(vec![ + ("name", self.expr(name).pretty()), + ("value", self.expr(value).pretty()), + ]), + ), Empty(_) => Pretty::childless_record("Empty", with_meta(vec![])), Max1Row(child) => Pretty::fieldless_record("Max1Row", vec![self.expr(child).pretty()]), } diff --git a/src/planner/mod.rs b/src/planner/mod.rs index 8d4a07e5..e00cd261 100644 --- a/src/planner/mod.rs +++ b/src/planner/mod.rs @@ -129,6 +129,8 @@ define_language! { ExtSource(Box), "explain" = Explain(Id), // (explain child) "analyze" = Analyze(Id), // (analyze child) + "pragma" = Pragma([Id; 2]), // (pragma name value) + "set" = Set([Id; 2]), // (set name value) // internal functions "empty" = Empty(Id), // (empty child) diff --git a/src/types/value.rs b/src/types/value.rs index 0e9681c7..82c74524 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -186,6 +186,14 @@ impl DataValue { })) } + /// Get the string value. + pub fn as_str(&self) -> &str { + match self { + Self::String(s) => s, + _ => panic!("not a string: {:?}", self), + } + } + /// Cast the value to another type. pub fn cast(&self, ty: &DataType) -> Result { Ok(ArrayImpl::from(self).cast(ty)?.get(0)) @@ -240,6 +248,12 @@ impl From> for DataValue { } } +impl From for DataValue { + fn from(v: String) -> Self { + Self::String(v.into()) + } +} + #[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)] pub enum ParseValueError { #[error("invalid interval: {0}")]