Skip to content

Commit

Permalink
Merge pull request #524 from hashicorp/f-aws-sdk-go-v2-errs
Browse files Browse the repository at this point in the history
Add `awsbase.ErrCodeEquals`, AWS SDK for Go v2 variant of helper in `v2/awsv1shim/tfawserr`
  • Loading branch information
gdavison committed Jun 22, 2023
2 parents 82b6e9e + f31cf73 commit 4a36d3d
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 6 deletions.
9 changes: 3 additions & 6 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,22 @@
package awsbase

import (
"errors"

"github.com/hashicorp/aws-sdk-go-base/v2/internal/config"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// CannotAssumeRoleError occurs when AssumeRole cannot complete.
type CannotAssumeRoleError = config.CannotAssumeRoleError

// IsCannotAssumeRoleError returns true if the error contains the CannotAssumeRoleError type.
func IsCannotAssumeRoleError(err error) bool {
var e CannotAssumeRoleError
return errors.As(err, &e)
return errs.IsA[CannotAssumeRoleError](err)
}

// NoValidCredentialSourcesError occurs when all credential lookup methods have been exhausted without results.
type NoValidCredentialSourcesError = config.NoValidCredentialSourcesError

// IsNoValidCredentialSourcesError returns true if the error contains the NoValidCredentialSourcesError type.
func IsNoValidCredentialSourcesError(err error) bool {
var e NoValidCredentialSourcesError
return errors.As(err, &e)
return errs.IsA[NoValidCredentialSourcesError](err)
}
85 changes: 85 additions & 0 deletions errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package awsbase

import (
"fmt"
"testing"
)

func TestIsCannotAssumeRoleError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
Expected: true,
},
{
Name: "Nested CannotAssumeRoleError",
Err: fmt.Errorf("test: %w", CannotAssumeRoleError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsCannotAssumeRoleError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}

func TestIsNoValidCredentialSourcesError(t *testing.T) {
testCases := []struct {
Name string
Err error
Expected bool
}{
{
Name: "nil error",
},
{
Name: "Top-level CannotAssumeRoleError",
Err: CannotAssumeRoleError{},
},
{
Name: "Top-level NoValidCredentialSourcesError",
Err: NoValidCredentialSourcesError{},
Expected: true,
},
{
Name: "Nested NoValidCredentialSourcesError",
Err: fmt.Errorf("test: %w", NoValidCredentialSourcesError{}),
Expected: true,
},
}

for _, testCase := range testCases {
testCase := testCase

t.Run(testCase.Name, func(t *testing.T) {
got := IsNoValidCredentialSourcesError(testCase.Err)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}
21 changes: 21 additions & 0 deletions internal/errs/errs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package errs

import (
"errors"
)

// IsA indicates whether an error matches an error type.
func IsA[T error](err error) bool {
_, ok := As[T](err)
return ok
}

// As is equivalent to errors.As(), but returns the value in-line.
func As[T error](err error) (T, bool) {
var as T
ok := errors.As(err, &as)
return as, ok
}
23 changes: 23 additions & 0 deletions tfawserr/awserr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
smithy "github.com/aws/smithy-go"
"github.com/hashicorp/aws-sdk-go-base/v2/internal/errs"
)

// ErrCodeEquals returns true if the error matches all these conditions:
// - err is of type smithy.APIError
// - Error.Code() equals one of the passed codes
func ErrCodeEquals(err error, codes ...string) bool {
if apiErr, ok := errs.As[smithy.APIError](err); ok {
for _, code := range codes {
if apiErr.ErrorCode() == code {
return true
}
}
}
return false
}
93 changes: 93 additions & 0 deletions tfawserr/awserr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package tfawserr

import (
"fmt"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts/types"
smithy "github.com/aws/smithy-go"
)

func TestErrCodeEquals(t *testing.T) {
testCases := map[string]struct {
Err error
Codes []string
Expected bool
}{
"nil error": {
Err: nil,
Expected: false,
},
"other error": {
Err: fmt.Errorf("other error"),
Expected: false,
},
"Top-level smithy.GenericAPIError matching first code": {
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"TestCode"},
Expected: true,
},
"Top-level smithy.GenericAPIError matching last code": {
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
"Top-level smithy.GenericAPIError no code": {
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
},
"Top-level smithy.GenericAPIError non-matching codes": {
Err: &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"},
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
"Wrapped smithy.GenericAPIError matching first code": {
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"TestCode"},
Expected: true,
},
"Wrapped smithy.GenericAPIError matching last code": {
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
"Wrapped smithy.GenericAPIError non-matching codes": {
Err: fmt.Errorf("test: %w", &smithy.GenericAPIError{Code: "TestCode", Message: "TestMessage"}),
Codes: []string{"NotMatching", "AlsoNotMatching"},
},
"Top-level sts ExpiredTokenException matching first code": {
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"TestCode"},
Expected: true,
},
"Top-level sts ExpiredTokenException matching last code": {
Err: &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")},
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
"Wrapped sts ExpiredTokenException matching first code": {
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"TestCode"},
Expected: true,
},
"Wrapped sts ExpiredTokenException matching last code": {
Err: fmt.Errorf("test: %w", &types.ExpiredTokenException{ErrorCodeOverride: aws.String("TestCode"), Message: aws.String("TestMessage")}),
Codes: []string{"NotMatching", "TestCode"},
Expected: true,
},
}

for name, testCase := range testCases {
testCase := testCase

t.Run(name, func(t *testing.T) {
got := ErrCodeEquals(testCase.Err, testCase.Codes...)

if got != testCase.Expected {
t.Errorf("got %t, expected %t", got, testCase.Expected)
}
})
}
}

0 comments on commit 4a36d3d

Please sign in to comment.