Skip to content

Commit

Permalink
Support DECLARE in PG DO statement
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengguanLi committed Jun 24, 2024
1 parent 3c0f503 commit ba3c23a
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ public SQLStatement parseFunction() {

if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
parserParameters(stmt.getParameters(), stmt);
parseParameters(stmt.getParameters(), stmt);
accept(Token.RPAREN);
}

Expand Down Expand Up @@ -1921,7 +1921,7 @@ public SQLStatement parseBlock() {
}

if (lexer.token() == Token.IDENTIFIER || lexer.token() == Token.CURSOR) {
parserParameters(block.getParameters(), block);
parseParameters(block.getParameters(), block);
for (SQLParameter param : block.getParameters()) {
param.setParent(block);
}
Expand Down Expand Up @@ -1965,7 +1965,7 @@ public SQLStatement parseBlock() {
return block;
}

protected void parserParameters(List<SQLParameter> parameters, SQLObject parent) {
private void parseParameters(List<SQLParameter> parameters, SQLObject parent) {
for (; ; ) {
SQLParameter parameter = new SQLParameter();
parameter.setParent(parent);
Expand Down Expand Up @@ -1994,7 +1994,7 @@ protected void parserParameters(List<SQLParameter> parameters, SQLObject parent)

if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
this.parserParameters(parameter.getCursorParameters(), parameter);
this.parseParameters(parameter.getCursorParameters(), parameter);
accept(Token.RPAREN);
}

Expand Down Expand Up @@ -2154,7 +2154,7 @@ protected void parserParameters(List<SQLParameter> parameters, SQLObject parent)
accept(Token.IDENTIFIER);
if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
this.parserParameters(functionDataType.getParameters(), functionDataType);
this.parseParameters(functionDataType.getParameters(), functionDataType);
accept(Token.RPAREN);
}
accept(Token.RETURN);
Expand All @@ -2180,7 +2180,7 @@ protected void parserParameters(List<SQLParameter> parameters, SQLObject parent)
accept(Token.IDENTIFIER);
if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
this.parserParameters(procedureDataType.getParameters(), procedureDataType);
this.parseParameters(procedureDataType.getParameters(), procedureDataType);
accept(Token.RPAREN);
}

Expand Down Expand Up @@ -2803,7 +2803,7 @@ public SQLCreateProcedureStatement parseCreateProcedure() {

if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
parserParameters(stmt.getParameters(), stmt);
parseParameters(stmt.getParameters(), stmt);
accept(Token.RPAREN);
}

Expand Down Expand Up @@ -3072,7 +3072,7 @@ public SQLStatement parseCreateType() {
}

if (lexer.identifierEquals(FnvHash.Constants.STATIC)) {
this.parserParameters(stmt.getParameters(), stmt);
this.parseParameters(stmt.getParameters(), stmt);
} else if (lexer.token() == Token.TABLE) {
lexer.nextToken();
accept(Token.OF);
Expand Down Expand Up @@ -3106,11 +3106,11 @@ public SQLStatement parseCreateType() {
} else {
if (lexer.token() == Token.LPAREN) {
lexer.nextToken();
this.parserParameters(stmt.getParameters(), stmt);
this.parseParameters(stmt.getParameters(), stmt);
stmt.setParen(true);
accept(Token.RPAREN);
} else {
this.parserParameters(stmt.getParameters(), stmt);
this.parseParameters(stmt.getParameters(), stmt);
if (lexer.token() == Token.END) {
lexer.nextToken();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,40 @@
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.SQLStatementImpl;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.dialect.postgresql.visitor.PGASTVisitor;
import com.alibaba.druid.sql.visitor.SQLASTVisitor;

public class PGDoStatement extends SQLStatementImpl implements PGSQLStatement {
private boolean isDollarQuoted;
private SQLName name;

private SQLName funcName;

private SQLStatement block;
private SQLIdentifierExpr language;

public PGDoStatement() {
isDollarQuoted = true;
}

protected void accept0(SQLASTVisitor visitor) {
accept0((PGASTVisitor) visitor);
}

@Override
public void accept0(PGASTVisitor visitor) {
if (visitor.visit(this)) {
acceptChild(visitor, funcName);
acceptChild(visitor, block);
}
visitor.endVisit(this);
}

public boolean isDollarQuoted() {
return isDollarQuoted;
}

public void setDollarQuoted(boolean dollarQuoted) {
this.isDollarQuoted = dollarQuoted;
}

public SQLName getName() {
return name;
}
Expand All @@ -49,14 +60,6 @@ public void setName(SQLName name) {
this.name = name;
}

public SQLName getFuncName() {
return funcName;
}

public void setFuncName(SQLName funcName) {
this.funcName = funcName;
}

public SQLStatement getBlock() {
return block;
}
Expand All @@ -67,4 +70,12 @@ public void setBlock(SQLStatement block) {
}
this.block = block;
}

public SQLIdentifierExpr getLanguage() {
return language;
}

public void setLanguage(SQLIdentifierExpr language) {
this.language = language;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class PGLexer extends Lexer {
map.put("MATCHED", Token.MATCHED);
map.put("PARTITION", Token.PARTITION);
map.put("INTERVAL", Token.INTERVAL);
map.put("LANGUAGE", Token.LANGUAGE);

DEFAULT_PG_KEYWORDS = new Keywords(map);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
*/
package com.alibaba.druid.sql.dialect.postgresql.parser;

import com.alibaba.druid.sql.ast.SQLDataType;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.SQLParameter;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.*;
import com.alibaba.druid.sql.ast.statement.*;
Expand Down Expand Up @@ -414,55 +417,147 @@ public PGDoStatement parseDo() {
stmt.setDbType(dbType);

accept(Token.DO);

stmt.setFuncName(this.exprParser.name());

if (lexer.token() == Token.DECLARE) {
parseVariables(stmt);
if (lexer.token() == Token.DOLLAR_DOLLAR) {
stmt.setDollarQuoted(true);
lexer.nextToken();
}


String labelName = null;
if (lexer.token() == Token.IDENTIFIER) {
labelName = lexer.stringVal();
lexer.nextToken();
}

SQLStatement block;
if (lexer.token() == Token.BEGIN) {
block = this.parseBlock();
if (lexer.token() == Token.BEGIN
|| lexer.token() == Token.DECLARE) {
block = this.parseBlock(labelName);
} else {
block = this.parseStatement();
}
stmt.setBlock(block);
if (lexer.token() == Token.IDENTIFIER) {
SQLName endFuncName = this.exprParser.name();
if (!stmt.getFuncName().equals(endFuncName)) {
printError(lexer.token());

if (lexer.token() == Token.DOLLAR_DOLLAR) {
lexer.nextToken();
if(lexer.token() != Token.SEMI) {
accept(Token.LANGUAGE);
stmt.setLanguage(this.exprParser.identifier());
}
accept(Token.SEMI);
}

return stmt;
}

public void parseVariables(PGDoStatement stmt) {
accept(Token.DECLARE);
if (lexer.token() != Token.BEGIN) {
// todo: parseVariables
throw new ParserException("TODO " + lexer.info());
}
}

public SQLBlockStatement parseBlock() {
public SQLBlockStatement parseBlock(String labelName) {
SQLBlockStatement block = new SQLBlockStatement();
block.setDbType(dbType);
block.setHaveBeginEnd(false);
if (labelName != null) {
block.setLabelName(labelName);
}

if (lexer.token() == Token.DECLARE) {
lexer.nextToken();
}
if (lexer.token() == Token.IDENTIFIER || lexer.token() == Token.CURSOR) {
parseParameters(block.getParameters(), block);
for (SQLParameter param : block.getParameters()) {
param.setParent(block);
}
}

accept(Token.BEGIN);
List<SQLStatement> statementList = block.getStatementList();
this.parseStatementList(statementList, -1, block);
if (lexer.token() != Token.END
&& statementList.size() > 0
&& !statementList.isEmpty()
&& (statementList.get(statementList.size() - 1) instanceof SQLCommitStatement
|| statementList.get(statementList.size() - 1) instanceof SQLRollbackStatement)) {
block.setEndOfCommit(true);
return block;
}
accept(Token.END);

Token token = lexer.token();
if (token != Token.SEMI) {
if (lexer.token() == Token.IDENTIFIER) {
labelName = lexer.stringVal();
if (!block.getLabelName().equals(labelName)) {
printError(lexer.token());
}
}
}

accept(Token.SEMI);
return block;
}

private void parseParameters(List<SQLParameter> parameters, SQLObject parent) {
for (;;) {
SQLParameter parameter = new SQLParameter();
parameter.setParent(parent);

SQLName name;
SQLDataType dataType = null;
name = this.exprParser.name();
if (lexer.token() == Token.IN) {
lexer.nextToken();

if (lexer.token() == Token.OUT) {
lexer.nextToken();
parameter.setParamType(SQLParameter.ParameterType.INOUT);
} else {
parameter.setParamType(SQLParameter.ParameterType.IN);
}
} else if (lexer.token() == Token.OUT) {
lexer.nextToken();

if (lexer.token() == Token.IN) {
lexer.nextToken();
parameter.setParamType(SQLParameter.ParameterType.INOUT);
} else {
parameter.setParamType(SQLParameter.ParameterType.OUT);
}
} else if (lexer.token() == Token.INOUT) {
lexer.nextToken();
parameter.setParamType(SQLParameter.ParameterType.INOUT);
}

dataType = this.exprParser.parseDataType(false);

if (lexer.token() == Token.NOT) {
lexer.nextToken();
accept(Token.NULL);
parameter.setNotNull(true);
}

if (lexer.token() == Token.COLONEQ || lexer.token() == Token.DEFAULT) {
lexer.nextToken();
parameter.setDefaultValue(this.exprParser.expr());
}

parameter.setName(name);
parameter.setDataType(dataType);

parameters.add(parameter);
Token token = lexer.token();
if (token == Token.COMMA || token == Token.SEMI || token == Token.IS) {
lexer.nextToken();
}

token = lexer.token();
if (token != Token.BEGIN
&& token != Token.RPAREN
&& token != Token.EOF
&& token != Token.FUNCTION
&& !lexer.identifierEquals("DETERMINISTIC")) {
continue;
}

break;
}
}

@Override
public SQLStatement parseIf() {
accept(Token.IF);
Expand Down
Loading

0 comments on commit ba3c23a

Please sign in to comment.