diff --git a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go index 525fe2c7e83d..1f5b3c426578 100644 --- a/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go +++ b/pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go @@ -30,7 +30,7 @@ var tenantUpgEntries = []versions.UpgradeEntry{ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_last_changed timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_last_changed") @@ -38,7 +38,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil @@ -48,7 +48,7 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{ var upg_mo_user_add_password_history = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column password_history text default '[]'", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_history") @@ -56,7 +56,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "TEXT" { + if colInfo.IsExits { return true, nil } return false, nil @@ -66,7 +66,7 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column login_attempts int unsigned default 0", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "login_attempts") @@ -74,7 +74,7 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ return false, err } - if colInfo.ColType == "INT UNSIGNED" { + if colInfo.IsExits { return true, nil } return false, nil @@ -84,14 +84,14 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{ var upg_mo_user_add_lock_time = versions.UpgradeEntry{ Schema: catalog.MO_CATALOG, TableName: catalog.MO_USER, - UpgType: versions.MODIFY_COLUMN, + UpgType: versions.ADD_COLUMN, UpgSql: "alter table mo_catalog.mo_user add column lock_time timestamp default utc_timestamp", CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) { colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "lock_time") if err != nil { return false, err } - if colInfo.ColType == "TIMESTAMP" { + if colInfo.IsExits { return true, nil } return false, nil diff --git a/pkg/frontend/authenticate.go b/pkg/frontend/authenticate.go index d26b0246750e..7381ce454288 100644 --- a/pkg/frontend/authenticate.go +++ b/pkg/frontend/authenticate.go @@ -1186,7 +1186,11 @@ const ( deletePitrFromMoPitrFormat = `delete from mo_catalog.mo_pitr where create_account = %d;` - getPasswordOfUserFormat = `select user_id, authentication_string, default_role, password_last_changed, password_history, status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + getPasswordOfUserFormat = `select user_id, authentication_string, default_role from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getLockInfoOfUserFormat = `select status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;` + + getExpiredTimeOfUserFormat = `select password_last_changed from mo_catalog.mo_user where user_name = "%s" order by user_id;` getPasswordHistotyOfUsrFormat = `select password_history from mo_catalog.mo_user where user_name = "%s";` @@ -1648,6 +1652,14 @@ func getPasswordHistotyOfUserSql(user string) string { return fmt.Sprintf(getPasswordHistotyOfUsrFormat, user) } +func getLockInfoOfUserSql(user string) string { + return fmt.Sprintf(getLockInfoOfUserFormat, user) +} + +func getExpiredTimeOfUserSql(user string) string { + return fmt.Sprintf(getExpiredTimeOfUserFormat, user) +} + func getSqlForUpdatePasswordHistoryOfUser(passwordHistory, user string) string { return fmt.Sprintf(updatePasswordHistoryOfUserFormat, passwordHistory, user) } diff --git a/pkg/frontend/session.go b/pkg/frontend/session.go index 8014f91d7966..19d9125d0052 100644 --- a/pkg/frontend/session.go +++ b/pkg/frontend/session.go @@ -1113,12 +1113,15 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa tenant *TenantInfo err error rsset []ExecResult + userRsset []ExecResult tenantID int64 userID int64 pwd, accountStatus string + psw []byte accountVersion uint64 createVersion string lastChangedTime string + defPwdLife int userStatus string loginAttempts uint64 lockTime string @@ -1213,47 +1216,27 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if err != nil { return nil, err } - rsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser) if err != nil { return nil, err } - if !execResultArrayHasData(rsset) { + if !execResultArrayHasData(userRsset) { return nil, moerr.NewInternalErrorf(tenantCtx, "there is no user %s", tenant.GetUser()) } - userID, err = rsset[0].GetInt64(tenantCtx, 0, 0) + userID, err = userRsset[0].GetInt64(tenantCtx, 0, 0) if err != nil { return nil, err } - pwd, err = rsset[0].GetString(tenantCtx, 0, 1) + pwd, err = userRsset[0].GetString(tenantCtx, 0, 1) if err != nil { return nil, err } //the default_role in the mo_user table. //the default_role is always valid. public or other valid role. - defaultRoleID, err = rsset[0].GetInt64(tenantCtx, 0, 2) - if err != nil { - return nil, err - } - - lastChangedTime, err = rsset[0].GetString(tenantCtx, 0, 3) - if err != nil { - return nil, err - } - - userStatus, err = rsset[0].GetString(tenantCtx, 0, 5) - if err != nil { - return nil, err - } - - loginAttempts, err = rsset[0].GetUint64(tenantCtx, 0, 6) - if err != nil { - return nil, err - } - - lockTime, err = rsset[0].GetString(tenantCtx, 0, 7) + defaultRoleID, err = userRsset[0].GetInt64(tenantCtx, 0, 2) if err != nil { return nil, err } @@ -1335,7 +1318,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa v2.CheckRoleDurationHistogram.Observe(ses.timestampMap[TSCheckRoleEnd].Sub(ses.timestampMap[TSCheckRoleStart]).Seconds()) } //------------------------------------------------------------------------------------------------------------------ - psw, err := GetPassWord(pwd) + psw, err = GetPassWord(pwd) if err != nil { return nil, err } @@ -1345,15 +1328,36 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa return nil, err } - needCheckLock, err = whetherCheckLoginAttempts(tenantCtx, ses) + needCheckLock, err = whetherNeedCheckLoginAttempts(tenantCtx, ses) if err != nil { return nil, err } + if needCheckLock { + // get user status, login_attempts, lock_time + userLockInfoSql := getLockInfoOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userLockInfoSql) + if err != nil { + return nil, err + } + userStatus, err = userRsset[0].GetString(tenantCtx, 0, 0) + if err != nil { + return nil, err + } + + loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 1) + if err != nil { + return nil, err + } + + lockTime, err = userRsset[0].GetString(tenantCtx, 0, 2) + if err != nil { + return nil, err + } + } /* if user lock status is locked check if the lock_time is not expired - if yes, return error */ if needCheckLock && userStatus == userStatusLock { if lockTimeExpired, err = checkLockTimeExpired(tenantCtx, ses, lockTime); err != nil { @@ -1382,15 +1386,31 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa if !isSuperUser(tenant.GetUser()) { // check password expired var expired bool - expired, err = checkPasswordExpired(tenantCtx, ses, lastChangedTime) + + defPwdLife, err = whetherNeedCheckExpired(tenantCtx, ses) if err != nil { return nil, err } - if expired { - ses.getRoutine().setExpired(true) + + if defPwdLife > 0 { + userExpiredSql := getExpiredTimeOfUserSql(tenant.GetUser()) + userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userExpiredSql) + if err != nil { + return nil, err + } + lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 0) + if err != nil { + return nil, err + } + expired, err = checkPasswordExpired(defPwdLife, lastChangedTime) + if err != nil { + return nil, err + } + if expired { + ses.getRoutine().setExpired(true) + } } - // if need check lock if needCheckLock && userStatus == userStatusLock { // if user lock status is locked, update status to unlock if err = setUserUnlock(tenantCtx, tenant.GetUser(), bh); err != nil { @@ -1967,20 +1987,25 @@ func appendTraceField(fields []zap.Field, ctx context.Context) []zap.Field { return fields } -func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime string) (bool, error) { +func whetherNeedCheckExpired(ctx context.Context, ses *Session) (int, error) { var ( defaultPasswordLifetime int err error - lastChanged time.Time ) - // get the default password lifetime defaultPasswordLifetime, err = getPasswordLifetime(ctx, ses) if err != nil { - return false, err + return 0, err } + return defaultPasswordLifetime, nil +} + +func checkPasswordExpired(defPwdLifeTime int, lastChangedTime string) (bool, error) { + var ( + err error + lastChanged time.Time + ) - // if the default password lifetime is 0, the password never expires - if defaultPasswordLifetime == 0 { + if defPwdLifeTime <= 0 { return false, nil } @@ -1992,7 +2017,7 @@ func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime str // get the current time as utc time now := time.Now().UTC() - if lastChanged.AddDate(0, 0, defaultPasswordLifetime).Before(now) { + if lastChanged.AddDate(0, 0, defPwdLifeTime).Before(now) { return true, nil } @@ -2069,13 +2094,13 @@ func getLoginMaxDelay(ctx context.Context, ses *Session) (int64, error) { return delay, nil } -func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) { +func whetherNeedCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) { var ( - loginTimes int64 + loginMaxTimes int64 err error loginMaxDelay int64 ) - loginTimes, err = getLoginAttempts(ctx, ses) + loginMaxTimes, err = getLoginAttempts(ctx, ses) if err != nil { return false, err } @@ -2084,7 +2109,7 @@ func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) if err != nil { return false, err } - return loginTimes > 0 && loginMaxDelay > 0, nil + return loginMaxTimes > 0 && loginMaxDelay > 0, nil } func setUserUnlock(ctx context.Context, userName string, bh BackgroundExec) error { diff --git a/pkg/frontend/session_test.go b/pkg/frontend/session_test.go index 0964b774acb2..f008451696a7 100644 --- a/pkg/frontend/session_test.go +++ b/pkg/frontend/session_test.go @@ -616,18 +616,18 @@ func TestCheckPasswordExpired(t *testing.T) { ses.SetTenantInfo(tenant) // password never expires - expired, err := checkPasswordExpired(ctx, ses, "2022-01-01 00:00:00") + expired, err := checkPasswordExpired(0, "2022-01-01 00:00:00") assert.NoError(t, err) assert.False(t, expired) // password not expires ses.gSysVars.Set(DefaultPasswordLifetime, int64(30)) - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.False(t, expired) // password not expires - expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) + expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05")) assert.NoError(t, err) assert.True(t, expired) @@ -651,7 +651,7 @@ func TestCheckPasswordExpired(t *testing.T) { // getPasswordLifetime error ses.gSysVars.Set(DefaultPasswordLifetime, int64(-1)) - _, err = checkPasswordExpired(ctx, ses, "1") + _, err = checkPasswordExpired(1, "1") assert.Error(t, err) assert.True(t, expired) }