From f4ed57a9fdf0e175f205d97e65b67ed28c37a861 Mon Sep 17 00:00:00 2001 From: LiuBo Date: Tue, 9 May 2023 17:33:34 +0800 Subject: [PATCH] [feat]: kill connected session if account suspended or dropped. If suspend or drop account is requested, kill all connections that belongs to this account. --- pkg/proxy/client_conn.go | 127 ++++++++++++++----- pkg/proxy/client_conn_test.go | 24 +++- pkg/proxy/conn_manager.go | 36 +++++- pkg/proxy/event.go | 148 ++++++++++++++++++++-- pkg/proxy/event_test.go | 226 +++++++++++++++++++++++++++++++++- pkg/proxy/handler_test.go | 203 +++++++++++++++++++----------- pkg/proxy/mysql_conn_buf.go | 1 + pkg/proxy/router.go | 14 ++- pkg/proxy/server_conn.go | 18 ++- pkg/proxy/server_conn_test.go | 77 ++++++++++-- 10 files changed, 740 insertions(+), 134 deletions(-) diff --git a/pkg/proxy/client_conn.go b/pkg/proxy/client_conn.go index 16789565cef89..5b68b643ac797 100644 --- a/pkg/proxy/client_conn.go +++ b/pkg/proxy/client_conn.go @@ -16,6 +16,7 @@ package proxy import ( "context" + "fmt" "net" "strings" "sync/atomic" @@ -234,63 +235,129 @@ func (c *clientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []by return c.handleKillQuery(ev, resp) case *setVarEvent: return c.handleSetVar(ev) + case *suspendAccountEvent: + return c.handleSuspendAccount(ev, resp) + case *dropAccountEvent: + return c.handleDropAccount(ev, resp) default: } return nil } -// handleKillQuery handles the kill query event. -func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error { - sendErr := func(errMsg string) { - fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR] - payload := c.mysqlProto.MakeErrPayload( - fail.ErrorCode, fail.SqlStates[0], errMsg) - r := &frontend.Packet{ - Length: 0, - SequenceID: 1, - Payload: payload, - } - sendResp(packetToBytes(r), resp) - } - - cn, err := c.router.SelectByConnID(e.connID) - if err != nil { - c.log.Error("failed to select CN server", zap.Error(err)) - sendErr(err.Error()) - return err +func (c *clientConn) sendErr(errMsg string, resp chan<- []byte) { + fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR] + payload := c.mysqlProto.MakeErrPayload( + fail.ErrorCode, fail.SqlStates[0], errMsg) + r := &frontend.Packet{ + Length: 0, + SequenceID: 1, + Payload: payload, } - // Before connect to backend server, update the salt. - cn.salt = c.mysqlProto.GetSalt() + sendResp(packetToBytes(r), resp) +} +func (c *clientConn) connAndExec(cn *CNServer, stmt string, resp chan<- []byte) error { sc, r, err := c.router.Connect(cn, c.handshakePack, c.tun) if err != nil { c.log.Error("failed to connect to backend server", zap.Error(err)) - sendErr(err.Error()) + if resp != nil { + c.sendErr(err.Error(), resp) + } return err } defer func() { _ = sc.Close() }() if !isOKPacket(r) { - c.log.Error("failed to connect to cn to handle kill query event", - zap.String("query", e.stmt), zap.String("error", string(r))) - sendResp(r, resp) + c.log.Error("failed to connect to cn to handle event", + zap.String("query", stmt), zap.String("error", string(r))) + if resp != nil { + sendResp(r, resp) + } return moerr.NewInternalErrorNoCtx("access error") } - err = sc.ExecStmt(e.stmt, resp) + ok, err := sc.ExecStmt(stmt, resp) if err != nil { - c.log.Error("failed to send query %s to server", - zap.String("query", e.stmt), zap.Error(err)) + c.log.Error("failed to send query to server", + zap.String("query", stmt), zap.Error(err)) + return err + } + if !ok { + return moerr.NewInternalErrorNoCtx("exec error") } return nil } +// handleKillQuery handles the kill query event. +func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error { + cn, err := c.router.SelectByConnID(e.connID) + if err != nil { + c.log.Error("failed to select CN server", zap.Error(err)) + c.sendErr(err.Error(), resp) + return err + } + // Before connect to backend server, update the salt. + cn.salt = c.mysqlProto.GetSalt() + + return c.connAndExec(cn, e.stmt, resp) +} + // handleSetVar handles the set variable event. func (c *clientConn) handleSetVar(e *setVarEvent) error { c.setVarStmts = append(c.setVarStmts, e.stmt) return nil } +// handleSuspendAccountEvent handles the suspend account event. +func (c *clientConn) handleSuspendAccount(e *suspendAccountEvent, resp chan<- []byte) error { + // a temp cn server. + cn := &CNServer{ + addr: e.addr, + salt: c.mysqlProto.GetSalt(), + } + if err := c.connAndExec(cn, e.stmt, resp); err != nil { + return err + } + + // handle kill connection. + cns, err := c.router.SelectByTenant(e.account) + if err != nil { + return err + } + if len(cns) == 0 { + return nil + } + for _, cn := range cns { + // Before connect to backend server, update the salt. + cn.salt = c.mysqlProto.GetSalt() + + go func(s *CNServer) { + query := fmt.Sprintf("kill connection %d", s.connID) + // No client to receive the result, so pass nil as the third + // parameter to ignore the result. + if err := c.connAndExec(s, query, nil); err != nil { + c.log.Error("failed to send query to server", + zap.String("query", query), zap.Error(err)) + return + } + c.log.Info("kill connection on server succeeded", + zap.String("query", query), zap.String("server", s.addr)) + }(cn) + } + return nil +} + +// handleDropAccountEvent handles the drop account event. +func (c *clientConn) handleDropAccount(e *dropAccountEvent, resp chan<- []byte) error { + se := &suspendAccountEvent{ + baseEvent: e.baseEvent, + stmt: e.stmt, + account: e.account, + addr: e.addr, + } + return c.handleSuspendAccount(se, resp) +} + // Close implements the ClientConn interface. func (c *clientConn) Close() error { return nil @@ -332,12 +399,12 @@ func (c *clientConn) connectToBackend(sendToClient bool) (ServerConn, error) { } // Set the label session variable. - if err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil { + if _, err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil { return nil, err } // Set the use defined variables, including session variables and user variables. for _, stmt := range c.setVarStmts { - if err := sc.ExecStmt(stmt, nil); err != nil { + if _, err := sc.ExecStmt(stmt, nil); err != nil { return nil, err } } diff --git a/pkg/proxy/client_conn_test.go b/pkg/proxy/client_conn_test.go index 144d53b7aedad..ccae120c6ad8b 100644 --- a/pkg/proxy/client_conn_test.go +++ b/pkg/proxy/client_conn_test.go @@ -134,12 +134,12 @@ func (c *mockClientConn) BuildConnWithServer(_ bool) (ServerConn, error) { return nil, err } // Set the label session variable. - if err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil { + if _, err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil { return nil, err } // Set the use defined variables, including session variables and user variables. for _, stmt := range c.setVarStmts { - if err := sc.ExecStmt(stmt, nil); err != nil { + if _, err := sc.ExecStmt(stmt, nil); err != nil { return nil, err } } @@ -159,6 +159,26 @@ func (c *mockClientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- c.setVarStmts = append(c.setVarStmts, ev.stmt) sendResp([]byte("ok"), resp) return nil + case *suspendAccountEvent: + cns, err := c.router.SelectByTenant(ev.account) + if err != nil { + sendResp([]byte(err.Error()), resp) + return err + } + for _, cn := range cns { + sendResp([]byte(cn.addr), resp) + } + return nil + case *dropAccountEvent: + cns, err := c.router.SelectByTenant(ev.account) + if err != nil { + sendResp([]byte(err.Error()), resp) + return err + } + for _, cn := range cns { + sendResp([]byte(cn.addr), resp) + } + return nil default: sendResp([]byte("type not supported"), resp) return moerr.NewInternalErrorNoCtx("type not supported") diff --git a/pkg/proxy/conn_manager.go b/pkg/proxy/conn_manager.go index 086410417fc6f..a48b138c61417 100644 --- a/pkg/proxy/conn_manager.go +++ b/pkg/proxy/conn_manager.go @@ -125,6 +125,9 @@ type connManager struct { // Map from connection ID to CN server. connIDServers map[uint32]*CNServer + + // Map from Tenant to *CNServer list. + tenantConns map[Tenant]map[*CNServer]struct{} } // newConnManager creates a new connManager. @@ -132,6 +135,7 @@ func newConnManager() *connManager { m := &connManager{ conns: make(map[LabelHash]*connInfo), connIDServers: make(map[uint32]*CNServer), + tenantConns: make(map[Tenant]map[*CNServer]struct{}), } return m } @@ -181,6 +185,14 @@ func (m *connManager) connect(cn *CNServer, t *tunnel) { } m.conns[cn.hash].cnTunnels.add(cn.uuid, t) m.connIDServers[cn.connID] = cn + + tenant := cn.reqLabel.Tenant + if tenant != "" { + if m.tenantConns[tenant] == nil { + m.tenantConns[tenant] = make(map[*CNServer]struct{}) + } + m.tenantConns[tenant][cn] = struct{}{} + } } // disconnect removes a connection from connection manager. @@ -193,6 +205,11 @@ func (m *connManager) disconnect(cn *CNServer, t *tunnel) { } ci.cnTunnels.del(cn.uuid, t) delete(m.connIDServers, cn.connID) + + tenant := cn.reqLabel.Tenant + if tenant != "" && m.tenantConns[tenant] != nil { + delete(m.tenantConns[tenant], cn) + } } // count returns the total connection count. @@ -241,8 +258,8 @@ func (m *connManager) getLabelInfo(hash LabelHash) labelInfo { return ci.label } -// getCNServer returns a CN server which has the connection ID. -func (m *connManager) getCNServer(connID uint32) *CNServer { +// getCNServerByConnID returns a CN server which has the connection ID. +func (m *connManager) getCNServerByConnID(connID uint32) *CNServer { m.Lock() defer m.Unlock() cn, ok := m.connIDServers[connID] @@ -251,3 +268,18 @@ func (m *connManager) getCNServer(connID uint32) *CNServer { } return nil } + +// getCNServersByTenant returns a CN server list by tenant. +func (m *connManager) getCNServersByTenant(tenant Tenant) []*CNServer { + m.Lock() + defer m.Unlock() + cns, ok := m.tenantConns[tenant] + if !ok { + return nil + } + cnList := make([]*CNServer, 0, len(cns)) + for cn := range cns { + cnList = append(cnList, cn) + } + return cnList +} diff --git a/pkg/proxy/event.go b/pkg/proxy/event.go index 1d5ee5735a191..d0bafd4461163 100644 --- a/pkg/proxy/event.go +++ b/pkg/proxy/event.go @@ -16,8 +16,11 @@ package proxy import ( "fmt" + "io" + "net" "regexp" "strconv" + "strings" "sync" ) @@ -31,6 +34,10 @@ func (t eventType) String() string { return "KillQuery" case TypeSetVar: return "SetVar" + case TypeSuspendAccount: + return "SuspendAccount" + case TypeDropAccount: + return "DropAccount" } return "Unknown" } @@ -42,12 +49,25 @@ const ( TypeKillQuery eventType = 1 // TypeSetVar indicates the set variable statement. TypeSetVar eventType = 2 + // TypeSuspendAccount indicates the suspend account statement. + TypeSuspendAccount eventType = 3 + // TypeDropAccount indicates the drop account statement. + TypeDropAccount eventType = 4 ) var ( - set = "[sS][eE][tT]" - session = "[sS][eE][sS][sS][iI][oO][nN]" - local = "[lL][oO][cC][aA][lL]" + kill = "[kK][iI][lL][lL]" + query = "[qQ][uU][eE][rR][yY]" + set = "[sS][eE][tT]" + session = "[sS][eE][sS][sS][iI][oO][nN]" + local = "[lL][oO][cC][aA][lL]" + alter = "[aA][lL][tT][eE][rR]" + account = "[aA][cC][cC][oO][uU][nN][tT]" + ifExists = "?:[iI][fF]\\s+[eE][xX][iI][sS][tT][sS]\\s+" + suspend = "[sS][uU][sS][pP][eE][nN][dD]" + drop = "[dD][rR][oO][pP]" + + num = "\\d+" spaceAtLeastZero = "\\s*" spaceAtLeastOne = "\\s+" varName = "[a-zA-Z][a-zA-Z0-9_]*" @@ -60,7 +80,12 @@ var ( // patterMap is a map eventType => pattern string var patternMap = map[eventType]string{ - TypeKillQuery: `^([kK][iI][lL][lL])\s+([qQ][uU][eE][rR][yY])\s+(\d+)$`, + // Sample: kill query 10 + TypeKillQuery: fmt.Sprintf(`^(%s)%s(%s)%s(%s)$`, + kill, spaceAtLeastOne, + query, spaceAtLeastOne, + num), + // TypeSetVar matches set session variable: // - set session key=value; // - set local key=value; @@ -81,6 +106,19 @@ var patternMap = map[eventType]string{ at, varName, // @key spaceAtLeastZero, assign, spaceAtLeastZero, // = or := varValue), // value + + // Sample: alter account [if exists] acc1 suspend + TypeSuspendAccount: fmt.Sprintf(`^(%s)%s(%s)%s(%s)?(%s)%s(%s)$`, + alter, spaceAtLeastOne, + account, spaceAtLeastOne, ifExists, + varName, spaceAtLeastOne, + suspend), + + // Sample: drop account [if exists] acc1 + TypeDropAccount: fmt.Sprintf(`^(%s)%s(%s)%s(%s)?(%s)$`, + drop, spaceAtLeastOne, + account, spaceAtLeastOne, ifExists, + varName), } // IEvent is the event interface. @@ -114,6 +152,7 @@ func sendResp(r []byte, c chan<- []byte) { type eventReq struct { // msg is a MySQL packet bytes. msg []byte + dst io.Writer } // eventReqPool is used to fetch event request from pool. @@ -129,7 +168,7 @@ var eventReqPool = sync.Pool{ // and do not need to send to dst anymore. func makeEvent(req *eventReq) (IEvent, bool) { if req == nil || len(req.msg) < preRecvLen { - return nil, true + return nil, false } if req.msg[4] == byte(cmdQuery) { stmt := getStatement(req.msg) @@ -143,7 +182,7 @@ func makeEvent(req *eventReq) (IEvent, bool) { } } if !matched { - return nil, true + return nil, false } switch typ { case TypeKillQuery: @@ -151,6 +190,31 @@ func makeEvent(req *eventReq) (IEvent, bool) { case TypeSetVar: // This event should be sent to dst, so return false, return makeSetVarEvent(stmt), false + case TypeSuspendAccount: + c, ok := req.dst.(net.Conn) + if !ok { + return nil, false + } + addr := c.RemoteAddr().String() + if strings.Index(addr, "sock") > 0 && strings.Index(addr, "unix") < 0 { + addr = "unix://" + addr + } + // Although the suspend statement could be handled directly, but + // we need to know if the result of this statement is ok. So it + // is handled in the event, and thus the second return value is true. + return makeSuspendAccountEvent(stmt, + addr, regexp.MustCompile(patternMap[typ])), true + case TypeDropAccount: + c, ok := req.dst.(net.Conn) + if !ok { + return nil, false + } + addr := c.RemoteAddr().String() + if strings.Index(addr, "sock") > 0 && strings.Index(addr, "unix") < 0 { + addr = "unix://" + addr + } + return makeDropAccountEvent(stmt, + addr, regexp.MustCompile(patternMap[typ])), true default: return nil, true } @@ -163,10 +227,10 @@ func makeEvent(req *eventReq) (IEvent, bool) { // the connection ID on it. type killQueryEvent struct { baseEvent - // The ID of connection that needs to be killed. - connID uint32 // stmt is the statement that will be sent to server. stmt string + // The ID of connection that needs to be killed. + connID uint32 } // makeKillQueryEvent creates a event with TypeKillQuery type. @@ -181,8 +245,8 @@ func makeKillQueryEvent(stmt string, reg *regexp.Regexp) IEvent { } e := &killQueryEvent{ - connID: uint32(connID), stmt: stmt, + connID: uint32(connID), } e.typ = TypeKillQuery return e @@ -215,3 +279,69 @@ func makeSetVarEvent(stmt string) IEvent { func (e *setVarEvent) eventType() eventType { return TypeSetVar } + +// suspendAccountEvent is the event that "alter account xxx suspend" statement +// is captured. We need to send this alter statement to event and execute it, +// and if the result is ok, then construct "kill connection" statement and send it +// to CN servers which have connections with the account. +type suspendAccountEvent struct { + baseEvent + // stmt is the statement that will be sent to server. + stmt string + // account is the tenant text value. + account Tenant + // addr is the remote server's address, which is used + // to build connection to execute suspend statement. + addr string +} + +// makeSuspendAccountEvent creates a event with TypeSuspendAccount type. +func makeSuspendAccountEvent(stmt string, addr string, reg *regexp.Regexp) IEvent { + items := reg.FindStringSubmatch(stmt) + if len(items) != 5 { + return nil + } + e := &suspendAccountEvent{ + stmt: stmt, + account: Tenant(items[3]), + addr: addr, + } + e.typ = TypeSuspendAccount + return e +} + +func (e *suspendAccountEvent) eventType() eventType { + return TypeSuspendAccount +} + +// dropAccountEvent is the event that "drop account xxx" statement +// is captured. The actions are the same as suspendAccountEvent. +type dropAccountEvent struct { + baseEvent + // stmt is the statement that will be sent to server. + stmt string + // account is the tenant text value. + account Tenant + // addr is the remote server's address, which is used + // to build connection to execute drop statement. + addr string +} + +// makeDropAccountEvent creates a event with TypeDropAccount type. +func makeDropAccountEvent(stmt string, addr string, reg *regexp.Regexp) IEvent { + items := reg.FindStringSubmatch(stmt) + if len(items) != 4 { + return nil + } + e := &dropAccountEvent{ + stmt: stmt, + account: Tenant(items[3]), + addr: addr, + } + e.typ = TypeDropAccount + return e +} + +func (e *dropAccountEvent) eventType() eventType { + return TypeDropAccount +} diff --git a/pkg/proxy/event_test.go b/pkg/proxy/event_test.go index bc3dd3abe1f92..a294c4bb5597c 100644 --- a/pkg/proxy/event_test.go +++ b/pkg/proxy/event_test.go @@ -30,12 +30,12 @@ import ( func TestMakeEvent(t *testing.T) { e, r := makeEvent(nil) require.Nil(t, e) - require.True(t, r) + require.False(t, r) t.Run("kill query", func(t *testing.T) { e, r = makeEvent(&eventReq{msg: makeSimplePacket("kill quer8y 12")}) require.Nil(t, e) - require.True(t, r) + require.False(t, r) e, r = makeEvent(&eventReq{msg: makeSimplePacket("kill query 123")}) require.NotNil(t, e) @@ -47,7 +47,7 @@ func TestMakeEvent(t *testing.T) { e, r = makeEvent(&eventReq{msg: makeSimplePacket("set ")}) require.Nil(t, e) - require.True(t, r) + require.False(t, r) }) t.Run("set var", func(t *testing.T) { @@ -101,9 +101,61 @@ func TestMakeEvent(t *testing.T) { for _, stmt := range stmtsInvalid { e, r = makeEvent(&eventReq{msg: makeSimplePacket(stmt)}) require.Nil(t, e) - require.True(t, r) + require.False(t, r) } }) + + t.Run("suspend account", func(t *testing.T) { + n1, _ := net.Pipe() + defer n1.Close() + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("alter account a1 suspend"), + dst: n1, + }) + require.NotNil(t, e) + require.True(t, r) + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("alter account if exists a1 suspend"), + dst: n1, + }) + require.NotNil(t, e) + require.True(t, r) + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("alter1 account a1 suspend"), + dst: n1, + }) + require.Nil(t, e) + require.False(t, r) + }) + + t.Run("drop account", func(t *testing.T) { + n1, _ := net.Pipe() + defer n1.Close() + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("drop account a1"), + dst: n1, + }) + require.NotNil(t, e) + require.True(t, r) + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("drop account if exists a1"), + dst: n1, + }) + require.NotNil(t, e) + require.True(t, r) + + e, r = makeEvent(&eventReq{ + msg: makeSimplePacket("dr1op account a1"), + dst: n1, + }) + require.Nil(t, e) + require.False(t, r) + }) } func TestKillQueryEvent(t *testing.T) { @@ -424,6 +476,166 @@ func TestSetVarEvent(t *testing.T) { } } +func TestSuspendDropAccountEvent(t *testing.T) { + defer leaktest.AfterTest(t)() + + tp := newTestProxyHandler(t) + defer tp.closeFn() + + temp := os.TempDir() + addr1 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) + require.NoError(t, os.RemoveAll(addr1)) + cn1 := testMakeCNServer("uuid1", addr1, 10, "", labelInfo{Tenant: "t1"}) + stopFn1 := startTestCNServer(t, tp.ctx, addr1) + defer func() { + require.NoError(t, stopFn1()) + }() + + addr2 := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) + require.NoError(t, os.RemoveAll(addr2)) + cn2 := testMakeCNServer("uuid2", addr2, 20, "", labelInfo{Tenant: "t2"}) + stopFn2 := startTestCNServer(t, tp.ctx, addr2) + defer func() { + require.NoError(t, stopFn2()) + }() + + tu1 := newTunnel(tp.ctx, tp.logger, nil) + defer func() { _ = tu1.Close() }() + tu2 := newTunnel(tp.ctx, tp.logger, nil) + defer func() { _ = tu2.Close() }() + + // Client2 will send "kill query 10", which will route to the server which + // has connection ID 10. In this case, the connection is server1. + clientProxy1, _ := net.Pipe() + serverProxy1, _ := net.Pipe() + + cc1 := newMockClientConn(clientProxy1, "t1", labelInfo{}, tp.ru, tu1) + require.NotNil(t, cc1) + sc1 := newMockServerConn(serverProxy1) + require.NotNil(t, sc1) + + clientProxy2, client2 := net.Pipe() + serverProxy2, _ := net.Pipe() + + cc2 := newMockClientConn(clientProxy2, "t2", labelInfo{}, tp.ru, tu2) + require.NotNil(t, cc2) + sc2 := newMockServerConn(serverProxy2) + require.NotNil(t, sc2) + + res := make(chan []byte) + st := stopper.NewStopper("test-event", stopper.WithLogger(tp.logger.RawLogger())) + defer st.Stop() + err := st.RunNamedTask("test-event-handler", func(ctx context.Context) { + for { + select { + case e := <-tu2.reqC: + err := cc2.HandleEvent(ctx, e, tu2.respC) + require.NoError(t, err) + case r := <-tu2.respC: + if len(r) > 0 { + res <- r + } + case <-ctx.Done(): + return + } + } + }) + require.NoError(t, err) + + // tunnel1 is on cn1, tenant is t1. + _, _, err = tp.ru.Connect(cn1, testPacket, tu1) + require.NoError(t, err) + + // tunnel2 is on cn2, tenant is t2. + _, _, err = tp.ru.Connect(cn2, testPacket, tu2) + require.NoError(t, err) + + err = tu1.run(cc1, sc1) + require.NoError(t, err) + require.Nil(t, tu1.ctx.Err()) + + func() { + tu1.mu.Lock() + defer tu1.mu.Unlock() + require.True(t, tu1.mu.started) + }() + + err = tu2.run(cc2, sc2) + require.NoError(t, err) + require.Nil(t, tu2.ctx.Err()) + + func() { + tu2.mu.Lock() + defer tu2.mu.Unlock() + require.True(t, tu2.mu.started) + }() + + tu1.mu.Lock() + csp1 := tu1.mu.csp + scp1 := tu1.mu.scp + tu1.mu.Unlock() + + tu2.mu.Lock() + csp2 := tu2.mu.csp + scp2 := tu2.mu.scp + tu2.mu.Unlock() + + barrierStart1, barrierEnd1 := make(chan struct{}), make(chan struct{}) + barrierStart2, barrierEnd2 := make(chan struct{}), make(chan struct{}) + csp1.testHelper.beforeSend = func() { + <-barrierStart1 + <-barrierEnd1 + } + csp2.testHelper.beforeSend = func() { + <-barrierStart2 + <-barrierEnd2 + } + + csp1.mu.Lock() + require.True(t, csp1.mu.started) + csp1.mu.Unlock() + + scp1.mu.Lock() + require.True(t, scp1.mu.started) + scp1.mu.Unlock() + + csp2.mu.Lock() + require.True(t, csp2.mu.started) + csp2.mu.Unlock() + + scp2.mu.Lock() + require.True(t, scp2.mu.started) + scp2.mu.Unlock() + + // Client2 writes some MySQL packets. + sendEventCh := make(chan struct{}, 1) + errChan := make(chan error, 1) + go func() { + <-sendEventCh + // client2 send kill connection to cn1, which tenant t1 connect to. + if _, err := client2.Write(makeSimplePacket("alter account t1 suspend")); err != nil { + errChan <- err + return + } + }() + + sendEventCh <- struct{}{} + barrierStart2 <- struct{}{} + barrierEnd2 <- struct{}{} + + addr := string(<-res) + // This test case is mainly focus on if the query is route to the + // right cn server, but not the result of the query. So we just + // check the address which is handled is equal to cn1, but not cn2. + require.Equal(t, cn1.addr, addr) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } +} + func TestEventType_String(t *testing.T) { e1 := baseEvent{} require.Equal(t, "Unknown", e1.eventType().String()) @@ -433,4 +645,10 @@ func TestEventType_String(t *testing.T) { e3 := setVarEvent{} require.Equal(t, "SetVar", e3.eventType().String()) + + e4 := suspendAccountEvent{} + require.Equal(t, "SuspendAccount", e4.eventType().String()) + + e5 := dropAccountEvent{} + require.Equal(t, "DropAccount", e5.eventType().String()) } diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index 507e3d424a82f..4c05fb13d44fe 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -240,7 +240,7 @@ func TestHandler_HandleWithSSL(t *testing.T) { require.Equal(t, int64(1), s.counterSet.connTotal.Load()) } -func TestHandler_HandleEventKillQuery(t *testing.T) { +func testWithServer(t *testing.T, fn func(*testing.T, string, *Server)) { defer leaktest.AfterTest(t)() temp := os.TempDir() @@ -276,90 +276,147 @@ func TestHandler_HandleEventKillQuery(t *testing.T) { err = s.Start() require.NoError(t, err) - db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr)) - // connect to server. - require.NoError(t, err) - require.NotNil(t, db1) - defer func() { - _ = db1.Close() - }() - res, err := db1.Exec("select 1") - require.NoError(t, err) - connID, _ := res.LastInsertId() // fake connection id + fn(t, listenAddr, s) +} - db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr)) - // connect to server. - require.NoError(t, err) - require.NotNil(t, db2) - defer func() { - _ = db2.Close() - }() +func TestHandler_HandleEventKillQuery(t *testing.T) { + testWithServer(t, func(t *testing.T, addr string, s *Server) { + db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db1) + defer func() { + _ = db1.Close() + }() + res, err := db1.Exec("select 1") + require.NoError(t, err) + connID, _ := res.LastInsertId() // fake connection id - _, err = db2.Exec("kill query 9999") - require.Error(t, err) + db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db2) + defer func() { + _ = db2.Close() + }() - _, err = db2.Exec(fmt.Sprintf("kill query %d", connID)) - require.NoError(t, err) + _, err = db2.Exec("kill query 9999") + require.Error(t, err) - require.Equal(t, int64(2), s.counterSet.connAccepted.Load()) + _, err = db2.Exec(fmt.Sprintf("kill query %d", connID)) + require.NoError(t, err) + + require.Equal(t, int64(2), s.counterSet.connAccepted.Load()) + }) } func TestHandler_HandleEventSetVar(t *testing.T) { - defer leaktest.AfterTest(t)() + testWithServer(t, func(t *testing.T, addr string, s *Server) { + db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db1) + defer func() { + _ = db1.Close() + }() + _, err = db1.Exec("set session cn_label='acc1'") + require.NoError(t, err) - temp := os.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime()) - listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) - require.NoError(t, os.RemoveAll(listenAddr)) - cfg := Config{ - ListenAddress: "unix://" + listenAddr, - RebalanceDisabled: true, - } - hc := &mockHAKeeperClient{} - addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) - require.NoError(t, os.RemoveAll(addr)) - cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) - hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{}) - // start backend server. - stopFn := startTestCNServer(t, ctx, addr) - defer func() { - require.NoError(t, stopFn()) - }() + res, err := db1.Query("show session variables") + require.NoError(t, err) + defer res.Close() + var varName, varValue string + for res.Next() { + err := res.Scan(&varName, &varValue) + require.NoError(t, err) + require.Equal(t, "cn_label", varName) + require.Equal(t, "acc1", varValue) + } - // start proxy. - s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()), - WithHAKeeperClient(hc)) - defer func() { - err := s.Close() + require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) + }) +} + +func TestHandler_HandleEventSuspendAccount(t *testing.T) { + testWithServer(t, func(t *testing.T, addr string, s *Server) { + db1, err := sql.Open("mysql", fmt.Sprintf("a1#root:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db1) + defer func() { + _ = db1.Close() + }() + _, err = db1.Exec("select 1") require.NoError(t, err) - }() - require.NoError(t, err) - require.NotNil(t, s) - err = s.Start() - require.NoError(t, err) - db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr)) - // connect to server. - require.NoError(t, err) - require.NotNil(t, db1) - defer func() { - _ = db1.Close() - }() - _, err = db1.Exec("set session cn_label='acc1'") - require.NoError(t, err) + db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db2) + defer func() { + _ = db2.Close() + }() - res, err := db1.Query("show session variables") - require.NoError(t, err) - defer res.Close() - var varName, varValue string - for res.Next() { - err := res.Scan(&varName, &varValue) + _, err = db2.Exec("alter account a1 suspend") require.NoError(t, err) - require.Equal(t, "cn_label", varName) - require.Equal(t, "acc1", varValue) - } - require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) + time.Sleep(time.Millisecond * 200) + res, err := db1.Query("show global variables") + require.NoError(t, err) + defer res.Close() + var rows int + var varName, varValue string + for res.Next() { + rows += 1 + err := res.Scan(&varName, &varValue) + require.NoError(t, err) + require.Equal(t, "killed", varName) + require.Equal(t, "yes", varValue) + } + require.Equal(t, 1, rows) + + require.Equal(t, int64(2), s.counterSet.connAccepted.Load()) + }) +} + +func TestHandler_HandleEventDropAccount(t *testing.T) { + testWithServer(t, func(t *testing.T, addr string, s *Server) { + db1, err := sql.Open("mysql", fmt.Sprintf("a1#root:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db1) + defer func() { + _ = db1.Close() + }() + _, err = db1.Exec("select 1") + require.NoError(t, err) + + db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) + // connect to server. + require.NoError(t, err) + require.NotNil(t, db2) + defer func() { + _ = db2.Close() + }() + + _, err = db2.Exec("drop account a1") + require.NoError(t, err) + + time.Sleep(time.Millisecond * 200) + res, err := db1.Query("show global variables") + require.NoError(t, err) + defer res.Close() + var rows int + var varName, varValue string + for res.Next() { + rows += 1 + err := res.Scan(&varName, &varValue) + require.NoError(t, err) + require.Equal(t, "killed", varName) + require.Equal(t, "yes", varValue) + } + require.Equal(t, 1, rows) + + require.Equal(t, int64(2), s.counterSet.connAccepted.Load()) + }) } diff --git a/pkg/proxy/mysql_conn_buf.go b/pkg/proxy/mysql_conn_buf.go index 8549f3e8458b3..d5590d74fd553 100644 --- a/pkg/proxy/mysql_conn_buf.go +++ b/pkg/proxy/mysql_conn_buf.go @@ -159,6 +159,7 @@ func (b *msgBuf) consumeMsg(msg []byte, dst io.Writer) bool { req := eventReqPool.Get().(*eventReq) defer eventReqPool.Put(req) req.msg = msg + req.dst = dst e, r := makeEvent(req) if e == nil { return false diff --git a/pkg/proxy/router.go b/pkg/proxy/router.go index 1515469c9db65..43144cf769ada 100644 --- a/pkg/proxy/router.go +++ b/pkg/proxy/router.go @@ -35,7 +35,13 @@ type Router interface { // SelectByConnID selects the CN server which has the connection ID. SelectByConnID(connID uint32) (*CNServer, error) + // SelectByTenant selects the CN servers belongs to the tenant. + SelectByTenant(tenant Tenant) ([]*CNServer, error) + // SelectByLabel selects the best CN server with the label. + // This is the only method that allocate *CNServer, and other + // SelectXXX method in this interface select CNServer from the + // ones it allocated. SelectByLabel(label labelInfo) (*CNServer, error) // Connect connects to the CN server and returns the connection. @@ -45,6 +51,7 @@ type Router interface { } // CNServer represents the backend CN server, including salt, tenant, uuid and address. +// When there is a new client connection, a new CNServer will be created. type CNServer struct { // connID is the backend CN server's connection ID, which is global unique // and is tracked in connManager. @@ -113,7 +120,7 @@ func newRouter( // SelectByConnID implements the CNConnector interface. func (r *router) SelectByConnID(connID uint32) (*CNServer, error) { - cn := r.rebalancer.connManager.getCNServer(connID) + cn := r.rebalancer.connManager.getCNServerByConnID(connID) if cn == nil { return nil, moerr.NewInternalErrorNoCtx("no available CN server.") } @@ -125,6 +132,11 @@ func (r *router) SelectByConnID(connID uint32) (*CNServer, error) { }, nil } +// SelectByTenant implements the CNConnector interface. +func (r *router) SelectByTenant(tenant Tenant) ([]*CNServer, error) { + return r.rebalancer.connManager.getCNServersByTenant(tenant), nil +} + // SelectByLabel implements the CNConnector interface. func (r *router) SelectByLabel(label labelInfo) (*CNServer, error) { var cns []*CNServer diff --git a/pkg/proxy/server_conn.go b/pkg/proxy/server_conn.go index 0dad1707d3b28..8dc5005a4d8fe 100644 --- a/pkg/proxy/server_conn.go +++ b/pkg/proxy/server_conn.go @@ -40,7 +40,8 @@ type ServerConn interface { // ExecStmt executes a simple statement, it sends a query to backend server. // After it finished, server connection should be closed immediately because // it is a temp connection. - ExecStmt(stmt string, resp chan<- []byte) error + // The first return value indicates that if the execution result is OK. + ExecStmt(stmt string, resp chan<- []byte) (bool, error) // Close closes the connection to CN server. Close() error } @@ -138,29 +139,34 @@ func (s *serverConn) HandleHandshake(handshakeResp *frontend.Packet) (*frontend. } // ExecStmt implements the ServerConn interface. -func (s *serverConn) ExecStmt(stmt string, resp chan<- []byte) error { +func (s *serverConn) ExecStmt(stmt string, resp chan<- []byte) (bool, error) { req := make([]byte, 1, len(stmt)+1) req[0] = byte(cmdQuery) req = append(req, []byte(stmt)...) s.mysqlProto.SetSequenceID(0) if err := s.mysqlProto.WritePacket(req); err != nil { - return err + return false, err } + execOK := true for { // readPacket makes sure return value is a whole MySQL packet. res, err := s.readPacket() if err != nil { - return err + return false, err } bs := packetToBytes(res) if resp != nil { sendResp(bs, resp) } - if isEOFPacket(bs) || isOKPacket(bs) || isErrPacket(bs) { + if isEOFPacket(bs) || isOKPacket(bs) { + break + } + if isErrPacket(bs) { + execOK = false break } } - return nil + return execOK, nil } // Close implements the ServerConn interface. diff --git a/pkg/proxy/server_conn_test.go b/pkg/proxy/server_conn_test.go index 2e48798d6dc5d..c891dffea324d 100644 --- a/pkg/proxy/server_conn_test.go +++ b/pkg/proxy/server_conn_test.go @@ -75,9 +75,9 @@ func (s *mockServerConn) RawConn() net.Conn { return s.conn } func (s *mockServerConn) HandleHandshake(_ *frontend.Packet) (*frontend.Packet, error) { return nil, nil } -func (s *mockServerConn) ExecStmt(stmt string, resp chan<- []byte) error { +func (s *mockServerConn) ExecStmt(stmt string, resp chan<- []byte) (bool, error) { sendResp(makeOKPacket(), resp) - return nil + return true, nil } func (s *mockServerConn) Close() error { if s.conn != nil { @@ -96,6 +96,8 @@ type testCNServer struct { listener net.Listener started bool quit chan interface{} + + globalVars map[string]string } type testHandler struct { @@ -103,14 +105,16 @@ type testHandler struct { connID uint32 conn goetty.IOSession sessionVars map[string]string + server *testCNServer } func startTestCNServer(t *testing.T, ctx context.Context, addr string) func() error { b := &testCNServer{ - ctx: ctx, - scheme: "tcp", - addr: addr, - quit: make(chan interface{}), + ctx: ctx, + scheme: "tcp", + addr: addr, + quit: make(chan interface{}), + globalVars: make(map[string]string), } if strings.Contains(addr, "sock") { b.scheme = "unix" @@ -187,6 +191,7 @@ func (s *testCNServer) Start() error { mysqlProto: frontend.NewMysqlClientProtocol( cid, c, 0, &fp), sessionVars: make(map[string]string), + server: s, } go func(h *testHandler) { testHandle(h) @@ -220,6 +225,10 @@ func testHandle(h *testHandler) { h.handleSetVar(packet) } else if string(packet.Payload[1:]) == "show session variables" { h.handleShowVar() + } else if string(packet.Payload[1:]) == "show global variables" { + h.handleShowGlobalVar() + } else if strings.HasPrefix(string(packet.Payload[1:]), "kill connection") { + h.handleKillConn() } else { h.handleCommon() } @@ -243,6 +252,12 @@ func (h *testHandler) handleSetVar(packet *frontend.Packet) { _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), 0, 0, "")) } +func (h *testHandler) handleKillConn() { + h.server.globalVars["killed"] = "yes" + h.mysqlProto.SetSequenceID(1) + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), 0, 0, "")) +} + func (h *testHandler) handleShowVar() { h.mysqlProto.SetSequenceID(1) err := h.mysqlProto.SendColumnCountPacket(2) @@ -291,6 +306,54 @@ func (h *testHandler) handleShowVar() { _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, 0)) } +func (h *testHandler) handleShowGlobalVar() { + h.mysqlProto.SetSequenceID(1) + err := h.mysqlProto.SendColumnCountPacket(2) + if err != nil { + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) + return + } + cols := []*plan.ColDef{ + {Typ: &plan.Type{Id: int32(types.T_char)}, Name: "Variable_name"}, + {Typ: &plan.Type{Id: int32(types.T_char)}, Name: "Value"}, + } + columns := make([]interface{}, len(cols)) + res := &frontend.MysqlResultSet{} + for i, col := range cols { + c := new(frontend.MysqlColumn) + c.SetName(col.Name) + c.SetOrgName(col.Name) + c.SetTable(col.Typ.Table) + c.SetOrgTable(col.Typ.Table) + c.SetAutoIncr(col.Typ.AutoIncr) + c.SetSchema("") + c.SetDecimal(col.Typ.Scale) + columns[i] = c + res.AddColumn(c) + } + for _, c := range columns { + if err := h.mysqlProto.SendColumnDefinitionPacket(context.TODO(), c.(frontend.Column), 3); err != nil { + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) + return + } + } + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, 0)) + for k, v := range h.server.globalVars { + row := make([]interface{}, 2) + row[0] = k + row[1] = v + res.AddRow(row) + } + ses := &frontend.Session{} + ses.SetRequestContext(context.Background()) + h.mysqlProto.SetSession(ses) + if err := h.mysqlProto.SendResultSetTextBatchRow(res, res.GetRowCount()); err != nil { + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) + return + } + _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, 0)) +} + func (s *testCNServer) Stop() error { close(s.quit) _ = s.listener.Close() @@ -403,7 +466,7 @@ func TestServerConn_ExecStmt(t *testing.T) { require.NoError(t, err) require.NotEqual(t, 0, int(sc.ConnID())) resp := make(chan []byte, 10) - err = sc.ExecStmt("kill query", resp) + _, err = sc.ExecStmt("kill query", resp) require.NoError(t, err) res := <-resp ok := isOKPacket(res)