diff --git a/errors.go b/errors.go index aaebeb9a..e12af673 100644 --- a/errors.go +++ b/errors.go @@ -4,9 +4,8 @@ package awsbase import ( - "errors" - "github.com/hashicorp/aws-sdk-go-base/v2/internal/config" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/errs" ) // CannotAssumeRoleError occurs when AssumeRole cannot complete. @@ -14,8 +13,7 @@ type CannotAssumeRoleError = config.CannotAssumeRoleError // IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type. func IsCannotAssumeRoleError(err error) bool { - var e CannotAssumeRoleError - return errors.As(err, &e) + return errs.IsA[CannotAssumeRoleError](err) } // NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results. @@ -23,6 +21,5 @@ type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError // IsNoValidCredentialSourcesError returns true if the error contains the NoValidCredentialSourcesError type. func IsNoValidCredentialSourcesError(err error) bool { - var e NoValidCredentialSourcesError - return errors.As(err, &e) + return errs.IsA[NoValidCredentialSourcesError](err) } diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 00000000..c36b315c --- /dev/null +++ b/errors_test.go @@ -0,0 +1,85 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package awsbase + +import ( + "fmt" + "testing" +) + +func TestIsCannotAssumeRoleError(t *testing.T) { + testCases := []struct { + Name string + Err error + Expected bool + }{ + { + Name: "nil error", + }, + { + Name: "Top-level NoValidCredentialSourcesError", + Err: NoValidCredentialSourcesError{}, + }, + { + Name: "Top-level CannotAssumeRoleError", + Err: CannotAssumeRoleError{}, + Expected: true, + }, + { + Name: "Nested CannotAssumeRoleError", + Err: fmt.Errorf("test: %w", CannotAssumeRoleError{}), + Expected: true, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.Name, func(t *testing.T) { + got := IsCannotAssumeRoleError(testCase.Err) + + if got != testCase.Expected { + t.Errorf("got %t, expected %t", got, testCase.Expected) + } + }) + } +} + +func TestIsNoValidCredentialSourcesError(t *testing.T) { + testCases := []struct { + Name string + Err error + Expected bool + }{ + { + Name: "nil error", + }, + { + Name: "Top-level CannotAssumeRoleError", + Err: CannotAssumeRoleError{}, + }, + { + Name: "Top-level NoValidCredentialSourcesError", + Err: NoValidCredentialSourcesError{}, + Expected: true, + }, + { + Name: "Nested NoValidCredentialSourcesError", + Err: fmt.Errorf("test: %w", NoValidCredentialSourcesError{}), + Expected: true, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.Name, func(t *testing.T) { + got := IsNoValidCredentialSourcesError(testCase.Err) + + if got != testCase.Expected { + t.Errorf("got %t, expected %t", got, testCase.Expected) + } + }) + } +} diff --git a/internal/errs/errs.go b/internal/errs/errs.go new file mode 100644 index 00000000..ab6017d9 --- /dev/null +++ b/internal/errs/errs.go @@ -0,0 +1,21 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package errs + +import ( + "errors" +) + +// IsA indicates whether an error matches an error type. +func IsA[T error](err error) bool { + _, ok := As[T](err) + return ok +} + +// As is equivalent to errors.As(), but returns the value in-line. +func As[T error](err error) (T, bool) { + var as T + ok := errors.As(err, &as) + return as, ok +} diff --git a/tfawserr/awserr.go b/tfawserr/awserr.go new file mode 100644 index 00000000..7c4b4721 --- /dev/null +++ b/tfawserr/awserr.go @@ -0,0 +1,23 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tfawserr + +import ( + smithy "github.com/aws/smithy-go" + "github.com/hashicorp/aws-sdk-go-base/v2/internal/errs" +) + +// ErrCodeEquals returns true if the error matches all these conditions: +// - err is of type smithy.APIError +// - Error.Code() equals one of the passed codes +func ErrCodeEquals(err error, codes ...string) bool { + if apiErr, ok := errs.As[smithy.APIError](err); ok { + for _, code := range codes { + if apiErr.ErrorCode() == code { + return true + } + } + } + return false +} diff --git a/tfawserr/awserr_test.go b/tfawserr/awserr_test.go new file mode 100644 index 00000000..b5f2520d --- /dev/null +++ b/tfawserr/awserr_test.go @@ -0,0 +1,93 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tfawserr + +import ( + "fmt" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + smithy "github.com/aws/smithy-go" +) + +func TestErrCodeEquals(t *testing.T) { + testCases := map[string]struct { + Err error + Codes []string + Expected bool + }{ + "nil error": { + Err: nil, + Expected: false, + }, + "other error": { + Err: fmt.Errorf("other error"), + Expected: false, + }, + "Top-level smithy.GenericAPIError matching first code": { + Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}, + Codes: []string{"TestCode"}, + Expected: true, + }, + "Top-level smithy.GenericAPIError matching last code": { + Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}, + Codes: []string{"NotMatching", "TestCode"}, + Expected: true, + }, + "Top-level smithy.GenericAPIError no code": { + Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}, + }, + "Top-level smithy.GenericAPIError non-matching codes": { + Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}, + Codes: []string{"NotMatching", "AlsoNotMatching"}, + }, + "Wrapped smithy.GenericAPIError matching first code": { + Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}), + Codes: []string{"TestCode"}, + Expected: true, + }, + "Wrapped smithy.GenericAPIError matching last code": { + Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}), + Codes: []string{"NotMatching", "TestCode"}, + Expected: true, + }, + "Wrapped smithy.GenericAPIError non-matching codes": { + Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}), + Codes: []string{"NotMatching", "AlsoNotMatching"}, + }, + "Top-level sts ExpiredTokenException matching first code": { + Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}, + Codes: []string{"TestCode"}, + Expected: true, + }, + "Top-level sts ExpiredTokenException matching last code": { + Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}, + Codes: []string{"NotMatching", "TestCode"}, + Expected: true, + }, + "Wrapped sts ExpiredTokenException matching first code": { + Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}), + Codes: []string{"TestCode"}, + Expected: true, + }, + "Wrapped sts ExpiredTokenException matching last code": { + Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}), + Codes: []string{"NotMatching", "TestCode"}, + Expected: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + + t.Run(name, func(t *testing.T) { + got := ErrCodeEquals(testCase.Err, testCase.Codes...) + + if got != testCase.Expected { + t.Errorf("got %t, expected %t", got, testCase.Expected) + } + }) + } +}