diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index 074caa077a..c00cab6025 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -179,56 +179,63 @@ func New(config ...Config) fiber.Handler { // Simple request if c.Method() != fiber.MethodOptions { - c.Vary(fiber.HeaderOrigin) - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - - if cfg.AllowCredentials { - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } - if exposeHeaders != "" { - c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders) - } + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) return c.Next() } // Preflight request - c.Vary(fiber.HeaderOrigin) c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) - c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods) - if cfg.AllowCredentials { - // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' - if allowOrigin != "*" && allowOrigin != "" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - c.Set(fiber.HeaderAccessControlAllowCredentials, "true") - } else if allowOrigin == "*" { - c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) - log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") - } - } else if len(allowOrigin) > 0 { - // For non-credential requests, it's safe to set to '*' or specific origins + setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg) + + // Send 204 No Content + return c.SendStatus(fiber.StatusNoContent) + } +} + +// Function to set CORS headers +func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) { + c.Vary(fiber.HeaderOrigin) + + if cfg.AllowCredentials { + // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' + if allowOrigin != "*" && allowOrigin != "" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + c.Set(fiber.HeaderAccessControlAllowCredentials, "true") + } else if allowOrigin == "*" { + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") } + } else if len(allowOrigin) > 0 { + // For non-credential requests, it's safe to set to '*' or specific origins + c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) + } - // Set Allow-Headers if not empty - if allowHeaders != "" { - c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders) - } else { - h := c.Get(fiber.HeaderAccessControlRequestHeaders) - if h != "" { - c.Set(fiber.HeaderAccessControlAllowHeaders, h) - } - } + // Set Allow-Methods if not empty + if allowMethods != "" { + c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods) + } - // Set MaxAge is set - if cfg.MaxAge > 0 { - c.Set(fiber.HeaderAccessControlMaxAge, maxAge) - } else if cfg.MaxAge < 0 { - c.Set(fiber.HeaderAccessControlMaxAge, "0") + // Set Allow-Headers if not empty + if allowHeaders != "" { + c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders) + } else { + h := c.Get(fiber.HeaderAccessControlRequestHeaders) + if h != "" { + c.Set(fiber.HeaderAccessControlAllowHeaders, h) } + } - // Send 204 No Content - return c.SendStatus(fiber.StatusNoContent) + // Set MaxAge if set + if cfg.MaxAge > 0 { + c.Set(fiber.HeaderAccessControlMaxAge, maxAge) + } else if cfg.MaxAge < 0 { + c.Set(fiber.HeaderAccessControlMaxAge, "0") + } + + // Set Expose-Headers if not empty + if exposeHeaders != "" { + c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders) } } diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 3176e880cb..0d1b428c81 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -137,6 +137,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) { // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} + ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx)