From c04094ab1359d7e8223143a73e865cfdd3747edd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Gardfj=C3=A4ll?= Date: Tue, 14 May 2024 09:05:15 +0200 Subject: [PATCH 1/2] avoid data races in Arguments.Diff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes a concurrency issue that would lead to testify mocks producing data races detected by go test -race. These data races would occur whenever a mock pointer argument was concurrently modified. The reason being that Arguments.Diff uses the %v format specifier to get a presentable string for the argument. This also traverses the pointed-to data structure, which would lead to the data race. Signed-off-by: Peter Gardfjäll Signed-off-by: Johannes Würbach --- mock/mock_test.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/mock/mock_test.go b/mock/mock_test.go index 3dc9e0b1e..fd5d5f612 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -2177,6 +2177,42 @@ func Test_MockReturnAndCalledConcurrent(t *testing.T) { wg.Wait() } +type argType struct{ Question string } + +type pointerArgMock struct{ Mock } + +func (m *pointerArgMock) Question(arg *argType) int { + args := m.Called(arg) + return args.Int(0) +} + +// Exercises calling a mock with a pointer value that gets modified concurrently. Prior to fix +// https://github.com/stretchr/testify/pull/1598 this would fail when running go test with the -race +// flag, due to Arguments.Diff printing the format with specifier %v which traverses the pointed to +// data structure (that is being concurrently modified by another goroutine). +func Test_CallMockWithConcurrentlyModifiedPointerArg(t *testing.T) { + m := &pointerArgMock{} + m.On("Question", Anything).Return(42) + + ptrArg := &argType{Question: "What's the meaning of life?"} + + // Emulates a situation where the pointer value gets concurrently updated by another thread. + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + ptrArg.Question = "What is 7 * 6?" + }() + + // This is where we would get a data race since Arguments.Diff would traverse the pointed to + // struct while being updated. Something go test -race would identify as a data race. + value := m.Question(ptrArg) + assert.Equal(t, 42, value) + wg.Wait() + + m.AssertExpectations(t) +} + type timer struct{ Mock } func (s *timer) GetTime(i int) string { From 98cc35fc74b30d61226db57c3dea5b9170725140 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20W=C3=BCrbach?= Date: Fri, 3 Jan 2025 20:57:25 +0100 Subject: [PATCH 2/2] avoid rendering output --- mock/mock.go | 53 +++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index c95eeeca8..8e6fceb37 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -235,7 +235,7 @@ func (c *Call) Unset() *Call { var index int // write index for _, call := range c.Parent.ExpectedCalls { if call.Method == c.Method { - _, diffCount := call.Arguments.Diff(c.Arguments) + diffCount := call.Arguments.countDiff(c.Arguments) if diffCount == 0 { foundMatchingCall = true // Remove from ExpectedCalls - just skip it @@ -389,7 +389,7 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, * for i, call := range m.ExpectedCalls { if call.Method == method { - _, diffCount := call.Arguments.Diff(arguments) + diffCount := call.Arguments.countDiff(arguments) if diffCount == 0 { expectedCall = call if call.Repeatability > -1 { @@ -755,7 +755,7 @@ func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool { for _, call := range m.calls() { if call.Method == methodName { - _, differences := Arguments(expected).Diff(call.Arguments) + differences := Arguments(expected).countDiff(call.Arguments) if differences == 0 { // found the expected call @@ -957,11 +957,34 @@ func (args Arguments) Is(objects ...interface{}) bool { // // Returns the diff string and number of differences found. func (args Arguments) Diff(objects []interface{}) (string, int) { + return args.diff(objects, true) +} + +// countDiff gets the number of differences between the arguments +// and the specified objects. +// +// Returns the diff number of differences found. +func (args Arguments) countDiff(objects []interface{}) int { + _, count := args.diff(objects, false) + return count +} + +func noOpSprintf(format string, args ...interface{}) string { + return "" +} + +// diff allows for the diffing of arguments and objects. +func (args Arguments) diff(objects []interface{}, includeOutput bool) (string, int) { // TODO: could return string as error and nil for No difference output := "\n" var differences int + optSprintf := fmt.Sprintf + if !includeOutput { + optSprintf = noOpSprintf + } + maxArgCount := len(args) if len(objects) > maxArgCount { maxArgCount = len(objects) @@ -976,7 +999,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { actualFmt = "(Missing)" } else { actual = objects[i] - actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + actualFmt = optSprintf("(%[1]T=%[1]v)", actual) } if len(args) <= i { @@ -984,7 +1007,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { expectedFmt = "(Missing)" } else { expected = args[i] - expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + expectedFmt = optSprintf("(%[1]T=%[1]v)", expected) } if matcher, ok := expected.(argumentMatcher); ok { @@ -992,16 +1015,16 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { func() { defer func() { if r := recover(); r != nil { - actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) + actualFmt = optSprintf("panic in argument matcher: %v", r) } }() matches = matcher.Matches(actual) }() if matches { - output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + output = optSprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) } else { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + output = optSprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) } } else { switch expected := expected.(type) { @@ -1010,13 +1033,13 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) + output = optSprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) } case *IsTypeArgument: actualT := reflect.TypeOf(actual) if actualT != expected.t { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, safeTypeName(expected.t), safeTypeName(actualT), actualFmt) + output = optSprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, safeTypeName(expected.t), safeTypeName(actualT), actualFmt) } case *FunctionalOptionsArgument: var name string @@ -1027,26 +1050,26 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { const tName = "[]interface{}" if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 { differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt) + output = optSprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt) } else { if ef, af := assertOpts(expected.values, actual); ef == "" && af == "" { // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName) + output = optSprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName) } else { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef) + output = optSprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef) } } default: if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) + output = optSprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) } else { // not match differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + output = optSprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) } } }