Skip to content

Commit

Permalink
[bug] proxy: current way of detecting txn status is not robust (#10486)
Browse files Browse the repository at this point in the history
  • Loading branch information
volgariver6 authored Jul 7, 2023
1 parent af32411 commit 8870ea8
Show file tree
Hide file tree
Showing 22 changed files with 372 additions and 360 deletions.
2 changes: 1 addition & 1 deletion pkg/frontend/cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (bse *baseStmtExecutor) ResponseBeforeExec(ctx context.Context, ses *Sessio
func (bse *baseStmtExecutor) ResponseAfterExec(ctx context.Context, ses *Session) error {
var err, retErr error
if bse.GetStatus() == stmtExecSuccess {
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, 0, int(COM_QUERY), "")
resp := NewOkResponse(bse.GetAffectedRows(), 0, 0, bse.GetServerStatus(), int(COM_QUERY), "")
if err = ses.GetMysqlProtocol().SendResponse(ctx, resp); err != nil {
retErr = moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
logStatementStatus(ctx, ses, bse.GetAst(), fail, retErr)
Expand Down
4 changes: 4 additions & 0 deletions pkg/frontend/computation_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ func (cwft *TxnComputationWrapper) GetAffectedRows() uint64 {
return cwft.compile.GetAffectedRows()
}

func (cwft *TxnComputationWrapper) GetServerStatus() uint16 {
return cwft.ses.GetServerStatus()
}

func (cwft *TxnComputationWrapper) Compile(requestCtx context.Context, u interface{}, fill func(interface{}, *batch.Batch) error) (interface{}, error) {
var err error
defer RecordStatementTxnID(requestCtx, cwft.ses)
Expand Down
79 changes: 30 additions & 49 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func (mce *MysqlCmdExecutor) handleSelectVariables(ve *tree.VarExpr, cwIndex, cw
mrs.AddRow(row)

mer := NewMysqlExecutionResult(0, 0, 0, 0, mrs)
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.GetRequestContext(), resp); err != nil {
return moerr.NewInternalError(ses.GetRequestContext(), "routine send response failed.")
Expand Down Expand Up @@ -625,7 +625,7 @@ func (mce *MysqlCmdExecutor) handleCmdFieldList(requestCtx context.Context, icfl
mysql CMD_FIELD_LIST response: End after the column has been sent.
send EOF packet
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -817,7 +817,7 @@ func (mce *MysqlCmdExecutor) handleShowErrors(cwIndex, cwsLen int) error {
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -961,7 +961,7 @@ func (mce *MysqlCmdExecutor) handleShowVariables(sv *tree.ShowVariables, proc *p
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1069,7 +1069,7 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(requestCtx context.Context, stmt

// mysql COM_QUERY response: End after the column has been sent.
// send EOF packet
err = protocol.SendEOFPacketIf(0, 0)
err = protocol.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand All @@ -1079,7 +1079,7 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(requestCtx context.Context, stmt
return err
}

err = protocol.sendEOFOrOkPacket(0, 0)
err = protocol.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -1270,7 +1270,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
return err
}

resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())

if len(results) == 0 {
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
Expand All @@ -1279,7 +1279,7 @@ func (mce *MysqlCmdExecutor) handleCallProcedure(ctx context.Context, call *tree
} else {
for i, result := range results {
mer := NewMysqlExecutionResult(0, 0, 0, 0, result.(*MysqlResultSet))
resp = SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
resp = mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, i, len(results))
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed. error:%v ", err)
}
Expand Down Expand Up @@ -1335,7 +1335,7 @@ func (mce *MysqlCmdExecutor) handleKill(ctx context.Context, k *tree.Kill) error
if err != nil {
return err
}
resp := NewGeneralOkResponse(COM_QUERY)
resp := NewGeneralOkResponse(COM_QUERY, mce.ses.GetServerStatus())
if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
}
Expand All @@ -1352,7 +1352,7 @@ func (mce *MysqlCmdExecutor) handleShowAccounts(ctx context.Context, sa *tree.Sh
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand All @@ -1369,7 +1369,7 @@ func (mce *MysqlCmdExecutor) handleShowSubscriptions(ctx context.Context, ss *tr
return err
}
mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)

if err = proto.SendResponse(ctx, resp); err != nil {
return moerr.NewInternalError(ctx, "routine send response failed. error:%v ", err)
Expand Down Expand Up @@ -1455,7 +1455,7 @@ func (mce *MysqlCmdExecutor) handleShowBackendServers(ctx context.Context, cwInd
}

mer := NewMysqlExecutionResult(0, 0, 0, 0, ses.GetMysqlResultSet())
resp := SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
resp := mce.ses.SetNewResponse(ResultResponse, 0, int(COM_QUERY), mer, cwIndex, cwsLen)
if err := proto.SendResponse(ses.requestCtx, resp); err != nil {
return moerr.NewInternalError(ses.requestCtx, "routine send response failed, error: %v ", err)
}
Expand Down Expand Up @@ -3030,7 +3030,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the column has been sent.
send EOF packet
*/
err = proto.SendEOFPacketIf(0, 0)
err = proto.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3083,7 +3083,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the data row has been sent.
After all row data has been sent, it sends the EOF or OK packet.
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3201,7 +3201,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the column has been sent.
send EOF packet
*/
err = proto.SendEOFPacketIf(0, 0)
err = proto.SendEOFPacketIf(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3247,7 +3247,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
mysql COM_QUERY response: End after the data row has been sent.
After all row data has been sent, it sends the EOF or OK packet.
*/
err = proto.sendEOFOrOkPacket(0, 0)
err = proto.sendEOFOrOkPacket(0, ses.GetServerStatus())
if err != nil {
return err
}
Expand Down Expand Up @@ -3504,26 +3504,7 @@ func (mce *MysqlCmdExecutor) doComQueryInProgress(requestCtx context.Context, in
}

func (mce *MysqlCmdExecutor) setResponse(cwIndex, cwsLen int, rspLen uint64) *Response {

//if the stmt has next stmt, should set the server status equals to 10
if cwIndex < cwsLen-1 {
return NewOkResponse(rspLen, 0, 0, SERVER_MORE_RESULTS_EXISTS, int(COM_QUERY), "")
} else {
return NewOkResponse(rspLen, 0, 0, 0, int(COM_QUERY), "")
}

}

func SetNewResponse(category int, status uint16, cmd int, d interface{}, cwIndex, cwsLen int) *Response {

//if the stmt has next stmt, should set the server status equals to 10
var resp *Response
if cwIndex < cwsLen-1 {
resp = NewResponse(category, SERVER_MORE_RESULTS_EXISTS, cmd, d)
} else {
resp = NewResponse(category, 0, cmd, d)
}
return resp
return mce.ses.SetNewResponse(OkResponse, rspLen, int(COM_QUERY), "", cwIndex, cwsLen)
}

// ExecRequest the server execute the commands from the client following the mysql's routine
Expand All @@ -3533,9 +3514,9 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
moe, ok := e.(*moerr.Error)
if !ok {
err = moerr.ConvertPanicError(requestCtx, e)
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
} else {
resp = NewGeneralErrorResponse(COM_QUERY, moe)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), moe)
}
}
}()
Expand All @@ -3559,7 +3540,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(SubStringFromBegin(query, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))))
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_QUERY, err)
resp = NewGeneralErrorResponse(COM_QUERY, mce.ses.GetServerStatus(), err)
}
return resp, nil
case COM_INIT_DB:
Expand All @@ -3568,7 +3549,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := "use `" + dbname + "`"
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_INIT_DB, err)
resp = NewGeneralErrorResponse(COM_INIT_DB, mce.ses.GetServerStatus(), err)
}

return resp, nil
Expand All @@ -3578,12 +3559,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
query := makeCmdFieldListSql(payload)
err = doComQuery(requestCtx, &UserInput{sql: query})
if err != nil {
resp = NewGeneralErrorResponse(COM_FIELD_LIST, err)
resp = NewGeneralErrorResponse(COM_FIELD_LIST, mce.ses.GetServerStatus(), err)
}

return resp, nil
case COM_PING:
resp = NewGeneralOkResponse(COM_PING)
resp = NewGeneralOkResponse(COM_PING, mce.ses.GetServerStatus())

return resp, nil

Expand All @@ -3600,7 +3581,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, err)
resp = NewGeneralErrorResponse(COM_STMT_PREPARE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3610,11 +3591,11 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
var prepareStmt *PrepareStmt
sql, prepareStmt, err = mce.parseStmtExecute(requestCtx, data)
if err != nil {
return NewGeneralErrorResponse(COM_STMT_EXECUTE, err), nil
return NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err), nil
}
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, err)
resp = NewGeneralErrorResponse(COM_STMT_EXECUTE, mce.ses.GetServerStatus(), err)
}
if prepareStmt.params != nil {
prepareStmt.params.GetNulls().Reset()
Expand All @@ -3629,7 +3610,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
data := req.GetData().([]byte)
err = mce.parseStmtSendLongData(requestCtx, data)
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, err)
resp = NewGeneralErrorResponse(COM_STMT_SEND_LONG_DATA, mce.ses.GetServerStatus(), err)
return resp, nil
}
return nil, nil
Expand All @@ -3645,7 +3626,7 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio

err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, err)
resp = NewGeneralErrorResponse(COM_STMT_CLOSE, mce.ses.GetServerStatus(), err)
}
return resp, nil

Expand All @@ -3659,12 +3640,12 @@ func (mce *MysqlCmdExecutor) ExecRequest(requestCtx context.Context, ses *Sessio
logDebug(ses, ses.GetDebugString(), "query trace", logutil.ConnectionIdField(ses.GetConnectionID()), logutil.QueryField(sql))
err = doComQuery(requestCtx, &UserInput{sql: sql})
if err != nil {
resp = NewGeneralErrorResponse(COM_STMT_RESET, err)
resp = NewGeneralErrorResponse(COM_STMT_RESET, mce.ses.GetServerStatus(), err)
}
return resp, nil

default:
resp = NewGeneralErrorResponse(req.GetCmd(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
resp = NewGeneralErrorResponse(req.GetCmd(), mce.ses.GetServerStatus(), moerr.NewInternalError(requestCtx, "unsupported command. 0x%x", req.GetCmd()))
}
return resp, nil
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/frontend/mysql_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *Prep
}
}
if numParams > 0 {
if err := mp.SendEOFPacketIf(0, 0); err != nil {
if err := mp.SendEOFPacketIf(0, mp.GetSession().GetServerStatus()); err != nil {
return err
}
}
Expand All @@ -487,7 +487,7 @@ func (mp *MysqlProtocolImpl) SendPrepareResponse(ctx context.Context, stmt *Prep
}
}
if numColumns > 0 {
if err := mp.SendEOFPacketIf(0, 0); err != nil {
if err := mp.SendEOFPacketIf(0, mp.GetSession().GetServerStatus()); err != nil {
return err
}
}
Expand Down Expand Up @@ -871,6 +871,10 @@ func (mp *MysqlProtocolImpl) readIntLenEnc(data []byte, pos int) (uint64, int, b
return uint64(data[pos]), pos + 1, true
}

func (mp *MysqlProtocolImpl) ReadIntLenEnc(data []byte, pos int) (uint64, int, bool) {
return mp.readIntLenEnc(data, pos)
}

// write an int with length encoded into the buffer at the position
// return position + the count of bytes for length encoded (1 or 3 or 4 or 9)
func (mp *MysqlProtocolImpl) writeIntLenEnc(data []byte, pos int, value uint64) int {
Expand Down
2 changes: 2 additions & 0 deletions pkg/frontend/mysql_protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ func (tRM *TestRoutineManager) resultsetHandler(rs goetty.IOSession, msg interfa
case COM_PING:
resp = NewResponse(
OkResponse,
0, 0, 0,
0,
int(COM_PING),
nil,
Expand Down Expand Up @@ -2058,6 +2059,7 @@ func TestSendPrepareResponse(t *testing.T) {
}

proto := NewMysqlClientProtocol(0, ioses, 1024, sv)
proto.SetSession(&Session{})

st := tree.NewPrepareString(tree.Identifier(getPrepareStmtName(1)), "select ?, 1")
stmts, err := mysql.Parse(ctx, st.Sql, 1)
Expand Down
33 changes: 13 additions & 20 deletions pkg/frontend/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,28 @@ type Response struct {
warnings uint16
}

func NewResponse(category int, status uint16, cmd int, d interface{}) *Response {
func NewResponse(category int, affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
return &Response{
category: category,
status: status,
cmd: cmd,
data: d,
category: category,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
status: status,
cmd: cmd,
data: d,
}
}

func NewGeneralErrorResponse(cmd CommandType, err error) *Response {
return NewResponse(ErrorResponse, 0, int(cmd), err)
func NewGeneralErrorResponse(cmd CommandType, status uint16, err error) *Response {
return NewResponse(ErrorResponse, 0, 0, 0, status, int(cmd), err)
}

func NewGeneralOkResponse(cmd CommandType) *Response {
return NewResponse(OkResponse, 0, int(cmd), nil)
func NewGeneralOkResponse(cmd CommandType, status uint16) *Response {
return NewResponse(OkResponse, 0, 0, 0, status, int(cmd), nil)
}

func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
resp := &Response{
category: OkResponse,
status: status,
cmd: cmd,
data: d,
affectedRows: affectedRows,
lastInsertId: lastInsertId,
warnings: warnings,
}

return resp
return NewResponse(OkResponse, affectedRows, lastInsertId, warnings, status, cmd, d)
}

func (resp *Response) GetData() interface{} {
Expand Down
Loading

0 comments on commit 8870ea8

Please sign in to comment.