Skip to content

Commit

Permalink
[feat]: kill connected session if account suspended or dropped.
Browse files Browse the repository at this point in the history
If suspend or drop account is requested, kill all connections
that belongs to this account.
  • Loading branch information
volgariver6 committed May 9, 2023
1 parent 5ef2f14 commit f4ed57a
Show file tree
Hide file tree
Showing 10 changed files with 740 additions and 134 deletions.
127 changes: 97 additions & 30 deletions pkg/proxy/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package proxy

import (
"context"
"fmt"
"net"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -234,63 +235,129 @@ func (c *clientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []by
return c.handleKillQuery(ev, resp)
case *setVarEvent:
return c.handleSetVar(ev)
case *suspendAccountEvent:
return c.handleSuspendAccount(ev, resp)
case *dropAccountEvent:
return c.handleDropAccount(ev, resp)
default:
}
return nil
}

// handleKillQuery handles the kill query event.
func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error {
sendErr := func(errMsg string) {
fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR]
payload := c.mysqlProto.MakeErrPayload(
fail.ErrorCode, fail.SqlStates[0], errMsg)
r := &frontend.Packet{
Length: 0,
SequenceID: 1,
Payload: payload,
}
sendResp(packetToBytes(r), resp)
}

cn, err := c.router.SelectByConnID(e.connID)
if err != nil {
c.log.Error("failed to select CN server", zap.Error(err))
sendErr(err.Error())
return err
func (c *clientConn) sendErr(errMsg string, resp chan<- []byte) {
fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR]
payload := c.mysqlProto.MakeErrPayload(
fail.ErrorCode, fail.SqlStates[0], errMsg)
r := &frontend.Packet{
Length: 0,
SequenceID: 1,
Payload: payload,
}
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()
sendResp(packetToBytes(r), resp)
}

func (c *clientConn) connAndExec(cn *CNServer, stmt string, resp chan<- []byte) error {
sc, r, err := c.router.Connect(cn, c.handshakePack, c.tun)
if err != nil {
c.log.Error("failed to connect to backend server", zap.Error(err))
sendErr(err.Error())
if resp != nil {
c.sendErr(err.Error(), resp)
}
return err
}
defer func() { _ = sc.Close() }()

if !isOKPacket(r) {
c.log.Error("failed to connect to cn to handle kill query event",
zap.String("query", e.stmt), zap.String("error", string(r)))
sendResp(r, resp)
c.log.Error("failed to connect to cn to handle event",
zap.String("query", stmt), zap.String("error", string(r)))
if resp != nil {
sendResp(r, resp)
}
return moerr.NewInternalErrorNoCtx("access error")
}

err = sc.ExecStmt(e.stmt, resp)
ok, err := sc.ExecStmt(stmt, resp)
if err != nil {
c.log.Error("failed to send query %s to server",
zap.String("query", e.stmt), zap.Error(err))
c.log.Error("failed to send query to server",
zap.String("query", stmt), zap.Error(err))
return err
}
if !ok {
return moerr.NewInternalErrorNoCtx("exec error")
}
return nil
}

// handleKillQuery handles the kill query event.
func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error {
cn, err := c.router.SelectByConnID(e.connID)
if err != nil {
c.log.Error("failed to select CN server", zap.Error(err))
c.sendErr(err.Error(), resp)
return err
}
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()

return c.connAndExec(cn, e.stmt, resp)
}

// handleSetVar handles the set variable event.
func (c *clientConn) handleSetVar(e *setVarEvent) error {
c.setVarStmts = append(c.setVarStmts, e.stmt)
return nil
}

// handleSuspendAccountEvent handles the suspend account event.
func (c *clientConn) handleSuspendAccount(e *suspendAccountEvent, resp chan<- []byte) error {
// a temp cn server.
cn := &CNServer{
addr: e.addr,
salt: c.mysqlProto.GetSalt(),
}
if err := c.connAndExec(cn, e.stmt, resp); err != nil {
return err
}

// handle kill connection.
cns, err := c.router.SelectByTenant(e.account)
if err != nil {
return err
}
if len(cns) == 0 {
return nil
}
for _, cn := range cns {
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()

go func(s *CNServer) {
query := fmt.Sprintf("kill connection %d", s.connID)
// No client to receive the result, so pass nil as the third
// parameter to ignore the result.
if err := c.connAndExec(s, query, nil); err != nil {
c.log.Error("failed to send query to server",
zap.String("query", query), zap.Error(err))
return
}
c.log.Info("kill connection on server succeeded",
zap.String("query", query), zap.String("server", s.addr))
}(cn)
}
return nil
}

// handleDropAccountEvent handles the drop account event.
func (c *clientConn) handleDropAccount(e *dropAccountEvent, resp chan<- []byte) error {
se := &suspendAccountEvent{
baseEvent: e.baseEvent,
stmt: e.stmt,
account: e.account,
addr: e.addr,
}
return c.handleSuspendAccount(se, resp)
}

// Close implements the ClientConn interface.
func (c *clientConn) Close() error {
return nil
Expand Down Expand Up @@ -332,12 +399,12 @@ func (c *clientConn) connectToBackend(sendToClient bool) (ServerConn, error) {
}

// Set the label session variable.
if err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil {
if _, err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil {
return nil, err
}
// Set the use defined variables, including session variables and user variables.
for _, stmt := range c.setVarStmts {
if err := sc.ExecStmt(stmt, nil); err != nil {
if _, err := sc.ExecStmt(stmt, nil); err != nil {
return nil, err
}
}
Expand Down
24 changes: 22 additions & 2 deletions pkg/proxy/client_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ func (c *mockClientConn) BuildConnWithServer(_ bool) (ServerConn, error) {
return nil, err
}
// Set the label session variable.
if err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil {
if _, err := sc.ExecStmt(c.labelInfo.genSetVarStmt(), nil); err != nil {
return nil, err
}
// Set the use defined variables, including session variables and user variables.
for _, stmt := range c.setVarStmts {
if err := sc.ExecStmt(stmt, nil); err != nil {
if _, err := sc.ExecStmt(stmt, nil); err != nil {
return nil, err
}
}
Expand All @@ -159,6 +159,26 @@ func (c *mockClientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<-
c.setVarStmts = append(c.setVarStmts, ev.stmt)
sendResp([]byte("ok"), resp)
return nil
case *suspendAccountEvent:
cns, err := c.router.SelectByTenant(ev.account)
if err != nil {
sendResp([]byte(err.Error()), resp)
return err
}
for _, cn := range cns {
sendResp([]byte(cn.addr), resp)
}
return nil
case *dropAccountEvent:
cns, err := c.router.SelectByTenant(ev.account)
if err != nil {
sendResp([]byte(err.Error()), resp)
return err
}
for _, cn := range cns {
sendResp([]byte(cn.addr), resp)
}
return nil
default:
sendResp([]byte("type not supported"), resp)
return moerr.NewInternalErrorNoCtx("type not supported")
Expand Down
36 changes: 34 additions & 2 deletions pkg/proxy/conn_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,17 @@ type connManager struct {

// Map from connection ID to CN server.
connIDServers map[uint32]*CNServer

// Map from Tenant to *CNServer list.
tenantConns map[Tenant]map[*CNServer]struct{}
}

// newConnManager creates a new connManager.
func newConnManager() *connManager {
m := &connManager{
conns: make(map[LabelHash]*connInfo),
connIDServers: make(map[uint32]*CNServer),
tenantConns: make(map[Tenant]map[*CNServer]struct{}),
}
return m
}
Expand Down Expand Up @@ -181,6 +185,14 @@ func (m *connManager) connect(cn *CNServer, t *tunnel) {
}
m.conns[cn.hash].cnTunnels.add(cn.uuid, t)
m.connIDServers[cn.connID] = cn

tenant := cn.reqLabel.Tenant
if tenant != "" {
if m.tenantConns[tenant] == nil {
m.tenantConns[tenant] = make(map[*CNServer]struct{})
}
m.tenantConns[tenant][cn] = struct{}{}
}
}

// disconnect removes a connection from connection manager.
Expand All @@ -193,6 +205,11 @@ func (m *connManager) disconnect(cn *CNServer, t *tunnel) {
}
ci.cnTunnels.del(cn.uuid, t)
delete(m.connIDServers, cn.connID)

tenant := cn.reqLabel.Tenant
if tenant != "" && m.tenantConns[tenant] != nil {
delete(m.tenantConns[tenant], cn)
}
}

// count returns the total connection count.
Expand Down Expand Up @@ -241,8 +258,8 @@ func (m *connManager) getLabelInfo(hash LabelHash) labelInfo {
return ci.label
}

// getCNServer returns a CN server which has the connection ID.
func (m *connManager) getCNServer(connID uint32) *CNServer {
// getCNServerByConnID returns a CN server which has the connection ID.
func (m *connManager) getCNServerByConnID(connID uint32) *CNServer {
m.Lock()
defer m.Unlock()
cn, ok := m.connIDServers[connID]
Expand All @@ -251,3 +268,18 @@ func (m *connManager) getCNServer(connID uint32) *CNServer {
}
return nil
}

// getCNServersByTenant returns a CN server list by tenant.
func (m *connManager) getCNServersByTenant(tenant Tenant) []*CNServer {
m.Lock()
defer m.Unlock()
cns, ok := m.tenantConns[tenant]
if !ok {
return nil
}
cnList := make([]*CNServer, 0, len(cns))
for cn := range cns {
cnList = append(cnList, cn)
}
return cnList
}
Loading

0 comments on commit f4ed57a

Please sign in to comment.