Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check user after open check switch to main #19837

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions pkg/bootstrap/versions/v2_0_1/tenant_upgrade_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ var tenantUpgEntries = []versions.UpgradeEntry{
var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{
Schema: catalog.MO_CATALOG,
TableName: catalog.MO_USER,
UpgType: versions.MODIFY_COLUMN,
UpgType: versions.ADD_COLUMN,
UpgSql: "alter table mo_catalog.mo_user add column password_last_changed timestamp default utc_timestamp",
CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) {
colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_last_changed")
if err != nil {
return false, err
}

if colInfo.ColType == "TIMESTAMP" {
if colInfo.IsExits {
return true, nil
}
return false, nil
Expand All @@ -48,15 +48,15 @@ var upg_mo_user_add_password_last_changed = versions.UpgradeEntry{
var upg_mo_user_add_password_history = versions.UpgradeEntry{
Schema: catalog.MO_CATALOG,
TableName: catalog.MO_USER,
UpgType: versions.MODIFY_COLUMN,
UpgType: versions.ADD_COLUMN,
UpgSql: "alter table mo_catalog.mo_user add column password_history text default '[]'",
CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) {
colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "password_history")
if err != nil {
return false, err
}

if colInfo.ColType == "TEXT" {
if colInfo.IsExits {
return true, nil
}
return false, nil
Expand All @@ -66,15 +66,15 @@ var upg_mo_user_add_password_history = versions.UpgradeEntry{
var upg_mo_user_add_login_attempts = versions.UpgradeEntry{
Schema: catalog.MO_CATALOG,
TableName: catalog.MO_USER,
UpgType: versions.MODIFY_COLUMN,
UpgType: versions.ADD_COLUMN,
UpgSql: "alter table mo_catalog.mo_user add column login_attempts int unsigned default 0",
CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) {
colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "login_attempts")
if err != nil {
return false, err
}

if colInfo.ColType == "INT UNSIGNED" {
if colInfo.IsExits {
return true, nil
}
return false, nil
Expand All @@ -84,14 +84,14 @@ var upg_mo_user_add_login_attempts = versions.UpgradeEntry{
var upg_mo_user_add_lock_time = versions.UpgradeEntry{
Schema: catalog.MO_CATALOG,
TableName: catalog.MO_USER,
UpgType: versions.MODIFY_COLUMN,
UpgType: versions.ADD_COLUMN,
UpgSql: "alter table mo_catalog.mo_user add column lock_time timestamp default utc_timestamp",
CheckFunc: func(txn executor.TxnExecutor, accountId uint32) (bool, error) {
colInfo, err := versions.CheckTableColumn(txn, accountId, catalog.MO_CATALOG, catalog.MO_USER, "lock_time")
if err != nil {
return false, err
}
if colInfo.ColType == "TIMESTAMP" {
if colInfo.IsExits {
return true, nil
}
return false, nil
Expand Down
14 changes: 13 additions & 1 deletion pkg/frontend/authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,11 @@ const (

deletePitrFromMoPitrFormat = `delete from mo_catalog.mo_pitr where create_account = %d;`

getPasswordOfUserFormat = `select user_id, authentication_string, default_role, password_last_changed, password_history, status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;`
getPasswordOfUserFormat = `select user_id, authentication_string, default_role from mo_catalog.mo_user where user_name = "%s" order by user_id;`

getLockInfoOfUserFormat = `select status, login_attempts, lock_time from mo_catalog.mo_user where user_name = "%s" order by user_id;`

getExpiredTimeOfUserFormat = `select password_last_changed from mo_catalog.mo_user where user_name = "%s" order by user_id;`

getPasswordHistotyOfUsrFormat = `select password_history from mo_catalog.mo_user where user_name = "%s";`

Expand Down Expand Up @@ -1648,6 +1652,14 @@ func getPasswordHistotyOfUserSql(user string) string {
return fmt.Sprintf(getPasswordHistotyOfUsrFormat, user)
}

func getLockInfoOfUserSql(user string) string {
return fmt.Sprintf(getLockInfoOfUserFormat, user)
}

func getExpiredTimeOfUserSql(user string) string {
return fmt.Sprintf(getExpiredTimeOfUserFormat, user)
}

func getSqlForUpdatePasswordHistoryOfUser(passwordHistory, user string) string {
return fmt.Sprintf(updatePasswordHistoryOfUserFormat, passwordHistory, user)
}
Expand Down
112 changes: 69 additions & 43 deletions pkg/frontend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -1113,12 +1113,15 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
tenant *TenantInfo
err error
rsset []ExecResult
userRsset []ExecResult
tenantID int64
userID int64
pwd, accountStatus string
psw []byte
accountVersion uint64
createVersion string
lastChangedTime string
defPwdLife int
userStatus string
loginAttempts uint64
lockTime string
Expand Down Expand Up @@ -1213,47 +1216,27 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
if err != nil {
return nil, err
}
rsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser)
userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, sqlForPasswordOfUser)
if err != nil {
return nil, err
}
if !execResultArrayHasData(rsset) {
if !execResultArrayHasData(userRsset) {
return nil, moerr.NewInternalErrorf(tenantCtx, "there is no user %s", tenant.GetUser())
}

userID, err = rsset[0].GetInt64(tenantCtx, 0, 0)
userID, err = userRsset[0].GetInt64(tenantCtx, 0, 0)
if err != nil {
return nil, err
}

pwd, err = rsset[0].GetString(tenantCtx, 0, 1)
pwd, err = userRsset[0].GetString(tenantCtx, 0, 1)
if err != nil {
return nil, err
}

//the default_role in the mo_user table.
//the default_role is always valid. public or other valid role.
defaultRoleID, err = rsset[0].GetInt64(tenantCtx, 0, 2)
if err != nil {
return nil, err
}

lastChangedTime, err = rsset[0].GetString(tenantCtx, 0, 3)
if err != nil {
return nil, err
}

userStatus, err = rsset[0].GetString(tenantCtx, 0, 5)
if err != nil {
return nil, err
}

loginAttempts, err = rsset[0].GetUint64(tenantCtx, 0, 6)
if err != nil {
return nil, err
}

lockTime, err = rsset[0].GetString(tenantCtx, 0, 7)
defaultRoleID, err = userRsset[0].GetInt64(tenantCtx, 0, 2)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1335,7 +1318,7 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
v2.CheckRoleDurationHistogram.Observe(ses.timestampMap[TSCheckRoleEnd].Sub(ses.timestampMap[TSCheckRoleStart]).Seconds())
}
//------------------------------------------------------------------------------------------------------------------
psw, err := GetPassWord(pwd)
psw, err = GetPassWord(pwd)
if err != nil {
return nil, err
}
Expand All @@ -1349,11 +1332,32 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
if err != nil {
return nil, err
}
if needCheckLock {
// get user status, login_attempts, lock_time
userLockInfoSql := getLockInfoOfUserSql(tenant.GetUser())
userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userLockInfoSql)
if err != nil {
return nil, err
}
userStatus, err = userRsset[0].GetString(tenantCtx, 0, 0)
if err != nil {
return nil, err
}

loginAttempts, err = userRsset[0].GetUint64(tenantCtx, 0, 1)
if err != nil {
return nil, err
}

lockTime, err = userRsset[0].GetString(tenantCtx, 0, 2)
if err != nil {
return nil, err
}
}

/*
if user lock status is locked
check if the lock_time is not expired
if yes, return error
*/
if needCheckLock && userStatus == userStatusLock {
if lockTimeExpired, err = checkLockTimeExpired(tenantCtx, ses, lockTime); err != nil {
Expand Down Expand Up @@ -1382,15 +1386,31 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
if !isSuperUser(tenant.GetUser()) {
// check password expired
var expired bool
expired, err = checkPasswordExpired(tenantCtx, ses, lastChangedTime)

defPwdLife, err = whetherNeedCheckExpired(tenantCtx, ses)
if err != nil {
return nil, err
}
if expired {
ses.getRoutine().setExpired(true)

if defPwdLife > 0 {
userExpiredSql := getExpiredTimeOfUserSql(tenant.GetUser())
userRsset, err = executeSQLInBackgroundSession(tenantCtx, ses, userExpiredSql)
if err != nil {
return nil, err
}
lastChangedTime, err = userRsset[0].GetString(tenantCtx, 0, 0)
if err != nil {
return nil, err
}
expired, err = checkPasswordExpired(defPwdLife, lastChangedTime)
if err != nil {
return nil, err
}
if expired {
ses.getRoutine().setExpired(true)
}
}

// if need check lock
if needCheckLock && userStatus == userStatusLock {
// if user lock status is locked, update status to unlock
if err = setUserUnlock(tenantCtx, tenant.GetUser(), bh); err != nil {
Expand All @@ -1403,7 +1423,8 @@ func (ses *Session) AuthenticateUser(ctx context.Context, userInput string, dbNa
if !isSuperUser(tenant.GetUser()) && needCheckLock {
if userStatus != userStatusLock {
loginAttempts++
if maxLoginAttempts, err = getLoginAttempts(tenantCtx, ses); err != nil {
maxLoginAttempts, err = getLoginAttemptMaxTimes(tenantCtx, ses)
if err != nil {
return nil, err
}
if int64(loginAttempts) >= maxLoginAttempts {
Expand Down Expand Up @@ -1967,20 +1988,25 @@ func appendTraceField(fields []zap.Field, ctx context.Context) []zap.Field {
return fields
}

func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime string) (bool, error) {
func whetherNeedCheckExpired(ctx context.Context, ses *Session) (int, error) {
var (
defaultPasswordLifetime int
err error
lastChanged time.Time
)
// get the default password lifetime
defaultPasswordLifetime, err = getPasswordLifetime(ctx, ses)
if err != nil {
return false, err
return 0, err
}
return defaultPasswordLifetime, nil
}

func checkPasswordExpired(defPwdLifeTime int, lastChangedTime string) (bool, error) {
var (
err error
lastChanged time.Time
)

// if the default password lifetime is 0, the password never expires
if defaultPasswordLifetime == 0 {
if defPwdLifeTime <= 0 {
return false, nil
}

Expand All @@ -1992,7 +2018,7 @@ func checkPasswordExpired(ctx context.Context, ses *Session, lastChangedTime str

// get the current time as utc time
now := time.Now().UTC()
if lastChanged.AddDate(0, 0, defaultPasswordLifetime).Before(now) {
if lastChanged.AddDate(0, 0, defPwdLifeTime).Before(now) {
return true, nil
}

Expand Down Expand Up @@ -2041,7 +2067,7 @@ func checkLockTimeExpired(ctx context.Context, ses *Session, lockTime string) (b
return true, nil
}

func getLoginAttempts(ctx context.Context, ses *Session) (int64, error) {
func getLoginAttemptMaxTimes(ctx context.Context, ses *Session) (int64, error) {
value, err := ses.GetGlobalSysVar(ConnectionControlFailedConnectionsThreshold)
if err != nil {
return 0, err
Expand Down Expand Up @@ -2071,11 +2097,11 @@ func getLoginMaxDelay(ctx context.Context, ses *Session) (int64, error) {

func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error) {
var (
loginTimes int64
loginMaxTimes int64
err error
loginMaxDelay int64
)
loginTimes, err = getLoginAttempts(ctx, ses)
loginMaxTimes, err = getLoginAttemptMaxTimes(ctx, ses)
if err != nil {
return false, err
}
Expand All @@ -2084,7 +2110,7 @@ func whetherCheckLoginAttempts(ctx context.Context, ses *Session) (bool, error)
if err != nil {
return false, err
}
return loginTimes > 0 && loginMaxDelay > 0, nil
return loginMaxTimes > 0 && loginMaxDelay > 0, nil
}

func setUserUnlock(ctx context.Context, userName string, bh BackgroundExec) error {
Expand Down
8 changes: 4 additions & 4 deletions pkg/frontend/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,18 +616,18 @@ func TestCheckPasswordExpired(t *testing.T) {
ses.SetTenantInfo(tenant)

// password never expires
expired, err := checkPasswordExpired(ctx, ses, "2022-01-01 00:00:00")
expired, err := checkPasswordExpired(0, "2022-01-01 00:00:00")
assert.NoError(t, err)
assert.False(t, expired)

// password not expires
ses.gSysVars.Set(DefaultPasswordLifetime, int64(30))
expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05"))
expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -10).Format("2006-01-02 15:04:05"))
assert.NoError(t, err)
assert.False(t, expired)

// password not expires
expired, err = checkPasswordExpired(ctx, ses, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05"))
expired, err = checkPasswordExpired(30, time.Now().AddDate(0, 0, -31).Format("2006-01-02 15:04:05"))
assert.NoError(t, err)
assert.True(t, expired)

Expand All @@ -651,7 +651,7 @@ func TestCheckPasswordExpired(t *testing.T) {

// getPasswordLifetime error
ses.gSysVars.Set(DefaultPasswordLifetime, int64(-1))
_, err = checkPasswordExpired(ctx, ses, "1")
_, err = checkPasswordExpired(1, "1")
assert.Error(t, err)
assert.True(t, expired)
}
Expand Down
Loading