diff --git a/changelog/22583.txt b/changelog/22583.txt new file mode 100644 index 000000000000..0bc29d60fea8 --- /dev/null +++ b/changelog/22583.txt @@ -0,0 +1,3 @@ +```release-note:bug +core/quotas: Reduce overhead for role calculation when using cloud auth methods. +``` \ No newline at end of file diff --git a/http/util.go b/http/util.go index 61dc8360c846..ef3e2e5199e8 100644 --- a/http/util.go +++ b/http/util.go @@ -5,6 +5,7 @@ package http import ( "bytes" + "context" "errors" "fmt" "io/ioutil" @@ -59,11 +60,16 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler } r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) + role := core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()) + + // add an entry to the context to prevent recalculating request role unnecessarily + r = r.WithContext(context.WithValue(r.Context(), logical.CtxKeyRequestRole{}, role)) + quotaResp, err := core.ApplyRateLimitQuota(r.Context(), "as.Request{ Type: quotas.TypeRateLimit, Path: path, MountPath: mountPath, - Role: core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()), + Role: role, NamespacePath: ns.Path, ClientAddress: parseRemoteIPAddress(r), }) diff --git a/sdk/logical/request.go b/sdk/logical/request.go index 8a6ac241fe80..39d5bbe2625f 100644 --- a/sdk/logical/request.go +++ b/sdk/logical/request.go @@ -447,3 +447,9 @@ type CtxKeyInFlightRequestID struct{} func (c CtxKeyInFlightRequestID) String() string { return "in-flight-request-ID" } + +type CtxKeyRequestRole struct{} + +func (c CtxKeyRequestRole) String() string { + return "request-role" +} diff --git a/vault/login_mfa.go b/vault/login_mfa.go index 751c99c4f30d..a81dd512371c 100644 --- a/vault/login_mfa.go +++ b/vault/login_mfa.go @@ -791,12 +791,17 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu return nil, fmt.Errorf("namespace not found: %w", err) } + var role string + if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil { + role = reqRole.(string) + } + // The request successfully authenticated itself. Run the quota checks on // the original login request path before creating the token. quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: reqPath, MountPath: strings.TrimPrefix(mountPoint, ns.Path), - Role: c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx), + Role: role, NamespacePath: ns.Path, }) @@ -816,7 +821,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu // note that we don't need to handle the error for the following function right away. // The function takes the response as in input variable and modify it. So, the returned // arguments are resp and err. - leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp, loginRequestData) + leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, role, resp) if quotaResp.Access != nil { quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) diff --git a/vault/request_handling.go b/vault/request_handling.go index a11306c75120..86856a036f76 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -489,6 +489,10 @@ func (c *Core) switchedLockHandleRequest(httpCtx context.Context, req *logical.R if ok { ctx = context.WithValue(ctx, logical.CtxKeyInFlightRequestID{}, inFlightReqID) } + requestRole, ok := httpCtx.Value(logical.CtxKeyRequestRole{}).(string) + if ok { + ctx = context.WithValue(ctx, logical.CtxKeyRequestRole{}, requestRole) + } resp, err = c.handleCancelableRequest(ctx, req) req.SetTokenEntry(nil) cancel() @@ -1248,7 +1252,14 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp Path: resp.Auth.CreationPath, NamespaceID: ns.ID, } - if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil { + + // Check for request role + var role string + if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil { + role = reqRole.(string) + } + + if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, role); err != nil { // Best-effort clean up on error, so we log the cleanup error as // a warning but still return as internal error. if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { @@ -1477,12 +1488,18 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re return } + // Check for request role + var role string + if reqRole := ctx.Value(logical.CtxKeyRequestRole{}); reqRole != nil { + role = reqRole.(string) + } + // The request successfully authenticated itself. Run the quota checks // before creating lease. quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: req.Path, MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), - Role: c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx), + Role: role, NamespacePath: ns.Path, }) @@ -1674,7 +1691,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // Attach the display name, might be used by audit backends req.DisplayName = auth.DisplayName - leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp, req.Data) + leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, role, resp) leaseGenerated = leaseGen if errCreateToken != nil { return respTokenCreate, nil, errCreateToken @@ -1726,9 +1743,8 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // LoginCreateToken creates a token as a result of a login request. // If MFA is enforced, mfa/validate endpoint calls this functions // after successful MFA validation to generate the token. -func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response, loginRequestData map[string]interface{}) (bool, *logical.Response, error) { +func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint, role string, resp *logical.Response) (bool, *logical.Response, error) { auth := resp.Auth - source := strings.TrimPrefix(mountPoint, credentialRoutePrefix) source = strings.ReplaceAll(source, "/", "-") @@ -1788,7 +1804,7 @@ func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, re } leaseGenerated := false - err = registerFunc(ctx, tokenTTL, reqPath, auth, c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx)) + err = registerFunc(ctx, tokenTTL, reqPath, auth, role) switch { case err == nil: if auth.TokenType != logical.TokenTypeBatch {