Skip to content

Commit

Permalink
all: struct return fixes (#221)
Browse files Browse the repository at this point in the history
On amd64, there was a crash when calling objc.Send with a struct return type. This was due to calling the wrong objc message send function.

On arm64, if a struct return contained structs that only contained floats it would expect it to be in R8 instead of the expected float registers

Closes #223
  • Loading branch information
TotallyGamerJet authored and hajimehoshi committed Apr 4, 2024
1 parent 4889c1a commit ff2c2cc
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 38 deletions.
31 changes: 23 additions & 8 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,8 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
keepAlive = append(keepAlive, val)
addInt(val.Pointer())
} else if runtime.GOARCH == "arm64" && outType.Size() > maxRegAllocStructSize {
if !isAllSameFloat(outType) || outType.NumField() > 4 {
isAllFloats, numFields := isAllSameFloat(outType)
if !isAllFloats || numFields > 4 {
val := reflect.New(outType)
keepAlive = append(keepAlive, val)
syscall.arm64_r8 = val.Pointer()
Expand Down Expand Up @@ -351,20 +352,34 @@ func RegisterFunc(fptr interface{}, cfn uintptr) {
// maxRegAllocStructSize is the biggest a struct can be while still fitting in registers.
// if it is bigger than this than enough space must be allocated on the heap and then passed into
// the function as the first parameter on amd64 or in R8 on arm64.
//
// If you change this make sure to update it in objc_runtime_darwin.go
const maxRegAllocStructSize = 16

func isAllSameFloat(ty reflect.Type) bool {
first := ty.Field(0).Type.Kind()
func isAllSameFloat(ty reflect.Type) (allFloats bool, numFields int) {
allFloats = true
root := ty.Field(0).Type
for root.Kind() == reflect.Struct {
root = root.Field(0).Type
}
first := root.Kind()
if first != reflect.Float32 && first != reflect.Float64 {
return false
allFloats = false
}
for i := 0; i < ty.NumField(); i++ {
f := ty.Field(i)
if f.Type.Kind() != first {
return false
f := ty.Field(i).Type
if f.Kind() == reflect.Struct {
var structNumFields int
allFloats, structNumFields = isAllSameFloat(f)
numFields += structNumFields
continue
}
numFields++
if f.Kind() != first {
allFloats = false
}
}
return true
return allFloats, numFields
}

func checkStructFieldsSupported(ty reflect.Type) {
Expand Down
78 changes: 54 additions & 24 deletions objc/objc_runtime_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"math"
"reflect"
"regexp"
"runtime"
"unicode"
"unsafe"

Expand All @@ -20,28 +21,30 @@ import (
// TODO: support try/catch?
// https://stackoverflow.com/questions/7062599/example-of-how-objective-cs-try-catch-implementation-is-executed-at-runtime
var (
objc_msgSend_fn uintptr
objc_msgSend func(obj ID, cmd SEL, args ...interface{}) ID
objc_msgSendSuper2_fn uintptr
objc_msgSendSuper2 func(super *objc_super, cmd SEL, args ...interface{}) ID
objc_getClass func(name string) Class
objc_getProtocol func(name string) *Protocol
objc_allocateClassPair func(super Class, name string, extraBytes uintptr) Class
objc_registerClassPair func(class Class)
sel_registerName func(name string) SEL
class_getSuperclass func(class Class) Class
class_getInstanceVariable func(class Class, name string) Ivar
class_getInstanceSize func(class Class) uintptr
class_addMethod func(class Class, name SEL, imp IMP, types string) bool
class_addIvar func(class Class, name string, size uintptr, alignment uint8, types string) bool
class_addProtocol func(class Class, protocol *Protocol) bool
ivar_getOffset func(ivar Ivar) uintptr
ivar_getName func(ivar Ivar) string
object_getClass func(obj ID) Class
object_getIvar func(obj ID, ivar Ivar) ID
object_setIvar func(obj ID, ivar Ivar, value ID)
protocol_getName func(protocol *Protocol) string
protocol_isEqual func(p *Protocol, p2 *Protocol) bool
objc_msgSend_fn uintptr
objc_msgSend_stret_fn uintptr
objc_msgSend func(obj ID, cmd SEL, args ...interface{}) ID
objc_msgSendSuper2_fn uintptr
objc_msgSendSuper2_stret_fn uintptr
objc_msgSendSuper2 func(super *objc_super, cmd SEL, args ...interface{}) ID
objc_getClass func(name string) Class
objc_getProtocol func(name string) *Protocol
objc_allocateClassPair func(super Class, name string, extraBytes uintptr) Class
objc_registerClassPair func(class Class)
sel_registerName func(name string) SEL
class_getSuperclass func(class Class) Class
class_getInstanceVariable func(class Class, name string) Ivar
class_getInstanceSize func(class Class) uintptr
class_addMethod func(class Class, name SEL, imp IMP, types string) bool
class_addIvar func(class Class, name string, size uintptr, alignment uint8, types string) bool
class_addProtocol func(class Class, protocol *Protocol) bool
ivar_getOffset func(ivar Ivar) uintptr
ivar_getName func(ivar Ivar) string
object_getClass func(obj ID) Class
object_getIvar func(obj ID, ivar Ivar) ID
object_setIvar func(obj ID, ivar Ivar, value ID)
protocol_getName func(protocol *Protocol) string
protocol_isEqual func(p *Protocol, p2 *Protocol) bool
)

func init() {
Expand All @@ -53,6 +56,16 @@ func init() {
if err != nil {
panic(fmt.Errorf("objc: %w", err))
}
if runtime.GOARCH == "amd64" {
objc_msgSend_stret_fn, err = purego.Dlsym(objc, "objc_msgSend_stret")
if err != nil {
panic(fmt.Errorf("objc: %w", err))
}
objc_msgSendSuper2_stret_fn, err = purego.Dlsym(objc, "objc_msgSendSuper2_stret")
if err != nil {
panic(fmt.Errorf("objc: %w", err))
}
}
purego.RegisterFunc(&objc_msgSend, objc_msgSend_fn)
objc_msgSendSuper2_fn, err = purego.Dlsym(objc, "objc_msgSendSuper2")
if err != nil {
Expand Down Expand Up @@ -104,12 +117,22 @@ func (id ID) SetIvar(ivar Ivar, value ID) {
object_setIvar(id, ivar, value)
}

// keep in sync with func.go
const maxRegAllocStructSize = 16

// Send is a convenience method for sending messages to objects that can return any type.
// This function takes a SEL instead of a string since RegisterName grabs the global Objective-C lock.
// It is best to cache the result of RegisterName.
func Send[T any](id ID, sel SEL, args ...any) T {
var fn func(id ID, sel SEL, args ...any) T
purego.RegisterFunc(&fn, objc_msgSend_fn)
var zero T
if runtime.GOARCH == "amd64" &&
reflect.ValueOf(zero).Kind() == reflect.Struct &&
reflect.ValueOf(zero).Type().Size() > maxRegAllocStructSize {
purego.RegisterFunc(&fn, objc_msgSend_stret_fn)
} else {
purego.RegisterFunc(&fn, objc_msgSend_fn)
}
return fn(id, sel, args...)
}

Expand Down Expand Up @@ -141,7 +164,14 @@ func SendSuper[T any](id ID, sel SEL, args ...any) T {
superClass: id.Class(),
}
var fn func(objcSuper *objc_super, sel SEL, args ...any) T
purego.RegisterFunc(&fn, objc_msgSendSuper2_fn)
var zero T
if runtime.GOARCH == "amd64" &&
reflect.ValueOf(zero).Kind() == reflect.Struct &&
reflect.ValueOf(zero).Type().Size() > maxRegAllocStructSize {
purego.RegisterFunc(&fn, objc_msgSendSuper2_stret_fn)
} else {
purego.RegisterFunc(&fn, objc_msgSendSuper2_fn)
}
return fn(super, sel, args...)
}

Expand Down
36 changes: 36 additions & 0 deletions objc/objc_runtime_darwin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,39 @@ func TestSend(t *testing.T) {
t.Failed()
}
}

func ExampleSend() {
type NSRange struct {
Location, Range uint
}
class_NSString := objc.GetClass("NSString")
sel_stringWithUTF8String := objc.RegisterName("stringWithUTF8String:")

fullString := objc.ID(class_NSString).Send(sel_stringWithUTF8String, "Hello, World!\x00")
subString := objc.ID(class_NSString).Send(sel_stringWithUTF8String, "lo, Wor\x00")

r := objc.Send[NSRange](fullString, objc.RegisterName("rangeOfString:"), subString)

fmt.Println(r)

// Output: {3 7}
}

func ExampleSendSuper() {
super := objc.AllocateClassPair(objc.GetClass("NSObject"), "SuperObject2", 0)
super.AddMethod(objc.RegisterName("doSomething"), objc.NewIMP(func(self objc.ID, _cmd objc.SEL) int {
return 16
}), "i@:")
super.Register()

child := objc.AllocateClassPair(super, "ChildObject2", 0)
child.AddMethod(objc.RegisterName("doSomething"), objc.NewIMP(func(self objc.ID, _cmd objc.SEL) int {
return 24
}), "i@:")
child.Register()

res := objc.SendSuper[int](objc.ID(child).Send(objc.RegisterName("new")), objc.RegisterName("doSomething"))

fmt.Println(res)
// Output: 16
}
12 changes: 6 additions & 6 deletions struct_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ func getStruct(outType reflect.Type, syscall syscall15Args) (v reflect.Value) {
return reflect.New(outType).Elem()
case outSize <= 8:
r1 := syscall.a1
if isAllSameFloat(outType) {
if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats {
r1 = syscall.f1
if outType.NumField() == 2 {
if numFields == 2 {
r1 = syscall.f2<<32 | syscall.f1
}
}
return reflect.NewAt(outType, unsafe.Pointer(&struct{ a uintptr }{r1})).Elem()
case outSize <= 16:
r1, r2 := syscall.a1, syscall.a2
if isAllSameFloat(outType) {
switch outType.NumField() {
if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats {
switch numFields {
case 4:
r1 = syscall.f2<<32 | syscall.f1
r2 = syscall.f4<<32 | syscall.f3
Expand All @@ -42,8 +42,8 @@ func getStruct(outType reflect.Type, syscall syscall15Args) (v reflect.Value) {
}
return reflect.NewAt(outType, unsafe.Pointer(&struct{ a, b uintptr }{r1, r2})).Elem()
default:
if isAllSameFloat(outType) && outType.NumField() <= 4 {
switch outType.NumField() {
if isAllFloats, numFields := isAllSameFloat(outType); isAllFloats && numFields <= 4 {
switch numFields {
case 4:
return reflect.NewAt(outType, unsafe.Pointer(&struct{ a, b, c, d uintptr }{syscall.f1, syscall.f2, syscall.f3, syscall.f4})).Elem()
case 3:
Expand Down
12 changes: 12 additions & 0 deletions struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,18 @@ func TestRegisterFunc_structReturns(t *testing.T) {
t.Fatalf("ReturnFourDoubles returned %+v wanted %+v", ret, expected)
}
}
{
type FourDoublesInternal struct {
f struct{ a, b float64 }
g struct{ c, d float64 }
}
var ReturnFourDoublesInternal func(a, b, c, d float64) FourDoublesInternal
purego.RegisterLibFunc(&ReturnFourDoublesInternal, lib, "ReturnFourDoublesInternal")
expected := FourDoublesInternal{f: struct{ a, b float64 }{a: 1, b: 2}, g: struct{ c, d float64 }{c: 3, d: 4}}
if ret := ReturnFourDoublesInternal(1, 2, 3, 4); ret != expected {
t.Fatalf("ReturnFourDoublesInternal returned %+v wanted %+v", ret, expected)
}
}
{
type FiveDoubles struct{ a, b, c, d, e float64 }
var ReturnFiveDoubles func(a, b, c, d, e float64) FiveDoubles
Expand Down
14 changes: 14 additions & 0 deletions testdata/structtest/structreturn_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,20 @@ struct FourDoubles ReturnFourDoubles(double a, double b, double c, double d) {
return e;
}

struct FourDoublesInternal{
struct {
double a, b;
} f;
struct {
double c, d;
} g;
};

struct FourDoublesInternal ReturnFourDoublesInternal(double a, double b, double c, double d) {
struct FourDoublesInternal e = { {a, b}, {c, d} };
return e;
}

struct FiveDoubles{
double a, b, c, d, e;
};
Expand Down

0 comments on commit ff2c2cc

Please sign in to comment.