From c5f9c910580e926b6d3ef620e9a6eb38d3cfb7d8 Mon Sep 17 00:00:00 2001 From: lovasoa Date: Tue, 25 Jul 2023 15:48:20 +0200 Subject: [PATCH] avoid database roundtrips for simple static select statements --- CHANGELOG.md | 1 + src/webserver/database/mod.rs | 64 ++++++----- src/webserver/database/sql.rs | 200 ++++++++++++++++++++++++++++------ 3 files changed, 204 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66f73faa..adb117aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Added a new `json` component, which allows building a JSON API entirely in SQL with SQLPage ! Now creating an api over your database is as simple as `SELECT 'json' AS component, JSON_OBJECT('hello', 'world') AS contents`. + - `SELECT` statements that contain only static values are now interpreted directly by SQLPage, and do not result in a database query. This greatly improves the performance of pages that contain many static elements. ## 0.8.0 (2023-07-17) diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 82b9b75a..828802d0 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -2,12 +2,11 @@ mod sql; mod sql_pseudofunctions; use anyhow::{anyhow, Context}; -use futures_util::stream::{self, BoxStream, Stream}; +use futures_util::stream::Stream; use futures_util::StreamExt; use serde_json::{Map, Value}; use std::borrow::Cow; use std::fmt::{Display, Formatter}; -use std::future::ready; use std::path::Path; use std::time::Duration; @@ -29,6 +28,8 @@ use sqlx::{ Any, AnyPool, Arguments, Column, ConnectOptions, Decode, Either, Executor, Row, Statement, }; +use self::sql::ParsedSQLStatement; + pub struct Database { pub(crate) connection: AnyPool, } @@ -98,42 +99,49 @@ pub async fn stream_query_results<'a>( sql_file: &'a ParsedSqlFile, request: &'a RequestInfo, ) -> impl Stream + 'a { - stream_query_results_direct(db, sql_file, request) - .await - .unwrap_or_else(|error| stream::once(ready(Err(error))).boxed()) - .map(|res| match res { - Ok(Either::Right(r)) => DbItem::Row(row_to_json(&r)), - Ok(Either::Left(res)) => { - log::debug!("Finished query with result: {:?}", res); - DbItem::FinishedQuery + async_stream::stream! { + let mut connection = match db.connection.acquire().await { + Ok(c) => c, + Err(e) => { + let err_msg = format!("Unable to acquire a database connection to execute the SQL file. All of the {} {:?} connections are busy.", db.connection.size(), db.connection.any_kind()); + yield DbItem::Error(anyhow::Error::new(e).context(err_msg)); + return; } - Err(e) => DbItem::Error(e), - }) -} - -pub async fn stream_query_results_direct<'a>( - db: &'a Database, - sql_file: &'a ParsedSqlFile, - request: &'a RequestInfo, -) -> anyhow::Result>>> { - Ok(async_stream::stream! { - let mut connection = db.connection.acquire().await - .with_context(|| anyhow::anyhow!("Unable to acquire a database connection to execute the SQL file. All of the {} {:?} connections are busy.", db.connection.size(), db.connection.any_kind()))?; + }; for res in &sql_file.statements { match res { - Ok(stmt)=>{ - let query = bind_parameters(stmt, request) - .with_context(|| format!("Unable to bind parameters to the SQL statement: {stmt}"))?; + ParsedSQLStatement::Statement(stmt)=>{ + let query = match bind_parameters(stmt, request) { + Ok(q) => q, + Err(e) => { + yield DbItem::Error(e); + continue; + } + }; let mut stream = query.fetch_many(&mut connection); while let Some(elem) = stream.next().await { - yield elem.with_context(|| format!("Error while running SQL: {stmt}")) + yield parse_single_sql_result(elem) } }, - Err(e) => yield Err(clone_anyhow_err(e)), + ParsedSQLStatement::StaticSimpleSelect(value) => { + yield DbItem::Row(value.clone().into()) + } + ParsedSQLStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)), } } } - .boxed()) +} + +#[inline] +fn parse_single_sql_result(res: sqlx::Result>) -> DbItem { + match res { + Ok(Either::Right(r)) => DbItem::Row(row_to_json(&r)), + Ok(Either::Left(res)) => { + log::debug!("Finished query with result: {:?}", res); + DbItem::FinishedQuery + } + Err(e) => DbItem::Error(e.into()), + } } fn clone_anyhow_err(err: &anyhow::Error) -> anyhow::Error { diff --git a/src/webserver/database/sql.rs b/src/webserver/database/sql.rs index 29da0562..ac9bbb11 100644 --- a/src/webserver/database/sql.rs +++ b/src/webserver/database/sql.rs @@ -4,8 +4,8 @@ use crate::file_cache::AsyncFromStrWithState; use crate::{AppState, Database}; use async_trait::async_trait; use sqlparser::ast::{ - DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value, VisitMut, - VisitorMut, + DataType, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Statement, Value, + VisitMut, VisitorMut, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::{Parser, ParserError}; @@ -18,7 +18,13 @@ use std::ops::ControlFlow; #[derive(Default)] pub struct ParsedSqlFile { - pub(super) statements: Vec>, + pub(super) statements: Vec, +} + +pub(super) enum ParsedSQLStatement { + Statement(PreparedStatement), + StaticSimpleSelect(serde_json::Map), + Error(anyhow::Error), } impl ParsedSqlFile { @@ -38,19 +44,29 @@ impl ParsedSqlFile { } }; while parser.consume_token(&SemiColon) {} + if let Some(static_statement) = extract_static_simple_select(&stmt) { + log::debug!("Optimised a static simple select to avoid a trivial database query: {stmt} optimized to {static_statement:?}"); + statements.push(ParsedSQLStatement::StaticSimpleSelect(static_statement)); + continue; + } let db_kind = db.connection.any_kind(); let parameters = ParameterExtractor::extract_parameters(&mut stmt, db_kind); let query = stmt.to_string(); let param_types = get_param_types(¶meters); let stmt_res = db.prepare_with(&query, ¶m_types).await; - match &stmt_res { - Ok(_) => log::debug!("Successfully prepared SQL statement '{query}'"), - Err(err) => log::warn!("{err:#}"), - } - let statement_result = stmt_res.map(|statement| PreparedStatement { - statement, - parameters, - }); + let statement_result = match stmt_res { + Ok(statement) => { + log::debug!("Successfully prepared SQL statement '{query}'"); + ParsedSQLStatement::Statement(PreparedStatement { + statement, + parameters, + }) + } + Err(err) => { + log::warn!("{err:#}"); + ParsedSQLStatement::Error(err) + } + }; statements.push(statement_result); } statements.shrink_to_fit(); @@ -60,7 +76,7 @@ impl ParsedSqlFile { fn finish_with_error( err: ParserError, mut parser: Parser, - mut statements: Vec>, + mut statements: Vec, ) -> ParsedSqlFile { let mut err_msg = "SQL syntax error before: ".to_string(); for _ in 0..32 { @@ -71,15 +87,15 @@ impl ParsedSqlFile { _ = write!(&mut err_msg, "{next_token} "); } let error = anyhow::Error::from(err).context(err_msg); - statements.push(Err(error)); + statements.push(ParsedSQLStatement::Error(error)); ParsedSqlFile { statements } } fn from_err(e: impl Into) -> Self { Self { - statements: vec![Err(e - .into() - .context("SQLPage could not parse the SQL file"))], + statements: vec![ParsedSQLStatement::Error( + e.into().context("SQLPage could not parse the SQL file"), + )], } } } @@ -110,6 +126,57 @@ fn map_param(mut name: String) -> StmtParam { } } +fn extract_static_simple_select( + stmt: &Statement, +) -> Option> { + let set_expr = match stmt { + Statement::Query(q) + if q.limit.is_none() + && q.fetch.is_none() + && q.order_by.is_empty() + && q.with.is_none() + && q.offset.is_none() + && q.locks.is_empty() => + { + q.body.as_ref() + } + _ => return None, + }; + let select_items = match set_expr { + sqlparser::ast::SetExpr::Select(s) + if s.cluster_by.is_empty() + && s.distinct.is_none() + && s.distribute_by.is_empty() + && s.from.is_empty() + && s.group_by.is_empty() + && s.having.is_none() + && s.into.is_none() + && s.lateral_views.is_empty() + && s.named_window.is_empty() + && s.qualify.is_none() + && s.selection.is_none() + && s.sort_by.is_empty() + && s.top.is_none() => + { + &s.projection + } + _ => return None, + }; + let mut map = serde_json::Map::with_capacity(select_items.len()); + for select_item in select_items { + let sqlparser::ast::SelectItem::ExprWithAlias { expr, alias } = select_item else { return None }; + let value = match expr { + Expr::Value(Value::Boolean(b)) => serde_json::Value::Bool(*b), + Expr::Value(Value::Number(n, _)) => serde_json::Value::Number(n.parse().ok()?), + Expr::Value(Value::SingleQuotedString(s)) => serde_json::Value::String(s.clone()), + Expr::Value(Value::Null) => serde_json::Value::Null, + _ => return None, + }; + map.insert(alias.value.clone(), value); + } + Some(map) +} + struct ParameterExtractor { db_kind: AnyKind, parameters: Vec, @@ -291,23 +358,90 @@ fn sqlpage_func_name(func_name_parts: &[Ident]) -> &str { } } -#[test] -fn test_statement_rewrite() { - let sql = "select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"; - let mut ast = Parser::parse_sql(&GenericDialect, sql).unwrap(); - let parameters = ParameterExtractor::extract_parameters(&mut ast[0], AnyKind::Postgres); - assert_eq!( - ast[0].to_string(), +#[cfg(test)] +mod test { + use super::*; + + fn parse_stmt(sql: &str) -> Statement { + let mut ast = Parser::parse_sql(&GenericDialect, sql).unwrap(); + assert_eq!(ast.len(), 1); + ast.pop().unwrap() + } + + #[test] + fn test_statement_rewrite() { + let mut ast = parse_stmt("select $a from t where $x > $a OR $x = sqlpage.cookie('cookoo')"); + let parameters = ParameterExtractor::extract_parameters(&mut ast, AnyKind::Postgres); + assert_eq!( + ast.to_string(), "SELECT CAST($1 AS TEXT) FROM t WHERE CAST($2 AS TEXT) > CAST($3 AS TEXT) OR CAST($4 AS TEXT) = CAST($5 AS TEXT)" ); - assert_eq!( - parameters, - [ - StmtParam::GetOrPost("a".to_string()), - StmtParam::GetOrPost("x".to_string()), - StmtParam::GetOrPost("a".to_string()), - StmtParam::GetOrPost("x".to_string()), - StmtParam::Cookie("cookoo".to_string()), - ] - ); + assert_eq!( + parameters, + [ + StmtParam::GetOrPost("a".to_string()), + StmtParam::GetOrPost("x".to_string()), + StmtParam::GetOrPost("a".to_string()), + StmtParam::GetOrPost("x".to_string()), + StmtParam::Cookie("cookoo".to_string()), + ] + ); + } + + #[test] + fn test_static_extract() { + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer, null as nothing" + )), + Some( + serde_json::json!({ + "hello": "hello", + "answer": 42, + "nothing": (), + }) + .as_object() + .unwrap() + .clone() + ) + ); + } + + #[test] + fn test_static_extract_doesnt_match() { + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer limit 0" + )), + None + ); + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer order by 1" + )), + None + ); + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer offset 1" + )), + None + ); + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer where 1 = 0" + )), + None + ); + assert_eq!( + extract_static_simple_select(&parse_stmt( + "select 'hello' as hello, 42 as answer FROM t" + )), + None + ); + assert_eq!( + extract_static_simple_select(&parse_stmt("select x'CAFEBABE' as hello, 42 as answer")), + None + ); + } }