From 6c3f81236984b8e16fa5790f52d3005eb89c24e1 Mon Sep 17 00:00:00 2001 From: Trock <35254251+GGXXLL@users.noreply.github.com> Date: Tue, 27 Apr 2021 10:08:06 +0800 Subject: [PATCH] feat(unierr): allow nil errors (close #125) (#126) --- unierr/error.go | 9 ++++++ unierr/error_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/unierr/error.go b/unierr/error.go index 41b246e7..dbf540c0 100644 --- a/unierr/error.go +++ b/unierr/error.go @@ -44,6 +44,12 @@ func Newf(code codes.Code, format string, args ...interface{}) *Error { // Wrap annotates an error with a codes.Code func Wrap(err error, code codes.Code) *Error { + if err == nil { + return &Error{ + msg: code.String(), + code: code, + } + } err = errors.WithStack(err) return &Error{ err: err, @@ -181,6 +187,9 @@ type stackTracer interface { // StackTrace implements the interface of errors.Wrap() func (e *Error) StackTrace() errors.StackTrace { + if e.err == nil { + return nil + } if err, ok := e.err.(stackTracer); ok { return err.StackTrace() } diff --git a/unierr/error_test.go b/unierr/error_test.go index 5db95229..21a4c66e 100644 --- a/unierr/error_test.go +++ b/unierr/error_test.go @@ -2,11 +2,12 @@ package unierr import ( "encoding/json" - "errors" - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" "strings" "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" ) func TestServerError_UnmarshalJSON(t *testing.T) { @@ -61,3 +62,62 @@ func TestServerError_CustomPrinter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []byte(`{"code":10,"message":"FOO"}`), bytes) } + +func TestWrap(t *testing.T) { + type args struct { + err error + code codes.Code + } + tests := []struct { + name string + args args + want string + }{ + {"err_nil", args{nil, codes.Aborted}, codes.Aborted.String()}, + {"err_foo", args{errors.New("foo"), codes.Aborted}, "foo"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testError := Wrap(tt.args.err, tt.args.code) + assert.Equal(t, tt.want, testError.Error()) + byts, err := json.Marshal(testError) + assert.NoError(t, err) + var result *Error + err = json.Unmarshal(byts, &result) + assert.NoError(t, err) + assert.Equal(t, testError.code, result.code) + assert.Equal(t, testError.msg, result.msg) + assert.True(t, IsAbortedErr(result)) + + status := testError.GRPCStatus() + assert.Equal(t, codes.Aborted, status.Code()) + assert.Equal(t, testError.Error(), status.Message()) + + }) + } +} + +func TestError_StackTrace(t *testing.T) { + type args struct { + err error + code codes.Code + } + tests := []struct { + name string + args args + want int + }{ + {"err_nil", args{nil, codes.Aborted}, 0}, + {"err_foo", args{errors.New("foo"), codes.Aborted}, 3}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &Error{ + err: tt.args.err, + code: tt.args.code, + } + s := e.StackTrace() + assert.Equal(t, tt.want, len(s)) + }) + } +}