diff --git a/builtin/logical/transit/backend_test.go b/builtin/logical/transit/backend_test.go index 17d44f021da3..562b6d3b9055 100644 --- a/builtin/logical/transit/backend_test.go +++ b/builtin/logical/transit/backend_test.go @@ -27,7 +27,7 @@ const ( testPlaintext = "the quick brown fox" ) -func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { +func createBackendWithStorage(t testing.TB) (*backend, logical.Storage) { config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} @@ -42,7 +42,7 @@ func createBackendWithStorage(t *testing.T) (*backend, logical.Storage) { return b, config.StorageView } -func createBackendWithSysView(t *testing.T) (*backend, logical.Storage) { +func createBackendWithSysView(t testing.TB) (*backend, logical.Storage) { sysView := logical.TestSystemView() storage := &logical.InmemStorage{} @@ -64,7 +64,7 @@ func createBackendWithSysView(t *testing.T) (*backend, logical.Storage) { return b, storage } -func createBackendWithSysViewWithStorage(t *testing.T, s logical.Storage) *backend { +func createBackendWithSysViewWithStorage(t testing.TB, s logical.Storage) *backend { sysView := logical.TestSystemView() conf := &logical.BackendConfig{ @@ -85,7 +85,7 @@ func createBackendWithSysViewWithStorage(t *testing.T, s logical.Storage) *backe return b } -func createBackendWithForceNoCacheWithSysViewWithStorage(t *testing.T, s logical.Storage) *backend { +func createBackendWithForceNoCacheWithSysViewWithStorage(t testing.TB, s logical.Storage) *backend { sysView := logical.TestSystemView() sysView.CachingDisabledVal = true diff --git a/builtin/logical/transit/path_decrypt.go b/builtin/logical/transit/path_decrypt.go index bd3d82541ac2..77d77f5176d1 100644 --- a/builtin/logical/transit/path_decrypt.go +++ b/builtin/logical/transit/path_decrypt.go @@ -9,7 +9,6 @@ import ( "github.com/hashicorp/vault/sdk/helper/errutil" "github.com/hashicorp/vault/sdk/helper/keysutil" "github.com/hashicorp/vault/sdk/logical" - "github.com/mitchellh/mapstructure" ) func (b *backend) pathDecrypt() *framework.Path { @@ -57,7 +56,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d var batchInputItems []BatchRequestItem var err error if batchInputRaw != nil { - err = mapstructure.Decode(batchInputRaw, &batchInputItems) + err = decodeBatchRequestItems(batchInputRaw, &batchInputItems) if err != nil { return nil, errwrap.Wrapf("failed to parse batch input: {{err}}", err) } diff --git a/builtin/logical/transit/path_decrypt_bench_test.go b/builtin/logical/transit/path_decrypt_bench_test.go new file mode 100644 index 000000000000..bc93fc5c4049 --- /dev/null +++ b/builtin/logical/transit/path_decrypt_bench_test.go @@ -0,0 +1,88 @@ +package transit + +import ( + "context" + "testing" + + "github.com/hashicorp/vault/sdk/logical" +) + +func BenchmarkTransit_BatchDecryption1(b *testing.B) { + BTransit_BatchDecryption(b, 1) +} + +func BenchmarkTransit_BatchDecryption10(b *testing.B) { + BTransit_BatchDecryption(b, 10) +} + +func BenchmarkTransit_BatchDecryption50(b *testing.B) { + BTransit_BatchDecryption(b, 50) +} + +func BenchmarkTransit_BatchDecryption100(b *testing.B) { + BTransit_BatchDecryption(b, 100) +} + +func BenchmarkTransit_BatchDecryption1000(b *testing.B) { + BTransit_BatchDecryption(b, 1_000) +} + +func BenchmarkTransit_BatchDecryption10000(b *testing.B) { + BTransit_BatchDecryption(b, 10_000) +} + +func BTransit_BatchDecryption(b *testing.B, bsize int) { + b.StopTimer() + + var resp *logical.Response + var err error + + backend, s := createBackendWithStorage(b) + + batchEncryptionInput := make([]interface{}, 0, bsize) + for i := 0; i < bsize; i++ { + batchEncryptionInput = append( + batchEncryptionInput, + map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, + ) + } + + batchEncryptionData := map[string]interface{}{ + "batch_input": batchEncryptionInput, + } + + batchEncryptionReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "encrypt/upserted_key", + Storage: s, + Data: batchEncryptionData, + } + resp, err = backend.HandleRequest(context.Background(), batchEncryptionReq) + if err != nil || (resp != nil && resp.IsError()) { + b.Fatalf("err:%v resp:%#v", err, resp) + } + + batchResponseItems := resp.Data["batch_results"].([]BatchResponseItem) + batchDecryptionInput := make([]interface{}, len(batchResponseItems)) + for i, item := range batchResponseItems { + batchDecryptionInput[i] = map[string]interface{}{"ciphertext": item.Ciphertext} + } + batchDecryptionData := map[string]interface{}{ + "batch_input": batchDecryptionInput, + } + + batchDecryptionReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "decrypt/upserted_key", + Storage: s, + Data: batchDecryptionData, + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + resp, err = backend.HandleRequest(context.Background(), batchDecryptionReq) + if err != nil || (resp != nil && resp.IsError()) { + b.Fatalf("err:%v resp:%#v", err, resp) + } + } +} diff --git a/builtin/logical/transit/path_encrypt.go b/builtin/logical/transit/path_encrypt.go index f085307620dc..414a3cbdc7d7 100644 --- a/builtin/logical/transit/path_encrypt.go +++ b/builtin/logical/transit/path_encrypt.go @@ -59,22 +59,22 @@ func (b *backend) pathEncrypt() *framework.Path { return &framework.Path{ Pattern: "encrypt/" + framework.GenericNameRegex("name"), Fields: map[string]*framework.FieldSchema{ - "name": &framework.FieldSchema{ + "name": { Type: framework.TypeString, Description: "Name of the policy", }, - "plaintext": &framework.FieldSchema{ + "plaintext": { Type: framework.TypeString, Description: "Base64 encoded plaintext value to be encrypted", }, - "context": &framework.FieldSchema{ + "context": { Type: framework.TypeString, Description: "Base64 encoded context for key derivation. Required if key derivation is enabled", }, - "nonce": &framework.FieldSchema{ + "nonce": { Type: framework.TypeString, Description: ` Base64 encoded nonce value. Must be provided if convergent encryption is @@ -85,7 +85,7 @@ encryption key) this nonce value is **never reused**. `, }, - "type": &framework.FieldSchema{ + "type": { Type: framework.TypeString, Default: "aes256-gcm96", Description: ` @@ -94,7 +94,7 @@ When performing an upsert operation, the type of key to create. Currently, "aes128-gcm96" (symmetric) and "aes256-gcm96" (symmetric) are the only types supported. Defaults to "aes256-gcm96".`, }, - "convergent_encryption": &framework.FieldSchema{ + "convergent_encryption": { Type: framework.TypeBool, Description: ` This parameter will only be used when a key is expected to be created. Whether @@ -107,7 +107,7 @@ you ensure that all nonces are unique for a given context. Failing to do so will severely impact the ciphertext's security.`, }, - "key_version": &framework.FieldSchema{ + "key_version": { Type: framework.TypeInt, Description: `The version of the key to use for encryption. Must be 0 (for latest) or a value greater than or equal @@ -127,6 +127,84 @@ to the min_encryption_version configured on the key.`, } } +// decodeBatchRequestItems is a fast path alternative to mapstructure.Decode to decode []BatchRequestItem. +// It aims to behave as closely possible to the original mapstructure.Decode and will return the same errors. +// https://github.com/hashicorp/vault/pull/8775/files#r437709722 +func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error { + if src == nil || dst == nil { + return nil + } + + items, ok := src.([]interface{}) + if !ok { + return fmt.Errorf("source data must be an array or slice, got %T", src) + } + + // Early return should happen before allocating the array if the batch is empty. + // However to comply with mapstructure output it's needed to allocate an empty array. + sitems := len(items) + *dst = make([]BatchRequestItem, sitems) + if sitems == 0 { + return nil + } + + // To comply with mapstructure output the same error type is needed. + var errs mapstructure.Error + + for i, iitem := range items { + item, ok := iitem.(map[string]interface{}) + if !ok { + return fmt.Errorf("[%d] expected a map, got '%T'", i, iitem) + } + + if v, has := item["context"]; has { + if casted, ok := v.(string); ok { + (*dst)[i].Context = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].context' expected type 'string', got unconvertible type '%T'", i, item["context"])) + } + } + + if v, has := item["ciphertext"]; has { + if casted, ok := v.(string); ok { + (*dst)[i].Ciphertext = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' expected type 'string', got unconvertible type '%T'", i, item["ciphertext"])) + } + } + + if v, has := item["plaintext"]; has { + if casted, ok := v.(string); ok { + (*dst)[i].Plaintext = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' expected type 'string', got unconvertible type '%T'", i, item["plaintext"])) + } + } + + if v, has := item["nonce"]; has { + if casted, ok := v.(string); ok { + (*dst)[i].Nonce = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].nonce' expected type 'string', got unconvertible type '%T'", i, item["nonce"])) + } + } + + if v, has := item["key_version"]; has { + if casted, ok := v.(int); ok { + (*dst)[i].KeyVersion = casted + } else { + errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].key_version' expected type 'int', got unconvertible type '%T'", i, item["key_version"])) + } + } + } + + if len(errs.Errors) > 0 { + return &errs + } + + return nil +} + func (b *backend) pathEncryptExistenceCheck(ctx context.Context, req *logical.Request, d *framework.FieldData) (bool, error) { name := d.Get("name").(string) p, _, err := b.lm.GetPolicy(ctx, keysutil.PolicyRequest{ @@ -146,11 +224,10 @@ func (b *backend) pathEncryptExistenceCheck(ctx context.Context, req *logical.Re func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) { name := d.Get("name").(string) var err error - batchInputRaw := d.Raw["batch_input"] var batchInputItems []BatchRequestItem if batchInputRaw != nil { - err = mapstructure.Decode(batchInputRaw, &batchInputItems) + err = decodeBatchRequestItems(batchInputRaw, &batchInputItems) if err != nil { return nil, errwrap.Wrapf("failed to parse batch input: {{err}}", err) } diff --git a/builtin/logical/transit/path_encrypt_bench_test.go b/builtin/logical/transit/path_encrypt_bench_test.go new file mode 100644 index 000000000000..e648c6e02fc3 --- /dev/null +++ b/builtin/logical/transit/path_encrypt_bench_test.go @@ -0,0 +1,68 @@ +package transit + +import ( + "context" + "testing" + + "github.com/hashicorp/vault/sdk/logical" +) + +func BenchmarkTransit_BatchEncryption1(b *testing.B) { + BTransit_BatchEncryption(b, 1) +} + +func BenchmarkTransit_BatchEncryption10(b *testing.B) { + BTransit_BatchEncryption(b, 10) +} + +func BenchmarkTransit_BatchEncryption50(b *testing.B) { + BTransit_BatchEncryption(b, 50) +} + +func BenchmarkTransit_BatchEncryption100(b *testing.B) { + BTransit_BatchEncryption(b, 100) +} + +func BenchmarkTransit_BatchEncryption1000(b *testing.B) { + BTransit_BatchEncryption(b, 1_000) +} + +func BenchmarkTransit_BatchEncryption10000(b *testing.B) { + BTransit_BatchEncryption(b, 10_000) +} + +func BTransit_BatchEncryption(b *testing.B, bsize int) { + b.StopTimer() + + var resp *logical.Response + var err error + + backend, s := createBackendWithStorage(b) + + batchEncryptionInput := make([]interface{}, 0, bsize) + for i := 0; i < bsize; i++ { + batchEncryptionInput = append( + batchEncryptionInput, + map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}, + ) + } + + batchEncryptionData := map[string]interface{}{ + "batch_input": batchEncryptionInput, + } + + batchEncryptionReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "encrypt/upserted_key", + Storage: s, + Data: batchEncryptionData, + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + resp, err = backend.HandleRequest(context.Background(), batchEncryptionReq) + if err != nil || (resp != nil && resp.IsError()) { + b.Fatalf("err:%v resp:%#v", err, resp) + } + } +} diff --git a/builtin/logical/transit/path_encrypt_test.go b/builtin/logical/transit/path_encrypt_test.go index 283d06fad016..0b1a65846a59 100644 --- a/builtin/logical/transit/path_encrypt_test.go +++ b/builtin/logical/transit/path_encrypt_test.go @@ -2,6 +2,7 @@ package transit import ( "context" + "reflect" "testing" "github.com/hashicorp/vault/sdk/logical" @@ -573,3 +574,129 @@ func TestTransit_BatchEncryptionCase12(t *testing.T) { t.Fatalf("expected an error") } } + +// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem. +func TestTransit_decodeBatchRequestItems(t *testing.T) { + tests := []struct { + name string + src interface{} + dest []BatchRequestItem + }{ + // basic edge cases of nil values + {name: "nil-nil", src: nil, dest: nil}, + {name: "nil-empty", src: nil, dest: []BatchRequestItem{}}, + {name: "empty-nil", src: []interface{}{}, dest: nil}, + { + name: "src-nil", + src: []interface{}{map[string]interface{}{}}, + dest: nil, + }, + // empty src & dest + { + name: "src-dest", + src: []interface{}{map[string]interface{}{}}, + dest: []BatchRequestItem{}, + }, + // empty src but with already populated dest, mapstructure discard pre-populated data. + { + name: "src-dest_pre_filled", + src: []interface{}{map[string]interface{}{}}, + dest: []BatchRequestItem{{}}, + }, + // two test per properties to test valid and invalid input + { + name: "src_plaintext-dest", + src: []interface{}{map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_plaintext_invalid-dest", + src: []interface{}{map[string]interface{}{"plaintext": 666}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_ciphertext-dest", + src: []interface{}{map[string]interface{}{"ciphertext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_ciphertext_invalid-dest", + src: []interface{}{map[string]interface{}{"ciphertext": 666}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_key_version-dest", + src: []interface{}{map[string]interface{}{"key_version": 1}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_key_version_invalid-dest", + src: []interface{}{map[string]interface{}{"key_version": "666"}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_nonce-dest", + src: []interface{}{map[string]interface{}{"nonce": "dGVzdGNvbnRleHQ="}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_nonce_invalid-dest", + src: []interface{}{map[string]interface{}{"nonce": 666}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_context-dest", + src: []interface{}{map[string]interface{}{"context": "dGVzdGNvbnRleHQ="}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_context_invalid-dest", + src: []interface{}{map[string]interface{}{"context": 666}}, + dest: []BatchRequestItem{}, + }, + { + name: "src_multi_order-dest", + src: []interface{}{ + map[string]interface{}{"context": "1"}, + map[string]interface{}{"context": "2"}, + map[string]interface{}{"context": "3"}, + }, + dest: []BatchRequestItem{}, + }, + { + name: "src_multi_with_invalid-dest", + src: []interface{}{ + map[string]interface{}{"context": "1"}, + map[string]interface{}{"context": "2", "key_version": "666"}, + map[string]interface{}{"context": "3"}, + }, + dest: []BatchRequestItem{}, + }, + { + name: "src_multi_with_multi_invalid-dest", + src: []interface{}{ + map[string]interface{}{"context": "1"}, + map[string]interface{}{"context": "2", "key_version": "666"}, + map[string]interface{}{"context": "3", "key_version": "1337"}, + }, + dest: []BatchRequestItem{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expectedDest := append(tt.dest[:0:0], tt.dest...) // copy of the dest state + expectedErr := mapstructure.Decode(tt.src, &expectedDest) + + gotErr := decodeBatchRequestItems(tt.src, &tt.dest) + gotDest := tt.dest + + if !reflect.DeepEqual(expectedErr, gotErr) { + t.Errorf("decodeBatchRequestItems unexpected error value, want: '%v', got: '%v'", expectedErr, gotErr) + } + + if !reflect.DeepEqual(expectedDest, gotDest) { + t.Errorf("decodeBatchRequestItems unexpected dest value, want: '%v', got: '%v'", expectedDest, gotDest) + } + }) + } +}