Skip to content

Commit

Permalink
fix(sql): Improve quotation handling (#2964)
Browse files Browse the repository at this point in the history
Quotation marks are not really well handled at the moment. Sometimes,
they're stripped, sometimes they're not. Rather than introducing massive
changes to preserve the fact that an identifier was originally quotes, I
propose the strip them whenever they end up being stored in a `String`.
On the generator side, we can than automatically add the quotes back if
we detect that they are needed.
  • Loading branch information
NicolasRichard authored May 18, 2023
1 parent 976894a commit e28316a
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package wvlet.airframe.sql.model
import wvlet.airframe.sql.analyzer.AnalyzerContext
import wvlet.airframe.sql.catalog.DataType
import wvlet.airframe.sql.catalog.DataType._
import wvlet.airframe.sql.model.Expression.{AllColumns, MultiSourceColumn}
import wvlet.airframe.sql.model.Expression.{AllColumns, MultiSourceColumn, QName}
import wvlet.airframe.sql.parser.SQLGenerator
import wvlet.log.LogSupport

Expand Down Expand Up @@ -431,6 +431,9 @@ object Expression {
QuotedIdentifier(x.stripPrefix("\"").stripSuffix("\""), None)
} else if (x.matches("[0-9]+")) {
DigitId(x, None)
} else if (!x.matches("[0-9a-zA-Z_]*")) {
// Quotations are needed with special characters to generate valid SQL
QuotedIdentifier(x, None)
} else {
UnquotedIdentifier(x, None)
}
Expand All @@ -444,6 +447,7 @@ object Expression {
case class QName(parts: List[String], nodeLocation: Option[NodeLocation]) extends LeafExpression {
def fullName: String = parts.mkString(".")
override def toString: String = fullName
override def sqlExpr: String = parts.map(Expression.newIdentifier).map(_.sqlExpr).mkString(".")
}
object QName {
def apply(s: String, nodeLocation: Option[NodeLocation]): QName = {
Expand Down Expand Up @@ -609,7 +613,7 @@ object Expression {
override def dataType: DataType = expr.dataType

override def sqlExpr: String = {
s"${expr.sqlExpr} AS ${fullName}"
s"${expr.sqlExpr} AS ${QName.apply(fullName, None).sqlExpr}"
}

override def sourceColumns: Seq[SourceColumn] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ package wvlet.airframe.sql.model

import wvlet.airframe.sql.analyzer.QuerySignatureConfig
import wvlet.airframe.sql.catalog.{Catalog, DataType}
import wvlet.airframe.sql.model.Expression.GroupingKey
import wvlet.airframe.sql.model.Expression.{GroupingKey, QName}
import wvlet.airframe.sql.model.LogicalPlan.Relation
import wvlet.log.LogSupport

Expand Down Expand Up @@ -78,7 +78,7 @@ case class ResolvedAttribute(
with LogSupport {

override lazy val resolved = true
override def sqlExpr: String = s"${prefix}${name}"
override def sqlExpr: String = QName.apply(fullName, None).sqlExpr

override def withQualifier(newQualifier: Option[String]): Attribute = {
this.copy(qualifier = newQualifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ object SQLGenerator extends LogSupport {
case c: CTERelationRef =>
c.name
case TableRef(t, _) =>
printExpression(t)
printNameWithQuotationsIfNeeded(t.fullName)
case t: TableScan =>
t.fullName
printNameWithQuotationsIfNeeded(t.fullName)
case Limit(in, l, _) =>
val s = seqBuilder
s += printRelation(in, context)
Expand Down Expand Up @@ -379,15 +379,15 @@ object SQLGenerator extends LogSupport {
s"(${printExpression(expr)})"
case a: Alias =>
val e = printExpression(a.expr)
s"${e} AS ${a.name}"
s"${e} AS ${printNameWithQuotationsIfNeeded(a.name)}"
case SingleColumn(ex, _, _) =>
printExpression(ex)
case m: MultiSourceColumn =>
m.sqlExpr
case a: AllColumns =>
a.fullName
case a: Attribute =>
a.fullName
printNameWithQuotationsIfNeeded(a.fullName)
case SortItem(key, ordering, nullOrdering, _) =>
val k = printExpression(key)
val o = ordering.map(x => s" ${x}").getOrElse("")
Expand Down Expand Up @@ -517,4 +517,8 @@ object SQLGenerator extends LogSupport {
case other => unknown(other)
}
}

private def printNameWithQuotationsIfNeeded(name: String): String = {
QName.apply(name, None).sqlExpr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,15 +409,15 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito

override def visitSelectSingle(ctx: SelectSingleContext): Attribute = {
val alias = Option(ctx.AS())
.map(x => expression(ctx.identifier()))
.orElse(Option(ctx.identifier()).map(expression(_)))
.map(_ => visitIdentifier(ctx.identifier()))
.orElse(Option(ctx.identifier()).map(visitIdentifier))
val child = expression(ctx.expression())
val qualifier = child match {
case a: Attribute => a.qualifier
case _ => None
}
SingleColumn(child, qualifier, getLocation(ctx))
.withAlias(alias.map(_.sqlExpr))
.withAlias(alias.map(_.value))
}

override def visitExpression(ctx: ExpressionContext): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,5 @@
select SUBSTRING(a FROM 1 FOR 5) FROM b
- sql: |
select user_agent || 'x', count(*) from impression group by 1
- sql: |
select * FROM "café"
Original file line number Diff line number Diff line change
Expand Up @@ -1066,4 +1066,8 @@ class TypeResolverTest extends AirSpec with ResolverTestHelper {
// No ambiguity error
}

test("Resolve quoted identifiers") {
analyze("select \"prénom\" from (select name as \"prénom\" from A)")
// No error
}
}

0 comments on commit e28316a

Please sign in to comment.