From 6a2e51ee479d0296293de6a60c7404b98b950a93 Mon Sep 17 00:00:00 2001 From: Onsi Fakhouri Date: Wed, 26 Oct 2022 13:40:12 -0600 Subject: [PATCH] First pass at gcustom: a convenience package for making custom matchers. Documentation to follow. --- gcustom/gcustom_suite_test.go | 13 +++ gcustom/make_matcher.go | 151 ++++++++++++++++++++++++++ gcustom/make_matcher_test.go | 195 ++++++++++++++++++++++++++++++++++ 3 files changed, 359 insertions(+) create mode 100644 gcustom/gcustom_suite_test.go create mode 100644 gcustom/make_matcher.go create mode 100644 gcustom/make_matcher_test.go diff --git a/gcustom/gcustom_suite_test.go b/gcustom/gcustom_suite_test.go new file mode 100644 index 000000000..45c577ccf --- /dev/null +++ b/gcustom/gcustom_suite_test.go @@ -0,0 +1,13 @@ +package gcustom_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGcustom(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Gcustom Suite") +} diff --git a/gcustom/make_matcher.go b/gcustom/make_matcher.go new file mode 100644 index 000000000..8395da4fd --- /dev/null +++ b/gcustom/make_matcher.go @@ -0,0 +1,151 @@ +package gcustom + +import ( + "fmt" + "reflect" + "strings" + "text/template" + + "github.com/onsi/gomega/format" +) + +var interfaceType = reflect.TypeOf((*interface{})(nil)).Elem() +var errInterface = reflect.TypeOf((*error)(nil)).Elem() + +var defaultTemplate = template.Must(ParseTemplate("{{if .Failure}}Custom matcher failed for:{{else}}Custom matcher succeeded (but was expected to fail) for:{{end}}\n{{.FormattedActual}}")) + +func formatObject(object any, indent ...uint) string { + indentation := uint(0) + if len(indent) > 0 { + indentation = indent[0] + } + return format.Object(object, indentation) +} + +func ParseTemplate(templ string) (*template.Template, error) { + return template.New("template").Funcs(template.FuncMap{ + "format": formatObject, + }).Parse(templ) +} + +func MakeMatcher(matchFunc any, args ...any) CustomGomegaMatcher { + t := reflect.TypeOf(matchFunc) + if !(t.Kind() == reflect.Func && t.NumIn() == 1 && t.NumOut() == 2 && t.Out(0).Kind() == reflect.Bool && t.Out(1).Implements(errInterface)) { + panic("MakeMatcher must be passed a function that takes one argument and returns (bool, error)") + } + var finalMatchFunc func(actual any) (bool, error) + if t.In(0) == interfaceType { + finalMatchFunc = matchFunc.(func(actual any) (bool, error)) + } else { + matchFuncValue := reflect.ValueOf(matchFunc) + finalMatchFunc = reflect.MakeFunc(reflect.TypeOf(finalMatchFunc), + func(args []reflect.Value) []reflect.Value { + actual := args[0].Interface() + if reflect.TypeOf(actual).AssignableTo(t.In(0)) { + return matchFuncValue.Call([]reflect.Value{reflect.ValueOf(actual)}) + } else { + return []reflect.Value{ + reflect.ValueOf(false), + reflect.ValueOf(fmt.Errorf("Matcher expected actual of type <%s>. Got:\n%s", t.In(0), format.Object(actual, 1))), + } + } + }).Interface().(func(actual any) (bool, error)) + } + + matcher := CustomGomegaMatcher{ + matchFunc: finalMatchFunc, + templateMessage: defaultTemplate, + } + + for _, arg := range args { + switch v := arg.(type) { + case string: + matcher = matcher.WithMessage(v) + case *template.Template: + matcher = matcher.WithPrecompiledTemplate(v) + } + } + + return matcher +} + +type CustomGomegaMatcher struct { + matchFunc func(actual any) (bool, error) + templateMessage *template.Template + templateData any + customFailureMessage func(actual any) string + customNegatedFailureMessage func(actual any) string +} + +func (c CustomGomegaMatcher) WithMessage(message string) CustomGomegaMatcher { + return c.WithTemplate("Expected:\n{{.FormattedActual}}\n{{.To}} " + message) +} + +func (c CustomGomegaMatcher) WithTemplate(templ string, data ...any) CustomGomegaMatcher { + return c.WithPrecompiledTemplate(template.Must(ParseTemplate(templ)), data...) +} + +func (c CustomGomegaMatcher) WithPrecompiledTemplate(templ *template.Template, data ...any) CustomGomegaMatcher { + c.templateMessage = templ + c.templateData = nil + if len(data) > 0 { + c.templateData = data[0] + } + return c +} + +func (c CustomGomegaMatcher) WithTemplateData(data any) CustomGomegaMatcher { + c.templateData = data + return c +} + +func (c CustomGomegaMatcher) Match(actual any) (bool, error) { + return c.matchFunc(actual) +} + +func (c CustomGomegaMatcher) FailureMessage(actual any) string { + return c.renderTemplateMessage(actual, true) +} + +func (c CustomGomegaMatcher) NegatedFailureMessage(actual any) string { + return c.renderTemplateMessage(actual, false) +} + +type templateData struct { + Failure bool + NegatedFailure bool + To string + FormattedActual string + Actual any + Data any +} + +func (c CustomGomegaMatcher) renderTemplateMessage(actual any, isFailure bool) string { + var data templateData + formattedActual := format.Object(actual, 1) + if isFailure { + data = templateData{ + Failure: true, + NegatedFailure: false, + To: "to", + FormattedActual: formattedActual, + Actual: actual, + Data: c.templateData, + } + } else { + data = templateData{ + Failure: false, + NegatedFailure: true, + To: "not to", + FormattedActual: formattedActual, + Actual: actual, + Data: c.templateData, + } + } + b := &strings.Builder{} + err := c.templateMessage.Execute(b, data) + if err != nil { + return fmt.Sprintf("Failed to render failure message template: %s", err.Error()) + } + return b.String() +} diff --git a/gcustom/make_matcher_test.go b/gcustom/make_matcher_test.go new file mode 100644 index 000000000..371253fc1 --- /dev/null +++ b/gcustom/make_matcher_test.go @@ -0,0 +1,195 @@ +package gcustom_test + +import ( + "errors" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gcustom" +) + +type someType struct { + Name string +} + +var _ = Describe("MakeMatcher", func() { + Describe("validating and wrapping the MatchFunc", func() { + DescribeTable("it panics when passed an invalid function", func(f any) { + Expect(func() { + gcustom.MakeMatcher(f) + }).To(PanicWith("MakeMatcher must be passed a function that takes one argument and returns (bool, error)")) + }, + Entry("a non-function", "foo"), + Entry("a non-function", 1), + Entry("a function with no input", func() (bool, error) { return false, nil }), + Entry("a function with too many inputs", func(a int, b string) (bool, error) { return false, nil }), + Entry("a function with no outputs", func(a any) {}), + Entry("a function with insufficient outputs", func(a any) bool { return false }), + Entry("a function with insufficient outputs", func(a any) error { return nil }), + Entry("a function with too many outputs", func(a any) (bool, error, string) { return false, nil, "" }), + Entry("a function with the wrong types of outputs", func(a any) (int, error) { return 1, nil }), + Entry("a function with the wrong types of outputs", func(a any) (bool, int) { return false, 1 }), + ) + + Context("when the match func accepts any actual", func() { + It("always passes in the actual, regardless of type", func() { + var passedIn any + m := gcustom.MakeMatcher(func(a any) (bool, error) { + passedIn = a + return true, nil + }) + + m.Match(1) + Ω(passedIn).Should(Equal(1)) + + m.Match("foo") + Ω(passedIn).Should(Equal("foo")) + + m.Match(someType{"foo"}) + Ω(passedIn).Should(Equal(someType{"foo"})) + + c := make(chan bool) + m.Match(c) + Ω(passedIn).Should(Equal(c)) + }) + }) + + Context("when the match func accepts a specific type", func() { + It("ensure the type matches before calling func", func() { + var passedIn any + m := gcustom.MakeMatcher(func(a int) (bool, error) { + passedIn = a + return true, nil + }) + + success, err := m.Match(1) + Ω(success).Should(BeTrue()) + Ω(err).ShouldNot(HaveOccurred()) + Ω(passedIn).Should(Equal(1)) + + passedIn = nil + success, err = m.Match(1.2) + Ω(success).Should(BeFalse()) + Ω(err).Should(MatchError(ContainSubstring("Matcher expected actual of type . Got:\n : 1.2"))) + Ω(passedIn).Should(BeNil()) + + m = gcustom.MakeMatcher(func(a someType) (bool, error) { + passedIn = a + return true, nil + }) + + success, err = m.Match(someType{"foo"}) + Ω(success).Should(BeTrue()) + Ω(err).ShouldNot(HaveOccurred()) + Ω(passedIn).Should(Equal(someType{"foo"})) + + passedIn = nil + success, err = m.Match("foo") + Ω(success).Should(BeFalse()) + Ω(err).Should(MatchError(ContainSubstring("Matcher expected actual of type . Got:\n : foo"))) + Ω(passedIn).Should(BeNil()) + + }) + }) + }) + + It("calls the matchFunc and returns whatever it returns when Match is called", func() { + m := gcustom.MakeMatcher(func(a int) (bool, error) { + if a == 0 { + return true, nil + } + if a == 1 { + return false, nil + } + return false, errors.New("bam") + }) + + Ω(m.Match(0)).Should(BeTrue()) + Ω(m.Match(1)).Should(BeFalse()) + success, err := m.Match(2) + Ω(success).Should(BeFalse()) + Ω(err).Should(MatchError("bam")) + }) + + Describe("rendering messages", func() { + var m gcustom.CustomGomegaMatcher + BeforeEach(func() { + m = gcustom.MakeMatcher(func(a any) (bool, error) { return false, nil }) + }) + + Context("when no message is configured", func() { + It("renders a simple canned message", func() { + Ω(m.FailureMessage(3)).Should(Equal("Custom matcher failed for:\n : 3")) + Ω(m.NegatedFailureMessage(3)).Should(Equal("Custom matcher succeeded (but was expected to fail) for:\n : 3")) + }) + }) + + Context("when a simple message is configured", func() { + It("tacks that message onto the end of a formatted string", func() { + m = m.WithMessage("have been confabulated") + Ω(m.FailureMessage(3)).Should(Equal("Expected:\n : 3\nto have been confabulated")) + Ω(m.NegatedFailureMessage(3)).Should(Equal("Expected:\n : 3\nnot to have been confabulated")) + + m = gcustom.MakeMatcher(func(a any) (bool, error) { return false, nil }, "have been confabulated") + Ω(m.FailureMessage(3)).Should(Equal("Expected:\n : 3\nto have been confabulated")) + Ω(m.NegatedFailureMessage(3)).Should(Equal("Expected:\n : 3\nnot to have been confabulated")) + + }) + }) + + Context("when a template is registered", func() { + It("uses that template", func() { + m = m.WithTemplate("{{.Failure}} {{.NegatedFailure}} {{.To}} {{.FormattedActual}} {{.Actual.Name}}") + Ω(m.FailureMessage(someType{"foo"})).Should(Equal("true false to : {Name: \"foo\"} foo")) + Ω(m.NegatedFailureMessage(someType{"foo"})).Should(Equal("false true not to : {Name: \"foo\"} foo")) + + }) + }) + + Context("when a template with custom data is registered", func() { + It("provides that custom data", func() { + m = m.WithTemplate("{{.Failure}} {{.NegatedFailure}} {{.To}} {{.FormattedActual}} {{.Actual.Name}} {{.Data}}", 17) + + Ω(m.FailureMessage(someType{"foo"})).Should(Equal("true false to : {Name: \"foo\"} foo 17")) + Ω(m.NegatedFailureMessage(someType{"foo"})).Should(Equal("false true not to : {Name: \"foo\"} foo 17")) + }) + + It("provides a mechanism for formatting custom data", func() { + m = m.WithTemplate("{{format .Data}}", 17) + + Ω(m.FailureMessage(0)).Should(Equal(": 17")) + Ω(m.NegatedFailureMessage(0)).Should(Equal(": 17")) + + m = m.WithTemplate("{{format .Data 1}}", 17) + + Ω(m.FailureMessage(0)).Should(Equal(" : 17")) + Ω(m.NegatedFailureMessage(0)).Should(Equal(" : 17")) + + }) + }) + + Context("when a precompiled template is registered", func() { + It("uses that template", func() { + templ, err := gcustom.ParseTemplate("{{.Failure}} {{.NegatedFailure}} {{.To}} {{.FormattedActual}} {{.Actual.Name}} {{format .Data}}") + Ω(err).ShouldNot(HaveOccurred()) + + m = m.WithPrecompiledTemplate(templ, 17) + Ω(m.FailureMessage(someType{"foo"})).Should(Equal("true false to : {Name: \"foo\"} foo : 17")) + Ω(m.NegatedFailureMessage(someType{"foo"})).Should(Equal("false true not to : {Name: \"foo\"} foo : 17")) + }) + + It("can also take a template as an argument upon construction", func() { + templ, err := gcustom.ParseTemplate("{{.To}} {{format .Data}}") + Ω(err).ShouldNot(HaveOccurred()) + m = gcustom.MakeMatcher(func(a any) (bool, error) { return false, nil }, templ) + + Ω(m.FailureMessage(0)).Should(Equal("to : nil")) + Ω(m.NegatedFailureMessage(0)).Should(Equal("not to : nil")) + + m = m.WithTemplateData(17) + Ω(m.FailureMessage(0)).Should(Equal("to : 17")) + Ω(m.NegatedFailureMessage(0)).Should(Equal("not to : 17")) + }) + }) + }) +})