From cb6b2ac069ed8ff65a1e59c616cfdd05f13ad118 Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Wed, 6 Nov 2024 14:54:32 +0800 Subject: [PATCH 1/6] check user logic after open check switch --- .../versions/v2_0_1/tenant_upgrade_list.go | 16 ++-- pkg/frontend/session.go | 87 +++++++++++-------- pkg/frontend/session_test.go | 8 +- 3 files changed, 63 insertions(+), 48 deletions(-) diff --git a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go index 525fe2c7e83d..1f5b3c426578 100644 --- a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go +++ b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go @@ -30,7 +30,7 @@ var tenantUpgEntries = []versions.UpgradeEntry{ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_last_changed timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_last_changed") @@ -38,7 +38,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil @@ -48,7 +48,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ var upg_mo_user_add_password_history = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_history text default '[]'", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_history") @@ -56,7 +56,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TEXT" { + if colInfo.IsExits { return true, nil } return false, nil @@ -66,7 +66,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column login_attempts int unsigned default 0", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "login_attempts") @@ -74,7 +74,7 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "INT UNSIGNED" { + if colInfo.IsExits { return true, nil } return false, nil @@ -84,14 +84,14 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ var upg_mo_user_add_lock_time = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column lock_time timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "lock_time") if err != nil { return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index 8014f91d7966..e6f4203cea7b 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1113,12 +1113,14 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa tenant *TenantInfo err error rsset []ExecResult + userRsset []ExecResult tenantID int64 userID int64 pwd, accountStatus string accountVersion uint64 createVersion string lastChangedTime string + defPwdLife int userStatus string loginAttempts uint64 lockTime string @@ -1213,47 +1215,27 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if err != nil { return nil, err } - rsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) if err != nil { return nil, err } - if !execResultArrayHasData(rsset) { + if !execResultArrayHasData(userRsset) { return nil, moerr.NewInternalErrorf(tenantCtx, "there is no user %s", tenant.GetUser()) } - userID, err = rsset[0].GetInt64(tenantCtx, 0, 0) + userID, err = userRsset[0].GetInt64(tenantCtx, 0, 0) if err != nil { return nil, err } - pwd, err = rsset[0].GetString(tenantCtx, 0, 1) + pwd, err = userRsset[0].GetString(tenantCtx, 0, 1) if err != nil { return nil, err } //the default_role in the mo_user table. //the default_role is always valid. public or other valid role. - defaultRoleID, err = rsset[0].GetInt64(tenantCtx, 0, 2) - if err != nil { - return nil, err - } - - lastChangedTime, err = rsset[0].GetString(tenantCtx, 0, 3) - if err != nil { - return nil, err - } - - userStatus, err = rsset[0].GetString(tenantCtx, 0, 5) - if err != nil { - return nil, err - } - - loginAttempts, err = rsset[0].GetUint64(tenantCtx, 0, 6) - if err != nil { - return nil, err - } - - lockTime, err = rsset[0].GetString(tenantCtx, 0, 7) + defaultRoleID, err = userRsset[0].GetInt64(tenantCtx, 0, 2) if err != nil { return nil, err } @@ -1349,6 +1331,22 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if err != nil { return nil, err } + if needCheckLock { + userStatus, err = userRsset[0].GetString(tenantCtx, 0, 5) + if err != nil { + return nil, err + } + + loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 6) + if err != nil { + return nil, err + } + + lockTime, err = userRsset[0].GetString(tenantCtx, 0, 7) + if err != nil { + return nil, err + } + } /* if user lock status is locked @@ -1382,12 +1380,24 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if !isSuperUser(tenant.GetUser()) { // check password expired var expired bool - expired, err = checkPasswordExpired(tenantCtx, ses, lastChangedTime) + + defPwdLife, err = whetherNeedCheckExpired(tenantCtx, ses) if err != nil { return nil, err } - if expired { - ses.getRoutine().setExpired(true) + + if defPwdLife > 0 { + lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 3) + if err != nil { + return nil, err + } + expired, err = checkPasswordExpired(defPwdLife, lastChangedTime) + if err != nil { + return nil, err + } + if expired { + ses.getRoutine().setExpired(true) + } } // if need check lock @@ -1400,7 +1410,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa } } else { - if !isSuperUser(tenant.GetUser()) && needCheckLock { + if needCheckLock && !isSuperUser(tenant.GetUser()) { if userStatus != userStatusLock { loginAttempts++ if maxLoginAttempts, err = getLoginAttempts(tenantCtx, ses); err != nil { @@ -1967,20 +1977,25 @@ func appendTraceField(fields []zap.Field, ctx context.Context) []zap.Field { return fields } -func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime string) (bool, error) { +func whetherNeedCheckExpired(ctx context.Context, ses *Session) (int, error) { var ( defaultPasswordLifetime int err error - lastChanged time.Time ) - // get the default password lifetime defaultPasswordLifetime, err = getPasswordLifetime(ctx, ses) if err != nil { - return false, err + return 0, err } + return defaultPasswordLifetime, nil +} + +func checkPasswordExpired(defPwdLifeTime int, lastChangedTime string) (bool, error) { + var ( + err error + lastChanged time.Time + ) - // if the default password lifetime is 0, the password never expires - if defaultPasswordLifetime == 0 { + if defPwdLifeTime <= 0 { return false, nil } @@ -1992,7 +2007,7 @@ func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime str // get the current time as utc time now := time.Now().UTC() - if lastChanged.AddDate(0, 0, defaultPasswordLifetime).Before(now) { + if lastChanged.AddDate(0, 0, defPwdLifeTime).Before(now) { return true, nil } diff --git a/pkg/frontend/session_test.go b/pkg/frontend/session_test.go index 0964b774acb2..f008451696a7 100644 --- a/pkg/frontend/session_test.go +++ b/pkg/frontend/session_test.go @@ -616,18 +616,18 @@ func TestCheckPasswordExpired(t *testing.T) { ses.SetTenantInfo(tenant) // password never expires - expired, err := checkPasswordExpired(ctx, ses, "2022-01-01 00:00:00") + expired, err := checkPasswordExpired(0, "2022-01-01 00:00:00") assert.NoError(t, err) assert.False(t, expired) // password not expires ses.gSysVars.Set(DefaultPasswordLifetime, int64(30)) - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.False(t, expired) // password not expires - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.True(t, expired) @@ -651,7 +651,7 @@ func TestCheckPasswordExpired(t *testing.T) { // getPasswordLifetime error ses.gSysVars.Set(DefaultPasswordLifetime, int64(-1)) - _, err = checkPasswordExpired(ctx, ses, "1") + _, err = checkPasswordExpired(1, "1") assert.Error(t, err) assert.True(t, expired) } From 50ebedd89fa780f45f2245e6b2993728bd0e4e0d Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Wed, 6 Nov 2024 15:50:38 +0800 Subject: [PATCH 2/6] fix --- pkg/frontend/session.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index e6f4203cea7b..4d567cb5ff09 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1117,6 +1117,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa tenantID int64 userID int64 pwd, accountStatus string + psw []byte accountVersion uint64 createVersion string lastChangedTime string @@ -1317,7 +1318,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa v2.CheckRoleDurationHistogram.Observe(ses.timestampMap[TSCheckRoleEnd].Sub(ses.timestampMap[TSCheckRoleStart]).Seconds()) } //------------------------------------------------------------------------------------------------------------------ - psw, err := GetPassWord(pwd) + psw, err = GetPassWord(pwd) if err != nil { return nil, err } @@ -1410,7 +1411,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa } } else { - if needCheckLock && !isSuperUser(tenant.GetUser()) { + if !isSuperUser(tenant.GetUser()) && needCheckLock { if userStatus != userStatusLock { loginAttempts++ if maxLoginAttempts, err = getLoginAttempts(tenantCtx, ses); err != nil { From d8ebe79025eb7bdf426fee53435570d17243ae1b Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Wed, 6 Nov 2024 16:41:35 +0800 Subject: [PATCH 3/6] fix --- pkg/frontend/authenticate.go | 14 +++++++++++++- pkg/frontend/session.go | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/pkg/frontend/authenticate.go b/pkg/frontend/authenticate.go index 9e0d94575606..6435645e10ce 100644 --- a/pkg/frontend/authenticate.go +++ b/pkg/frontend/authenticate.go @@ -1186,7 +1186,11 @@ const ( deletePitrFromMoPitrFormat = `delete from mo_catalog.mo_pitr where create_account = %d;` - getPasswordOfUserFormat = `select user_id, authentication_string, default_role, password_last_changed, password_history, status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + getPasswordOfUserFormat = `select user_id, authentication_string, default_role from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getLockInfoOfUserFormat = `select status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getExpiredTimeOfUserFormat = `select password_last_changed from mo_catalog.mo_user where user_name = "%s" order by user_id;` getPasswordHistotyOfUsrFormat = `select password_history from mo_catalog.mo_user where user_name = "%s";` @@ -1648,6 +1652,14 @@ func getPasswordHistotyOfUserSql(user string) string { return fmt.Sprintf(getPasswordHistotyOfUsrFormat, user) } +func getLockInfoOfUserSql(user string) string { + return fmt.Sprintf(getLockInfoOfUserFormat, user) +} + +func getExpiredTimeOfUserSql(user string) string { + return fmt.Sprintf(getExpiredTimeOfUserFormat, user) +} + func getSqlForUpdatePasswordHistoryOfUser(passwordHistory, user string) string { return fmt.Sprintf(updatePasswordHistoryOfUserFormat, passwordHistory, user) } diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index 4d567cb5ff09..b7161a177501 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1333,17 +1333,23 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa return nil, err } if needCheckLock { - userStatus, err = userRsset[0].GetString(tenantCtx, 0, 5) + // get user status, login_attempts, lock_time + userLockInfoSql := getLockInfoOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userLockInfoSql) + if err != nil { + return nil, err + } + userStatus, err = userRsset[0].GetString(tenantCtx, 0, 0) if err != nil { return nil, err } - loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 6) + loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 1) if err != nil { return nil, err } - lockTime, err = userRsset[0].GetString(tenantCtx, 0, 7) + lockTime, err = userRsset[0].GetString(tenantCtx, 0, 2) if err != nil { return nil, err } @@ -1388,7 +1394,12 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa } if defPwdLife > 0 { - lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 3) + userExpiredSql := getExpiredTimeOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userExpiredSql) + if err != nil { + return nil, err + } + lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 0) if err != nil { return nil, err } From 47b9f620e8a58c6a64a06fa551673187950ca04a Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Wed, 6 Nov 2024 15:55:23 +0800 Subject: [PATCH 4/6] fix --- pkg/frontend/session.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index b7161a177501..97a00e4c4fef 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1358,7 +1358,6 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa /* if user lock status is locked check if the lock_time is not expired - if yes, return error */ if needCheckLock && userStatus == userStatusLock { if lockTimeExpired, err = checkLockTimeExpired(tenantCtx, ses, lockTime); err != nil { @@ -1412,7 +1411,6 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa } } - // if need check lock if needCheckLock && userStatus == userStatusLock { // if user lock status is locked, update status to unlock if err = setUserUnlock(tenantCtx, tenant.GetUser(), bh); err != nil { From 2aae23618f88b28df3a0fe6cd882bc6173e928ba Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Wed, 6 Nov 2024 18:10:11 +0800 Subject: [PATCH 5/6] fix --- pkg/frontend/session.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index 97a00e4c4fef..b61ff396dcc4 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1423,7 +1423,8 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if !isSuperUser(tenant.GetUser()) && needCheckLock { if userStatus != userStatusLock { loginAttempts++ - if maxLoginAttempts, err = getLoginAttempts(tenantCtx, ses); err != nil { + maxLoginAttempts, err = getLoginAttemptMaxTimes(tenantCtx, ses) + if err != nil { return nil, err } if int64(loginAttempts) >= maxLoginAttempts { @@ -2066,7 +2067,7 @@ func checkLockTimeExpired(ctx context.Context, ses *Session, lockTime string) (b return true, nil } -func getLoginAttempts(ctx context.Context, ses *Session) (int64, error) { +func getLoginAttemptMaxTimes(ctx context.Context, ses *Session) (int64, error) { value, err := ses.GetGlobalSysVar(ConnectionControlFailedConnectionsThreshold) if err != nil { return 0, err @@ -2096,11 +2097,11 @@ func getLoginMaxDelay(ctx context.Context, ses *Session) (int64, error) { func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) { var ( - loginTimes int64 + loginMaxTimes int64 err error loginMaxDelay int64 ) - loginTimes, err = getLoginAttempts(ctx, ses) + loginMaxTimes, err = getLoginAttemptMaxTimes(ctx, ses) if err != nil { return false, err } @@ -2109,7 +2110,7 @@ func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) if err != nil { return false, err } - return loginTimes > 0 && loginMaxDelay > 0, nil + return loginMaxTimes > 0 && loginMaxDelay > 0, nil } func setUserUnlock(ctx context.Context, userName string, bh BackgroundExec) error { From df331783e44640631f62e2946013aee689741bea Mon Sep 17 00:00:00 2001 From: YANGGMM <18916107305@163.com> Date: Thu, 7 Nov 2024 09:50:55 +0800 Subject: [PATCH 6/6] fix --- pkg/frontend/session.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index b61ff396dcc4..19d9125d0052 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1328,7 +1328,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa return nil, err } - needCheckLock, err = whetherCheckLoginAttempts(tenantCtx, ses) + needCheckLock, err = whetherNeedCheckLoginAttempts(tenantCtx, ses) if err != nil { return nil, err } @@ -1423,8 +1423,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if !isSuperUser(tenant.GetUser()) && needCheckLock { if userStatus != userStatusLock { loginAttempts++ - maxLoginAttempts, err = getLoginAttemptMaxTimes(tenantCtx, ses) - if err != nil { + if maxLoginAttempts, err = getLoginAttempts(tenantCtx, ses); err != nil { return nil, err } if int64(loginAttempts) >= maxLoginAttempts { @@ -2067,7 +2066,7 @@ func checkLockTimeExpired(ctx context.Context, ses *Session, lockTime string) (b return true, nil } -func getLoginAttemptMaxTimes(ctx context.Context, ses *Session) (int64, error) { +func getLoginAttempts(ctx context.Context, ses *Session) (int64, error) { value, err := ses.GetGlobalSysVar(ConnectionControlFailedConnectionsThreshold) if err != nil { return 0, err @@ -2095,13 +2094,13 @@ func getLoginMaxDelay(ctx context.Context, ses *Session) (int64, error) { return delay, nil } -func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) { +func whetherNeedCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) { var ( loginMaxTimes int64 err error loginMaxDelay int64 ) - loginMaxTimes, err = getLoginAttemptMaxTimes(ctx, ses) + loginMaxTimes, err = getLoginAttempts(ctx, ses) if err != nil { return false, err }