Skip to content

Commit

Permalink
[bug] proxy: current way of detecting txn status is not robust
Browse files Browse the repository at this point in the history
Reason: current way of detecting a connection is in active txn
or not is by parse the SQL statement from client. If it matches
begin, commit or rollback, we change the txn status. This way is
not robust, because the statement may contains some comment strings.

Fix: check the txn status by the OK and EOF packet returned from server.
Those packets contains the txn status. But currently, MO server
does not set the status in OK and EOF packets correctly. So first,
set the correty status in OK and EOF packets, then set the txn status
in proxy session.
  • Loading branch information
volgariver6 committed Jul 7, 2023
1 parent d5bc8ba commit 1d49fe1
Show file tree
Hide file tree
Showing 21 changed files with 353 additions and 348 deletions.
2 changes: 1 addition & 1 deletion pkg/frontend/cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (bse *baseStmtExecutor) ResponseBeforeExec(ctx context.Context, ses *Sessio
func (bse *baseStmtExecutor) ResponseAfterExec(ctx context.Context, ses *Session) error {
var err, retErr error
if bse.GetStatus() == stmtExecSuccess {
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, bse.GetServerStatus(), int(COM_QUERY), "")
if err = ses.GetMysqlProtocol().SendResponse(ctx, resp); err != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
logStatementStatus(ctx, ses, bse.GetAst(), fail, retErr)
Expand Down
4 changes: 4 additions & 0 deletions pkg/frontend/computation_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ func (cwft *TxnComputationWrapper) GetAffectedRows() uint64 {
return cwft.compile.GetAffectedRows()
}

func (cwft *TxnComputationWrapper) GetServerStatus() uint16 {
return cwft.ses.GetServerStatus()
}

func (cwft *TxnComputationWrapper) Compile(requestCtx context.Context, u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
var err error
defer RecordStatementTxnID(requestCtx, cwft.ses)
Expand Down
65 changes: 23 additions & 42 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func (mce *MysqlCmdExecutor) handleSelectVariables(ve *tree.VarExpr, cwIndex, cw
mrs.AddRow(row)

mer := NewMysqlExecutionResult(0, 0, 0, 0, mrs)
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.GetRequestContext(), resp); err != nil {
return moerr.NewInternalError(ses.GetRequestContext(), "routine send response failed.")
Expand Down Expand Up @@ -817,7 +817,7 @@ func (mce *MysqlCmdExecutor) handleShowErrors(cwIndex, cwsLen int) error {
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -961,7 +961,7 @@ func (mce *MysqlCmdExecutor) handleShowVariables(sv *tree.ShowVariables, proc *p
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1270,7 +1270,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
return err
}

resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())

if len(results) == 0 {
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
Expand All @@ -1279,7 +1279,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
} else {
for i, result := range results {
mer := NewMysqlExecutionResult(0, 0, 0, 0, result.(*MysqlResultSet))
resp = SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
resp = mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
}
Expand Down Expand Up @@ -1335,7 +1335,7 @@ func (mce *MysqlCmdExecutor) handleKill(ctx context.Context, k *tree.Kill) error
if err != nil {
return err
}
resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())
if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
}
Expand All @@ -1352,7 +1352,7 @@ func (mce *MysqlCmdExecutor) handleShowAccounts(ctx context.Context, sa *tree.Sh
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand All @@ -1369,7 +1369,7 @@ func (mce *MysqlCmdExecutor) handleShowSubscriptions(ctx context.Context, ss *tr
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1455,7 +1455,7 @@ func (mce *MysqlCmdExecutor) handleShowBackendServers(ctx context.Context, cwInd
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed, error: %v ", err)
}
Expand Down Expand Up @@ -3504,26 +3504,7 @@ func (mce *MysqlCmdExecutor) doComQueryInProgress(requestCtx context.Context, in
}

func (mce *MysqlCmdExecutor) setResponse(cwIndex, cwsLen int, rspLen uint64) *Response {

//if the stmt has next stmt, should set the server status equals to 10
if cwIndex < cwsLen-1 {
return NewOkResponse(rspLen, 0, 0, SERVER_MORE_RESULTS_EXISTS, int(COM_QUERY), "")
} else {
return NewOkResponse(rspLen, 0, 0, 0, int(COM_QUERY), "")
}

}

func SetNewResponse(category int, status uint16, cmd int, d interface{}, cwIndex, cwsLen int) *Response {

//if the stmt has next stmt, should set the server status equals to 10
var resp *Response
if cwIndex < cwsLen-1 {
resp = NewResponse(category, SERVER_MORE_RESULTS_EXISTS, cmd, d)
} else {
resp = NewResponse(category, 0, cmd, d)
}
return resp
return mce.ses.SetNewResponse(OkResponse, rspLen, int(COM_QUERY), "", cwIndex, cwsLen)
}

// ExecRequest the server execute the commands from the client following the mysql's routine
Expand All @@ -3533,9 +3514,9 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
moe, ok := e.(*moerr.Error)
if !ok {
err = moerr.ConvertPanicError(requestCtx, e)
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
} else {
resp = NewGeneralErrorResponse(COM_QUERY, moe)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), moe)
}
}
}()
Expand All @@ -3559,7 +3540,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(SubStringFromBegin(query, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))))
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
}
return resp, nil
case COM_INIT_DB:
Expand All @@ -3568,7 +3549,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := "use `" + dbname + "`"
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_INIT_DB, err)
resp = NewGeneralErrorResponse(COM_INIT_DB, mce.ses.GetServerStatus(), err)
}

return resp, nil
Expand All @@ -3578,12 +3559,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := makeCmdFieldListSql(payload)
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_FIELD_LIST, err)
resp = NewGeneralErrorResponse(COM_FIELD_LIST, mce.ses.GetServerStatus(), err)
}

return resp, nil
case COM_PING:
resp = NewGeneralOkResponse(COM_PING)
resp = NewGeneralOkResponse(COM_PING, mce.ses.GetServerStatus())

return resp, nil

Expand All @@ -3600,7 +3581,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, err)
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3610,11 +3591,11 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
var prepareStmt *PrepareStmt
sql, prepareStmt, err = mce.parseStmtExecute(requestCtx, data)
if err != nil {
return NewGeneralErrorResponse(COM_STMT_EXECUTE, err), nil
return NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err), nil
}
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, err)
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err)
}
if prepareStmt.params != nil {
prepareStmt.params.GetNulls().Reset()
Expand All @@ -3629,7 +3610,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
data := req.GetData().([]byte)
err = mce.parseStmtSendLongData(requestCtx, data)
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, err)
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, mce.ses.GetServerStatus(), err)
return resp, nil
}
return nil, nil
Expand All @@ -3645,7 +3626,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, err)
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3659,12 +3640,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(sql))
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_RESET, err)
resp = NewGeneralErrorResponse(COM_STMT_RESET, mce.ses.GetServerStatus(), err)
}
return resp, nil

default:
resp = NewGeneralErrorResponse(req.GetCmd(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
resp = NewGeneralErrorResponse(req.GetCmd(), mce.ses.GetServerStatus(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
}
return resp, nil
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/frontend/mysql_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,10 @@ func (mp *MysqlProtocolImpl) readIntLenEnc(data []byte, pos int) (uint64, int, b
return uint64(data[pos]), pos + 1, true
}

func (mp *MysqlProtocolImpl) ReadIntLenEnc(data []byte, pos int) (uint64, int, bool) {
return mp.readIntLenEnc(data, pos)
}

// write an int with length encoded into the buffer at the position
// return position + the count of bytes for length encoded (1 or 3 or 4 or 9)
func (mp *MysqlProtocolImpl) writeIntLenEnc(data []byte, pos int, value uint64) int {
Expand Down
1 change: 1 addition & 0 deletions pkg/frontend/mysql_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ func (tRM *TestRoutineManager) resultsetHandler(rs goetty.IOSession, msg interfa
case COM_PING:
resp = NewResponse(
OkResponse,
0, 0, 0,
0,
int(COM_PING),
nil,
Expand Down
33 changes: 13 additions & 20 deletions pkg/frontend/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,28 @@ type Response struct {
warnings uint16
}

func NewResponse(category int, status uint16, cmd int, d interface{}) *Response {
func NewResponse(category int, affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
return &Response{
category: category,
status: status,
cmd: cmd,
data: d,
category: category,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
status: status,
cmd: cmd,
data: d,
}
}

func NewGeneralErrorResponse(cmd CommandType, err error) *Response {
return NewResponse(ErrorResponse, 0, int(cmd), err)
func NewGeneralErrorResponse(cmd CommandType, status uint16, err error) *Response {
return NewResponse(ErrorResponse, 0, 0, 0, status, int(cmd), err)
}

func NewGeneralOkResponse(cmd CommandType) *Response {
return NewResponse(OkResponse, 0, int(cmd), nil)
func NewGeneralOkResponse(cmd CommandType, status uint16) *Response {
return NewResponse(OkResponse, 0, 0, 0, status, int(cmd), nil)
}

func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
resp := &Response{
category: OkResponse,
status: status,
cmd: cmd,
data: d,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
}

return resp
return NewResponse(OkResponse, affectedRows, lastInsertId, warnings, status, cmd, d)
}

func (resp *Response) GetData() interface{} {
Expand Down
12 changes: 12 additions & 0 deletions pkg/frontend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1945,6 +1945,18 @@ func (ses *Session) getGlobalSystemVariableValue(varName string) (interface{}, e
return nil, moerr.NewInternalError(ctx, "can not resolve global system variable %s", varName)
}

func (ses *Session) SetNewResponse(category int, affectedRows uint64, cmd int, d interface{}, cwIndex, cwsLen int) *Response {
// If the stmt has next stmt, should add SERVER_MORE_RESULTS_EXISTS to the server status.
var resp *Response
if cwIndex < cwsLen-1 {
resp = NewResponse(category, affectedRows, 0, 0,
ses.GetServerStatus()|SERVER_MORE_RESULTS_EXISTS, cmd, d)
} else {
resp = NewResponse(category, affectedRows, 0, 0, ses.GetServerStatus(), cmd, d)
}
return resp
}

func checkPlanIsInsertValues(proc *process.Process,
p *plan.Plan) (bool, [][]colexec.ExpressionExecutor) {
qry := p.GetQuery()
Expand Down
8 changes: 4 additions & 4 deletions pkg/frontend/status_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (pse *PrepareStmtExecutor) ResponseAfterExec(ctx context.Context, ses *Sess
return retErr
}
} else {
resp := NewOkResponse(pse.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(pse.GetAffectedRows(), 0, 0, pse.GetServerStatus(), int(COM_QUERY), "")
if err2 = ses.GetMysqlProtocol().SendResponse(ctx, resp); err2 != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err2)
logStatementStatus(ctx, ses, pse.GetAst(), fail, retErr)
Expand Down Expand Up @@ -146,7 +146,7 @@ func (pse *PrepareStringExecutor) ResponseAfterExec(ctx context.Context, ses *Se
return retErr
}
} else {
resp := NewOkResponse(pse.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(pse.GetAffectedRows(), 0, 0, pse.GetServerStatus(), int(COM_QUERY), "")
if err2 = ses.GetMysqlProtocol().SendResponse(ctx, resp); err2 != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err2)
logStatementStatus(ctx, ses, pse.GetAst(), fail, retErr)
Expand Down Expand Up @@ -175,7 +175,7 @@ func (de *DeallocateExecutor) ResponseAfterExec(ctx context.Context, ses *Sessio
var err2, retErr error
//we will not send response in COM_STMT_CLOSE command
if ses.GetCmd() != COM_STMT_CLOSE {
resp := NewOkResponse(de.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(de.GetAffectedRows(), 0, 0, de.GetServerStatus(), int(COM_QUERY), "")
if err2 = ses.GetMysqlProtocol().SendResponse(ctx, resp); err2 != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err2)
logStatementStatus(ctx, ses, de.GetAst(), fail, retErr)
Expand Down Expand Up @@ -455,7 +455,7 @@ type InsertExecutor struct {
func (ie *InsertExecutor) ResponseAfterExec(ctx context.Context, ses *Session) error {
var err, retErr error
if ie.GetStatus() == stmtExecSuccess {
resp := NewOkResponse(ie.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(ie.GetAffectedRows(), 0, 0, ie.GetServerStatus(), int(COM_QUERY), "")
resp.lastInsertId = 1
if err = ses.GetMysqlProtocol().SendResponse(ctx, resp); err != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand Down
14 changes: 14 additions & 0 deletions pkg/frontend/test/types_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/frontend/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ type ComputationWrapper interface {
RecordExecPlan(ctx context.Context) error

GetLoadTag() bool

GetServerStatus() uint16
}

type ColumnInfo interface {
Expand Down
Loading

0 comments on commit 1d49fe1

Please sign in to comment.