diff --git a/vault/external_tests/policy/no_default_test.go b/vault/external_tests/policy/policy_test.go similarity index 52% rename from vault/external_tests/policy/no_default_test.go rename to vault/external_tests/policy/policy_test.go index fcba0d011028..4d80cd0022da 100644 --- a/vault/external_tests/policy/no_default_test.go +++ b/vault/external_tests/policy/policy_test.go @@ -8,8 +8,10 @@ import ( "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/builtin/credential/ldap" + credUserpass "github.com/hashicorp/vault/builtin/credential/userpass" ldaphelper "github.com/hashicorp/vault/helper/testhelpers/ldap" vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/sdk/helper/strutil" "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" ) @@ -181,3 +183,144 @@ func TestPolicy_NoConfiguredPolicy(t *testing.T) { t.Fatalf("failed to renew lease, got: %v", secret.Auth.LeaseDuration) } } + +func TestPolicy_TokenRenewal(t *testing.T) { + cases := []struct { + name string + tokenPolicies []string + identityPolicies []string + }{ + { + "default only", + nil, + nil, + }, + { + "with token policies", + []string{"token-policy"}, + nil, + }, + { + "with identity policies", + nil, + []string{"identity-policy"}, + }, + { + "with token and identity policies", + []string{"token-policy"}, + []string{"identity-policy"}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": credUserpass.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) + client := cluster.Cores[0].Client + + // Enable userpass auth + err := client.Sys().EnableAuthWithOptions("userpass", &api.EnableAuthOptions{ + Type: "userpass", + }) + if err != nil { + t.Fatal(err) + } + + // Add a user to userpass backend + data := map[string]interface{}{ + "password": "testpassword", + } + if len(tc.tokenPolicies) > 0 { + data["token_policies"] = tc.tokenPolicies + } + _, err = client.Logical().Write("auth/userpass/users/testuser", data) + if err != nil { + t.Fatal(err) + } + + // Set up entity if we're testing against an identity_policies + if len(tc.identityPolicies) > 0 { + auths, err := client.Sys().ListAuth() + if err != nil { + t.Fatal(err) + } + userpassAccessor := auths["userpass/"].Accessor + + resp, err := client.Logical().Write("identity/entity", map[string]interface{}{ + "name": "test-entity", + "policies": tc.identityPolicies, + }) + if err != nil { + t.Fatal(err) + } + entityID := resp.Data["id"].(string) + + // Create an alias + resp, err = client.Logical().Write("identity/entity-alias", map[string]interface{}{ + "name": "testuser", + "mount_accessor": userpassAccessor, + "canonical_id": entityID, + }) + if err != nil { + t.Fatal(err) + } + } + + // Authenticate + secret, err := client.Logical().Write("auth/userpass/login/testuser", map[string]interface{}{ + "password": "testpassword", + }) + if err != nil { + t.Fatal(err) + } + clientToken := secret.Auth.ClientToken + + // Verify the policies exist in the login response + expectedTokenPolicies := append([]string{"default"}, tc.tokenPolicies...) + if !strutil.EquivalentSlices(secret.Auth.TokenPolicies, expectedTokenPolicies) { + t.Fatalf("token policy mismatch:\nexpected: %v\ngot: %v", expectedTokenPolicies, secret.Auth.TokenPolicies) + } + + if !strutil.EquivalentSlices(secret.Auth.IdentityPolicies, tc.identityPolicies) { + t.Fatalf("identity policy mismatch:\nexpected: %v\ngot: %v", tc.identityPolicies, secret.Auth.IdentityPolicies) + } + + expectedPolicies := append(expectedTokenPolicies, tc.identityPolicies...) + if !strutil.EquivalentSlices(secret.Auth.Policies, expectedPolicies) { + t.Fatalf("policy mismatch:\nexpected: %v\ngot: %v", expectedPolicies, secret.Auth.Policies) + } + + // Renew token + secret, err = client.Logical().Write("auth/token/renew", map[string]interface{}{ + "token": clientToken, + }) + if err != nil { + t.Fatal(err) + } + + // Verify the policies exist in the renewal response + if !strutil.EquivalentSlices(secret.Auth.TokenPolicies, expectedTokenPolicies) { + t.Fatalf("policy mismatch:\nexpected: %v\ngot: %v", expectedTokenPolicies, secret.Auth.TokenPolicies) + } + + if !strutil.EquivalentSlices(secret.Auth.IdentityPolicies, tc.identityPolicies) { + t.Fatalf("identity policy mismatch:\nexpected: %v\ngot: %v", tc.identityPolicies, secret.Auth.IdentityPolicies) + } + + if !strutil.EquivalentSlices(secret.Auth.Policies, expectedPolicies) { + t.Fatalf("policy mismatch:\nexpected: %v\ngot: %v", expectedPolicies, secret.Auth.Policies) + } + }) + } +} diff --git a/vault/request_handling.go b/vault/request_handling.go index b6090fa965fe..0dae4e309938 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -836,9 +836,8 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp } // Only the token store is allowed to return an auth block, for any - // other request this is an internal error. We exclude renewal of a token, - // since it does not need to be re-registered - if resp != nil && resp.Auth != nil && !strings.HasPrefix(req.Path, "auth/token/renew") { + // other request this is an internal error. + if resp != nil && resp.Auth != nil { if !strings.HasPrefix(req.Path, "auth/token/") { c.logger.Error("unexpected Auth response for non-token backend", "request_path", req.Path) retErr = multierror.Append(retErr, ErrInternalError) @@ -868,24 +867,34 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp return nil, nil, ErrInternalError } - resp.Auth.TokenPolicies = policyutil.SanitizePolicies(resp.Auth.Policies, policyutil.DoNotAddDefaultPolicy) - switch resp.Auth.TokenType { - case logical.TokenTypeBatch: - case logical.TokenTypeService: - if err := c.expiration.RegisterAuth(ctx, &logical.TokenEntry{ - TTL: auth.TTL, - Policies: auth.TokenPolicies, - Path: resp.Auth.CreationPath, - NamespaceID: ns.ID, - }, resp.Auth); err != nil { - // Best-effort clean up on error, so we log the cleanup error as - // a warning but still return as internal error. - if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { - c.logger.Warn("failed to clean up token lease during auth/token/ request", "request_path", req.Path, "error", err) + // We skip expiration manager registration for token renewal since it + // does not need to be re-registered + if strings.HasPrefix(req.Path, "auth/token/renew") { + // We build the "policies" list to be returned by starting with + // token policies, and add identity policies right after this + // conditional + resp.Auth.Policies = policyutil.SanitizePolicies(resp.Auth.TokenPolicies, policyutil.DoNotAddDefaultPolicy) + } else { + resp.Auth.TokenPolicies = policyutil.SanitizePolicies(resp.Auth.Policies, policyutil.DoNotAddDefaultPolicy) + + switch resp.Auth.TokenType { + case logical.TokenTypeBatch: + case logical.TokenTypeService: + if err := c.expiration.RegisterAuth(ctx, &logical.TokenEntry{ + TTL: auth.TTL, + Policies: auth.TokenPolicies, + Path: resp.Auth.CreationPath, + NamespaceID: ns.ID, + }, resp.Auth); err != nil { + // Best-effort clean up on error, so we log the cleanup error as + // a warning but still return as internal error. + if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { + c.logger.Warn("failed to clean up token lease during auth/token/ request", "request_path", req.Path, "error", err) + } + c.logger.Error("failed to register token lease during auth/token/ request", "request_path", req.Path, "error", err) + retErr = multierror.Append(retErr, ErrInternalError) + return nil, auth, retErr } - c.logger.Error("failed to register token lease during auth/token/ request", "request_path", req.Path, "error", err) - retErr = multierror.Append(retErr, ErrInternalError) - return nil, auth, retErr } }