Skip to content

Commit

Permalink
[bug] proxy: current way of detecting txn status is not robust
Browse files Browse the repository at this point in the history
Reason: current way of detecting a connection is in active txn
or not is by parse the SQL statement from client. If it matches
begin, commit or rollback, we change the txn status. This way is
not robust, because the statement may contains some comment strings.

Fix: check the txn status by the OK and EOF packet returned from server.
Those packets contains the txn status. But currently, MO server
does not set the status in OK and EOF packets correctly. So first,
set the correty status in OK and EOF packets, then set the txn status
in proxy session.
  • Loading branch information
volgariver6 committed Jul 7, 2023
1 parent a45f309 commit b0c5821
Show file tree
Hide file tree
Showing 22 changed files with 372 additions and 360 deletions.
2 changes: 1 addition & 1 deletion pkg/frontend/cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (bse *baseStmtExecutor) ResponseBeforeExec(ctx context.Context, ses *Sessio
func (bse *baseStmtExecutor) ResponseAfterExec(ctx context.Context, ses *Session) error {
var err, retErr error
if bse.GetStatus() == stmtExecSuccess {
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, bse.GetServerStatus(), int(COM_QUERY), "")
if err = ses.GetMysqlProtocol().SendResponse(ctx, resp); err != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
logStatementStatus(ctx, ses, bse.GetAst(), fail, retErr)
Expand Down
4 changes: 4 additions & 0 deletions pkg/frontend/computation_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ func (cwft *TxnComputationWrapper) GetAffectedRows() uint64 {
return cwft.compile.GetAffectedRows()
}

func (cwft *TxnComputationWrapper) GetServerStatus() uint16 {
return cwft.ses.GetServerStatus()
}

func (cwft *TxnComputationWrapper) Compile(requestCtx context.Context, u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
var err error
defer RecordStatementTxnID(requestCtx, cwft.ses)
Expand Down
79 changes: 30 additions & 49 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func (mce *MysqlCmdExecutor) handleSelectVariables(ve *tree.VarExpr, cwIndex, cw
mrs.AddRow(row)

mer := NewMysqlExecutionResult(0, 0, 0, 0, mrs)
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.GetRequestContext(), resp); err != nil {
return moerr.NewInternalError(ses.GetRequestContext(), "routine send response failed.")
Expand Down Expand Up @@ -625,7 +625,7 @@ func (mce *MysqlCmdExecutor) handleCmdFieldList(requestCtx context.Context, icfl
mysql CMD_FIELD_LIST response: End after the column has been sent.
send EOF packet
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -817,7 +817,7 @@ func (mce *MysqlCmdExecutor) handleShowErrors(cwIndex, cwsLen int) error {
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -961,7 +961,7 @@ func (mce *MysqlCmdExecutor) handleShowVariables(sv *tree.ShowVariables, proc *p
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1069,7 +1069,7 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(requestCtx context.Context, stmt

// mysql COM_QUERY response: End after the column has been sent.
// send EOF packet
err = protocol.SendEOFPacketIf(0, 0)
err = protocol.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand All @@ -1079,7 +1079,7 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(requestCtx context.Context, stmt
return err
}

err = protocol.sendEOFOrOkPacket(0, 0)
err = protocol.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -1270,7 +1270,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
return err
}

resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())

if len(results) == 0 {
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
Expand All @@ -1279,7 +1279,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
} else {
for i, result := range results {
mer := NewMysqlExecutionResult(0, 0, 0, 0, result.(*MysqlResultSet))
resp = SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
resp = mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
}
Expand Down Expand Up @@ -1335,7 +1335,7 @@ func (mce *MysqlCmdExecutor) handleKill(ctx context.Context, k *tree.Kill) error
if err != nil {
return err
}
resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())
if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
}
Expand All @@ -1352,7 +1352,7 @@ func (mce *MysqlCmdExecutor) handleShowAccounts(ctx context.Context, sa *tree.Sh
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand All @@ -1369,7 +1369,7 @@ func (mce *MysqlCmdExecutor) handleShowSubscriptions(ctx context.Context, ss *tr
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1455,7 +1455,7 @@ func (mce *MysqlCmdExecutor) handleShowBackendServers(ctx context.Context, cwInd
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed, error: %v ", err)
}
Expand Down Expand Up @@ -3030,7 +3030,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the column has been sent.
send EOF packet
*/
err = proto.SendEOFPacketIf(0, 0)
err = proto.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3083,7 +3083,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the data row has been sent.
After all row data has been sent, it sends the EOF or OK packet.
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3201,7 +3201,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the column has been sent.
send EOF packet
*/
err = proto.SendEOFPacketIf(0, 0)
err = proto.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3247,7 +3247,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the data row has been sent.
After all row data has been sent, it sends the EOF or OK packet.
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3504,26 +3504,7 @@ func (mce *MysqlCmdExecutor) doComQueryInProgress(requestCtx context.Context, in
}

func (mce *MysqlCmdExecutor) setResponse(cwIndex, cwsLen int, rspLen uint64) *Response {

//if the stmt has next stmt, should set the server status equals to 10
if cwIndex < cwsLen-1 {
return NewOkResponse(rspLen, 0, 0, SERVER_MORE_RESULTS_EXISTS, int(COM_QUERY), "")
} else {
return NewOkResponse(rspLen, 0, 0, 0, int(COM_QUERY), "")
}

}

func SetNewResponse(category int, status uint16, cmd int, d interface{}, cwIndex, cwsLen int) *Response {

//if the stmt has next stmt, should set the server status equals to 10
var resp *Response
if cwIndex < cwsLen-1 {
resp = NewResponse(category, SERVER_MORE_RESULTS_EXISTS, cmd, d)
} else {
resp = NewResponse(category, 0, cmd, d)
}
return resp
return mce.ses.SetNewResponse(OkResponse, rspLen, int(COM_QUERY), "", cwIndex, cwsLen)
}

// ExecRequest the server execute the commands from the client following the mysql's routine
Expand All @@ -3533,9 +3514,9 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
moe, ok := e.(*moerr.Error)
if !ok {
err = moerr.ConvertPanicError(requestCtx, e)
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
} else {
resp = NewGeneralErrorResponse(COM_QUERY, moe)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), moe)
}
}
}()
Expand All @@ -3559,7 +3540,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(SubStringFromBegin(query, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))))
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
}
return resp, nil
case COM_INIT_DB:
Expand All @@ -3568,7 +3549,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := "use `" + dbname + "`"
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_INIT_DB, err)
resp = NewGeneralErrorResponse(COM_INIT_DB, mce.ses.GetServerStatus(), err)
}

return resp, nil
Expand All @@ -3578,12 +3559,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := makeCmdFieldListSql(payload)
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_FIELD_LIST, err)
resp = NewGeneralErrorResponse(COM_FIELD_LIST, mce.ses.GetServerStatus(), err)
}

return resp, nil
case COM_PING:
resp = NewGeneralOkResponse(COM_PING)
resp = NewGeneralOkResponse(COM_PING, mce.ses.GetServerStatus())

return resp, nil

Expand All @@ -3600,7 +3581,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, err)
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3610,11 +3591,11 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
var prepareStmt *PrepareStmt
sql, prepareStmt, err = mce.parseStmtExecute(requestCtx, data)
if err != nil {
return NewGeneralErrorResponse(COM_STMT_EXECUTE, err), nil
return NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err), nil
}
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, err)
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err)
}
if prepareStmt.params != nil {
prepareStmt.params.GetNulls().Reset()
Expand All @@ -3629,7 +3610,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
data := req.GetData().([]byte)
err = mce.parseStmtSendLongData(requestCtx, data)
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, err)
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, mce.ses.GetServerStatus(), err)
return resp, nil
}
return nil, nil
Expand All @@ -3645,7 +3626,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, err)
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3659,12 +3640,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(sql))
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_RESET, err)
resp = NewGeneralErrorResponse(COM_STMT_RESET, mce.ses.GetServerStatus(), err)
}
return resp, nil

default:
resp = NewGeneralErrorResponse(req.GetCmd(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
resp = NewGeneralErrorResponse(req.GetCmd(), mce.ses.GetServerStatus(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
}
return resp, nil
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/frontend/mysql_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *Prep
}
}
if numParams > 0 {
if err := mp.SendEOFPacketIf(0, 0); err != nil {
if err := mp.SendEOFPacketIf(0, mp.GetSession().GetServerStatus()); err != nil {
return err
}
}
Expand All @@ -487,7 +487,7 @@ func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *Prep
}
}
if numColumns > 0 {
if err := mp.SendEOFPacketIf(0, 0); err != nil {
if err := mp.SendEOFPacketIf(0, mp.GetSession().GetServerStatus()); err != nil {
return err
}
}
Expand Down Expand Up @@ -871,6 +871,10 @@ func (mp *MysqlProtocolImpl) readIntLenEnc(data []byte, pos int) (uint64, int, b
return uint64(data[pos]), pos + 1, true
}

func (mp *MysqlProtocolImpl) ReadIntLenEnc(data []byte, pos int) (uint64, int, bool) {
return mp.readIntLenEnc(data, pos)
}

// write an int with length encoded into the buffer at the position
// return position + the count of bytes for length encoded (1 or 3 or 4 or 9)
func (mp *MysqlProtocolImpl) writeIntLenEnc(data []byte, pos int, value uint64) int {
Expand Down
2 changes: 2 additions & 0 deletions pkg/frontend/mysql_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ func (tRM *TestRoutineManager) resultsetHandler(rs goetty.IOSession, msg interfa
case COM_PING:
resp = NewResponse(
OkResponse,
0, 0, 0,
0,
int(COM_PING),
nil,
Expand Down Expand Up @@ -2058,6 +2059,7 @@ func TestSendPrepareResponse(t *testing.T) {
}

proto := NewMysqlClientProtocol(0, ioses, 1024, sv)
proto.SetSession(&Session{})

st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?, 1")
stmts, err := mysql.Parse(ctx, st.Sql, 1)
Expand Down
33 changes: 13 additions & 20 deletions pkg/frontend/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,28 @@ type Response struct {
warnings uint16
}

func NewResponse(category int, status uint16, cmd int, d interface{}) *Response {
func NewResponse(category int, affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
return &Response{
category: category,
status: status,
cmd: cmd,
data: d,
category: category,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
status: status,
cmd: cmd,
data: d,
}
}

func NewGeneralErrorResponse(cmd CommandType, err error) *Response {
return NewResponse(ErrorResponse, 0, int(cmd), err)
func NewGeneralErrorResponse(cmd CommandType, status uint16, err error) *Response {
return NewResponse(ErrorResponse, 0, 0, 0, status, int(cmd), err)
}

func NewGeneralOkResponse(cmd CommandType) *Response {
return NewResponse(OkResponse, 0, int(cmd), nil)
func NewGeneralOkResponse(cmd CommandType, status uint16) *Response {
return NewResponse(OkResponse, 0, 0, 0, status, int(cmd), nil)
}

func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
resp := &Response{
category: OkResponse,
status: status,
cmd: cmd,
data: d,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
}

return resp
return NewResponse(OkResponse, affectedRows, lastInsertId, warnings, status, cmd, d)
}

func (resp *Response) GetData() interface{} {
Expand Down
Loading

0 comments on commit b0c5821

Please sign in to comment.