From ff2c2cce0d0b43e3ed9743fcd3408cbd9187f835 Mon Sep 17 00:00:00 2001 From: TotallyGamerJet Date: Wed, 3 Apr 2024 22:43:20 -0400 Subject: [PATCH] all: struct return fixes (#221) On amd64, there was a crash when calling objc.Send with a struct return type. This was due to calling the wrong objc message send function. On arm64, if a struct return contained structs that only contained floats it would expect it to be in R8 instead of the expected float registers Closes #223 --- func.go | 31 +++++++--- objc/objc_runtime_darwin.go | 78 +++++++++++++++++-------- objc/objc_runtime_darwin_test.go | 36 ++++++++++++ struct_arm64.go | 12 ++-- struct_test.go | 12 ++++ testdata/structtest/structreturn_test.c | 14 +++++ 6 files changed, 145 insertions(+), 38 deletions(-) diff --git a/func.go b/func.go index b0be0c37..7f3520c1 100644 --- a/func.go +++ b/func.go @@ -250,7 +250,8 @@ func RegisterFunc(fptr interface{}, cfn uintptr) { keepAlive = append(keepAlive, val) addInt(val.Pointer()) } else if runtime.GOARCH == "arm64" && outType.Size() > maxRegAllocStructSize { - if !isAllSameFloat(outType) || outType.NumField() > 4 { + isAllFloats, numFields := isAllSameFloat(outType) + if !isAllFloats || numFields > 4 { val := reflect.New(outType) keepAlive = append(keepAlive, val) syscall.arm64_r8 = val.Pointer() @@ -351,20 +352,34 @@ func RegisterFunc(fptr interface{}, cfn uintptr) { // maxRegAllocStructSize is the biggest a struct can be while still fitting in registers. // if it is bigger than this than enough space must be allocated on the heap and then passed into // the function as the first parameter on amd64 or in R8 on arm64. +// +// If you change this make sure to update it in objc_runtime_darwin.go const maxRegAllocStructSize = 16 -func isAllSameFloat(ty reflect.Type) bool { - first := ty.Field(0).Type.Kind() +func isAllSameFloat(ty reflect.Type) (allFloats bool, numFields int) { + allFloats = true + root := ty.Field(0).Type + for root.Kind() == reflect.Struct { + root = root.Field(0).Type + } + first := root.Kind() if first != reflect.Float32 && first != reflect.Float64 { - return false + allFloats = false } for i := 0; i < ty.NumField(); i++ { - f := ty.Field(i) - if f.Type.Kind() != first { - return false + f := ty.Field(i).Type + if f.Kind() == reflect.Struct { + var structNumFields int + allFloats, structNumFields = isAllSameFloat(f) + numFields += structNumFields + continue + } + numFields++ + if f.Kind() != first { + allFloats = false } } - return true + return allFloats, numFields } func checkStructFieldsSupported(ty reflect.Type) { diff --git a/objc/objc_runtime_darwin.go b/objc/objc_runtime_darwin.go index ed691e2c..b6855fd7 100644 --- a/objc/objc_runtime_darwin.go +++ b/objc/objc_runtime_darwin.go @@ -11,6 +11,7 @@ import ( "math" "reflect" "regexp" + "runtime" "unicode" "unsafe" @@ -20,28 +21,30 @@ import ( // TODO: support try/catch? // https://stackoverflow.com/questions/7062599/example-of-how-objective-cs-try-catch-implementation-is-executed-at-runtime var ( - objc_msgSend_fn uintptr - objc_msgSend func(obj ID, cmd SEL, args ...interface{}) ID - objc_msgSendSuper2_fn uintptr - objc_msgSendSuper2 func(super *objc_super, cmd SEL, args ...interface{}) ID - objc_getClass func(name string) Class - objc_getProtocol func(name string) *Protocol - objc_allocateClassPair func(super Class, name string, extraBytes uintptr) Class - objc_registerClassPair func(class Class) - sel_registerName func(name string) SEL - class_getSuperclass func(class Class) Class - class_getInstanceVariable func(class Class, name string) Ivar - class_getInstanceSize func(class Class) uintptr - class_addMethod func(class Class, name SEL, imp IMP, types string) bool - class_addIvar func(class Class, name string, size uintptr, alignment uint8, types string) bool - class_addProtocol func(class Class, protocol *Protocol) bool - ivar_getOffset func(ivar Ivar) uintptr - ivar_getName func(ivar Ivar) string - object_getClass func(obj ID) Class - object_getIvar func(obj ID, ivar Ivar) ID - object_setIvar func(obj ID, ivar Ivar, value ID) - protocol_getName func(protocol *Protocol) string - protocol_isEqual func(p *Protocol, p2 *Protocol) bool + objc_msgSend_fn uintptr + objc_msgSend_stret_fn uintptr + objc_msgSend func(obj ID, cmd SEL, args ...interface{}) ID + objc_msgSendSuper2_fn uintptr + objc_msgSendSuper2_stret_fn uintptr + objc_msgSendSuper2 func(super *objc_super, cmd SEL, args ...interface{}) ID + objc_getClass func(name string) Class + objc_getProtocol func(name string) *Protocol + objc_allocateClassPair func(super Class, name string, extraBytes uintptr) Class + objc_registerClassPair func(class Class) + sel_registerName func(name string) SEL + class_getSuperclass func(class Class) Class + class_getInstanceVariable func(class Class, name string) Ivar + class_getInstanceSize func(class Class) uintptr + class_addMethod func(class Class, name SEL, imp IMP, types string) bool + class_addIvar func(class Class, name string, size uintptr, alignment uint8, types string) bool + class_addProtocol func(class Class, protocol *Protocol) bool + ivar_getOffset func(ivar Ivar) uintptr + ivar_getName func(ivar Ivar) string + object_getClass func(obj ID) Class + object_getIvar func(obj ID, ivar Ivar) ID + object_setIvar func(obj ID, ivar Ivar, value ID) + protocol_getName func(protocol *Protocol) string + protocol_isEqual func(p *Protocol, p2 *Protocol) bool ) func init() { @@ -53,6 +56,16 @@ func init() { if err != nil { panic(fmt.Errorf("objc: %w", err)) } + if runtime.GOARCH == "amd64" { + objc_msgSend_stret_fn, err = purego.Dlsym(objc, "objc_msgSend_stret") + if err != nil { + panic(fmt.Errorf("objc: %w", err)) + } + objc_msgSendSuper2_stret_fn, err = purego.Dlsym(objc, "objc_msgSendSuper2_stret") + if err != nil { + panic(fmt.Errorf("objc: %w", err)) + } + } purego.RegisterFunc(&objc_msgSend, objc_msgSend_fn) objc_msgSendSuper2_fn, err = purego.Dlsym(objc, "objc_msgSendSuper2") if err != nil { @@ -104,12 +117,22 @@ func (id ID) SetIvar(ivar Ivar, value ID) { object_setIvar(id, ivar, value) } +// keep in sync with func.go +const maxRegAllocStructSize = 16 + // Send is a convenience method for sending messages to objects that can return any type. // This function takes a SEL instead of a string since RegisterName grabs the global Objective-C lock. // It is best to cache the result of RegisterName. func Send[T any](id ID, sel SEL, args ...any) T { var fn func(id ID, sel SEL, args ...any) T - purego.RegisterFunc(&fn, objc_msgSend_fn) + var zero T + if runtime.GOARCH == "amd64" && + reflect.ValueOf(zero).Kind() == reflect.Struct && + reflect.ValueOf(zero).Type().Size() > maxRegAllocStructSize { + purego.RegisterFunc(&fn, objc_msgSend_stret_fn) + } else { + purego.RegisterFunc(&fn, objc_msgSend_fn) + } return fn(id, sel, args...) } @@ -141,7 +164,14 @@ func SendSuper[T any](id ID, sel SEL, args ...any) T { superClass: id.Class(), } var fn func(objcSuper *objc_super, sel SEL, args ...any) T - purego.RegisterFunc(&fn, objc_msgSendSuper2_fn) + var zero T + if runtime.GOARCH == "amd64" && + reflect.ValueOf(zero).Kind() == reflect.Struct && + reflect.ValueOf(zero).Type().Size() > maxRegAllocStructSize { + purego.RegisterFunc(&fn, objc_msgSendSuper2_stret_fn) + } else { + purego.RegisterFunc(&fn, objc_msgSendSuper2_fn) + } return fn(super, sel, args...) } diff --git a/objc/objc_runtime_darwin_test.go b/objc/objc_runtime_darwin_test.go index 55ebd9e7..ba2e7b70 100644 --- a/objc/objc_runtime_darwin_test.go +++ b/objc/objc_runtime_darwin_test.go @@ -117,3 +117,39 @@ func TestSend(t *testing.T) { t.Failed() } } + +func ExampleSend() { + type NSRange struct { + Location, Range uint + } + class_NSString := objc.GetClass("NSString") + sel_stringWithUTF8String := objc.RegisterName("stringWithUTF8String:") + + fullString := objc.ID(class_NSString).Send(sel_stringWithUTF8String, "Hello, World!\x00") + subString := objc.ID(class_NSString).Send(sel_stringWithUTF8String, "lo, Wor\x00") + + r := objc.Send[NSRange](fullString, objc.RegisterName("rangeOfString:"), subString) + + fmt.Println(r) + + // Output: {3 7} +} + +func ExampleSendSuper() { + super := objc.AllocateClassPair(objc.GetClass("NSObject"), "SuperObject2", 0) + super.AddMethod(objc.RegisterName("doSomething"), objc.NewIMP(func(self objc.ID, _cmd objc.SEL) int { + return 16 + }), "i@:") + super.Register() + + child := objc.AllocateClassPair(super, "ChildObject2", 0) + child.AddMethod(objc.RegisterName("doSomething"), objc.NewIMP(func(self objc.ID, _cmd objc.SEL) int { + return 24 + }), "i@:") + child.Register() + + res := objc.SendSuper[int](objc.ID(child).Send(objc.RegisterName("new")), objc.RegisterName("doSomething")) + + fmt.Println(res) + // Output: 16 +} diff --git a/struct_arm64.go b/struct_arm64.go index 009d817c..11c36bd6 100644 --- a/struct_arm64.go +++ b/struct_arm64.go @@ -16,17 +16,17 @@ func getStruct(outType reflect.Type, syscall syscall15Args) (v reflect.Value) { return reflect.New(outType).Elem() case outSize <= 8: r1 := syscall.a1 - if isAllSameFloat(outType) { + if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats { r1 = syscall.f1 - if outType.NumField() == 2 { + if numFields == 2 { r1 = syscall.f2<<32 | syscall.f1 } } return reflect.NewAt(outType, unsafe.Pointer(&struct{ a uintptr }{r1})).Elem() case outSize <= 16: r1, r2 := syscall.a1, syscall.a2 - if isAllSameFloat(outType) { - switch outType.NumField() { + if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats { + switch numFields { case 4: r1 = syscall.f2<<32 | syscall.f1 r2 = syscall.f4<<32 | syscall.f3 @@ -42,8 +42,8 @@ func getStruct(outType reflect.Type, syscall syscall15Args) (v reflect.Value) { } return reflect.NewAt(outType, unsafe.Pointer(&struct{ a, b uintptr }{r1, r2})).Elem() default: - if isAllSameFloat(outType) && outType.NumField() <= 4 { - switch outType.NumField() { + if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats && numFields <= 4 { + switch numFields { case 4: return reflect.NewAt(outType, unsafe.Pointer(&struct{ a, b, c, d uintptr }{syscall.f1, syscall.f2, syscall.f3, syscall.f4})).Elem() case 3: diff --git a/struct_test.go b/struct_test.go index 66ab30f3..0d612a0a 100644 --- a/struct_test.go +++ b/struct_test.go @@ -585,6 +585,18 @@ func TestRegisterFunc_structReturns(t *testing.T) { t.Fatalf("ReturnFourDoubles returned %+v wanted %+v", ret, expected) } } + { + type FourDoublesInternal struct { + f struct{ a, b float64 } + g struct{ c, d float64 } + } + var ReturnFourDoublesInternal func(a, b, c, d float64) FourDoublesInternal + purego.RegisterLibFunc(&ReturnFourDoublesInternal, lib, "ReturnFourDoublesInternal") + expected := FourDoublesInternal{f: struct{ a, b float64 }{a: 1, b: 2}, g: struct{ c, d float64 }{c: 3, d: 4}} + if ret := ReturnFourDoublesInternal(1, 2, 3, 4); ret != expected { + t.Fatalf("ReturnFourDoublesInternal returned %+v wanted %+v", ret, expected) + } + } { type FiveDoubles struct{ a, b, c, d, e float64 } var ReturnFiveDoubles func(a, b, c, d, e float64) FiveDoubles diff --git a/testdata/structtest/structreturn_test.c b/testdata/structtest/structreturn_test.c index 7b470d7a..c55a0f97 100644 --- a/testdata/structtest/structreturn_test.c +++ b/testdata/structtest/structreturn_test.c @@ -129,6 +129,20 @@ struct FourDoubles ReturnFourDoubles(double a, double b, double c, double d) { return e; } +struct FourDoublesInternal{ + struct { + double a, b; + } f; + struct { + double c, d; + } g; +}; + +struct FourDoublesInternal ReturnFourDoublesInternal(double a, double b, double c, double d) { + struct FourDoublesInternal e = { {a, b}, {c, d} }; + return e; +} + struct FiveDoubles{ double a, b, c, d, e; };