Skip to content

Commit

Permalink
Use reflect.Value.Pointer() to compare pointers
Browse files Browse the repository at this point in the history
Fixes #1076
  • Loading branch information
AlexanderYastrebov committed Oct 7, 2024
1 parent 2fc4e39 commit e14a11f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
28 changes: 12 additions & 16 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,12 +491,12 @@ func validateEqualArgs(expected, actual interface{}) error {
return nil
}

// Same asserts that two pointers reference the same object.
// Same asserts that two arguments reference the same object.
//
// assert.Same(t, ptr1, ptr2)
// assert.Same(t, arg1, arg2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
// Both arguments can be pointers, channels, functions, maps, slices or strings.
// Argument sameness is determined based on the equality of both type and value.
func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
Expand All @@ -511,12 +511,12 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
return true
}

// NotSame asserts that two pointers do not reference the same object.
// NotSame asserts that two arguments do not reference the same object.
//
// assert.NotSame(t, ptr1, ptr2)
// assert.NotSame(t, arg1, arg2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
// Both arguments can be pointers, channels, functions, maps, slices or strings.
// Argument sameness is determined based on the equality of both type and value.
func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
Expand All @@ -534,17 +534,13 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
// they point to the same object
func samePointers(first, second interface{}) bool {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false
}

firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
if firstType != secondType {
switch firstPtr.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.String:
return firstPtr.Kind() == secondPtr.Kind() && firstPtr.Pointer() == secondPtr.Pointer()
default:
return false
}

// compare pointer addresses
return first == second
}

// formatUnequalValues takes two values of arbitrary types and returns string
Expand Down
68 changes: 67 additions & 1 deletion assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,12 @@ func TestNotSame(t *testing.T) {

func Test_samePointers(t *testing.T) {
p := ptr(2)
c1, c2 := make(chan int), make(chan int)
f1, f2 := func() {}, func() {}
m1, m2 := map[int]int{1: 2}, map[int]int{1: 2}
p1, p2 := ptr(3), ptr(3)
s1, s2 := []int{4, 5}, []int{4, 5}
str1, str2 := "6", "6_"[:1] // ensure strings use different backing arrays

type args struct {
first interface{}
Expand Down Expand Up @@ -634,6 +640,66 @@ func Test_samePointers(t *testing.T) {
args: args{first: [2]int{1, 2}, second: []int{1, 2}},
assertion: False,
},
{
name: "chan(1) == chan(1)",
args: args{first: c1, second: c1},
assertion: True,
},
{
name: "func(1) == func(1)",
args: args{first: f1, second: f1},
assertion: True,
},
{
name: "map(1) == map(1)",
args: args{first: m1, second: m1},
assertion: True,
},
{
name: "ptr(1) == ptr(1)",
args: args{first: p1, second: p1},
assertion: True,
},
{
name: "slice(1) == slice(1)",
args: args{first: s1, second: s1},
assertion: True,
},
{
name: "string(1) == string(1)",
args: args{first: str1, second: str1},
assertion: True,
},
{
name: "chan(1) != chan(2)",
args: args{first: c1, second: c2},
assertion: False,
},
{
name: "func(1) != func(2)",
args: args{first: f1, second: f2},
assertion: False,
},
{
name: "map(1) != map(2)",
args: args{first: m1, second: m2},
assertion: False,
},
{
name: "ptr(1) != ptr(2)",
args: args{first: p1, second: p2},
assertion: False,
},
{
name: "slice(1) != slice(2)",
args: args{first: s1, second: s2},
assertion: False,
},
{
name: "string(1) != string(2)",
args: args{first: str1, second: str2},
assertion: False,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -2505,7 +2571,7 @@ Diff:
@@ -1,2 +1,2 @@
-(time.Time) 2020-09-24 00:00:00 +0000 UTC
+(time.Time) 2020-09-25 00:00:00 +0000 UTC
`

actual = diff(
Expand Down

0 comments on commit e14a11f

Please sign in to comment.