Skip to content

Commit

Permalink
Syscall Sysv: Add support for unreferencing callbacks
Browse files Browse the repository at this point in the history
Also deduplicate callbacks by function pointer.

Function equality in Go is problematic, which is why functions are not
comparable in the language spec or via reflect.Value.Equal().

For our particular use-case though I believe that comparing function
pointers is good enough, since we don't need to differentiate closures
based on scope variable content.
  • Loading branch information
pdf committed Jan 3, 2024
1 parent 57b69bd commit 78d550b
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 16 deletions.
57 changes: 57 additions & 0 deletions callback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,60 @@ func TestNewCallbackFloat32AndFloat64(t *testing.T) {
t.Errorf("cbTotalF64 not correct got %f but wanted %f", cbTotalF64, expectedCbTotalF64)
}
}

func TestCallbackDedup(t *testing.T) {
imp := func() bool {
return true
}

ref1 := purego.NewCallback(imp)
ref2 := purego.NewCallback(imp)

if ref1 != ref2 {
t.Errorf("deduplicate expected %d to equal %d", ref1, ref2)
}
}

func TestUnrefCallback(t *testing.T) {
imp := func() bool {
return true
}

if err := purego.UnrefCallback(imp); err == nil {
t.Errorf("unref of unknown callback produced nil but wanted error")
}

ref := purego.NewCallback(imp)

if err := purego.UnrefCallback(imp); err != nil {
t.Errorf("unref produced %v but wanted nil", err)
}
if err := purego.UnrefCallback(imp); err == nil {
t.Errorf("unref of already unref'd callback produced nil but wanted error")
}
if err := purego.UnrefCallbackPtr(ref); err == nil {
t.Errorf("unref of already unref'd callback ptr produced nil but wanted error")
}
}

func TestUnrefCallbackPtr(t *testing.T) {
imp := func() bool {
return true
}

if err := purego.UnrefCallbackPtr(0); err == nil {
t.Errorf("unref of unknown callback produced nil but wanted error")
}

ref := purego.NewCallback(imp)

if err := purego.UnrefCallbackPtr(ref); err != nil {
t.Errorf("callback unref produced %v but wanted nil", err)
}
if err := purego.UnrefCallbackPtr(ref); err == nil {
t.Errorf("callback unref of already unref'd callback produced nil but wanted error")
}
if err := purego.UnrefCallback(imp); err == nil {
t.Errorf("unref of already unref'd callback value produced nil but wanted error")
}
}
108 changes: 92 additions & 16 deletions syscall_sysv.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package purego

import (
"errors"
"reflect"
"runtime"
"sync"
Expand Down Expand Up @@ -33,12 +34,52 @@ func syscall_syscall9X(fn, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2,
return args.r1, args.r2, args.err
}

// UnrefCallback unreferences the associated callback (created by NewCallback) by reflect.Value.
func UnrefCallback(cb any) error {
val := reflect.ValueOf(cb)
if val.Kind() != reflect.Func {
panic("purego: the type must be a function but was not")
}
if val.IsNil() {
panic("purego: function must not be nil")
}

cbs.lock.Lock()
defer cbs.lock.Unlock()
ptr, ok := cbs.knownFn[val.Pointer()]
if !ok {
return errors.New(`callback not found`)
}
idx := cbs.knownIdx[ptr]
delete(cbs.knownFn, val.Pointer())
delete(cbs.knownIdx, ptr)
cbs.holes[idx] = struct{}{}
cbs.funcs[idx] = reflect.Value{}
return nil
}

// UnrefCallbackPtr unreferences the associated callback (created by NewCallback) by pointer.
func UnrefCallbackPtr(cb uintptr) error {
cbs.lock.Lock()
defer cbs.lock.Unlock()
idx, ok := cbs.knownIdx[cb]
if !ok {
return errors.New(`callback not found`)
}
val := cbs.funcs[idx]
delete(cbs.knownFn, val.Pointer())
delete(cbs.knownIdx, cb)
cbs.holes[idx] = struct{}{}
cbs.funcs[idx] = reflect.Value{}
return nil
}

// NewCallback converts a Go function to a function pointer conforming to the C calling convention.
// This is useful when interoperating with C code requiring callbacks. The argument is expected to be a
// function with zero or one uintptr-sized result. The function must not have arguments with size larger than the size
// of uintptr. Only a limited number of callbacks may be created in a single Go process, and any memory allocated
// for these callbacks is never released. At least 2000 callbacks can always be created. Although this function
// provides similar functionality to windows.NewCallback it is distinct.
// of uintptr. Only a limited number of callbacks may be live in a single Go process, and any memory allocated
// for these callbacks is not released until CallbackUnref is called. At least 2000 callbacks can always be live.
// Although this function provides similar functionality to windows.NewCallback it is distinct.
func NewCallback(fn interface{}) uintptr {
return compileCallback(fn)
}
Expand All @@ -47,10 +88,22 @@ func NewCallback(fn interface{}) uintptr {
// only increase this if you have added more to the callbackasm function
const maxCB = 2000

var cbs struct {
lock sync.Mutex
numFn int // the number of functions currently in cbs.funcs
funcs [maxCB]reflect.Value // the saved callbacks
var cbs = struct {
lock sync.RWMutex
holes map[int]struct{}
funcs [maxCB]reflect.Value // the saved callbacks
knownFn map[uintptr]uintptr
knownIdx map[uintptr]int
}{
holes: make(map[int]struct{}),
knownFn: make(map[uintptr]uintptr),
knownIdx: make(map[uintptr]int),
}

func init() {
for i := 0; i < maxCB; i++ {
cbs.holes[i] = struct{}{}
}
}

type callbackArgs struct {
Expand All @@ -77,6 +130,14 @@ func compileCallback(fn interface{}) uintptr {
if val.IsNil() {
panic("purego: function must not be nil")
}

cbs.lock.RLock()
if cb, ok := cbs.knownFn[val.Pointer()]; ok {
cbs.lock.RUnlock()
return cb
}
cbs.lock.RUnlock()

ty := val.Type()
for i := 0; i < ty.NumIn(); i++ {
in := ty.In(i)
Expand All @@ -100,14 +161,21 @@ output:
case ty.NumOut() > 1:
panic("purego: callbacks can only have one return")
}

cbs.lock.Lock()
defer cbs.lock.Unlock()
if cbs.numFn >= maxCB {
if len(cbs.holes) == 0 {
panic("purego: the maximum number of callbacks has been reached")
}
cbs.funcs[cbs.numFn] = val
cbs.numFn++
return callbackasmAddr(cbs.numFn - 1)

var idx int
for i := range cbs.holes {
idx = i
break
}
delete(cbs.holes, idx)
cbs.funcs[idx] = val
cbs.lock.Unlock()
return callbackasmAddr(idx, val)
}

const ptrSize = unsafe.Sizeof((*int)(nil))
Expand All @@ -129,9 +197,12 @@ var callbackWrap_call = callbackWrap
// callbackWrap is called by assembly code which determines which Go function to call.
// This function takes the arguments and passes them to the Go function and returns the result.
func callbackWrap(a *callbackArgs) {
cbs.lock.Lock()
cbs.lock.RLock()
fn := cbs.funcs[a.index]
cbs.lock.Unlock()
cbs.lock.RUnlock()
if !fn.IsValid() {
panic("purego: attempted call to unreferenced callback")
}
fnType := fn.Type()
args := make([]reflect.Value, fnType.NumIn())
frame := (*[callbackMaxFrame]uintptr)(a.args)
Expand Down Expand Up @@ -202,7 +273,7 @@ func callbackWrap(a *callbackArgs) {
// On ARM, runtime.callbackasm is a series of mov and branch instructions.
// R12 is loaded with the callback index. Each entry is two instructions,
// hence 8 bytes.
func callbackasmAddr(i int) uintptr {
func callbackasmAddr(i int, val reflect.Value) uintptr {
var entrySize int
switch runtime.GOARCH {
default:
Expand All @@ -214,5 +285,10 @@ func callbackasmAddr(i int) uintptr {
// followed by a branch instruction
entrySize = 8
}
return callbackasmABI0 + uintptr(i*entrySize)
addr := callbackasmABI0 + uintptr(i*entrySize)
cbs.lock.Lock()
cbs.knownFn[val.Pointer()] = addr
cbs.knownIdx[addr] = i
cbs.lock.Unlock()
return addr
}

0 comments on commit 78d550b

Please sign in to comment.