diff --git a/middleware/csrf.go b/middleware/csrf.go index 6899700c7..adf12210b 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -6,7 +6,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) type ( @@ -103,6 +102,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } + if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -132,7 +132,7 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = random.String(config.TokenLength) // Generate token + token = randomString(config.TokenLength) } else { token = k.Value // Reuse token } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 6bccdbe4d..6b20297ee 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" ) @@ -233,7 +232,7 @@ func TestCSRF(t *testing.T) { assert.Error(t, h(c)) // Valid CSRF token - token := random.String(32) + token := randomString(32) req.Header.Set(echo.HeaderCookie, "_csrf="+token) req.Header.Set(echo.HeaderXCSRFToken, token) if assert.NoError(t, h(c)) { diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 0f7c9141d..f66961fe2 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -10,7 +10,6 @@ import ( "time" "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -410,7 +409,7 @@ func TestNewRateLimiterMemoryStore(t *testing.T) { func generateAddressList(count int) []string { addrs := make([]string, count) for i := 0; i < count; i++ { - addrs[i] = random.String(15) + addrs[i] = randomString(15) } return addrs } diff --git a/middleware/request_id.go b/middleware/request_id.go index 8c5ff6605..e29c8f50d 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -2,7 +2,6 @@ package middleware import ( "github.com/labstack/echo/v4" - "github.com/labstack/gommon/random" ) type ( @@ -12,7 +11,7 @@ type ( Skipper Skipper // Generator defines a function to generate an ID. - // Optional. Default value random.String(32). + // Optional. Defaults to generator for random string of length 32. Generator func() string // RequestIDHandler defines a function which is executed for a request id. @@ -73,5 +72,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { } func generator() string { - return random.String(32) + return randomString(32) } diff --git a/middleware/util.go b/middleware/util.go index ab951a0e9..aa34d78f3 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -1,6 +1,8 @@ package middleware import ( + "crypto/rand" + "fmt" "strings" ) @@ -52,3 +54,18 @@ func matchSubdomain(domain, pattern string) bool { } return false } + +func randomString(length uint8) string { + charset := "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + bytes := make([]byte, length) + _, err := rand.Read(bytes) + if err != nil { + // we are out of random. let the request fail + panic(fmt.Errorf("echo randomString failed to read random bytes: %w", err)) + } + for i, b := range bytes { + bytes[i] = charset[b%byte(len(charset))] + } + return string(bytes) +} diff --git a/middleware/util_test.go b/middleware/util_test.go index df1d26295..7562d4a5f 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -93,3 +93,27 @@ func Test_matchSubdomain(t *testing.T) { assert.Equal(t, v.expected, matchSubdomain(v.domain, v.pattern)) } } + +func TestRandomString(t *testing.T) { + var testCases = []struct { + name string + whenLength uint8 + expect string + }{ + { + name: "ok, 16", + whenLength: 16, + }, + { + name: "ok, 32", + whenLength: 32, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + uid := randomString(tc.whenLength) + assert.Len(t, uid, int(tc.whenLength)) + }) + } +}