From 38ab39b937990f44196b2e0402ef02eed9baab75 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Mon, 18 Mar 2024 21:53:13 -0300 Subject: [PATCH] refactor(middleware/cors): origin validation and normalization --- middleware/cors/cors.go | 22 ++++++---------------- middleware/cors/cors_test.go | 2 ++ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 9159a1340c..7debfdfaa0 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -119,33 +119,23 @@ func New(config ...Config) fiber.Handler { allowSOrigins := []subdomain{} allowAllOrigins := false - // processOrigin processes an origin string, normalizes it and checks its validity - // it will panic if the origin is invalid - processOrigin := func(origin string) (string, bool) { - trimmedOrigin := strings.TrimSpace(origin) - isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) - if !isValid { - log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) - panic("[CORS] Invalid origin provided in configuration") - } - return normalizedOrigin, true - } - // Validate and normalize static AllowOrigins if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { origins := strings.Split(cfg.AllowOrigins, ",") for _, origin := range origins { if i := strings.Index(origin, "://*."); i != -1 { - normalizedOrigin, isValid := processOrigin(origin[:i+3] + origin[i+4:]) + trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:]) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { - continue + panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin) } sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]} allowSOrigins = append(allowSOrigins, sd) } else { - normalizedOrigin, isValid := processOrigin(origin) + trimmedOrigin := strings.TrimSpace(origin) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { - continue + panic("[CORS] Invalid origin format in configuration: " + trimmedOrigin) } allowOrigins = append(allowOrigins, normalizedOrigin) } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 57f5f91205..2e0b5c2244 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -190,7 +190,9 @@ func Test_CORS_Invalid_Origins_Panic(t *testing.T) { "http://foo.[a-z]*.example.com", "http://*", "https://*", + "http://*.com*", "invalid url", + "http://origin.com,invalid url", // add more invalid origins as needed }