Skip to content

Commit

Permalink
Merge pull request #837 from aws/fix-unexpected-panic
Browse files Browse the repository at this point in the history
`rest-json`: updates rest-json error code retriever util
  • Loading branch information
skotambkar committed Oct 21, 2020
2 parents 5550687 + 3abce74 commit e113848
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 41 deletions.
62 changes: 21 additions & 41 deletions aws/protocol/restjson/decoder_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,48 @@ package restjson

import (
"encoding/json"
"fmt"
"io"
"strings"

"github.com/awslabs/smithy-go"
smithyjson "github.com/awslabs/smithy-go/json"
)

// GetErrorInfo util looks for code, __type, and message members in the
// json body. These members are optionally available, and the function
// returns the value of member if it is available. This function is useful to
// identify the error code, msg in a REST JSON error response.
func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) {
startToken, err := decoder.Token()
if err == io.EOF {
return "", "", nil
var errInfo struct {
Code string
Type string `json:"__type"`
Message string
}
if err != nil {
return "", "", err
}

if t, ok := startToken.(json.Delim); !ok || t.String() != "{" {
return "", "", fmt.Errorf("expected start token to be {")
}

for decoder.More() {
var target *string
t, err := decoder.Token()
if err != nil {
return "", "", err
}

switch st := t.(string); {
case strings.EqualFold(st, "code"):
fallthrough
case strings.EqualFold(st, "__type"):
target = &errorType
case strings.EqualFold(st, "message"):
target = &message
default:
smithyjson.DiscardUnknownField(decoder)
continue
}

v, err := decoder.Token()
if err != nil {
return errorType, message, err
err = decoder.Decode(&errInfo)
if err != nil {
if err == io.EOF {
return errorType, message, nil
}
*target = v.(string)
return errorType, message, err
}

endToken, err := decoder.Token()
if err != nil {
return "", "", err
// assign error type
if len(errInfo.Code) != 0 {
errorType = errInfo.Code
} else if len(errInfo.Type) != 0 {
errorType = errInfo.Type
}

if t, ok := endToken.(json.Delim); !ok || t.String() != "}" {
return "", "", fmt.Errorf("expected end token to be }")
// assign error message
if len(errInfo.Message) != 0 {
message = errInfo.Message
}

// sanitize error
errorType = SanitizeErrorCode(errorType)
if len(errorType) != 0 {
errorType = SanitizeErrorCode(errorType)
}

return errorType, message, nil
}

Expand Down
83 changes: 83 additions & 0 deletions aws/protocol/restjson/decoder_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package restjson

import (
"bytes"
"encoding/json"
"io"
"strings"
"testing"
)

func TestGetErrorInfo(t *testing.T) {
cases := map[string]struct {
errorResponse []byte
expectedErrorType string
expectedErrorMsg string
expectedDeserializationError string
}{
"error with code": {
errorResponse: []byte(`{"code": "errorCode", "message": "message for errorCode"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "message for errorCode",
},
"error with type": {
errorResponse: []byte(`{"__type": "errorCode", "message": "message for errorCode"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "message for errorCode",
},

"error with only message": {
errorResponse: []byte(`{"message": "message for errorCode"}`),
expectedErrorMsg: "message for errorCode",
},

"error with only code": {
errorResponse: []byte(`{"code": "errorCode"}`),
expectedErrorType: "errorCode",
},

"empty": {
errorResponse: []byte(``),
},

"unknownField": {
errorResponse: []byte(`{"xyz":"abc", "code": "errorCode"}`),
expectedErrorType: "errorCode",
},

"unexpectedEOF": {
errorResponse: []byte(`{"xyz":"abc"`),
expectedDeserializationError: io.ErrUnexpectedEOF.Error(),
},

"caseless compare": {
errorResponse: []byte(`{"Code": "errorCode", "Message": "errorMessage", "xyz": "abc"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "errorMessage",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
decoder := json.NewDecoder(bytes.NewReader(c.errorResponse))
actualType, actualMsg, err := GetErrorInfo(decoder)
if err != nil {
if len(c.expectedDeserializationError) == 0 {
t.Fatalf("expected no error, got %v", err.Error())
}

if e, a := c.expectedDeserializationError, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to be %v, got %v", e, a)
}
}

if e, a := c.expectedErrorType, actualType; !strings.EqualFold(e, a) {
t.Fatalf("expected error type to be %v, got %v", e, a)
}

if e, a := c.expectedErrorMsg, actualMsg; !strings.EqualFold(e, a) {
t.Fatalf("expected error message to be %v, got %v", e, a)
}
})
}
}

0 comments on commit e113848

Please sign in to comment.