Skip to content

Commit

Permalink
[feat]: kill connected session if account suspended or dropped. (#9338)
Browse files Browse the repository at this point in the history
  • Loading branch information
volgariver6 authored May 14, 2023
1 parent b7ceaaa commit 90f59ec
Show file tree
Hide file tree
Showing 15 changed files with 893 additions and 235 deletions.
134 changes: 104 additions & 30 deletions pkg/proxy/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package proxy

import (
"context"
"fmt"
"net"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -225,63 +226,136 @@ func (c *clientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []by
return c.handleKillQuery(ev, resp)
case *setVarEvent:
return c.handleSetVar(ev)
case *suspendAccountEvent:
return c.handleSuspendAccount(ev, resp)
case *dropAccountEvent:
return c.handleDropAccount(ev, resp)
default:
}
return nil
}

// handleKillQuery handles the kill query event.
func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error {
sendErr := func(errMsg string) {
fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR]
payload := c.mysqlProto.MakeErrPayload(
fail.ErrorCode, fail.SqlStates[0], errMsg)
r := &frontend.Packet{
Length: 0,
SequenceID: 1,
Payload: payload,
}
sendResp(packetToBytes(r), resp)
}

cn, err := c.router.SelectByConnID(e.connID)
if err != nil {
c.log.Error("failed to select CN server", zap.Error(err))
sendErr(err.Error())
return err
func (c *clientConn) sendErr(errMsg string, resp chan<- []byte) {
fail := moerr.MysqlErrorMsgRefer[moerr.ER_ACCESS_DENIED_ERROR]
payload := c.mysqlProto.MakeErrPayload(
fail.ErrorCode, fail.SqlStates[0], errMsg)
r := &frontend.Packet{
Length: 0,
SequenceID: 1,
Payload: payload,
}
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()
sendResp(packetToBytes(r), resp)
}

func (c *clientConn) connAndExec(cn *CNServer, stmt string, resp chan<- []byte) error {
sc, r, err := c.router.Connect(cn, c.handshakePack, c.tun)
if err != nil {
c.log.Error("failed to connect to backend server", zap.Error(err))
sendErr(err.Error())
if resp != nil {
c.sendErr(err.Error(), resp)
}
return err
}
defer func() { _ = sc.Close() }()

if !isOKPacket(r) {
c.log.Error("failed to connect to cn to handle kill query event",
zap.String("query", e.stmt), zap.String("error", string(r)))
sendResp(r, resp)
c.log.Error("failed to connect to cn to handle event",
zap.String("query", stmt), zap.String("error", string(r)))
if resp != nil {
sendResp(r, resp)
}
return moerr.NewInternalErrorNoCtx("access error")
}

err = sc.ExecStmt(e.stmt, resp)
ok, err := sc.ExecStmt(stmt, resp)
if err != nil {
c.log.Error("failed to send query %s to server",
zap.String("query", e.stmt), zap.Error(err))
c.log.Error("failed to send query to server",
zap.String("query", stmt), zap.Error(err))
return err
}
if !ok {
return moerr.NewInternalErrorNoCtx("exec error")
}
return nil
}

// handleKillQuery handles the kill query event.
func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error {
cn, err := c.router.SelectByConnID(e.connID)
if err != nil {
c.log.Error("failed to select CN server", zap.Error(err))
c.sendErr(err.Error(), resp)
return err
}
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()

return c.connAndExec(cn, e.stmt, resp)
}

// handleSetVar handles the set variable event.
func (c *clientConn) handleSetVar(e *setVarEvent) error {
c.setVarStmts = append(c.setVarStmts, e.stmt)
return nil
}

// handleSuspendAccountEvent handles the suspend account event.
func (c *clientConn) handleSuspendAccount(e *suspendAccountEvent, resp chan<- []byte) error {
// a temp cn server.
cn := &CNServer{
addr: e.addr,
salt: c.mysqlProto.GetSalt(),
}
csp, _ := c.tun.getPipes()
if csp.inTxn() {
// TODO(volgariver6): this is for the compatibility with the case that we are
// now in a transaction. suspend or drop operation can not be activated within a
// transaction.
c.sendErr(moerr.NewInternalErrorNoCtx("administrative command is unsupported in transactions").Error(), resp)
}
if err := c.connAndExec(cn, e.stmt, resp); err != nil {
return err
}

// handle kill connection.
cns, err := c.router.SelectByTenant(e.account)
if err != nil {
return err
}
if len(cns) == 0 {
return nil
}
for _, cn := range cns {
// Before connect to backend server, update the salt.
cn.salt = c.mysqlProto.GetSalt()

go func(s *CNServer) {
query := fmt.Sprintf("kill connection %d", s.connID)
// No client to receive the result, so pass nil as the third
// parameter to ignore the result.
if err := c.connAndExec(s, query, nil); err != nil {
c.log.Error("failed to send query to server",
zap.String("query", query), zap.Error(err))
return
}
c.log.Info("kill connection on server succeeded",
zap.String("query", query), zap.String("server", s.addr))
}(cn)
}
return nil
}

// handleDropAccountEvent handles the drop account event.
func (c *clientConn) handleDropAccount(e *dropAccountEvent, resp chan<- []byte) error {
se := &suspendAccountEvent{
baseEvent: e.baseEvent,
stmt: e.stmt,
account: e.account,
addr: e.addr,
}
return c.handleSuspendAccount(se, resp)
}

// Close implements the ClientConn interface.
func (c *clientConn) Close() error {
return nil
Expand Down Expand Up @@ -323,12 +397,12 @@ func (c *clientConn) connectToBackend(sendToClient bool) (ServerConn, error) {
}

// Set the label session variable.
if err := sc.ExecStmt(c.clientInfo.genSetVarStmt(), nil); err != nil {
if _, err := sc.ExecStmt(c.clientInfo.genSetVarStmt(), nil); err != nil {
return nil, err
}
// Set the use defined variables, including session variables and user variables.
for _, stmt := range c.setVarStmts {
if err := sc.ExecStmt(stmt, nil); err != nil {
if _, err := sc.ExecStmt(stmt, nil); err != nil {
return nil, err
}
}
Expand Down
24 changes: 22 additions & 2 deletions pkg/proxy/client_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ func (c *mockClientConn) BuildConnWithServer(_ bool) (ServerConn, error) {
return nil, err
}
// Set the label session variable.
if err := sc.ExecStmt(c.clientInfo.genSetVarStmt(), nil); err != nil {
if _, err := sc.ExecStmt(c.clientInfo.genSetVarStmt(), nil); err != nil {
return nil, err
}
// Set the use defined variables, including session variables and user variables.
for _, stmt := range c.setVarStmts {
if err := sc.ExecStmt(stmt, nil); err != nil {
if _, err := sc.ExecStmt(stmt, nil); err != nil {
return nil, err
}
}
Expand All @@ -159,6 +159,26 @@ func (c *mockClientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<-
c.setVarStmts = append(c.setVarStmts, ev.stmt)
sendResp([]byte("ok"), resp)
return nil
case *suspendAccountEvent:
cns, err := c.router.SelectByTenant(ev.account)
if err != nil {
sendResp([]byte(err.Error()), resp)
return err
}
for _, cn := range cns {
sendResp([]byte(cn.addr), resp)
}
return nil
case *dropAccountEvent:
cns, err := c.router.SelectByTenant(ev.account)
if err != nil {
sendResp([]byte(err.Error()), resp)
return err
}
for _, cn := range cns {
sendResp([]byte(cn.addr), resp)
}
return nil
default:
sendResp([]byte("type not supported"), resp)
return moerr.NewInternalErrorNoCtx("type not supported")
Expand Down
36 changes: 34 additions & 2 deletions pkg/proxy/conn_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,17 @@ type connManager struct {

// Map from connection ID to CN server.
connIDServers map[uint32]*CNServer

// Map from Tenant to *CNServer list.
tenantConns map[Tenant]map[*CNServer]struct{}
}

// newConnManager creates a new connManager.
func newConnManager() *connManager {
m := &connManager{
conns: make(map[LabelHash]*connInfo),
connIDServers: make(map[uint32]*CNServer),
tenantConns: make(map[Tenant]map[*CNServer]struct{}),
}
return m
}
Expand Down Expand Up @@ -181,6 +185,14 @@ func (m *connManager) connect(cn *CNServer, t *tunnel) {
}
m.conns[cn.hash].cnTunnels.add(cn.uuid, t)
m.connIDServers[cn.connID] = cn

tenant := cn.reqLabel.Tenant
if tenant != "" {
if m.tenantConns[tenant] == nil {
m.tenantConns[tenant] = make(map[*CNServer]struct{})
}
m.tenantConns[tenant][cn] = struct{}{}
}
}

// disconnect removes a connection from connection manager.
Expand All @@ -193,6 +205,11 @@ func (m *connManager) disconnect(cn *CNServer, t *tunnel) {
}
ci.cnTunnels.del(cn.uuid, t)
delete(m.connIDServers, cn.connID)

tenant := cn.reqLabel.Tenant
if tenant != "" && m.tenantConns[tenant] != nil {
delete(m.tenantConns[tenant], cn)
}
}

// count returns the total connection count.
Expand Down Expand Up @@ -241,8 +258,8 @@ func (m *connManager) getLabelInfo(hash LabelHash) labelInfo {
return ci.label
}

// getCNServer returns a CN server which has the connection ID.
func (m *connManager) getCNServer(connID uint32) *CNServer {
// getCNServerByConnID returns a CN server which has the connection ID.
func (m *connManager) getCNServerByConnID(connID uint32) *CNServer {
m.Lock()
defer m.Unlock()
cn, ok := m.connIDServers[connID]
Expand All @@ -251,3 +268,18 @@ func (m *connManager) getCNServer(connID uint32) *CNServer {
}
return nil
}

// getCNServersByTenant returns a CN server list by tenant.
func (m *connManager) getCNServersByTenant(tenant Tenant) []*CNServer {
m.Lock()
defer m.Unlock()
cns, ok := m.tenantConns[tenant]
if !ok {
return nil
}
cnList := make([]*CNServer, 0, len(cns))
for cn := range cns {
cnList = append(cnList, cn)
}
return cnList
}
Loading

0 comments on commit 90f59ec

Please sign in to comment.