From 28166faf97a76fb20488556f84ddceb1c915d630 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 13 Jun 2018 17:44:32 -0400 Subject: [PATCH 01/12] Add an idle timeout for the server Because tidy operations can be long-running, this also changes all tidy operations to behave the same operationally (kick off the process, get a warning back, log errors to server log) and makes them all run in a goroutine. This could mean a sort of hard stop if Vault gets sealed because the function won't have the read lock. This should generally be okay (running tidy again should pick back up where it left off), but future work could use cleanup funcs to trigger the functions to stop. --- .../credential/approle/path_tidy_user_id.go | 208 +++++----- .../aws/path_tidy_identity_whitelist.go | 83 ++-- .../aws/path_tidy_roletag_blacklist.go | 84 ++-- builtin/logical/pki/backend.go | 2 + builtin/logical/pki/path_tidy.go | 191 ++++----- command/server.go | 3 +- vault/expiration.go | 26 +- vault/logical_system.go | 17 +- vault/token_store.go | 362 +++++++++--------- 9 files changed, 535 insertions(+), 441 deletions(-) diff --git a/builtin/credential/approle/path_tidy_user_id.go b/builtin/credential/approle/path_tidy_user_id.go index 590cb7284d41..6137e0594fbe 100644 --- a/builtin/credential/approle/path_tidy_user_id.go +++ b/builtin/credential/approle/path_tidy_user_id.go @@ -26,141 +26,153 @@ func pathTidySecretID(b *backend) *framework.Path { } // tidySecretID is used to delete entries in the whitelist that are expired. -func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) error { - grabbed := atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) - } else { - return fmt.Errorf("SecretID tidy operation already running") +func (b *backend) tidySecretID(ctx context.Context, s logical.Storage) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidySecretIDCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - var result error + go func() { + defer atomic.StoreUint32(b.tidySecretIDCASGuard, 0) - tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error { - roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse) - if err != nil { - return err - } + var result error - // List all the accessors and add them all to a map - accessorHashes, err := s.List(ctx, accessorIDPrefixToUse) - if err != nil { - return err - } - accessorMap := make(map[string]bool, len(accessorHashes)) - for _, accessorHash := range accessorHashes { - accessorMap[accessorHash] = true - } + // Don't cancel when the original client request goes away + ctx = context.Background() - secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error { - lock := b.secretIDLock(secretIDHMAC) - lock.Lock() - defer lock.Unlock() + logger := b.Logger().Named("tidy") - entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC) - secretIDEntry, err := s.Get(ctx, entryIndex) + tidyFunc := func(secretIDPrefixToUse, accessorIDPrefixToUse string) error { + roleNameHMACs, err := s.List(ctx, secretIDPrefixToUse) if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err) + return err } - if secretIDEntry == nil { - result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC)) - return nil + // List all the accessors and add them all to a map + accessorHashes, err := s.List(ctx, accessorIDPrefixToUse) + if err != nil { + return err } - - if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 { - return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC) + accessorMap := make(map[string]bool, len(accessorHashes)) + for _, accessorHash := range accessorHashes { + accessorMap[accessorHash] = true } - var result secretIDStorageEntry - if err := secretIDEntry.DecodeJSON(&result); err != nil { - return err - } + secretIDCleanupFunc := func(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse string) error { + lock := b.secretIDLock(secretIDHMAC) + lock.Lock() + defer lock.Unlock() - // If a secret ID entry does not have a corresponding accessor - // entry, revoke the secret ID immediately - accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) - if err != nil { - return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err) - } - if accessorEntry == nil { - if err := s.Delete(ctx, entryIndex); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err) + entryIndex := fmt.Sprintf("%s%s%s", secretIDPrefixToUse, roleNameHMAC, secretIDHMAC) + secretIDEntry, err := s.Get(ctx, entryIndex) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching SecretID %q: {{err}}", secretIDHMAC), err) + } + + if secretIDEntry == nil { + result = multierror.Append(result, fmt.Errorf("entry for SecretID %q is nil", secretIDHMAC)) + return nil + } + + if secretIDEntry.Value == nil || len(secretIDEntry.Value) == 0 { + return fmt.Errorf("found entry for SecretID %q but actual SecretID is empty", secretIDHMAC) + } + + var result secretIDStorageEntry + if err := secretIDEntry.DecodeJSON(&result); err != nil { + return err } - return nil - } - // ExpirationTime not being set indicates non-expiring SecretIDs - if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) { - // Clean up the accessor of the secret ID first - err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) + // If a secret ID entry does not have a corresponding accessor + // entry, revoke the secret ID immediately + accessorEntry, err := b.secretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) if err != nil { - return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err) + return errwrap.Wrapf("failed to read secret ID accessor entry: {{err}}", err) + } + if accessorEntry == nil { + if err := s.Delete(ctx, entryIndex); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting secret ID %q from storage: {{err}}", secretIDHMAC), err) + } + return nil + } + + // ExpirationTime not being set indicates non-expiring SecretIDs + if !result.ExpirationTime.IsZero() && time.Now().After(result.ExpirationTime) { + // Clean up the accessor of the secret ID first + err = b.deleteSecretIDAccessorEntry(ctx, s, result.SecretIDAccessor, secretIDPrefixToUse) + if err != nil { + return errwrap.Wrapf("failed to delete secret ID accessor entry: {{err}}", err) + } + + if err := s.Delete(ctx, entryIndex); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err) + } + + return nil } - if err := s.Delete(ctx, entryIndex); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting SecretID %q from storage: {{err}}", secretIDHMAC), err) + // At this point, the secret ID is not expired and is valid. Delete + // the corresponding accessor from the accessorMap. This will leave + // only the dangling accessors in the map which can then be cleaned + // up later. + salt, err := b.Salt(ctx) + if err != nil { + return err } + delete(accessorMap, salt.SaltID(result.SecretIDAccessor)) return nil } - // At this point, the secret ID is not expired and is valid. Delete - // the corresponding accessor from the accessorMap. This will leave - // only the dangling accessors in the map which can then be cleaned - // up later. - salt, err := b.Salt(ctx) - if err != nil { - return err + for _, roleNameHMAC := range roleNameHMACs { + secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) + if err != nil { + return err + } + for _, secretIDHMAC := range secretIDHMACs { + err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse) + if err != nil { + return err + } + } } - delete(accessorMap, salt.SaltID(result.SecretIDAccessor)) - return nil - } - - for _, roleNameHMAC := range roleNameHMACs { - secretIDHMACs, err := s.List(ctx, fmt.Sprintf("%s%s", secretIDPrefixToUse, roleNameHMAC)) - if err != nil { - return err - } - for _, secretIDHMAC := range secretIDHMACs { - err = secretIDCleanupFunc(secretIDHMAC, roleNameHMAC, secretIDPrefixToUse) + // Accessor indexes were not getting cleaned up until 0.9.3. This is a fix + // to clean up the dangling accessor entries. + for accessorHash, _ := range accessorMap { + // Ideally, locking should be performed here. But for that, accessors + // are required in plaintext, which are not available. Hence performing + // a racy cleanup. + err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash) if err != nil { return err } } - } - // Accessor indexes were not getting cleaned up until 0.9.3. This is a fix - // to clean up the dangling accessor entries. - for accessorHash, _ := range accessorMap { - // Ideally, locking should be performed here. But for that, accessors - // are required in plaintext, which are not available. Hence performing - // a racy cleanup. - err = s.Delete(ctx, secretIDAccessorPrefix+accessorHash) - if err != nil { - return err - } + return nil } - return nil - } - - err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix) - if err != nil { - return err - } - err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix) - if err != nil { - return err - } + err := tidyFunc(secretIDPrefix, secretIDAccessorPrefix) + if err != nil { + logger.Error("error tidying global secret IDs", "error", err) + return + } + err = tidyFunc(secretIDLocalPrefix, secretIDAccessorLocalPrefix) + if err != nil { + logger.Error("error tidying local secret IDs", "error", err) + return + } + }() - return result + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidySecretIDUpdate is used to delete the expired SecretID entries func (b *backend) pathTidySecretIDUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidySecretID(ctx, req.Storage) + return b.tidySecretID(ctx, req.Storage) } const pathTidySecretIDSyn = "Trigger the clean-up of expired SecretID entries." diff --git a/builtin/credential/aws/path_tidy_identity_whitelist.go b/builtin/credential/aws/path_tidy_identity_whitelist.go index f1abe2308614..2ff035b8d9af 100644 --- a/builtin/credential/aws/path_tidy_identity_whitelist.go +++ b/builtin/credential/aws/path_tidy_identity_whitelist.go @@ -33,53 +33,72 @@ expiration, before it is removed from the backend storage.`, } // tidyWhitelistIdentity is used to delete entries in the whitelist that are expired. -func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0) - } else { - return fmt.Errorf("identity whitelist tidy operation already running") +func (b *backend) tidyWhitelistIdentity(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidyWhitelistCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - bufferDuration := time.Duration(safety_buffer) * time.Second + go func() { + defer atomic.StoreUint32(b.tidyWhitelistCASGuard, 0) - identities, err := s.List(ctx, "whitelist/identity/") - if err != nil { - return err - } + // Don't cancel when the original client request goes away + ctx = context.Background() - for _, instanceID := range identities { - identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID) - if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err) - } + logger := b.Logger().Named("wltidy") - if identityEntry == nil { - return fmt.Errorf("identity entry for instanceID %q is nil", instanceID) - } + bufferDuration := time.Duration(safety_buffer) * time.Second - if identityEntry.Value == nil || len(identityEntry.Value) == 0 { - return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID) - } + doTidy := func() error { + identities, err := s.List(ctx, "whitelist/identity/") + if err != nil { + return err + } + + for _, instanceID := range identities { + identityEntry, err := s.Get(ctx, "whitelist/identity/"+instanceID) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching identity of instanceID %q: {{err}}", instanceID), err) + } + + if identityEntry == nil { + return fmt.Errorf("identity entry for instanceID %q is nil", instanceID) + } + + if identityEntry.Value == nil || len(identityEntry.Value) == 0 { + return fmt.Errorf("found identity entry for instanceID %q but actual identity is empty", instanceID) + } + + var result whitelistIdentity + if err := identityEntry.DecodeJSON(&result); err != nil { + return err + } + + if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { + if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err) + } + } + } - var result whitelistIdentity - if err := identityEntry.DecodeJSON(&result); err != nil { - return err + return nil } - if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := s.Delete(ctx, "whitelist/identity"+instanceID); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting identity of instanceID %q from storage: {{err}}", instanceID), err) - } + if err := doTidy(); err != nil { + logger.Error("error running whitelist tidy", "error", err) + return } - } + }() - return nil + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidyIdentityWhitelistUpdate is used to delete entries in the whitelist that are expired. func (b *backend) pathTidyIdentityWhitelistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int)) + return b.tidyWhitelistIdentity(ctx, req.Storage, data.Get("safety_buffer").(int)) } const pathTidyIdentityWhitelistSyn = ` diff --git a/builtin/credential/aws/path_tidy_roletag_blacklist.go b/builtin/credential/aws/path_tidy_roletag_blacklist.go index a29837110d2b..4eaafc22df4b 100644 --- a/builtin/credential/aws/path_tidy_roletag_blacklist.go +++ b/builtin/credential/aws/path_tidy_roletag_blacklist.go @@ -33,52 +33,72 @@ expiration, before it is removed from the backend storage.`, } // tidyBlacklistRoleTag is used to clean-up the entries in the role tag blacklist. -func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) error { - grabbed := atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) - if grabbed { - defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0) - } else { - return fmt.Errorf("roletag blacklist tidy operation already running") +func (b *backend) tidyBlacklistRoleTag(ctx context.Context, s logical.Storage, safety_buffer int) (*logical.Response, error) { + if !atomic.CompareAndSwapUint32(b.tidyBlacklistCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - bufferDuration := time.Duration(safety_buffer) * time.Second - tags, err := s.List(ctx, "blacklist/roletag/") - if err != nil { - return err - } + go func() { + defer atomic.StoreUint32(b.tidyBlacklistCASGuard, 0) - for _, tag := range tags { - tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag) - if err != nil { - return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err) - } + // Don't cancel when the original client request goes away + ctx = context.Background() - if tagEntry == nil { - return fmt.Errorf("tag entry for tag %q is nil", tag) - } + logger := b.Logger().Named("bltidy") - if tagEntry.Value == nil || len(tagEntry.Value) == 0 { - return fmt.Errorf("found entry for tag %q but actual tag is empty", tag) - } + bufferDuration := time.Duration(safety_buffer) * time.Second - var result roleTagBlacklistEntry - if err := tagEntry.DecodeJSON(&result); err != nil { - return err - } + doTidy := func() error { + tags, err := s.List(ctx, "blacklist/roletag/") + if err != nil { + return err + } - if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { - if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil { - return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err) + for _, tag := range tags { + tagEntry, err := s.Get(ctx, "blacklist/roletag/"+tag) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching tag %q: {{err}}", tag), err) + } + + if tagEntry == nil { + return fmt.Errorf("tag entry for tag %q is nil", tag) + } + + if tagEntry.Value == nil || len(tagEntry.Value) == 0 { + return fmt.Errorf("found entry for tag %q but actual tag is empty", tag) + } + + var result roleTagBlacklistEntry + if err := tagEntry.DecodeJSON(&result); err != nil { + return err + } + + if time.Now().After(result.ExpirationTime.Add(bufferDuration)) { + if err := s.Delete(ctx, "blacklist/roletag"+tag); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting tag %q from storage: {{err}}", tag), err) + } + } } + + return nil } - } - return nil + if err := doTidy(); err != nil { + logger.Error("error running blacklist tidy", "error", err) + return + } + }() + + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // pathTidyRoletagBlacklistUpdate is used to clean-up the entries in the role tag blacklist. func (b *backend) pathTidyRoletagBlacklistUpdate(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - return nil, b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int)) + return b.tidyBlacklistRoleTag(ctx, req.Storage, data.Get("safety_buffer").(int)) } const pathTidyRoletagBlacklistSyn = ` diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index de8a3517c9f3..9cf894d78389 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -85,6 +85,7 @@ func Backend() *backend { } b.crlLifetime = time.Hour * 72 + b.tidyCASGuard = new(uint32) return &b } @@ -94,6 +95,7 @@ type backend struct { crlLifetime time.Duration revokeStorageLock sync.RWMutex + tidyCASGuard *uint32 } const backendHelp = ` diff --git a/builtin/logical/pki/path_tidy.go b/builtin/logical/pki/path_tidy.go index 9b0a86df9cf3..ccb76af39a97 100644 --- a/builtin/logical/pki/path_tidy.go +++ b/builtin/logical/pki/path_tidy.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "fmt" + "sync/atomic" "time" "github.com/hashicorp/errwrap" @@ -59,116 +60,128 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr bufferDuration := time.Duration(safetyBuffer) * time.Second - var resp *logical.Response + if !atomic.CompareAndSwapUint32(b.tidyCASGuard, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil + } - if tidyCertStore { - serials, err := req.Storage.List(ctx, "certs/") - if err != nil { - return nil, errwrap.Wrapf("error fetching list of certs: {{err}}", err) - } + go func() { + defer atomic.StoreUint32(b.tidyCASGuard, 0) - for _, serial := range serials { - certEntry, err := req.Storage.Get(ctx, "certs/"+serial) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err) - } + // Don't cancel when the original client request goes away + ctx = context.Background() - if certEntry == nil { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err) - } - } + logger := b.Logger().Named("tidy") - if certEntry.Value == nil || len(certEntry.Value) == 0 { - if resp == nil { - resp = &logical.Response{} - } - resp.AddWarning(fmt.Sprintf("Certificate entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err) + doTidy := func() error { + if tidyCertStore { + serials, err := req.Storage.List(ctx, "certs/") + if err != nil { + return errwrap.Wrapf("error fetching list of certs: {{err}}", err) } - } - - cert, err := x509.ParseCertificate(certEntry.Value) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err) - } - if time.Now().After(cert.NotAfter.Add(bufferDuration)) { - if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err) + for _, serial := range serials { + certEntry, err := req.Storage.Get(ctx, "certs/"+serial) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error fetching certificate %q: {{err}}", serial), err) + } + + if certEntry == nil { + logger.Warn("certificate entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting nil entry with serial %s: {{err}}", serial), err) + } + } + + if certEntry.Value == nil || len(certEntry.Value) == 0 { + logger.Warn("certificate entry has no value; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting entry with nil value with serial %s: {{err}}", serial), err) + } + } + + cert, err := x509.ParseCertificate(certEntry.Value) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to parse stored certificate with serial %q: {{err}}", serial), err) + } + + if time.Now().After(cert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete(ctx, "certs/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from storage: {{err}}", serial), err) + } + } } } - } - } - - if tidyRevocationList { - b.revokeStorageLock.Lock() - defer b.revokeStorageLock.Unlock() - tidiedRevoked := false + if tidyRevocationList { + b.revokeStorageLock.Lock() + defer b.revokeStorageLock.Unlock() - revokedSerials, err := req.Storage.List(ctx, "revoked/") - if err != nil { - return nil, errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err) - } + tidiedRevoked := false - var revInfo revocationInfo - for _, serial := range revokedSerials { - revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err) - } - - if revokedEntry == nil { - if resp == nil { - resp = &logical.Response{} + revokedSerials, err := req.Storage.List(ctx, "revoked/") + if err != nil { + return errwrap.Wrapf("error fetching list of revoked certs: {{err}}", err) } - resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s is nil; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err) - } - } - if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { - if resp == nil { - resp = &logical.Response{} + var revInfo revocationInfo + for _, serial := range revokedSerials { + revokedEntry, err := req.Storage.Get(ctx, "revoked/"+serial) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to fetch revoked cert with serial %q: {{err}}", serial), err) + } + + if revokedEntry == nil { + logger.Warn("revoked entry is nil; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting nil revoked entry with serial %s: {{err}}", serial), err) + } + } + + if revokedEntry.Value == nil || len(revokedEntry.Value) == 0 { + logger.Warn("revoked entry has nil value; tidying up since it is no longer useful for any server operations", "serial", serial) + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err) + } + } + + err = revokedEntry.DecodeJSON(&revInfo) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err) + } + + revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) + if err != nil { + return errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err) + } + + if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) { + if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { + return errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err) + } + tidiedRevoked = true + } } - resp.AddWarning(fmt.Sprintf("Revoked entry for serial %s has nil value; tidying up since it is no longer useful for any server operations", serial)) - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting revoked entry with nil value with serial %s: {{err}}", serial), err) - } - } - err = revokedEntry.DecodeJSON(&revInfo) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error decoding revocation entry for serial %q: {{err}}", serial), err) - } - - revokedCert, err := x509.ParseCertificate(revInfo.CertificateBytes) - if err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("unable to parse stored revoked certificate with serial %q: {{err}}", serial), err) - } - - if time.Now().After(revokedCert.NotAfter.Add(bufferDuration)) { - if err := req.Storage.Delete(ctx, "revoked/"+serial); err != nil { - return nil, errwrap.Wrapf(fmt.Sprintf("error deleting serial %q from revoked list: {{err}}", serial), err) + if tidiedRevoked { + if err := buildCRL(ctx, b, req); err != nil { + return err + } } - tidiedRevoked = true } + + return nil } - if tidiedRevoked { - if err := buildCRL(ctx, b, req); err != nil { - return nil, err - } + if err := doTidy(); err != nil { + logger.Error("error running tidy", "error", err) + return } - } + }() + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") return resp, nil } diff --git a/command/server.go b/command/server.go index f3e2fbd22aba..08743edca3f4 100644 --- a/command/server.go +++ b/command/server.go @@ -935,7 +935,8 @@ CLUSTER_SYNTHESIS_COMPLETE: } server := &http.Server{ - Handler: handler, + Handler: handler, + IdleTimeout: 10 * time.Minute, } go server.Serve(ln.Listener) } diff --git a/vault/expiration.go b/vault/expiration.go index 5e42c3fb3f2f..7924a401704a 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -191,15 +191,17 @@ func (m *ExpirationManager) Tidy() error { var tidyErrors *multierror.Error + logger := m.logger.Named("tidy") + if !atomic.CompareAndSwapInt32(m.tidyLock, 0, 1) { - m.logger.Warn("tidy operation on leases is already in progress") - return fmt.Errorf("tidy operation on leases is already in progress") + logger.Warn("tidy operation on leases is already in progress") + return nil } defer atomic.CompareAndSwapInt32(m.tidyLock, 1, 0) - m.logger.Info("beginning tidy operation on leases") - defer m.logger.Info("finished tidy operation on leases") + logger.Info("beginning tidy operation on leases") + defer logger.Info("finished tidy operation on leases") // Create a cache to keep track of looked up tokens tokenCache := make(map[string]bool) @@ -208,7 +210,7 @@ func (m *ExpirationManager) Tidy() error { tidyFunc := func(leaseID string) { countLease++ if countLease%500 == 0 { - m.logger.Info("tidying leases", "progress", countLease) + logger.Info("tidying leases", "progress", countLease) } le, err := m.loadEntry(leaseID) @@ -225,7 +227,7 @@ func (m *ExpirationManager) Tidy() error { var isValid, ok bool revokeLease := false if le.ClientToken == "" { - m.logger.Debug("revoking lease which has an empty token", "lease_id", leaseID) + logger.Debug("revoking lease which has an empty token", "lease_id", leaseID) revokeLease = true deletedCountEmptyToken++ goto REVOKE_CHECK @@ -249,7 +251,7 @@ func (m *ExpirationManager) Tidy() error { } if te == nil { - m.logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID) + logger.Debug("revoking lease which holds an invalid token", "lease_id", leaseID) revokeLease = true deletedCountInvalidToken++ tokenCache[le.ClientToken] = false @@ -262,7 +264,7 @@ func (m *ExpirationManager) Tidy() error { return } - m.logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID) + logger.Debug("revoking lease which contains an invalid token", "lease_id", leaseID) revokeLease = true deletedCountInvalidToken++ goto REVOKE_CHECK @@ -285,10 +287,10 @@ func (m *ExpirationManager) Tidy() error { return err } - m.logger.Info("number of leases scanned", "count", countLease) - m.logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken) - m.logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken) - m.logger.Info("number of leases successfully revoked", "count", revokedCount) + logger.Info("number of leases scanned", "count", countLease) + logger.Info("number of leases which had empty tokens", "count", deletedCountEmptyToken) + logger.Info("number of leases which had invalid tokens", "count", deletedCountInvalidToken) + logger.Info("number of leases successfully revoked", "count", revokedCount) return tidyErrors.ErrorOrNil() } diff --git a/vault/logical_system.go b/vault/logical_system.go index 821d534ef4c5..bd909c1fd07d 100644 --- a/vault/logical_system.go +++ b/vault/logical_system.go @@ -1182,12 +1182,17 @@ func (b *SystemBackend) handleCORSDelete(ctx context.Context, req *logical.Reque } func (b *SystemBackend) handleTidyLeases(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { - err := b.Core.expiration.Tidy() - if err != nil { - b.Backend.Logger().Error("failed to tidy leases", "error", err) - return handleErrorNoReadOnlyForward(err) - } - return nil, err + go func() { + err := b.Core.expiration.Tidy() + if err != nil { + b.Backend.Logger().Error("failed to tidy leases", "error", err) + return + } + }() + + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } func (b *SystemBackend) invalidate(ctx context.Context, key string) { diff --git a/vault/token_store.go b/vault/token_store.go index d986a79f528e..7bafaa703fd0 100644 --- a/vault/token_store.go +++ b/vault/token_store.go @@ -130,7 +130,7 @@ type TokenStore struct { saltLock sync.RWMutex salt *salt.Salt - tidyLock *int32 + tidyLock *uint32 identityPoliciesDeriverFunc func(string) (*identity.Entity, []string, error) } @@ -150,7 +150,7 @@ func NewTokenStore(ctx context.Context, logger log.Logger, c *Core, config *logi tokensPendingDeletion: &sync.Map{}, saltLock: sync.RWMutex{}, identityPoliciesDeriverFunc: c.fetchEntityAndDerivedPolicies, - tidyLock: new(int32), + tidyLock: new(uint32), } if c.policyStore != nil { @@ -1290,204 +1290,224 @@ func (ts *TokenStore) lookupBySaltedAccessor(ctx context.Context, saltedAccessor // handleTidy handles the cleaning up of leaked accessor storage entries and // cleaning up of leases that are associated to tokens that are expired. func (ts *TokenStore) handleTidy(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - var tidyErrors *multierror.Error - - if !atomic.CompareAndSwapInt32(ts.tidyLock, 0, 1) { - ts.logger.Warn("tidy operation on tokens is already in progress") - return nil, fmt.Errorf("tidy operation on tokens is already in progress") + if !atomic.CompareAndSwapUint32(ts.tidyLock, 0, 1) { + resp := &logical.Response{} + resp.AddWarning("Tidy operation already in progress.") + return resp, nil } - defer atomic.CompareAndSwapInt32(ts.tidyLock, 1, 0) + go func() { + defer atomic.StoreUint32(ts.tidyLock, 0) - ts.logger.Info("beginning tidy operation on tokens") - defer ts.logger.Info("finished tidy operation on tokens") + // Don't cancel when the original client request goes away + ctx = context.Background() - // List out all the accessors - saltedAccessorList, err := ts.view.List(ctx, accessorPrefix) - if err != nil { - return nil, errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err) - } + logger := ts.logger.Named("tidy") - // First, clean up secondary index entries that are no longer valid - parentList, err := ts.view.List(ctx, parentPrefix) - if err != nil { - return nil, errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err) - } + var tidyErrors *multierror.Error - var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64 + doTidy := func() error { - // Scan through the secondary index entries; if there is an entry - // with the token's salt ID at the end, remove it - for _, parent := range parentList { - countParentEntries++ + ts.logger.Info("beginning tidy operation on tokens") + defer ts.logger.Info("finished tidy operation on tokens") - // Get the children - children, err := ts.view.List(ctx, parentPrefix+parent) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err)) - continue - } - - // First check if the salt ID of the parent exists, and if not mark this so - // that deletion of children later with this loop below applies to all - // children - originalChildrenCount := int64(len(children)) - exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true) - if exists == nil { - ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent) - } - - var deletedChildrenCount int64 - for _, child := range children { - countParentList++ - if countParentList%500 == 0 { - ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList) + // List out all the accessors + saltedAccessorList, err := ts.view.List(ctx, accessorPrefix) + if err != nil { + return errwrap.Wrapf("failed to fetch accessor index entries: {{err}}", err) } - // Look up tainted entries so we can be sure that if this isn't - // found, it doesn't exist. Doing the following without locking - // since appropriate locks cannot be held with salted token IDs. - // Also perform deletion if the parent doesn't exist any more. - te, _ := ts.lookupSalted(ctx, child, true) - // If the child entry is not nil, but the parent doesn't exist, then turn - // that child token into an orphan token. Theres no deletion in this case. - if te != nil && exists == nil { - lock := locksutil.LockForKey(ts.tokenLocks, te.ID) - lock.Lock() - - te.Parent = "" - err = ts.store(ctx, te) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err)) - } - lock.Unlock() - continue + // First, clean up secondary index entries that are no longer valid + parentList, err := ts.view.List(ctx, parentPrefix) + if err != nil { + return errwrap.Wrapf("failed to fetch secondary index entries: {{err}}", err) } - // Otherwise, if the entry doesn't exist, or if the parent doesn't exist go - // on with the delete on the secondary index - if te == nil || exists == nil { - index := parentPrefix + parent + child - ts.logger.Debug("deleting invalid secondary index", "index", index) - err = ts.view.Delete(ctx, index) + + var countParentEntries, deletedCountParentEntries, countParentList, deletedCountParentList int64 + + // Scan through the secondary index entries; if there is an entry + // with the token's salt ID at the end, remove it + for _, parent := range parentList { + countParentEntries++ + + // Get the children + children, err := ts.view.List(ctx, parentPrefix+parent) if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err)) + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read secondary index: {{err}}", err)) continue } - deletedChildrenCount++ - } - } - // Add current children deleted count to the total count - deletedCountParentList += deletedChildrenCount - // N.B.: We don't call delete on the parent prefix since physical.Backend.Delete - // implementations should be in charge of deleting empty prefixes. - // If we deleted all the children, then add that to our deleted parent entries count. - if originalChildrenCount == deletedChildrenCount { - deletedCountParentEntries++ - } - } - - var countAccessorList, - deletedCountAccessorEmptyToken, - deletedCountAccessorInvalidToken, - deletedCountInvalidTokenInAccessor int64 - - // For each of the accessor, see if the token ID associated with it is - // a valid one. If not, delete the leases associated with that token - // and delete the accessor as well. - for _, saltedAccessor := range saltedAccessorList { - countAccessorList++ - if countAccessorList%500 == 0 { - ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList) - } - accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err)) - continue - } + // First check if the salt ID of the parent exists, and if not mark this so + // that deletion of children later with this loop below applies to all + // children + originalChildrenCount := int64(len(children)) + exists, _ := ts.lookupSalted(ctx, strings.TrimSuffix(parent, "/"), true) + if exists == nil { + ts.logger.Debug("deleting invalid parent prefix entry", "index", parentPrefix+parent) + } - // A valid accessor storage entry should always have a token ID - // in it. If not, it is an invalid accessor entry and needs to - // be deleted. - if accessorEntry.TokenID == "" { - index := accessorPrefix + saltedAccessor - // If deletion of accessor fails, move on to the next - // item since this is just a best-effort operation - err = ts.view.Delete(ctx, index) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err)) - continue + var deletedChildrenCount int64 + for _, child := range children { + countParentList++ + if countParentList%500 == 0 { + ts.logger.Info("checking validity of tokens in secondary index list", "progress", countParentList) + } + + // Look up tainted entries so we can be sure that if this isn't + // found, it doesn't exist. Doing the following without locking + // since appropriate locks cannot be held with salted token IDs. + // Also perform deletion if the parent doesn't exist any more. + te, _ := ts.lookupSalted(ctx, child, true) + // If the child entry is not nil, but the parent doesn't exist, then turn + // that child token into an orphan token. Theres no deletion in this case. + if te != nil && exists == nil { + lock := locksutil.LockForKey(ts.tokenLocks, te.ID) + lock.Lock() + + te.Parent = "" + err = ts.store(ctx, te) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to convert child token into an orphan token: {{err}}", err)) + } + lock.Unlock() + continue + } + // Otherwise, if the entry doesn't exist, or if the parent doesn't exist go + // on with the delete on the secondary index + if te == nil || exists == nil { + index := parentPrefix + parent + child + ts.logger.Debug("deleting invalid secondary index", "index", index) + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete secondary index: {{err}}", err)) + continue + } + deletedChildrenCount++ + } + } + // Add current children deleted count to the total count + deletedCountParentList += deletedChildrenCount + // N.B.: We don't call delete on the parent prefix since physical.Backend.Delete + // implementations should be in charge of deleting empty prefixes. + // If we deleted all the children, then add that to our deleted parent entries count. + if originalChildrenCount == deletedChildrenCount { + deletedCountParentEntries++ + } } - deletedCountAccessorEmptyToken++ - } - lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID) - lock.RLock() + var countAccessorList, + deletedCountAccessorEmptyToken, + deletedCountAccessorInvalidToken, + deletedCountInvalidTokenInAccessor int64 + + // For each of the accessor, see if the token ID associated with it is + // a valid one. If not, delete the leases associated with that token + // and delete the accessor as well. + for _, saltedAccessor := range saltedAccessorList { + countAccessorList++ + if countAccessorList%500 == 0 { + ts.logger.Info("checking if accessors contain valid tokens", "progress", countAccessorList) + } - // Look up tainted variants so we only find entries that truly don't - // exist - saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err)) - lock.RUnlock() - continue - } - te, err := ts.lookupSalted(ctx, saltedID, true) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err)) - lock.RUnlock() - continue - } + accessorEntry, err := ts.lookupBySaltedAccessor(ctx, saltedAccessor, true) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read the accessor index: {{err}}", err)) + continue + } - lock.RUnlock() + // A valid accessor storage entry should always have a token ID + // in it. If not, it is an invalid accessor entry and needs to + // be deleted. + if accessorEntry.TokenID == "" { + index := accessorPrefix + saltedAccessor + // If deletion of accessor fails, move on to the next + // item since this is just a best-effort operation + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete the accessor index: {{err}}", err)) + continue + } + deletedCountAccessorEmptyToken++ + } - // If token entry is not found assume that the token is not valid any - // more and conclude that accessor, leases, and secondary index entries - // for this token should not exist as well. - if te == nil { - ts.logger.Info("deleting token with nil entry", "salted_token", saltedID) + lock := locksutil.LockForKey(ts.tokenLocks, accessorEntry.TokenID) + lock.RLock() - // RevokeByToken expects a '*logical.TokenEntry'. For the - // purposes of tidying, it is sufficient if the token - // entry only has ID set. - tokenEntry := &logical.TokenEntry{ - ID: accessorEntry.TokenID, - } + // Look up tainted variants so we only find entries that truly don't + // exist + saltedID, err := ts.SaltID(ctx, accessorEntry.TokenID) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to read salt id: {{err}}", err)) + lock.RUnlock() + continue + } + te, err := ts.lookupSalted(ctx, saltedID, true) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to lookup tainted ID: {{err}}", err)) + lock.RUnlock() + continue + } - // Attempt to revoke the token. This will also revoke - // the leases associated with the token. - err := ts.expiration.RevokeByToken(tokenEntry) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err)) - continue + lock.RUnlock() + + // If token entry is not found assume that the token is not valid any + // more and conclude that accessor, leases, and secondary index entries + // for this token should not exist as well. + if te == nil { + ts.logger.Info("deleting token with nil entry", "salted_token", saltedID) + + // RevokeByToken expects a '*logical.TokenEntry'. For the + // purposes of tidying, it is sufficient if the token + // entry only has ID set. + tokenEntry := &logical.TokenEntry{ + ID: accessorEntry.TokenID, + } + + // Attempt to revoke the token. This will also revoke + // the leases associated with the token. + err := ts.expiration.RevokeByToken(tokenEntry) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to revoke leases of expired token: {{err}}", err)) + continue + } + deletedCountInvalidTokenInAccessor++ + + index := accessorPrefix + saltedAccessor + + // If deletion of accessor fails, move on to the next item since + // this is just a best-effort operation. We do this last so that on + // next run if something above failed we still have the accessor + // entry to try again. + err = ts.view.Delete(ctx, index) + if err != nil { + tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err)) + continue + } + deletedCountAccessorInvalidToken++ + } } - deletedCountInvalidTokenInAccessor++ - index := accessorPrefix + saltedAccessor + ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries) + ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries) + ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList) + ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList) + ts.logger.Info("number of accessors scanned", "count", countAccessorList) + ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken) + ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor) + ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken) - // If deletion of accessor fails, move on to the next item since - // this is just a best-effort operation. We do this last so that on - // next run if something above failed we still have the accessor - // entry to try again. - err = ts.view.Delete(ctx, index) - if err != nil { - tidyErrors = multierror.Append(tidyErrors, errwrap.Wrapf("failed to delete accessor entry: {{err}}", err)) - continue - } - deletedCountAccessorInvalidToken++ + return tidyErrors.ErrorOrNil() } - } - ts.logger.Info("number of entries scanned in parent prefix", "count", countParentEntries) - ts.logger.Info("number of entries deleted in parent prefix", "count", deletedCountParentEntries) - ts.logger.Info("number of tokens scanned in parent index list", "count", countParentList) - ts.logger.Info("number of tokens revoked in parent index list", "count", deletedCountParentList) - ts.logger.Info("number of accessors scanned", "count", countAccessorList) - ts.logger.Info("number of deleted accessors which had empty tokens", "count", deletedCountAccessorEmptyToken) - ts.logger.Info("number of revoked tokens which were invalid but present in accessors", "count", deletedCountInvalidTokenInAccessor) - ts.logger.Info("number of deleted accessors which had invalid tokens", "count", deletedCountAccessorInvalidToken) + if err := doTidy(); err != nil { + logger.Error("error running tidy", "error", err) + return + } + }() - return nil, tidyErrors.ErrorOrNil() + resp := &logical.Response{} + resp.AddWarning("Tidy operation successfully started. Any information from the operation will be printed to Vault's server logs.") + return resp, nil } // handleUpdateLookupAccessor handles the auth/token/lookup-accessor path for returning From 99c124487c596f532c390fe794c07dd2e67fa89c Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 13 Jun 2018 21:24:38 -0400 Subject: [PATCH 02/12] Fix up tidy test --- vault/expiration_test.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/vault/expiration_test.go b/vault/expiration_test.go index e0b93ec43bc8..b4997e6a08e9 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -1,6 +1,7 @@ package vault import ( + "bytes" "context" "fmt" "reflect" @@ -38,6 +39,14 @@ func TestExpiration_Tidy(t *testing.T) { var err error exp := mockExpiration(t) + + // We use this later for tidy testing where we need to check the output + logOut := new(bytes.Buffer) + logger := log.New(&log.LoggerOptions{ + Output: logOut, + }) + exp.logger = logger + if err := exp.Restore(nil); err != nil { t.Fatal(err) } @@ -212,9 +221,11 @@ func TestExpiration_Tidy(t *testing.T) { } } - if !(err1 != nil && err1.Error() == "tidy operation on leases is already in progress") && - !(err2 != nil && err2.Error() == "tidy operation on leases is already in progress") { - t.Fatalf("expected at least one of err1 or err2 to be set; err1: %#v\n err2:%#v\n", err1, err2) + if err1 != nil || err2 != nil { + t.Fatalf("got an error: err1: %v; err2: %v", err1, err2) + } + if !strings.Contains(logOut.String(), "tidy operation on leases is already in progress") { + t.Fatalf("expected to see a warning saying operation in progress, output is %s", logOut.String()) } root, err := exp.tokenStore.rootToken(context.Background()) From fc1a813441c68de7662ceac9ba150b00d7712a6d Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 13:55:24 -0400 Subject: [PATCH 03/12] Add deadline to cluster connections and an idle timeout to the cluster server, plus add readheader/read timeout to api server --- command/server.go | 6 ++++-- vault/request_forwarding.go | 30 +++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/command/server.go b/command/server.go index 08743edca3f4..98001f3e9cb6 100644 --- a/command/server.go +++ b/command/server.go @@ -935,8 +935,10 @@ CLUSTER_SYNTHESIS_COMPLETE: } server := &http.Server{ - Handler: handler, - IdleTimeout: 10 * time.Minute, + Handler: handler, + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + IdleTimeout: 5 * time.Minute, } go server.Serve(ln.Listener) } diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index c4c0e73cd977..ce61b320f21b 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -76,7 +76,11 @@ func (c *Core) startForwarding(ctx context.Context) error { // duties. Doing it this way instead of listening via the server and gRPC // allows us to re-use the same port via ALPN. We can just tell the server // to serve a given conn and which handler to use. - fws := &http2.Server{} + fws := &http2.Server{ + // Our forwarding connections heartbeat regularly so anything else we + // want to go away/get cleaned up pretty rapidly + IdleTimeout: 5 * HeartbeatInterval, + } // Shutdown coordination logic shutdown := new(uint32) @@ -147,6 +151,20 @@ func (c *Core) startForwarding(ctx context.Context) error { // Type assert to TLS connection and handshake to populate the // connection state tlsConn := conn.(*tls.Conn) + + // Set a deadline for the handshake. This will cause clients + // that don't successfully auth to be kicked out quickly. + // Cluster connections should be reliable so being marginally + // aggressive here is fine. + err = tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + if err != nil { + if c.logger.IsDebug() { + c.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + err = tlsConn.Handshake() if err != nil { if c.logger.IsDebug() { @@ -156,6 +174,16 @@ func (c *Core) startForwarding(ctx context.Context) error { continue } + // Now, set it back to unlimited + err = tlsConn.SetDeadline(time.Time{}) + if err != nil { + if c.logger.IsDebug() { + c.logger.Debug("error setting deadline for cluster connection", "error", err) + } + tlsConn.Close() + continue + } + switch tlsConn.ConnectionState().NegotiatedProtocol { case requestForwardingALPN: if !ha { From 2e2e1c60d450077821e7cb22120dd9042d543dab Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 14:29:33 -0400 Subject: [PATCH 04/12] Add proxy header timeout when wrapping in proxy proto --- helper/proxyutil/proxyutil.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/helper/proxyutil/proxyutil.go b/helper/proxyutil/proxyutil.go index 875e74831c94..f18ec1d14905 100644 --- a/helper/proxyutil/proxyutil.go +++ b/helper/proxyutil/proxyutil.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "sync" + "time" proxyproto "github.com/armon/go-proxyproto" "github.com/hashicorp/errwrap" @@ -41,12 +42,14 @@ func WrapInProxyProto(listener net.Listener, config *ProxyProtoConfig) (net.List switch config.Behavior { case "use_always": newLn = &proxyproto.Listener{ - Listener: listener, + Listener: listener, + ProxyHeaderTimeout: 10 * time.Second, } case "allow_authorized", "deny_unauthorized": newLn = &proxyproto.Listener{ - Listener: listener, + Listener: listener, + ProxyHeaderTimeout: 10 * time.Second, SourceCheck: func(addr net.Addr) (bool, error) { config.RLock() defer config.RUnlock() From d7a5df3583bd777165f498bca6489838ed31a3f8 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 16:00:19 -0400 Subject: [PATCH 05/12] Fix approle build --- builtin/credential/approle/path_tidy_user_id_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/builtin/credential/approle/path_tidy_user_id_test.go b/builtin/credential/approle/path_tidy_user_id_test.go index b52b711356f4..314a69175abe 100644 --- a/builtin/credential/approle/path_tidy_user_id_test.go +++ b/builtin/credential/approle/path_tidy_user_id_test.go @@ -64,7 +64,7 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) { t.Fatalf("bad: len(accessorHashes); expect 3, got %d", len(accessorHashes)) } - err = b.tidySecretID(context.Background(), storage) + _, err = b.tidySecretID(context.Background(), storage) if err != nil { t.Fatal(err) } From c44018d9b3f1bfc68b12d187291b55a89657f7a4 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 18:36:55 -0400 Subject: [PATCH 06/12] Update request_forwarding.go --- vault/request_forwarding.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vault/request_forwarding.go b/vault/request_forwarding.go index ce61b320f21b..3cbd22f33a63 100644 --- a/vault/request_forwarding.go +++ b/vault/request_forwarding.go @@ -156,7 +156,7 @@ func (c *Core) startForwarding(ctx context.Context) error { // that don't successfully auth to be kicked out quickly. // Cluster connections should be reliable so being marginally // aggressive here is fine. - err = tlsConn.SetDeadline(time.Now().Add(10 * time.Second)) + err = tlsConn.SetDeadline(time.Now().Add(30 * time.Second)) if err != nil { if c.logger.IsDebug() { c.logger.Debug("error setting deadline for cluster connection", "error", err) From 0e22fa143dd7eff3d18f15973b98d79a2edd2833 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 18:38:05 -0400 Subject: [PATCH 07/12] Capture req locally since tests need it --- builtin/logical/pki/path_tidy.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/builtin/logical/pki/path_tidy.go b/builtin/logical/pki/path_tidy.go index ccb76af39a97..4d9ea992db00 100644 --- a/builtin/logical/pki/path_tidy.go +++ b/builtin/logical/pki/path_tidy.go @@ -66,6 +66,12 @@ func (b *backend) pathTidyWrite(ctx context.Context, req *logical.Request, d *fr return resp, nil } + // Tests using framework will screw up the storage so make a locally + // scoped req to hold a reference + req = &logical.Request{ + Storage: req.Storage, + } + go func() { defer atomic.StoreUint32(b.tidyCASGuard, 0) From 8a328190a1b56800291bfaacd4fe53a84f627f67 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 18:42:19 -0400 Subject: [PATCH 08/12] Add a timeout to tests to account for async tidy. It's super sadly monster because that passes more often than less monster timeouts, but really it's just flaky for reasons that likely have to do with logical.Framework testing stuff because there's no evidence of any problem other than it just not having run in time. --- builtin/logical/pki/backend_test.go | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 3b80d5972545..83729bc0df32 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -12,6 +12,7 @@ import ( "crypto/x509/pkix" "encoding/base64" "encoding/pem" + "errors" "fmt" "math" "math/big" @@ -1429,6 +1430,20 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int "tidy_cert_store": true, "tidy_revocation_list": true, }, + Check: func(resp *logical.Response) error { + if resp.IsError() { + return fmt.Errorf("got an error resp: %v", resp.Error()) + } + if len(resp.Warnings) != 1 { + return fmt.Errorf("expected a warning, resp is %#v", *resp) + } + + // Give time for the certificates to pass the safety buffer + t.Logf("Sleeping for 15 seconds to allow tidy to work") + t.Logf(time.Now().String()) + time.Sleep(45 * time.Second) + return nil + }, }, // We do *not* expect to find these @@ -1436,8 +1451,10 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int Operation: logical.ReadOperation, PreFlight: setSerialUnderTest, Check: func(resp *logical.Response) error { + t.Logf(time.Now().String()) if resp != nil { - return fmt.Errorf("expected no response") + t.Logf("failed at first check, resp is %#v", *resp) + return errors.New("expected no response") } serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) @@ -1451,6 +1468,7 @@ func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, int PreFlight: setSerialUnderTest, Check: func(resp *logical.Response) error { if resp != nil { + t.Logf("failed at second check, resp is %#v", *resp) return fmt.Errorf("expected no response") } From 2fa4514e554532cc28bece14a0a6d33536af489f Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 18:44:57 -0400 Subject: [PATCH 09/12] Add sleeps for tidy since it's now async --- builtin/credential/approle/path_tidy_user_id_test.go | 4 ++++ vault/token_store_test.go | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/builtin/credential/approle/path_tidy_user_id_test.go b/builtin/credential/approle/path_tidy_user_id_test.go index 314a69175abe..f4e8b8e91da6 100644 --- a/builtin/credential/approle/path_tidy_user_id_test.go +++ b/builtin/credential/approle/path_tidy_user_id_test.go @@ -3,6 +3,7 @@ package approle import ( "context" "testing" + "time" "github.com/hashicorp/vault/logical" ) @@ -69,6 +70,9 @@ func TestAppRole_TidyDanglingAccessors(t *testing.T) { t.Fatal(err) } + // It runs async so we give it a bit of time to run + time.Sleep(10 * time.Second) + accessorHashes, err = storage.List(context.Background(), "accessor/") if err != nil { t.Fatal(err) diff --git a/vault/token_store_test.go b/vault/token_store_test.go index c20c36b0c066..9acff94c2040 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -3777,6 +3777,9 @@ func TestTokenStore_HandleTidyCase1(t *testing.T) { t.Fatalf("err:%v resp:%v", err, resp) } + // Tidy runs async so give it time + time.Sleep(10 * time.Second) + // Tidy should have removed all the dangling accessor entries resp, err = ts.HandleRequest(context.Background(), accessorListReq) if err != nil || (resp != nil && resp.IsError()) { @@ -3909,6 +3912,9 @@ func TestTokenStore_HandleTidy_parentCleanup(t *testing.T) { t.Fatalf("err:%v resp:%v", err, resp) } + // Tidy runs async so give it time + time.Sleep(10 * time.Second) + // Tidy should have removed all the dangling accessor entries resp, err = ts.HandleRequest(context.Background(), accessorListReq) if err != nil || (resp != nil && resp.IsError()) { From 7632e1dafd6520a1ac0b7788e34968732617986d Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 14 Jun 2018 20:40:39 -0400 Subject: [PATCH 10/12] Use deep.Equal for a dynamo test --- physical/dynamodb/dynamodb_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/physical/dynamodb/dynamodb_test.go b/physical/dynamodb/dynamodb_test.go index 70e69802d553..28a63d51d3f6 100644 --- a/physical/dynamodb/dynamodb_test.go +++ b/physical/dynamodb/dynamodb_test.go @@ -6,10 +6,10 @@ import ( "math/rand" "net/http" "os" - "reflect" "testing" "time" + "github.com/go-test/deep" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/vault/helper/logging" "github.com/hashicorp/vault/physical" @@ -106,8 +106,8 @@ func TestDynamoDBBackend(t *testing.T) { if err != nil { t.Fatalf("err: %s", err) } - if !reflect.DeepEqual(inputEntry, entry) { - t.Fatalf("exp: %#v, act: %#v", inputEntry, entry) + if diff := deep.Equal(inputEntry, entry); diff != nil { + t.Fatal(diff) } }) } From 5c7cf69b40e1ff2cb9ad144ff06a576b26470f0c Mon Sep 17 00:00:00 2001 From: Jim Kalafut Date: Fri, 15 Jun 2018 12:40:29 -0700 Subject: [PATCH 11/12] Update to a newer dynamodb-local container --- physical/dynamodb/dynamodb_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/physical/dynamodb/dynamodb_test.go b/physical/dynamodb/dynamodb_test.go index 28a63d51d3f6..a831a29150fd 100644 --- a/physical/dynamodb/dynamodb_test.go +++ b/physical/dynamodb/dynamodb_test.go @@ -285,7 +285,7 @@ func prepareDynamoDBTestContainer(t *testing.T) (cleanup func(), retAddress stri t.Fatalf("Failed to connect to docker: %s", err) } - resource, err := pool.Run("deangiberson/aws-dynamodb-local", "latest", []string{}) + resource, err := pool.Run("cnadiminti/dynamodb-local", "latest", []string{}) if err != nil { t.Fatalf("Could not start local DynamoDB: %s", err) } From 712cc87004614a8d5e20c283586406c78a2f9255 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Sat, 16 Jun 2018 15:58:24 -0400 Subject: [PATCH 12/12] Modernize CA testing steps --- builtin/logical/pki/backend.go | 6 +- builtin/logical/pki/backend_test.go | 761 +------------------------ builtin/logical/pki/ca_test.go | 570 ++++++++++++++++++ builtin/logical/pki/path_roles_test.go | 2 +- 4 files changed, 578 insertions(+), 761 deletions(-) create mode 100644 builtin/logical/pki/ca_test.go diff --git a/builtin/logical/pki/backend.go b/builtin/logical/pki/backend.go index 9cf894d78389..60e943acb8a3 100644 --- a/builtin/logical/pki/backend.go +++ b/builtin/logical/pki/backend.go @@ -12,7 +12,7 @@ import ( // Factory creates a new backend implementing the logical.Backend interface func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() + b := Backend(conf) if err := b.Setup(ctx, conf); err != nil { return nil, err } @@ -20,7 +20,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, } // Backend returns a new Backend framework struct -func Backend() *backend { +func Backend(conf *logical.BackendConfig) *backend { var b backend b.Backend = &framework.Backend{ Help: strings.TrimSpace(backendHelp), @@ -86,6 +86,7 @@ func Backend() *backend { b.crlLifetime = time.Hour * 72 b.tidyCASGuard = new(uint32) + b.storage = conf.StorageView return &b } @@ -93,6 +94,7 @@ func Backend() *backend { type backend struct { *framework.Backend + storage logical.Storage crlLifetime time.Duration revokeStorageLock sync.RWMutex tidyCASGuard *uint32 diff --git a/builtin/logical/pki/backend_test.go b/builtin/logical/pki/backend_test.go index 83729bc0df32..9be7b55dd3a4 100644 --- a/builtin/logical/pki/backend_test.go +++ b/builtin/logical/pki/backend_test.go @@ -12,7 +12,6 @@ import ( "crypto/x509/pkix" "encoding/base64" "encoding/pem" - "errors" "fmt" "math" "math/big" @@ -135,64 +134,6 @@ func TestPKI_RequireCN(t *testing.T) { } } -// Performs basic tests on CA functionality -// Uses the RSA CA key -func TestBackend_RSAKey(t *testing.T) { - initTest.Do(setCerts) - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - b, err := Factory(context.Background(), &logical.BackendConfig{ - Logger: nil, - System: &logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - }, - }) - if err != nil { - t.Fatalf("Unable to create backend: %s", err) - } - - testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, - } - - intdata := map[string]interface{}{} - reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, rsaCACert, rsaCAKey, ecCACert, intdata, reqdata)...) - - logicaltest.Test(t, testCase) -} - -// Performs basic tests on CA functionality -// Uses the EC CA key -func TestBackend_ECKey(t *testing.T) { - initTest.Do(setCerts) - defaultLeaseTTLVal := time.Hour * 24 - maxLeaseTTLVal := time.Hour * 24 * 32 - b, err := Factory(context.Background(), &logical.BackendConfig{ - Logger: nil, - System: &logical.StaticSystemView{ - DefaultLeaseTTLVal: defaultLeaseTTLVal, - MaxLeaseTTLVal: maxLeaseTTLVal, - }, - }) - if err != nil { - t.Fatalf("Unable to create backend: %s", err) - } - - testCase := logicaltest.TestCase{ - Backend: b, - Steps: []logicaltest.TestStep{}, - } - - intdata := map[string]interface{}{} - reqdata := map[string]interface{}{} - testCase.Steps = append(testCase.Steps, generateCATestingSteps(t, ecCACert, ecCAKey, rsaCACert, intdata, reqdata)...) - - logicaltest.Test(t, testCase) -} - func TestBackend_CSRValues(t *testing.T) { initTest.Do(setCerts) defaultLeaseTTLVal := time.Hour * 24 @@ -806,702 +747,6 @@ func generateCSRSteps(t *testing.T, caCert, caKey string, intdata, reqdata map[s return ret } -// Generates steps to test out CA configuration -- certificates + CRL expiry, -// and ensure that the certificates are readable after storing them -func generateCATestingSteps(t *testing.T, caCert, caKey, otherCaCert string, intdata, reqdata map[string]interface{}) []logicaltest.TestStep { - setSerialUnderTest := func(req *logical.Request) error { - req.Path = serialUnderTest - return nil - } - - ret := []logicaltest.TestStep{ - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": strings.Join([]string{caKey, caCert}, "\n"), - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/crl", - Data: map[string]interface{}{ - "expiry": "16h", - }, - }, - - // Ensure we can fetch it back via unauthenticated means, in various formats - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "cert/ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - if resp.Data["certificate"].(string) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca/pem", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - if !reflect.DeepEqual(rawBytes, []byte(caCert)) { - return fmt.Errorf("CA certificate:\n%#v\ndoes not match original:\n%#v\n", rawBytes, []byte(caCert)) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: rawBytes, - }))) - if pemBytes != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "config/crl", - Check: func(resp *logical.Response) error { - if resp.Data["expiry"].(string) != "16h" { - return fmt.Errorf("CRL lifetimes do not match (got %s)", resp.Data["expiry"].(string)) - } - return nil - }, - }, - - // Ensure that both parts of the PEM bundle are required - // Here, just the cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": caCert, - }, - ErrorOk: true, - }, - - // Here, just the key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: map[string]interface{}{ - "pem_bundle": caKey, - }, - ErrorOk: true, - }, - - // Ensure we can fetch it back via unauthenticated means, in various formats - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "cert/ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - if resp.Data["certificate"].(string) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", resp.Data["certificate"].(string), caCert) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca/pem", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - if string(rawBytes) != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", string(rawBytes), caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "ca", - Unauthenticated: true, - Check: func(resp *logical.Response) error { - rawBytes := resp.Data["http_raw_body"].([]byte) - pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: rawBytes, - }))) - if pemBytes != caCert { - return fmt.Errorf("CA certificate:\n%s\ndoes not match original:\n%s\n", pemBytes, caCert) - } - if resp.Data["http_content_type"].(string) != "application/pkix-cert" { - return fmt.Errorf("expected application/pkix-cert as content-type, but got %s", resp.Data["http_content_type"].(string)) - } - return nil - }, - }, - - // Test a bunch of generation stuff - logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: "root", - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/generate/exported", - Data: map[string]interface{}{ - "common_name": "Root Cert", - "ttl": "180h", - }, - Check: func(resp *logical.Response) error { - intdata["root"] = resp.Data["certificate"].(string) - intdata["rootkey"] = resp.Data["private_key"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["root"].(string), intdata["rootkey"].(string)}, "\n") - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/generate/exported", - Data: map[string]interface{}{ - "common_name": "intermediate.cert.com", - }, - Check: func(resp *logical.Response) error { - intdata["intermediatecsr"] = resp.Data["csr"].(string) - intdata["intermediatekey"] = resp.Data["private_key"].(string) - return nil - }, - }, - - // Re-load the root key in so we can sign it - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - delete(reqdata, "ttl") - reqdata["csr"] = intdata["intermediatecsr"].(string) - reqdata["common_name"] = "intermediate.cert.com" - reqdata["ttl"] = "10s" - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/sign-intermediate", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "csr") - delete(reqdata, "common_name") - delete(reqdata, "ttl") - intdata["intermediatecert"] = resp.Data["certificate"].(string) - reqdata["serial_number"] = resp.Data["serial_number"].(string) - reqdata["rsa_int_serial_number"] = resp.Data["serial_number"].(string) - reqdata["certificate"] = resp.Data["certificate"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n") - return nil - }, - }, - - // First load in this way to populate the private key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - return nil - }, - }, - - // Now test setting the intermediate, signed CA cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/set-signed", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "certificate") - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // We expect to find a zero revocation time - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) != 0 { - return fmt.Errorf("expected a zero revocation time") - } - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "revoke", - Data: reqdata, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 1 { - t.Fatalf("length of revoked list not 1; %d", len(revokedList)) - } - revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":") - if revokedString != reqdata["serial_number"].(string) { - t.Fatalf("got serial %s, expecting %s", revokedString, reqdata["serial_number"].(string)) - } - delete(reqdata, "serial_number") - return nil - }, - }, - - // Do it all again, with EC keys and DER format - logicaltest.TestStep{ - Operation: logical.DeleteOperation, - Path: "root", - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/generate/exported", - Data: map[string]interface{}{ - "common_name": "Root Cert", - "ttl": "180h", - "key_type": "ec", - "key_bits": 384, - "format": "der", - }, - Check: func(resp *logical.Response) error { - certBytes, _ := base64.StdEncoding.DecodeString(resp.Data["certificate"].(string)) - certPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }))) - keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string)) - keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: keyBytes, - }))) - intdata["root"] = certPem - intdata["rootkey"] = keyPem - reqdata["pem_bundle"] = strings.Join([]string{certPem, keyPem}, "\n") - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/generate/exported", - Data: map[string]interface{}{ - "format": "der", - "key_type": "ec", - "key_bits": 384, - "common_name": "intermediate.cert.com", - }, - Check: func(resp *logical.Response) error { - csrBytes, _ := base64.StdEncoding.DecodeString(resp.Data["csr"].(string)) - csrPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }))) - keyBytes, _ := base64.StdEncoding.DecodeString(resp.Data["private_key"].(string)) - keyPem := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ - Type: "EC PRIVATE KEY", - Bytes: keyBytes, - }))) - intdata["intermediatecsr"] = csrPem - intdata["intermediatekey"] = keyPem - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - delete(reqdata, "ttl") - reqdata["csr"] = intdata["intermediatecsr"].(string) - reqdata["common_name"] = "intermediate.cert.com" - reqdata["ttl"] = "10s" - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "root/sign-intermediate", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "csr") - delete(reqdata, "common_name") - delete(reqdata, "ttl") - intdata["intermediatecert"] = resp.Data["certificate"].(string) - reqdata["serial_number"] = resp.Data["serial_number"].(string) - reqdata["ec_int_serial_number"] = resp.Data["serial_number"].(string) - reqdata["certificate"] = resp.Data["certificate"].(string) - reqdata["pem_bundle"] = strings.Join([]string{intdata["intermediatekey"].(string), resp.Data["certificate"].(string)}, "\n") - return nil - }, - }, - - // First load in this way to populate the private key - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "config/ca", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "pem_bundle") - return nil - }, - }, - - // Now test setting the intermediate, signed CA cert - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "intermediate/set-signed", - Data: reqdata, - Check: func(resp *logical.Response) error { - delete(reqdata, "certificate") - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - // We expect to find a zero revocation time - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) != 0 { - return fmt.Errorf("expected a zero revocation time") - } - - return nil - }, - }, - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "revoke", - Data: reqdata, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 2 { - t.Fatalf("length of revoked list not 2; %d", len(revokedList)) - } - found := false - for _, revEntry := range revokedList { - revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":") - if revokedString == reqdata["serial_number"].(string) { - found = true - } - } - if !found { - t.Fatalf("did not find %s in CRL", reqdata["serial_number"].(string)) - } - delete(reqdata, "serial_number") - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Make sure both serial numbers we expect to find are found - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) == 0 { - return fmt.Errorf("expected a non-zero revocation time") - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - if resp.Data["revocation_time"].(int64) == 0 { - return fmt.Errorf("expected a non-zero revocation time") - } - - // Give time for the certificates to pass the safety buffer - t.Logf("Sleeping for 15 seconds to allow safety buffer time to pass before testing tidying") - time.Sleep(15 * time.Second) - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // This shouldn't do anything since the safety buffer is too long - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "3h", - "tidy_cert_store": true, - "tidy_revocation_list": true, - }, - }, - - // We still expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Both should appear in the CRL - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 2 { - t.Fatalf("length of revoked list not 2; %d", len(revokedList)) - } - foundRsa := false - foundEc := false - for _, revEntry := range revokedList { - revokedString := certutil.GetHexFormatted(revEntry.SerialNumber.Bytes(), ":") - if revokedString == reqdata["rsa_int_serial_number"].(string) { - foundRsa = true - } - if revokedString == reqdata["ec_int_serial_number"].(string) { - foundEc = true - } - } - if !foundRsa || !foundEc { - t.Fatalf("did not find an expected entry in CRL") - } - - return nil - }, - }, - - // This shouldn't do anything since the boolean values default to false - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "1s", - }, - }, - - // We still expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil && resp.Data["error"] != nil && resp.Data["error"].(string) != "" { - return fmt.Errorf("got an error: %s", resp.Data["error"].(string)) - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // This should remove the values since the safety buffer is short - logicaltest.TestStep{ - Operation: logical.UpdateOperation, - Path: "tidy", - Data: map[string]interface{}{ - "safety_buffer": "1s", - "tidy_cert_store": true, - "tidy_revocation_list": true, - }, - Check: func(resp *logical.Response) error { - if resp.IsError() { - return fmt.Errorf("got an error resp: %v", resp.Error()) - } - if len(resp.Warnings) != 1 { - return fmt.Errorf("expected a warning, resp is %#v", *resp) - } - - // Give time for the certificates to pass the safety buffer - t.Logf("Sleeping for 15 seconds to allow tidy to work") - t.Logf(time.Now().String()) - time.Sleep(45 * time.Second) - return nil - }, - }, - - // We do *not* expect to find these - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - t.Logf(time.Now().String()) - if resp != nil { - t.Logf("failed at first check, resp is %#v", *resp) - return errors.New("expected no response") - } - - serialUnderTest = "cert/" + reqdata["ec_int_serial_number"].(string) - - return nil - }, - }, - - logicaltest.TestStep{ - Operation: logical.ReadOperation, - PreFlight: setSerialUnderTest, - Check: func(resp *logical.Response) error { - if resp != nil { - t.Logf("failed at second check, resp is %#v", *resp) - return fmt.Errorf("expected no response") - } - - serialUnderTest = "cert/" + reqdata["rsa_int_serial_number"].(string) - - return nil - }, - }, - - // Both should be gone from the CRL - logicaltest.TestStep{ - Operation: logical.ReadOperation, - Path: "crl", - Data: reqdata, - Check: func(resp *logical.Response) error { - crlBytes := resp.Data["http_raw_body"].([]byte) - certList, err := x509.ParseCRL(crlBytes) - if err != nil { - t.Fatalf("err: %s", err) - } - revokedList := certList.TBSCertList.RevokedCertificates - if len(revokedList) != 0 { - t.Fatalf("length of revoked list not 0; %d", len(revokedList)) - } - - return nil - }, - }, - } - - return ret -} - // Generates steps to test out various role permutations func generateRoleSteps(t *testing.T, useCSRs bool) []logicaltest.TestStep { roleVals := roleEntry{ @@ -2158,7 +1403,7 @@ func TestBackend_PathFetchCertList(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) @@ -2285,7 +1530,7 @@ func TestBackend_SignVerbatim(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) @@ -2841,7 +2086,7 @@ func TestBackend_SignSelfIssued(t *testing.T) { storage := &logical.InmemStorage{} config.StorageView = storage - b := Backend() + b := Backend(config) err := b.Setup(context.Background(), config) if err != nil { t.Fatal(err) diff --git a/builtin/logical/pki/ca_test.go b/builtin/logical/pki/ca_test.go new file mode 100644 index 000000000000..82bac5af0c03 --- /dev/null +++ b/builtin/logical/pki/ca_test.go @@ -0,0 +1,570 @@ +package pki + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" + "math/big" + mathrand "math/rand" + "strings" + "testing" + "time" + + "github.com/go-test/deep" + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/helper/certutil" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +func TestBackend_CA_Steps(t *testing.T) { + var b *backend + + factory := func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { + be, err := Factory(ctx, conf) + if err == nil { + b = be.(*backend) + } + return be, err + } + + coreConfig := &vault.CoreConfig{ + LogicalBackends: map[string]logical.Factory{ + "pki": factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + client := cluster.Cores[0].Client + + // Set RSA/EC CA certificates + var rsaCAKey, rsaCACert, ecCAKey, ecCACert string + { + cak, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + marshaledKey, err := x509.MarshalECPrivateKey(cak) + if err != nil { + panic(err) + } + keyPEMBlock := &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: marshaledKey, + } + ecCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) + if err != nil { + panic(err) + } + subjKeyID, err := certutil.GetSubjKeyID(cak) + if err != nil { + panic(err) + } + caCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "root.localhost", + }, + SubjectKeyId: subjKeyID, + DNSNames: []string{"root.localhost"}, + KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign), + SerialNumber: big.NewInt(mathrand.Int63()), + NotBefore: time.Now().Add(-30 * time.Second), + NotAfter: time.Now().Add(262980 * time.Hour), + BasicConstraintsValid: true, + IsCA: true, + } + caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, cak.Public(), cak) + if err != nil { + panic(err) + } + caCertPEMBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + ecCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) + + rak, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(err) + } + marshaledKey = x509.MarshalPKCS1PrivateKey(rak) + keyPEMBlock = &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: marshaledKey, + } + rsaCAKey = strings.TrimSpace(string(pem.EncodeToMemory(keyPEMBlock))) + if err != nil { + panic(err) + } + subjKeyID, err = certutil.GetSubjKeyID(rak) + if err != nil { + panic(err) + } + caBytes, err = x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, rak.Public(), rak) + if err != nil { + panic(err) + } + caCertPEMBlock = &pem.Block{ + Type: "CERTIFICATE", + Bytes: caBytes, + } + rsaCACert = strings.TrimSpace(string(pem.EncodeToMemory(caCertPEMBlock))) + } + + // Setup backends + var rsaRoot, rsaInt, ecRoot, ecInt *backend + { + if err := client.Sys().Mount("rsaroot", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + rsaRoot = b + + if err := client.Sys().Mount("rsaint", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + rsaInt = b + + if err := client.Sys().Mount("ecroot", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + ecRoot = b + + if err := client.Sys().Mount("ecint", &api.MountInput{ + Type: "pki", + Config: api.MountConfigInput{ + DefaultLeaseTTL: "16h", + MaxLeaseTTL: "60h", + }, + }); err != nil { + t.Fatal(err) + } + ecInt = b + } + + t.Run("teststeps", func(t *testing.T) { + t.Run("rsa", func(t *testing.T) { + t.Parallel() + subClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + subClient.SetToken(client.Token()) + runSteps(t, rsaRoot, rsaInt, subClient, "rsaroot/", "rsaint/", rsaCACert, rsaCAKey) + }) + t.Run("ec", func(t *testing.T) { + t.Parallel() + subClient, err := client.Clone() + if err != nil { + t.Fatal(err) + } + subClient.SetToken(client.Token()) + runSteps(t, ecRoot, ecInt, subClient, "ecroot/", "ecint/", ecCACert, ecCAKey) + }) + }) +} + +func runSteps(t *testing.T, rootB, intB *backend, client *api.Client, rootName, intName, caCert, caKey string) { + // Load CA cert/key in and ensure we can fetch it back in various formats, + // unauthenticated + { + // Attempt import but only provide one the cert + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": caCert, + }) + if err == nil { + t.Fatal("expected error") + } + } + + // Same but with only the key + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": caKey, + }) + if err == nil { + t.Fatal("expected error") + } + } + + // Import CA bundle + { + _, err := client.Logical().Write(rootName+"config/ca", map[string]interface{}{ + "pem_bundle": strings.Join([]string{caKey, caCert}, "\n"), + }) + if err != nil { + t.Fatal(err) + } + } + + prevToken := client.Token() + client.SetToken("") + + // cert/ca path + { + resp, err := client.Logical().Read(rootName + "cert/ca") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if diff := deep.Equal(resp.Data["certificate"].(string), caCert); diff != nil { + t.Fatal(diff) + } + } + // ca/pem path (raw string) + { + req := &logical.Request{ + Path: "ca/pem", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if diff := deep.Equal(resp.Data["http_raw_body"].([]byte), []byte(caCert)); diff != nil { + t.Fatal(diff) + } + if resp.Data["http_content_type"].(string) != "application/pkix-cert" { + t.Fatal("wrong content type") + } + } + + // ca (raw DER bytes) + { + req := &logical.Request{ + Path: "ca", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + rawBytes := resp.Data["http_raw_body"].([]byte) + pemBytes := strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: rawBytes, + }))) + if diff := deep.Equal(pemBytes, caCert); diff != nil { + t.Fatal(diff) + } + if resp.Data["http_content_type"].(string) != "application/pkix-cert" { + t.Fatal("wrong content type") + } + } + + client.SetToken(prevToken) + } + + // Configure an expiry on the CRL and verify what comes back + { + // Set CRL config + { + _, err := client.Logical().Write(rootName+"config/crl", map[string]interface{}{ + "expiry": "16h", + }) + if err != nil { + t.Fatal(err) + } + } + + // Verify it + { + resp, err := client.Logical().Read(rootName + "config/crl") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["expiry"].(string) != "16h" { + t.Fatal("expected a 16 hour expiry") + } + } + } + + // Test generating a root, an intermediate, signing it, setting signed, and + // revoking it + + // We'll need this later + var intSerialNumber string + { + // First, delete the existing CA info + { + _, err := client.Logical().Delete(rootName + "root") + if err != nil { + t.Fatal(err) + } + } + + var rootPEM, rootKey, rootPEMBundle string + // Test exported root generation + { + resp, err := client.Logical().Write(rootName+"root/generate/exported", map[string]interface{}{ + "common_name": "Root Cert", + "ttl": "180h", + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + rootPEM = resp.Data["certificate"].(string) + rootKey = resp.Data["private_key"].(string) + rootPEMBundle = strings.Join([]string{rootPEM, rootKey}, "\n") + // This is really here to keep the use checker happy + if rootPEMBundle == "" { + t.Fatal("bad root pem bundle") + } + } + + var intPEM, intCSR, intKey string + // Test exported intermediate CSR generation + { + resp, err := client.Logical().Write(intName+"intermediate/generate/exported", map[string]interface{}{ + "common_name": "intermediate.cert.com", + "ttl": "180h", + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + intCSR = resp.Data["csr"].(string) + intKey = resp.Data["private_key"].(string) + // This is really here to keep the use checker happy + if intCSR == "" || intKey == "" { + t.Fatal("int csr or key empty") + } + } + + // Test signing + { + resp, err := client.Logical().Write(rootName+"root/sign-intermediate", map[string]interface{}{ + "common_name": "intermediate.cert.com", + "ttl": "10s", + "csr": intCSR, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + intPEM = resp.Data["certificate"].(string) + intSerialNumber = resp.Data["serial_number"].(string) + } + + // Test setting signed + { + resp, err := client.Logical().Write(intName+"intermediate/set-signed", map[string]interface{}{ + "certificate": intPEM, + }) + if err != nil { + t.Fatal(err) + } + if resp != nil { + t.Fatal("expected nil response") + } + } + + // Verify we can find it via the root + { + resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["revocation_time"].(json.Number).String() != "0" { + t.Fatal("expected a zero revocation time") + } + } + + // Revoke the intermediate + { + resp, err := client.Logical().Write(rootName+"revoke", map[string]interface{}{ + "serial_number": intSerialNumber, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + } + } + + verifyRevocation := func(t *testing.T, serial string, shouldFind bool) { + // Verify it is now revoked + { + resp, err := client.Logical().Read(rootName + "cert/" + intSerialNumber) + if err != nil { + t.Fatal(err) + } + switch shouldFind { + case true: + if resp == nil { + t.Fatal("nil response") + } + if resp.Data["revocation_time"].(json.Number).String() == "0" { + t.Fatal("expected a non-zero revocation time") + } + default: + if resp != nil { + t.Fatalf("expected nil response, got %#v", *resp) + } + } + } + + // Fetch the CRL and make sure it shows up + { + req := &logical.Request{ + Path: "crl", + Operation: logical.ReadOperation, + Storage: rootB.storage, + } + resp, err := rootB.HandleRequest(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("nil response") + } + crlBytes := resp.Data["http_raw_body"].([]byte) + certList, err := x509.ParseCRL(crlBytes) + if err != nil { + t.Fatal(err) + } + switch shouldFind { + case true: + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 1 { + t.Fatalf("bad length of revoked list: %d", len(revokedList)) + } + revokedString := certutil.GetHexFormatted(revokedList[0].SerialNumber.Bytes(), ":") + if revokedString != intSerialNumber { + t.Fatalf("bad revoked serial: %s", revokedString) + } + default: + revokedList := certList.TBSCertList.RevokedCertificates + if len(revokedList) != 0 { + t.Fatalf("bad length of revoked list: %d", len(revokedList)) + } + } + } + } + + // Validate current state of revoked certificates + verifyRevocation(t, intSerialNumber, true) + + // Give time for the safety buffer to pass before tidying + time.Sleep(10 * time.Second) + + // Test tidying + { + // Run with a high safety buffer, nothing should happen + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "3h", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, true) + } + + // Run with both values set false, nothing should happen + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "1s", + "tidy_cert_store": false, + "tidy_revocation_list": false, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, true) + } + + // Run with a short safety buffer and both set to true, both should be cleared + { + resp, err := client.Logical().Write(rootName+"tidy", map[string]interface{}{ + "safety_buffer": "1s", + "tidy_cert_store": true, + "tidy_revocation_list": true, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected warnings") + } + + // Wait a few seconds as it runs in a goroutine + time.Sleep(5 * time.Second) + + // Check to make sure we still find the cert and see it on the CRL + verifyRevocation(t, intSerialNumber, false) + } + } +} diff --git a/builtin/logical/pki/path_roles_test.go b/builtin/logical/pki/path_roles_test.go index f3927781aa23..6f915eeca742 100644 --- a/builtin/logical/pki/path_roles_test.go +++ b/builtin/logical/pki/path_roles_test.go @@ -15,7 +15,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { config.StorageView = &logical.InmemStorage{} var err error - b := Backend() + b := Backend(config) err = b.Setup(context.Background(), config) if err != nil { t.Fatal(err)