From e5467b3819f46c4b770d068b12f5adf244061e15 Mon Sep 17 00:00:00 2001 From: Kai Cao Date: Wed, 6 Nov 2024 16:44:57 +0800 Subject: [PATCH 1/2] fix ut (#19836) Fix ut Approved by: @daviszhen, @sukki37 --- pkg/cdc/sinker_test.go | 46 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/pkg/cdc/sinker_test.go b/pkg/cdc/sinker_test.go index 9ad449d20c9b..5eb4a0431074 100644 --- a/pkg/cdc/sinker_test.go +++ b/pkg/cdc/sinker_test.go @@ -487,9 +487,13 @@ func Test_mysqlSinker_Sink(t *testing.T) { ar := NewCdcActiveRoutine() - sinker := NewMysqlSinker(sink, dbTblInfo, watermarkUpdater, tableDef, ar) - go sinker.Run(ctx, ar) - defer func() { sinker.Close() }() + s := NewMysqlSinker(sink, dbTblInfo, watermarkUpdater, tableDef, ar) + go s.Run(ctx, ar) + defer func() { + // call dummy to guarantee sqls has been sent, then close + s.SendDummy() + s.Close() + }() packerPool := fileservice.NewPool( 128, @@ -513,14 +517,14 @@ func Test_mysqlSinker_Sink(t *testing.T) { ckpBat.Vecs[1] = testutil.MakeInt32Vector([]int32{1, 2, 3}, nil) ckpBat.SetRowCount(3) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ outputTyp: OutputTypeSnapshot, fromTs: t0, toTs: t1, checkpointBat: ckpBat, }) assert.NoError(t, err) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ noMoreData: true, fromTs: t0, toTs: t1, @@ -542,7 +546,7 @@ func Test_mysqlSinker_Sink(t *testing.T) { deleteBat.SetRowCount(1) deleteAtomicBat.Append(packer, deleteBat, 1, 0) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ outputTyp: OutputTypeTail, fromTs: t1, toTs: t2, @@ -551,7 +555,7 @@ func Test_mysqlSinker_Sink(t *testing.T) { }) assert.NoError(t, err) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ outputTyp: OutputTypeTail, fromTs: t1, toTs: t2, @@ -560,7 +564,7 @@ func Test_mysqlSinker_Sink(t *testing.T) { }) assert.NoError(t, err) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ outputTyp: OutputTypeTail, fromTs: t1, toTs: t2, @@ -569,7 +573,7 @@ func Test_mysqlSinker_Sink(t *testing.T) { }) assert.NoError(t, err) - sinker.Sink(ctx, &DecoderOutput{ + s.Sink(ctx, &DecoderOutput{ noMoreData: true, fromTs: t1, toTs: t2, @@ -617,7 +621,11 @@ func Test_mysqlSinker_Sink_NoMoreData(t *testing.T) { s.preSqlBufLen = 128 s.sqlBufSendCh = make(chan []byte) go s.Run(ctx, ar) - defer func() { s.Close() }() + defer func() { + // call dummy to guarantee sqls has been sent, then close + s.SendDummy() + s.Close() + }() s.Sink(ctx, &DecoderOutput{ noMoreData: true, @@ -877,7 +885,7 @@ func Test_mysqlSinker_SendBeginCommitRollback(t *testing.T) { mock.ExpectRollback() ar := NewCdcActiveRoutine() - sinker := &mysqlSinker{ + s := &mysqlSinker{ mysql: &mysqlSink{ retryTimes: 3, retryDuration: 3 * time.Second, @@ -886,17 +894,21 @@ func Test_mysqlSinker_SendBeginCommitRollback(t *testing.T) { ar: ar, sqlBufSendCh: make(chan []byte), } - go sinker.Run(context.Background(), ar) - defer func() { sinker.Close() }() + go s.Run(context.Background(), ar) + defer func() { + // call dummy to guarantee sqls has been sent, then close + s.SendDummy() + s.Close() + }() - sinker.SendBegin() + s.SendBegin() assert.NoError(t, err) - sinker.SendCommit() + s.SendCommit() assert.NoError(t, err) - sinker.SendBegin() + s.SendBegin() assert.NoError(t, err) - sinker.SendRollback() + s.SendRollback() assert.NoError(t, err) } From 4cfa7013f7a36d736736a10b0c5a1a6a80811ef6 Mon Sep 17 00:00:00 2001 From: YANGGMM Date: Wed, 6 Nov 2024 18:36:48 +0800 Subject: [PATCH 2/2] check user after open check switch (#19830) check user after open check switch Approved by: @qingxinhome, @sukki37 --- .../versions/v2_0_1/tenant_upgrade_list.go | 16 +-- pkg/frontend/authenticate.go | 14 ++- pkg/frontend/session.go | 101 +++++++++++------- pkg/frontend/session_test.go | 8 +- 4 files changed, 88 insertions(+), 51 deletions(-) diff --git a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go index 525fe2c7e83d..1f5b3c426578 100644 --- a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go +++ b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go @@ -30,7 +30,7 @@ var tenantUpgEntries = []versions.UpgradeEntry{ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_last_changed timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_last_changed") @@ -38,7 +38,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil @@ -48,7 +48,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ var upg_mo_user_add_password_history = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_history text default '[]'", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_history") @@ -56,7 +56,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TEXT" { + if colInfo.IsExits { return true, nil } return false, nil @@ -66,7 +66,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column login_attempts int unsigned default 0", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "login_attempts") @@ -74,7 +74,7 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "INT UNSIGNED" { + if colInfo.IsExits { return true, nil } return false, nil @@ -84,14 +84,14 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ var upg_mo_user_add_lock_time = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column lock_time timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "lock_time") if err != nil { return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil diff --git a/pkg/frontend/authenticate.go b/pkg/frontend/authenticate.go index 319f90c34939..a1eff71020e3 100644 --- a/pkg/frontend/authenticate.go +++ b/pkg/frontend/authenticate.go @@ -1185,7 +1185,11 @@ const ( deletePitrFromMoPitrFormat = `delete from mo_catalog.mo_pitr where create_account = %d;` - getPasswordOfUserFormat = `select user_id, authentication_string, default_role, password_last_changed, password_history, status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + getPasswordOfUserFormat = `select user_id, authentication_string, default_role from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getLockInfoOfUserFormat = `select status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getExpiredTimeOfUserFormat = `select password_last_changed from mo_catalog.mo_user where user_name = "%s" order by user_id;` getPasswordHistotyOfUsrFormat = `select password_history from mo_catalog.mo_user where user_name = "%s";` @@ -1647,6 +1651,14 @@ func getPasswordHistotyOfUserSql(user string) string { return fmt.Sprintf(getPasswordHistotyOfUsrFormat, user) } +func getLockInfoOfUserSql(user string) string { + return fmt.Sprintf(getLockInfoOfUserFormat, user) +} + +func getExpiredTimeOfUserSql(user string) string { + return fmt.Sprintf(getExpiredTimeOfUserFormat, user) +} + func getSqlForUpdatePasswordHistoryOfUser(passwordHistory, user string) string { return fmt.Sprintf(updatePasswordHistoryOfUserFormat, passwordHistory, user) } diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index 8014f91d7966..97a00e4c4fef 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1113,12 +1113,15 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa tenant *TenantInfo err error rsset []ExecResult + userRsset []ExecResult tenantID int64 userID int64 pwd, accountStatus string + psw []byte accountVersion uint64 createVersion string lastChangedTime string + defPwdLife int userStatus string loginAttempts uint64 lockTime string @@ -1213,47 +1216,27 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if err != nil { return nil, err } - rsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) if err != nil { return nil, err } - if !execResultArrayHasData(rsset) { + if !execResultArrayHasData(userRsset) { return nil, moerr.NewInternalErrorf(tenantCtx, "there is no user %s", tenant.GetUser()) } - userID, err = rsset[0].GetInt64(tenantCtx, 0, 0) + userID, err = userRsset[0].GetInt64(tenantCtx, 0, 0) if err != nil { return nil, err } - pwd, err = rsset[0].GetString(tenantCtx, 0, 1) + pwd, err = userRsset[0].GetString(tenantCtx, 0, 1) if err != nil { return nil, err } //the default_role in the mo_user table. //the default_role is always valid. public or other valid role. - defaultRoleID, err = rsset[0].GetInt64(tenantCtx, 0, 2) - if err != nil { - return nil, err - } - - lastChangedTime, err = rsset[0].GetString(tenantCtx, 0, 3) - if err != nil { - return nil, err - } - - userStatus, err = rsset[0].GetString(tenantCtx, 0, 5) - if err != nil { - return nil, err - } - - loginAttempts, err = rsset[0].GetUint64(tenantCtx, 0, 6) - if err != nil { - return nil, err - } - - lockTime, err = rsset[0].GetString(tenantCtx, 0, 7) + defaultRoleID, err = userRsset[0].GetInt64(tenantCtx, 0, 2) if err != nil { return nil, err } @@ -1335,7 +1318,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa v2.CheckRoleDurationHistogram.Observe(ses.timestampMap[TSCheckRoleEnd].Sub(ses.timestampMap[TSCheckRoleStart]).Seconds()) } //------------------------------------------------------------------------------------------------------------------ - psw, err := GetPassWord(pwd) + psw, err = GetPassWord(pwd) if err != nil { return nil, err } @@ -1349,11 +1332,32 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if err != nil { return nil, err } + if needCheckLock { + // get user status, login_attempts, lock_time + userLockInfoSql := getLockInfoOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userLockInfoSql) + if err != nil { + return nil, err + } + userStatus, err = userRsset[0].GetString(tenantCtx, 0, 0) + if err != nil { + return nil, err + } + + loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 1) + if err != nil { + return nil, err + } + + lockTime, err = userRsset[0].GetString(tenantCtx, 0, 2) + if err != nil { + return nil, err + } + } /* if user lock status is locked check if the lock_time is not expired - if yes, return error */ if needCheckLock && userStatus == userStatusLock { if lockTimeExpired, err = checkLockTimeExpired(tenantCtx, ses, lockTime); err != nil { @@ -1382,15 +1386,31 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if !isSuperUser(tenant.GetUser()) { // check password expired var expired bool - expired, err = checkPasswordExpired(tenantCtx, ses, lastChangedTime) + + defPwdLife, err = whetherNeedCheckExpired(tenantCtx, ses) if err != nil { return nil, err } - if expired { - ses.getRoutine().setExpired(true) + + if defPwdLife > 0 { + userExpiredSql := getExpiredTimeOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userExpiredSql) + if err != nil { + return nil, err + } + lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 0) + if err != nil { + return nil, err + } + expired, err = checkPasswordExpired(defPwdLife, lastChangedTime) + if err != nil { + return nil, err + } + if expired { + ses.getRoutine().setExpired(true) + } } - // if need check lock if needCheckLock && userStatus == userStatusLock { // if user lock status is locked, update status to unlock if err = setUserUnlock(tenantCtx, tenant.GetUser(), bh); err != nil { @@ -1967,20 +1987,25 @@ func appendTraceField(fields []zap.Field, ctx context.Context) []zap.Field { return fields } -func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime string) (bool, error) { +func whetherNeedCheckExpired(ctx context.Context, ses *Session) (int, error) { var ( defaultPasswordLifetime int err error - lastChanged time.Time ) - // get the default password lifetime defaultPasswordLifetime, err = getPasswordLifetime(ctx, ses) if err != nil { - return false, err + return 0, err } + return defaultPasswordLifetime, nil +} + +func checkPasswordExpired(defPwdLifeTime int, lastChangedTime string) (bool, error) { + var ( + err error + lastChanged time.Time + ) - // if the default password lifetime is 0, the password never expires - if defaultPasswordLifetime == 0 { + if defPwdLifeTime <= 0 { return false, nil } @@ -1992,7 +2017,7 @@ func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime str // get the current time as utc time now := time.Now().UTC() - if lastChanged.AddDate(0, 0, defaultPasswordLifetime).Before(now) { + if lastChanged.AddDate(0, 0, defPwdLifeTime).Before(now) { return true, nil } diff --git a/pkg/frontend/session_test.go b/pkg/frontend/session_test.go index 0964b774acb2..f008451696a7 100644 --- a/pkg/frontend/session_test.go +++ b/pkg/frontend/session_test.go @@ -616,18 +616,18 @@ func TestCheckPasswordExpired(t *testing.T) { ses.SetTenantInfo(tenant) // password never expires - expired, err := checkPasswordExpired(ctx, ses, "2022-01-01 00:00:00") + expired, err := checkPasswordExpired(0, "2022-01-01 00:00:00") assert.NoError(t, err) assert.False(t, expired) // password not expires ses.gSysVars.Set(DefaultPasswordLifetime, int64(30)) - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.False(t, expired) // password not expires - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.True(t, expired) @@ -651,7 +651,7 @@ func TestCheckPasswordExpired(t *testing.T) { // getPasswordLifetime error ses.gSysVars.Set(DefaultPasswordLifetime, int64(-1)) - _, err = checkPasswordExpired(ctx, ses, "1") + _, err = checkPasswordExpired(1, "1") assert.Error(t, err) assert.True(t, expired) }