Skip to content

Commit

Permalink
refactor(middleware/cors): headers handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sixcolors committed Mar 11, 2024
1 parent 6a482d2 commit b792951
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 38 deletions.
83 changes: 45 additions & 38 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,56 +179,63 @@ func New(config ...Config) fiber.Handler {

// Simple request
if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin)
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)

if cfg.AllowCredentials {
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
return c.Next()
}

// Preflight request
c.Vary(fiber.HeaderOrigin)
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
}
} else if len(allowOrigin) > 0 {
// For non-credential requests, it's safe to set to '*' or specific origins
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)

// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
}
}

// Function to set CORS headers
func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
}
} else if len(allowOrigin) > 0 {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}

// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}
// Set Allow-Methods if not empty
if allowMethods != "" {
c.Set(fiber.HeaderAccessControlAllowMethods, allowMethods)
}

// Set MaxAge is set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
// Set Allow-Headers if not empty
if allowHeaders != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, allowHeaders)
} else {
h := c.Get(fiber.HeaderAccessControlRequestHeaders)
if h != "" {
c.Set(fiber.HeaderAccessControlAllowHeaders, h)
}
}

// Send 204 No Content
return c.SendStatus(fiber.StatusNoContent)
// Set MaxAge if set
if cfg.MaxAge > 0 {
c.Set(fiber.HeaderAccessControlMaxAge, maxAge)
} else if cfg.MaxAge < 0 {
c.Set(fiber.HeaderAccessControlMaxAge, "0")
}

// Set Expose-Headers if not empty
if exposeHeaders != "" {
c.Set(fiber.HeaderAccessControlExposeHeaders, exposeHeaders)
}
}
1 change: 1 addition & 0 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {

// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)

Expand Down

0 comments on commit b792951

Please sign in to comment.