From ba1076d8b3b67cdaf7bf92c95b3641636a039be2 Mon Sep 17 00:00:00 2001 From: Paul Dufour Date: Wed, 22 Jun 2022 10:31:35 +0100 Subject: [PATCH] Add .Unset method to mock (#982) * Add .Off method to mock * Update README.md * Update mock.go * Update mock_test.go * Update README.md * Fix tests * Add unset test * remove prints * fix test * update readme --- README.md | 25 +++++++++++++++ mock/mock.go | 37 ++++++++++++++++++++++ mock/mock_test.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+) diff --git a/README.md b/README.md index 6b9ff9ad3..ce6d3de28 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,31 @@ func TestSomethingWithPlaceholder(t *testing.T) { } + +// TestSomethingElse2 is a third example that shows how you can use +// the Unset method to cleanup handlers and then add new ones. +func TestSomethingElse2(t *testing.T) { + + // create an instance of our test object + testObj := new(MyMockedObject) + + // setup expectations with a placeholder in the argument list + mockCall := testObj.On("DoSomething", mock.Anything).Return(true, nil) + + // call the code we are testing + targetFuncThatDoesSomethingWithObj(testObj) + + // assert that the expectations were met + testObj.AssertExpectations(t) + + // remove the handler now so we can add another one that takes precedence + mockCall.Unset() + + // return false now instead of true + testObj.On("DoSomething", mock.Anything).Return(false, nil) + + testObj.AssertExpectations(t) +} ``` For more information on how to write mock code, check out the [API documentation for the `mock` package](http://godoc.org/github.com/stretchr/testify/mock). diff --git a/mock/mock.go b/mock/mock.go index fefc8e985..769aed8b3 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -199,6 +199,43 @@ func (c *Call) On(methodName string, arguments ...interface{}) *Call { return c.Parent.On(methodName, arguments...) } +// Unset removes a mock handler from being called. +// test.On("func", mock.Anything).Unset() +func (c *Call) Unset() *Call { + var unlockOnce sync.Once + + for _, arg := range c.Arguments { + if v := reflect.ValueOf(arg); v.Kind() == reflect.Func { + panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg)) + } + } + + c.lock() + defer unlockOnce.Do(c.unlock) + + foundMatchingCall := false + + for i, call := range c.Parent.ExpectedCalls { + if call.Method == c.Method { + _, diffCount := call.Arguments.Diff(c.Arguments) + if diffCount == 0 { + foundMatchingCall = true + // Remove from ExpectedCalls + c.Parent.ExpectedCalls = append(c.Parent.ExpectedCalls[:i], c.Parent.ExpectedCalls[i+1:]...) + } + } + } + + if !foundMatchingCall { + unlockOnce.Do(c.unlock) + c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n", + callString(c.Method, c.Arguments, true), + ) + } + + return c +} + // Mock is the workhorse used to track activity on another object. // For an example of its usage, refer to the "Example Usage" section at the top // of this document. diff --git a/mock/mock_test.go b/mock/mock_test.go index f8befa87b..211568690 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -462,6 +462,87 @@ func Test_Mock_On_WithFuncTypeArg(t *testing.T) { }) } +func Test_Mock_Unset(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + call := mockedService. + On("TheExampleMethodFuncType", "argA"). + Return("blah") + + found, foundCall := mockedService.findExpectedCall("TheExampleMethodFuncType", "argA") + require.NotEqual(t, -1, found) + require.Equal(t, foundCall, call) + + call.Unset() + + found, foundCall = mockedService.findExpectedCall("TheExampleMethodFuncType", "argA") + require.Equal(t, -1, found) + + var expectedCall *Call + require.Equal(t, expectedCall, foundCall) + + fn := func(string) error { return nil } + assert.Panics(t, func() { + mockedService.TheExampleMethodFuncType(fn) + }) +} + +// Since every time you call On it creates a new object +// the last time you call Unset it will only unset the last call +func Test_Mock_Chained_UnsetOnlyUnsetsLastCall(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + // determine our current line number so we can assert the expected calls callerInfo properly + _, _, line, _ := runtime.Caller(0) + mockedService. + On("TheExampleMethod1", 1, 1). + Return(0). + On("TheExampleMethod2", 2, 2). + On("TheExampleMethod3", 3, 3, 3). + Return(nil). + Unset() + + expectedCalls := []*Call{ + { + Parent: &mockedService.Mock, + Method: "TheExampleMethod1", + Arguments: []interface{}{1, 1}, + ReturnArguments: []interface{}{0}, + callerInfo: []string{fmt.Sprintf("mock_test.go:%d", line+2)}, + }, + { + Parent: &mockedService.Mock, + Method: "TheExampleMethod2", + Arguments: []interface{}{2, 2}, + ReturnArguments: []interface{}{}, + callerInfo: []string{fmt.Sprintf("mock_test.go:%d", line+4)}, + }, + } + assert.Equal(t, 2, len(expectedCalls)) + assert.Equal(t, expectedCalls, mockedService.ExpectedCalls) +} + +func Test_Mock_UnsetIfAlreadyUnsetFails(t *testing.T) { + // make a test impl object + var mockedService = new(TestExampleImplementation) + + mock1 := mockedService. + On("TheExampleMethod1", 1, 1). + Return(1) + + assert.Equal(t, 1, len(mockedService.ExpectedCalls)) + mock1.Unset() + assert.Equal(t, 0, len(mockedService.ExpectedCalls)) + + assert.Panics(t, func() { + mock1.Unset() + }) + + assert.Equal(t, 0, len(mockedService.ExpectedCalls)) +} + func Test_Mock_Return(t *testing.T) { // make a test impl object