Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(middleware/cors): Categorize requests correctly #2921

Merged
merged 7 commits into from
Mar 20, 2024
18 changes: 10 additions & 8 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ func New(config ...Config) fiber.Handler {
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))

// If the request does not have an Origin header, the request is outside the scope of CORS
if originHeader == "" {
// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
return c.Next()
}

Expand Down Expand Up @@ -211,8 +212,9 @@ func New(config ...Config) fiber.Handler {
}

// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}

Expand All @@ -233,14 +235,14 @@ func setCORSHeaders(c *fiber.Ctx, allowOrigin, allowMethods, allowHeaders, expos

if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin != "*" && allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
} else if allowOrigin == "*" {
if allowOrigin == "*" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.")
} else if allowOrigin != "" {
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
c.Set(fiber.HeaderAccessControlAllowCredentials, "true")
}
} else if len(allowOrigin) > 0 {
} else if allowOrigin != "" {
// For non-credential requests, it's safe to set to '*' or specific origins
c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin)
}
Expand Down
109 changes: 109 additions & 0 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func Test_CORS_Negative_MaxAge(t *testing.T) {

ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
app.Handler()(ctx)

Expand All @@ -49,6 +50,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand All @@ -59,6 +61,7 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)

Expand Down Expand Up @@ -87,6 +90,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)

// Perform request
handler(ctx)
Expand All @@ -101,6 +105,7 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)

utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
Expand Down Expand Up @@ -128,6 +133,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)

// Perform request
handler(ctx)
Expand All @@ -141,6 +147,7 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)

Expand Down Expand Up @@ -226,6 +233,7 @@ func Test_CORS_Subdomain(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")

// Perform request
Expand All @@ -240,6 +248,7 @@ func Test_CORS_Subdomain(t *testing.T) {
// Make request with domain only (disallowed)
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")

handler(ctx)
Expand All @@ -252,6 +261,7 @@ func Test_CORS_Subdomain(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")

handler(ctx)
Expand Down Expand Up @@ -366,6 +376,7 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)

handler(ctx)
Expand Down Expand Up @@ -422,6 +433,98 @@ func Test_CORS_Next(t *testing.T) {
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

// go test -run Test_CORS_Headers_BasedOnRequestType
func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(Config{}))
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

methods := []string{
fiber.MethodGet,
fiber.MethodPost,
fiber.MethodPut,
fiber.MethodDelete,
fiber.MethodPatch,
fiber.MethodHead,
}

// Get handler pointer
handler := app.Handler()

t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})

t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
// Make preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
utils.AssertEqual(t, 204, ctx.Response.StatusCode(), "Status code should be 204")
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)")
}
})

t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
utils.AssertEqual(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)")
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)")
}
})
}

func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
Expand All @@ -440,6 +543,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")

// Perform request
Expand All @@ -454,6 +558,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")

handler(ctx)
Expand All @@ -466,6 +571,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")

handler(ctx)
Expand Down Expand Up @@ -505,6 +611,7 @@ func Test_CORS_AllowOriginsFunc(t *testing.T) {
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")

handler(ctx)
Expand Down Expand Up @@ -652,6 +759,7 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)
Expand Down Expand Up @@ -742,6 +850,7 @@ func Test_CORS_AllowCredentials(t *testing.T) {
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin)

handler(ctx)
Expand Down
52 changes: 52 additions & 0 deletions middleware/cors/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package cors

import (
"testing"

"github.com/gofiber/fiber/v2/utils"
)

// go test -run -v Test_normalizeOrigin
Expand All @@ -16,6 +18,9 @@ func Test_normalizeOrigin(t *testing.T) {
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
{"http://", false, ""}, // Invalid origin should not be accepted.
{"file:///etc/passwd", false, ""}, // File scheme should not be accepted.
{"https://*example.com", false, ""}, // Wildcard domain should not be accepted.
{"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted.
{"http://example.com/path", false, ""}, // Path should not be accepted.
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
Expand Down Expand Up @@ -105,3 +110,50 @@ func Test_normalizeDomain(t *testing.T) {
}
}
}

func TestSubdomainMatch(t *testing.T) {
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://api.", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://api.", suffix: ".example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://api.", suffix: ".example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://api.", suffix: ".example.com"},
origin: "",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://api.", suffix: ".example.com"},
origin: "https://api.example.com",
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sub.match(tt.origin)
utils.AssertEqual(t, tt.expected, got, "subdomain.match()")
})
}
}
Loading