diff --git a/claims.go b/claims.go index 52fd91d0..8a125c70 100644 --- a/claims.go +++ b/claims.go @@ -98,21 +98,13 @@ func validateBoundClaims(logger log.Logger, boundClaims, allClaims map[string]in var actVals, expVals []interface{} - switch v := actValue.(type) { - case []interface{}: - actVals = v - case string: - actVals = []interface{}{v} - default: + actVals, ok := normalizeList(actValue) + if !ok { return fmt.Errorf("received claim is not a string or list: %v", actValue) } - switch v := expValue.(type) { - case []interface{}: - expVals = v - case string: - expVals = []interface{}{v} - default: + expVals, ok = normalizeList(expValue) + if !ok { return fmt.Errorf("bound claim is not a string or list: %v", expValue) } @@ -135,3 +127,21 @@ func validateBoundClaims(logger log.Logger, boundClaims, allClaims map[string]in return nil } + +// normalizeList takes a string or list and returns a list. This is useful when +// providers are expected to return a list (typically of strings) but reduce it +// to a string type when the list count is 1. +func normalizeList(raw interface{}) ([]interface{}, bool) { + var normalized []interface{} + + switch v := raw.(type) { + case []interface{}: + normalized = v + case string: + normalized = []interface{}{v} + default: + return nil, false + } + + return normalized, true +} diff --git a/claims_test.go b/claims_test.go index 073e83b4..631257aa 100644 --- a/claims_test.go +++ b/claims_test.go @@ -2,6 +2,7 @@ package jwtauth import ( "encoding/json" + "reflect" "testing" "github.com/go-test/deep" @@ -376,3 +377,56 @@ func TestValidateBoundClaims(t *testing.T) { } } } + +func Test_normalizeList(t *testing.T) { + tests := []struct { + raw interface{} + normalized []interface{} + ok bool + }{ + { + raw: []interface{}{"green", 42}, + normalized: []interface{}{"green", 42}, + ok: true, + }, + { + raw: []interface{}{"green"}, + normalized: []interface{}{"green"}, + ok: true, + }, + { + raw: []interface{}{}, + normalized: []interface{}{}, + ok: true, + }, + { + raw: "green", + normalized: []interface{}{"green"}, + ok: true, + }, + { + raw: "", + normalized: []interface{}{""}, + ok: true, + }, + { + raw: 42, + normalized: nil, + ok: false, + }, + { + raw: nil, + normalized: nil, + ok: false, + }, + } + for _, tt := range tests { + normalized, ok := normalizeList(tt.raw) + if !reflect.DeepEqual(normalized, tt.normalized) { + t.Errorf("normalizeList() got normalized = %v, want %v", normalized, tt.normalized) + } + if ok != tt.ok { + t.Errorf("normalizeList() got ok = %v, want %v", ok, tt.ok) + } + } +} diff --git a/path_login.go b/path_login.go index 08703db0..3ef52756 100644 --- a/path_login.go +++ b/path_login.go @@ -339,7 +339,8 @@ func (b *jwtAuthBackend) createIdentity(allClaims map[string]interface{}, role * if groupsClaimRaw == nil { return nil, nil, fmt.Errorf("%q claim not found in token", role.GroupsClaim) } - groups, ok := groupsClaimRaw.([]interface{}) + + groups, ok := normalizeList(groupsClaimRaw) if !ok { return nil, nil, fmt.Errorf("%q claim could not be converted to string list", role.GroupsClaim) diff --git a/path_login_test.go b/path_login_test.go index fae3f8fb..8c923796 100644 --- a/path_login_test.go +++ b/path_login_test.go @@ -19,17 +19,34 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) -func setupBackend(t *testing.T, oidc, role_type_oidc, audience, boundClaims, boundCIDRs, jwks bool, defaultLeeway, expLeeway, nbfLeeway int) (logical.Backend, logical.Storage) { +type testConfig struct { + oidc bool + role_type_oidc bool + audience bool + boundClaims bool + boundCIDRs bool + jwks bool + defaultLeeway int + expLeeway int + nbfLeeway int + groupsClaim string +} + +func setupBackend(t *testing.T, cfg testConfig) (logical.Backend, logical.Storage) { b, storage := getBackend(t) + if cfg.groupsClaim == "" { + cfg.groupsClaim = "https://vault/groups" + } + var data map[string]interface{} - if oidc { + if cfg.oidc { data = map[string]interface{}{ "bound_issuer": "https://team-vault.auth0.com/", "oidc_discovery_url": "https://team-vault.auth0.com/", } } else { - if !jwks { + if !cfg.jwks { data = map[string]interface{}{ "bound_issuer": "https://team-vault.auth0.com/", "jwt_validation_pubkeys": ecdsaPubKey, @@ -64,7 +81,7 @@ func setupBackend(t *testing.T, oidc, role_type_oidc, audience, boundClaims, bou "role_type": "jwt", "bound_subject": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", "user_claim": "https://vault/user", - "groups_claim": "https://vault/groups", + "groups_claim": cfg.groupsClaim, "policies": "test", "period": "3s", "ttl": "1s", @@ -75,25 +92,25 @@ func setupBackend(t *testing.T, oidc, role_type_oidc, audience, boundClaims, bou "/org/primary": "primary_org", }, } - if role_type_oidc { + if cfg.role_type_oidc { data["role_type"] = "oidc" data["allowed_redirect_uris"] = "http://127.0.0.1" } - if audience { + if cfg.audience { data["bound_audiences"] = []string{"https://vault.plugin.auth.jwt.test", "another_audience"} } - if boundClaims { + if cfg.boundClaims { data["bound_claims"] = map[string]interface{}{ "color": "green", } } - if boundCIDRs { + if cfg.boundCIDRs { data["bound_cidrs"] = "127.0.0.42" } - data["clock_skew_leeway"] = defaultLeeway - data["expiration_leeway"] = expLeeway - data["not_before_leeway"] = nbfLeeway + data["clock_skew_leeway"] = cfg.defaultLeeway + data["expiration_leeway"] = cfg.expLeeway + data["not_before_leeway"] = cfg.nbfLeeway req = &logical.Request{ Operation: logical.CreateOperation, @@ -175,7 +192,12 @@ func TestLogin_JWT(t *testing.T) { func testLogin_JWT(t *testing.T, jwks bool) { // Test role_type oidc { - b, storage := setupBackend(t, false, true, true, false, false, jwks, 0, 0, 0) + cfg := testConfig{ + role_type_oidc: true, + audience: true, + jwks: jwks, + } + b, storage := setupBackend(t, cfg) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -226,7 +248,11 @@ func testLogin_JWT(t *testing.T, jwks bool) { // Test missing audience { - b, storage := setupBackend(t, false, false, false, false, false, jwks, 0, 0, 0) + + cfg := testConfig{ + jwks: jwks, + } + b, storage := setupBackend(t, cfg) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -279,7 +305,13 @@ func testLogin_JWT(t *testing.T, jwks bool) { { // run test with and without bound_cidrs configured for _, useBoundCIDRs := range []bool{false, true} { - b, storage := setupBackend(t, false, false, true, true, useBoundCIDRs, jwks, 0, 0, 0) + cfg := testConfig{ + audience: true, + boundClaims: true, + boundCIDRs: useBoundCIDRs, + jwks: jwks, + } + b, storage := setupBackend(t, cfg) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -368,7 +400,12 @@ func testLogin_JWT(t *testing.T, jwks bool) { } } - b, storage := setupBackend(t, false, false, true, true, false, jwks, 0, 0, 0) + cfg := testConfig{ + audience: true, + boundClaims: true, + jwks: jwks, + } + b, storage := setupBackend(t, cfg) // test invalid bound claim { @@ -645,7 +682,11 @@ func testLogin_JWT(t *testing.T, jwks bool) { // test invalid address { - b, storage := setupBackend(t, false, false, false, false, true, jwks, 0, 0, 0) + cfg := testConfig{ + boundCIDRs: true, + jwks: jwks, + } + b, storage := setupBackend(t, cfg) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", @@ -789,7 +830,13 @@ func testLogin_ExpiryClaims(t *testing.T, jwks bool) { } for i, tt := range tests { - b, storage := setupBackend(t, false, false, true, false, false, tt.JWKS, tt.DefaultLeeway, tt.ExpLeeway, 0) + cfg := testConfig{ + audience: true, + jwks: tt.JWKS, + defaultLeeway: tt.DefaultLeeway, + expLeeway: tt.ExpLeeway, + } + b, storage := setupBackend(t, cfg) req := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, b, storage) resp, err := b.HandleRequest(context.Background(), req) @@ -859,7 +906,14 @@ func testLogin_NotBeforeClaims(t *testing.T, jwks bool) { } for i, tt := range tests { - b, storage := setupBackend(t, false, false, true, false, false, tt.JWKS, tt.DefaultLeeway, 0, tt.NBFLeeway) + cfg := testConfig{ + audience: true, + jwks: tt.JWKS, + defaultLeeway: tt.DefaultLeeway, + expLeeway: 0, + nbfLeeway: tt.NBFLeeway, + } + b, storage := setupBackend(t, cfg) req := setupLogin(t, tt.IssuedAt, tt.Expiration, tt.NotBefore, b, storage) resp, err := b.HandleRequest(context.Background(), req) @@ -919,7 +973,12 @@ func setupLogin(t *testing.T, iat, exp, nbf time.Time, b logical.Backend, storag } func TestLogin_OIDC(t *testing.T) { - b, storage := setupBackend(t, true, false, true, false, false, false, -1, 0, 0) + cfg := testConfig{ + oidc: true, + audience: true, + defaultLeeway: -1, + } + b, storage := setupBackend(t, cfg) jwtData := getTestOIDC(t) @@ -1081,8 +1140,58 @@ func TestLogin_NestedGroups(t *testing.T) { } } +func TestLogin_OIDC_StringGroupClaim(t *testing.T) { + cfg := testConfig{ + oidc: true, + audience: true, + jwks: false, + defaultLeeway: -1, + groupsClaim: "https://vault/groups/string", + } + b, storage := setupBackend(t, cfg) + + jwtData := getTestOIDC(t) + + data := map[string]interface{}{ + "role": "plugin-test", + "jwt": jwtData, + } + + req := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "login", + Storage: storage, + Data: data, + Connection: &logical.Connection{ + RemoteAddr: "127.0.0.1", + }, + } + + resp, err := b.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("got nil response") + } + if resp.IsError() { + t.Fatalf("got error: %v", resp.Error()) + } + + auth := resp.Auth + switch { + case len(auth.GroupAliases) != 1 || auth.GroupAliases[0].Name != "just_a_string": + t.Fatal(auth.GroupAliases) + } +} + func TestLogin_JWKS_Concurrent(t *testing.T) { - b, storage := setupBackend(t, false, false, true, false, false, true, -1, 0, 0) + cfg := testConfig{ + audience: true, + jwks: true, + defaultLeeway: -1, + } + b, storage := setupBackend(t, cfg) cl := jwt.Claims{ Subject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients",