Skip to content

Commit

Permalink
coroc: capture function receivers in closure types (#129)
Browse files Browse the repository at this point in the history
This fixes a bug in the generation of closure types where function
receivers weren't considered.
  • Loading branch information
chriso authored Dec 14, 2023
2 parents 8fd653a + 09278ad commit 83d5568
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 3 deletions.
6 changes: 6 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ func TestCoroutineYield(t *testing.T) {
result: 42,
},

{
name: "closure capturing receiver and param",
coro: func() { StructClosure(3) },
yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000},
},

{
name: "identity generic",
coro: func() { IdentityGeneric[int](11) },
Expand Down
43 changes: 41 additions & 2 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"maps"
"slices"
"strconv"
"strings"

"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
Expand Down Expand Up @@ -82,7 +83,8 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func
}

signature := functionTypeOf(fn)
for _, fields := range []*ast.FieldList{signature.Params, signature.Results} {
recv := functionRecvOf(fn)
for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} {
if fields != nil {
for _, field := range fields.List {
for _, name := range field.Names {
Expand Down Expand Up @@ -205,7 +207,33 @@ func packagePath(p *packages.Package) string {
}

func functionPath(p *packages.Package, f *ast.FuncDecl) string {
return packagePath(p) + "." + f.Name.Name
var b strings.Builder
b.WriteString(packagePath(p))
if f.Recv != nil {
signature := p.TypesInfo.Defs[f.Name].Type().(*types.Signature)
recvType := signature.Recv().Type()
isptr := false
if ptr, ok := recvType.(*types.Pointer); ok {
recvType = ptr.Elem()
isptr = true
}
b.WriteByte('.')
if isptr {
b.WriteString("(*")
}
switch t := recvType.(type) {
case *types.Named:
b.WriteString(t.Obj().Name())
default:
panic(fmt.Sprintf("not implemented: %T", t))
}
if isptr {
b.WriteByte(')')
}
}
b.WriteByte('.')
b.WriteString(f.Name.Name)
return b.String()
}

func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) {
Expand Down Expand Up @@ -302,6 +330,17 @@ func functionTypeOf(fn ast.Node) *ast.FuncType {
}
}

func functionRecvOf(fn ast.Node) *ast.FieldList {
switch f := fn.(type) {
case *ast.FuncDecl:
return f.Recv
case *ast.FuncLit:
return nil
default:
panic("node is neither *ast.FuncDecl or *ast.FuncLit")
}
}

func functionBodyOf(fn ast.Node) *ast.BlockStmt {
switch f := fn.(type) {
case *ast.FuncDecl:
Expand Down
27 changes: 27 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,30 @@ func IdentityGenericInt(n int) {
func IdentityGenericStructInt(n int) {
(&IdentityGenericStruct[int]{n: n}).Run()
}

type Box struct {
x int
}

func (b *Box) Closure(y int) func(int) {
// Force compilation of this method and the closure within.
// Remove once #84 is fixed.
coroutine.Yield[int, any](-1)

return func(z int) {
coroutine.Yield[int, any](b.x)
coroutine.Yield[int, any](y)
coroutine.Yield[int, any](z)
b.x++
y++
z++ // mutation is lost
}
}

func StructClosure(n int) {
box := Box{10}
fn := box.Closure(100)
for i := 0; i < n; i++ {
fn(1000)
}
}
151 changes: 150 additions & 1 deletion compiler/testdata/coroutine_durable.go
Original file line number Diff line number Diff line change
Expand Up @@ -3252,7 +3252,156 @@ func IdentityGenericInt(n int) { IdentityGeneric[int](n) }

//go:noinline
func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() }

type Box struct {
x int
}

//go:noinline
func (_fn0 *Box) Closure(_fn1 int) (_ func(int)) {
_c := coroutine.LoadContext[int, any]()
var _f1 *struct {
IP int
X0 *Box
X1 int
} = coroutine.Push[struct {
IP int
X0 *Box
X1 int
}](&_c.Stack)
if _f1.IP == 0 {
*_f1 = struct {
IP int
X0 *Box
X1 int
}{X0: _fn0, X1: _fn1}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f1.IP < 2:

coroutine.Yield[int, any](-1)
_f1.IP = 2
fallthrough
case _f1.IP < 3:

return func(_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
} = coroutine.Push[struct {
IP int
X0 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
}{X0: _fn0}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f0.IP < 2:
coroutine.Yield[int, any](_f1.X0.x)
_f0.IP = 2
fallthrough
case _f0.IP < 3:
coroutine.Yield[int, any](_f1.X1)
_f0.IP = 3
fallthrough
case _f0.IP < 4:
coroutine.Yield[int, any](_f0.X0)
_f0.IP = 4
fallthrough
case _f0.IP < 5:
_f1.X0.
x++
_f0.IP = 5
fallthrough
case _f0.IP < 6:
_f1.X1++
_f0.IP = 6
fallthrough
case _f0.IP < 7:
_f0.X0++
}
}
}
panic("unreachable")
}

//go:noinline
func StructClosure(_fn0 int) {
_c := coroutine.LoadContext[int, any]()
var _f0 *struct {
IP int
X0 int
X1 Box
X2 func(int)
X3 int
} = coroutine.Push[struct {
IP int
X0 int
X1 Box
X2 func(int)
X3 int
}](&_c.Stack)
if _f0.IP == 0 {
*_f0 = struct {
IP int
X0 int
X1 Box
X2 func(int)
X3 int
}{X0: _fn0}
}
defer func() {
if !_c.Unwinding() {
coroutine.Pop(&_c.Stack)
}
}()
switch {
case _f0.IP < 2:
_f0.X1 = Box{10}
_f0.IP = 2
fallthrough
case _f0.IP < 3:
_f0.X2 = _f0.X1.Closure(100)
_f0.IP = 3
fallthrough
case _f0.IP < 5:
switch {
case _f0.IP < 4:
_f0.X3 = 0
_f0.IP = 4
fallthrough
case _f0.IP < 5:
for ; _f0.X3 < _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 {
_f0.X2(1000)
}
}
}
}
func init() {
_types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure")
_types.RegisterClosure[func(_fn0 int), struct {
F uintptr
X0 *struct {
IP int
X0 *Box
X1 int
}
}]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2")
_types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator")
Expand All @@ -3261,7 +3410,6 @@ func init() {
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt")
_types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue")
_types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator")
_types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops")
_types.RegisterFunc[func(_fn0 int, _fn1 func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range")
_types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers")
Expand Down Expand Up @@ -3357,6 +3505,7 @@ func init() {
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure")
_types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator")
_types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.VarArgs")
_types.RegisterFunc[func(_fn0 *int, _fn1, _fn2 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign")
Expand Down

0 comments on commit 83d5568

Please sign in to comment.