From 38379b548cf827732e77467f28c830238437b8f0 Mon Sep 17 00:00:00 2001 From: mahigadamsetty <80254540+mahigadamsetty@users.noreply.github.com> Date: Sun, 3 Mar 2024 10:38:42 +0530 Subject: [PATCH 1/3] added a new api ArgsForCallCount on mock struct --- mock/mock.go | 61 ++++++++++++++++++++++++++++++++++++----------- mock/mock_test.go | 59 ++++++++++++++++++++++++++++++++------------- 2 files changed, 89 insertions(+), 31 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index 213bde2ea..09348b2e8 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -282,7 +282,7 @@ type Mock struct { ExpectedCalls []*Call // Holds the calls that were made to this mocked object. - Calls []Call + Calls map[string][]Call // test is An optional variable that holds the test struct, to be used when an // invalid mock call was made. @@ -528,8 +528,21 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen } call.totalCalls++ + fmt.Println(m.Calls) + // add the call - m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...)) + if m.Calls == nil { + fmt.Printf("hello") + m.Calls = make(map[string][]Call) + } + + calls, ok := m.Calls[methodName] + if !ok { + m.Calls[methodName] = []Call{*newCall(m, methodName, assert.CallerInfo(), arguments...)} + } else { + calls = append(calls, *newCall(m, methodName, assert.CallerInfo(), arguments...)) + m.Calls[methodName] = calls + } m.mutex.Unlock() // block if specified @@ -640,11 +653,11 @@ func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls m.mutex.Lock() defer m.mutex.Unlock() var actualCalls int - for _, call := range m.calls() { - if call.Method == methodName { - actualCalls++ - } + calls, ok := m.Calls[methodName] + if ok { + actualCalls = len(calls) } + return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls)) } @@ -658,8 +671,10 @@ func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interfac defer m.mutex.Unlock() if !m.methodWasCalled(methodName, arguments) { var calledWithArgs []string - for _, call := range m.calls() { - calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments)) + for _, calls := range m.Calls { + for _, call := range calls { + calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments)) + } } if len(calledWithArgs) == 0 { return assert.Fail(t, "Should have called with given arguments", @@ -712,6 +727,24 @@ func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...inte return false } +// ArgsForCallCount returns the arguments of a function for a specific call count(0 based index) +func (m *Mock) ArgsForCallCount(t TestingT, methodName string, count int) Arguments { + fmt.Println("ArgsForCallCount") + fmt.Println(m.Calls) + calls, ok := m.Calls[methodName] + if !ok { + assert.Fail(t, "ArgsForCallCount", + fmt.Sprintf("Expected %q to have been called with:\nbut no actual calls happened", methodName)) + } + + if len(calls) < count+1 { + assert.Fail(t, "ArgsForCallCount", + fmt.Sprintf("Expected %q to have been called with count:%d:\nbut no actual calls happened", methodName, count)) + } + + return calls[count].Arguments +} + // isArgsEqual compares arguments func isArgsEqual(expected Arguments, args []interface{}) bool { if len(expected) != len(args) { @@ -726,8 +759,9 @@ func isArgsEqual(expected Arguments, args []interface{}) bool { } func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { - for _, call := range m.calls() { - if call.Method == methodName { + calls, ok := m.Calls[methodName] + if ok { + for _, call := range calls { _, differences := Arguments(expected).Diff(call.Arguments) @@ -735,7 +769,6 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { // found the expected call return true } - } } // we didn't find the expected call @@ -746,9 +779,9 @@ func (m *Mock) expectedCalls() []*Call { return append([]*Call{}, m.ExpectedCalls...) } -func (m *Mock) calls() []Call { - return append([]Call{}, m.Calls...) -} +// func (m *Mock) calls() []Call { +// return append([]Call{}, m.Calls...) +// } /* Arguments diff --git a/mock/mock_test.go b/mock/mock_test.go index b80a8a75b..addfbde77 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -1180,14 +1180,31 @@ func Test_Mock_Called(t *testing.T) { var mockedService = new(TestExampleImplementation) mockedService.On("Test_Mock_Called", 1, 2, 3).Return(5, "6", true) + mockedService.On("Test_Mock_Called", 4, 5, 6).Return(8, "10", false) returnArguments := mockedService.Called(1, 2, 3) + returnArguments1 := mockedService.Called(4, 5, 6) if assert.Equal(t, 1, len(mockedService.Calls)) { - assert.Equal(t, "Test_Mock_Called", mockedService.Calls[0].Method) - assert.Equal(t, 1, mockedService.Calls[0].Arguments[0]) - assert.Equal(t, 2, mockedService.Calls[0].Arguments[1]) - assert.Equal(t, 3, mockedService.Calls[0].Arguments[2]) + // assert.Equal(t, "Test_Mock_Called", mockedService.Calls["Test_Mock_Called"][0].Method) + // assert.Equal(t, 1, mockedService.Calls["Test_Mock_Called"][0].Arguments[0]) + // assert.Equal(t, 2, mockedService.Calls["Test_Mock_Called"][0].Arguments[1]) + // assert.Equal(t, 3, mockedService.Calls["Test_Mock_Called"][0].Arguments[2]) + + returnArgs := mockedService.ArgsForCallCount(t, "Test_Mock_Called", 0) + if assert.Equal(t, 3, len(returnArgs)) { + assert.Equal(t, 1, returnArgs[0]) + assert.Equal(t, 2, returnArgs[1]) + assert.Equal(t, 3, returnArgs[2]) + } + + returnArgs = mockedService.ArgsForCallCount(t, "Test_Mock_Called", 1) + if assert.Equal(t, 3, len(returnArgs)) { + assert.Equal(t, 4, returnArgs[0]) + assert.Equal(t, 5, returnArgs[1]) + assert.Equal(t, 6, returnArgs[2]) + } + } if assert.Equal(t, 3, len(returnArguments)) { @@ -1196,6 +1213,12 @@ func Test_Mock_Called(t *testing.T) { assert.Equal(t, true, returnArguments[2]) } + if assert.Equal(t, 3, len(returnArguments1)) { + assert.Equal(t, 8, returnArguments1[0]) + assert.Equal(t, "10", returnArguments1[1]) + assert.Equal(t, false, returnArguments1[2]) + } + } func asyncCall(m *Mock, ch chan Arguments) { @@ -1221,10 +1244,10 @@ func Test_Mock_Called_blocks(t *testing.T) { returnArguments := <-ch if assert.Equal(t, 1, len(mockedService.Mock.Calls)) { - assert.Equal(t, "asyncCall", mockedService.Mock.Calls[0].Method) - assert.Equal(t, 1, mockedService.Mock.Calls[0].Arguments[0]) - assert.Equal(t, 2, mockedService.Mock.Calls[0].Arguments[1]) - assert.Equal(t, 3, mockedService.Mock.Calls[0].Arguments[2]) + assert.Equal(t, "asyncCall", mockedService.Calls["asyncCall"][0].Method) + assert.Equal(t, 1, mockedService.Calls["asyncCall"][0].Arguments[0]) + assert.Equal(t, 2, mockedService.Calls["asyncCall"][0].Arguments[1]) + assert.Equal(t, 3, mockedService.Calls["asyncCall"][0].Arguments[2]) } if assert.Equal(t, 3, len(returnArguments)) { @@ -1250,16 +1273,18 @@ func Test_Mock_Called_For_Bounded_Repeatability(t *testing.T) { returnArguments1 := mockedService.Called(1, 2, 3) returnArguments2 := mockedService.Called(1, 2, 3) - if assert.Equal(t, 2, len(mockedService.Calls)) { - assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Calls[0].Method) - assert.Equal(t, 1, mockedService.Calls[0].Arguments[0]) - assert.Equal(t, 2, mockedService.Calls[0].Arguments[1]) - assert.Equal(t, 3, mockedService.Calls[0].Arguments[2]) + // t.Error(mockedService.Calls) + + if assert.Equal(t, 2, len(mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"])) { + assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][0].Method) + assert.Equal(t, 1, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][0].Arguments[0]) + assert.Equal(t, 2, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][0].Arguments[1]) + assert.Equal(t, 3, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][0].Arguments[2]) - assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Calls[1].Method) - assert.Equal(t, 1, mockedService.Calls[1].Arguments[0]) - assert.Equal(t, 2, mockedService.Calls[1].Arguments[1]) - assert.Equal(t, 3, mockedService.Calls[1].Arguments[2]) + assert.Equal(t, "Test_Mock_Called_For_Bounded_Repeatability", mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][1].Method) + assert.Equal(t, 1, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][1].Arguments[0]) + assert.Equal(t, 2, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][1].Arguments[1]) + assert.Equal(t, 3, mockedService.Calls["Test_Mock_Called_For_Bounded_Repeatability"][1].Arguments[2]) } if assert.Equal(t, 3, len(returnArguments1)) { From 3ed6f734608ec564c8841f17dd8216db583ab91d Mon Sep 17 00:00:00 2001 From: mahigadamsetty <80254540+mahigadamsetty@users.noreply.github.com> Date: Sun, 3 Mar 2024 10:42:26 +0530 Subject: [PATCH 2/3] removed commented code --- mock/mock.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index 09348b2e8..7e71e89b5 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -779,10 +779,6 @@ func (m *Mock) expectedCalls() []*Call { return append([]*Call{}, m.ExpectedCalls...) } -// func (m *Mock) calls() []Call { -// return append([]Call{}, m.Calls...) -// } - /* Arguments */ From e110004e1f8e24428a032ca9600f4310c9486ecb Mon Sep 17 00:00:00 2001 From: mahigadamsetty <80254540+mahigadamsetty@users.noreply.github.com> Date: Sun, 3 Mar 2024 10:43:17 +0530 Subject: [PATCH 3/3] removed commented code --- mock/mock_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mock/mock_test.go b/mock/mock_test.go index addfbde77..113d33faa 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -1186,11 +1186,6 @@ func Test_Mock_Called(t *testing.T) { returnArguments1 := mockedService.Called(4, 5, 6) if assert.Equal(t, 1, len(mockedService.Calls)) { - // assert.Equal(t, "Test_Mock_Called", mockedService.Calls["Test_Mock_Called"][0].Method) - // assert.Equal(t, 1, mockedService.Calls["Test_Mock_Called"][0].Arguments[0]) - // assert.Equal(t, 2, mockedService.Calls["Test_Mock_Called"][0].Arguments[1]) - // assert.Equal(t, 3, mockedService.Calls["Test_Mock_Called"][0].Arguments[2]) - returnArgs := mockedService.ArgsForCallCount(t, "Test_Mock_Called", 0) if assert.Equal(t, 3, len(returnArgs)) { assert.Equal(t, 1, returnArgs[0])