From aaeaa5da16210f00ab06e287d69618cfda14f69e Mon Sep 17 00:00:00 2001 From: Tareq Sharafy Date: Sat, 14 Jan 2023 17:29:38 +0200 Subject: [PATCH] compare unwrapped errors using DeepEqual (#617) --- docs/index.md | 4 +++- matchers/match_error_matcher.go | 12 +++++++++++- matchers/match_error_matcher_test.go | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/index.md b/docs/index.md index 1a6b5b2ca..c28b3b338 100644 --- a/docs/index.md +++ b/docs/index.md @@ -794,7 +794,9 @@ succeeds if `ACTUAL` is a non-nil `error` that matches `EXPECTED`. `EXPECTED` mu - A string, in which case `ACTUAL.Error()` will be compared against `EXPECTED`. - A matcher, in which case `ACTUAL.Error()` is tested against the matcher. -- An error, in which case `ACTUAL` and `EXPECTED` are compared via `reflect.DeepEqual()`. If they are not deeply equal, they are tested by `errors.Is(ACTUAL, EXPECTED)`. (The latter allows to test whether `ACTUAL` wraps an `EXPECTED` error.) +- An error, in which case anyo of the following is satisfied: + - `errors.Is(ACTUAL, EXPECTED)` returns `true` + - `ACTUAL` or any of the errors it wraps (directly or indirectly) equals `EXPECTED` in terms of `reflect.DeepEqual()`. Any other type for `EXPECTED` is an error. diff --git a/matchers/match_error_matcher.go b/matchers/match_error_matcher.go index c8993a86d..827475ea5 100644 --- a/matchers/match_error_matcher.go +++ b/matchers/match_error_matcher.go @@ -25,7 +25,17 @@ func (matcher *MatchErrorMatcher) Match(actual interface{}) (success bool, err e expected := matcher.Expected if isError(expected) { - return reflect.DeepEqual(actualErr, expected) || errors.Is(actualErr, expected.(error)), nil + // first try the built-in errors.Is + if errors.Is(actualErr, expected.(error)) { + return true, nil + } + // if not, try DeepEqual along the error chain + for unwrapped := actualErr; unwrapped != nil; unwrapped = errors.Unwrap(unwrapped) { + if reflect.DeepEqual(unwrapped, expected) { + return true, nil + } + } + return false, nil } if isString(expected) { diff --git a/matchers/match_error_matcher_test.go b/matchers/match_error_matcher_test.go index a65b8c00e..99ea9acef 100644 --- a/matchers/match_error_matcher_test.go +++ b/matchers/match_error_matcher_test.go @@ -16,6 +16,14 @@ func (c CustomError) Error() string { return "an error" } +type ComplexError struct { + Key string +} + +func (t *ComplexError) Error() string { + return fmt.Sprintf("err: %s", t.Key) +} + var _ = Describe("MatchErrorMatcher", func() { Context("When asserting against an error", func() { When("passed an error", func() { @@ -37,6 +45,12 @@ var _ = Describe("MatchErrorMatcher", func() { Expect(outerErr).Should(MatchError(innerErr)) }) + + It("uses deep equality with unwrapped errors", func() { + innerErr := &ComplexError{Key: "abc"} + outerErr := fmt.Errorf("outer error wrapping: %w", &ComplexError{Key: "abc"}) + Expect(outerErr).To(MatchError(innerErr)) + }) }) When("actual an expected are both pointers to an error", func() { @@ -130,6 +144,7 @@ var _ = Describe("MatchErrorMatcher", func() { }) Expect(failuresMessages[0]).To(ContainSubstring("{s: \"foo\"}\nnot to match error\n : foo")) }) + }) type mockErr string