diff --git a/pkg/bastion/dbinit.go b/pkg/bastion/dbinit.go index 15d3011a..c47468d4 100644 --- a/pkg/bastion/dbinit.go +++ b/pkg/bastion/dbinit.go @@ -1,10 +1,11 @@ package bastion // import "moul.io/sshportal/pkg/bastion" import ( + "crypto/rand" "fmt" "io/ioutil" "log" - "math/rand" + "math/big" "os" "os/user" "strings" @@ -617,7 +618,10 @@ func DBInit(db *gorm.DB) error { } if count == 0 { // if no admin, create an account for the first connection - inviteToken := randStringBytes(16) + inviteToken, err := randStringBytes(16) + if err != nil { + return err + } if os.Getenv("SSHPORTAL_DEFAULT_ADMIN_INVITE_TOKEN") != "" { inviteToken = os.Getenv("SSHPORTAL_DEFAULT_ADMIN_INVITE_TOKEN") } @@ -673,12 +677,16 @@ func DBInit(db *gorm.DB) error { }).Error } -func randStringBytes(n int) string { +func randStringBytes(n int) (string, error) { const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, n) for i := range b { - b[i] = letterBytes[rand.Intn(len(letterBytes))] + r, err := rand.Int(rand.Reader, big.NewInt(int64(len(letterBytes)))) + if err != nil { + return "", fmt.Errorf("failed to generate random string: %s", err) + } + b[i] = letterBytes[r.Int64()] } - return string(b) + return string(b), nil } diff --git a/pkg/bastion/shell.go b/pkg/bastion/shell.go index 1742b820..d81b9449 100644 --- a/pkg/bastion/shell.go +++ b/pkg/bastion/shell.go @@ -1640,11 +1640,15 @@ GLOBAL OPTIONS: name = c.String("name") } + r, err := randStringBytes(16) + if err != nil { + return err + } user := dbmodels.User{ Name: name, Email: email, Comment: c.String("comment"), - InviteToken: randStringBytes(16), + InviteToken: r, } if _, err := govalidator.ValidateStruct(user); err != nil {