Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rest-json: updates rest-json error code retriever util #837

Merged
merged 2 commits into from
Oct 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 21 additions & 41 deletions aws/protocol/restjson/decoder_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,48 @@ package restjson

import (
"encoding/json"
"fmt"
"io"
"strings"

"github.com/awslabs/smithy-go"
smithyjson "github.com/awslabs/smithy-go/json"
)

// GetErrorInfo util looks for code, __type, and message members in the
// json body. These members are optionally available, and the function
// returns the value of member if it is available. This function is useful to
// identify the error code, msg in a REST JSON error response.
func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) {
startToken, err := decoder.Token()
if err == io.EOF {
return "", "", nil
var errInfo struct {
Code string
Type string `json:"__type"`
Message string
}
if err != nil {
return "", "", err
}

if t, ok := startToken.(json.Delim); !ok || t.String() != "{" {
return "", "", fmt.Errorf("expected start token to be {")
}

for decoder.More() {
var target *string
t, err := decoder.Token()
if err != nil {
return "", "", err
}

switch st := t.(string); {
case strings.EqualFold(st, "code"):
fallthrough
case strings.EqualFold(st, "__type"):
target = &errorType
case strings.EqualFold(st, "message"):
target = &message
default:
smithyjson.DiscardUnknownField(decoder)
continue
}

v, err := decoder.Token()
if err != nil {
return errorType, message, err
err = decoder.Decode(&errInfo)
if err != nil {
if err == io.EOF {
return errorType, message, nil
}
*target = v.(string)
return errorType, message, err
}

endToken, err := decoder.Token()
if err != nil {
return "", "", err
// assign error type
if len(errInfo.Code) != 0 {
errorType = errInfo.Code
} else if len(errInfo.Type) != 0 {
errorType = errInfo.Type
}

if t, ok := endToken.(json.Delim); !ok || t.String() != "}" {
return "", "", fmt.Errorf("expected end token to be }")
// assign error message
if len(errInfo.Message) != 0 {
message = errInfo.Message
}

// sanitize error
errorType = SanitizeErrorCode(errorType)
if len(errorType) != 0 {
errorType = SanitizeErrorCode(errorType)
}

return errorType, message, nil
}

Expand Down
83 changes: 83 additions & 0 deletions aws/protocol/restjson/decoder_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package restjson

import (
"bytes"
"encoding/json"
"io"
"strings"
"testing"
)

func TestGetErrorInfo(t *testing.T) {
cases := map[string]struct {
errorResponse []byte
expectedErrorType string
expectedErrorMsg string
expectedDeserializationError string
}{
"error with code": {
errorResponse: []byte(`{"code": "errorCode", "message": "message for errorCode"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "message for errorCode",
},
"error with type": {
errorResponse: []byte(`{"__type": "errorCode", "message": "message for errorCode"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "message for errorCode",
},

"error with only message": {
errorResponse: []byte(`{"message": "message for errorCode"}`),
expectedErrorMsg: "message for errorCode",
},

"error with only code": {
errorResponse: []byte(`{"code": "errorCode"}`),
expectedErrorType: "errorCode",
},

"empty": {
errorResponse: []byte(``),
},

"unknownField": {
errorResponse: []byte(`{"xyz":"abc", "code": "errorCode"}`),
expectedErrorType: "errorCode",
},

"unexpectedEOF": {
errorResponse: []byte(`{"xyz":"abc"`),
expectedDeserializationError: io.ErrUnexpectedEOF.Error(),
},

"caseless compare": {
errorResponse: []byte(`{"Code": "errorCode", "Message": "errorMessage", "xyz": "abc"}`),
expectedErrorType: "errorCode",
expectedErrorMsg: "errorMessage",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
decoder := json.NewDecoder(bytes.NewReader(c.errorResponse))
actualType, actualMsg, err := GetErrorInfo(decoder)
if err != nil {
if len(c.expectedDeserializationError) == 0 {
t.Fatalf("expected no error, got %v", err.Error())
}

if e, a := c.expectedDeserializationError, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expected error to be %v, got %v", e, a)
}
}

if e, a := c.expectedErrorType, actualType; !strings.EqualFold(e, a) {
t.Fatalf("expected error type to be %v, got %v", e, a)
}

if e, a := c.expectedErrorMsg, actualMsg; !strings.EqualFold(e, a) {
t.Fatalf("expected error message to be %v, got %v", e, a)
}
})
}
}