diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index d6e92c8..a278da2 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -215,6 +215,12 @@ func TestCoroutineYield(t *testing.T) { result: 42, }, + { + name: "closure capturing receiver and param", + coro: func() { StructClosure(3) }, + yields: []int{-1, 10, 100, 1000, 11, 101, 1000, 12, 102, 1000}, + }, + { name: "identity generic", coro: func() { IdentityGeneric[int](11) }, diff --git a/compiler/function.go b/compiler/function.go index e468bd0..7a4a4f8 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -9,6 +9,7 @@ import ( "maps" "slices" "strconv" + "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" @@ -82,7 +83,8 @@ func collectFunctypes(p *packages.Package, name string, fn ast.Node, scope *func } signature := functionTypeOf(fn) - for _, fields := range []*ast.FieldList{signature.Params, signature.Results} { + recv := functionRecvOf(fn) + for _, fields := range []*ast.FieldList{recv, signature.Params, signature.Results} { if fields != nil { for _, field := range fields.List { for _, name := range field.Names { @@ -205,7 +207,33 @@ func packagePath(p *packages.Package) string { } func functionPath(p *packages.Package, f *ast.FuncDecl) string { - return packagePath(p) + "." + f.Name.Name + var b strings.Builder + b.WriteString(packagePath(p)) + if f.Recv != nil { + signature := p.TypesInfo.Defs[f.Name].Type().(*types.Signature) + recvType := signature.Recv().Type() + isptr := false + if ptr, ok := recvType.(*types.Pointer); ok { + recvType = ptr.Elem() + isptr = true + } + b.WriteByte('.') + if isptr { + b.WriteString("(*") + } + switch t := recvType.(type) { + case *types.Named: + b.WriteString(t.Obj().Name()) + default: + panic(fmt.Sprintf("not implemented: %T", t)) + } + if isptr { + b.WriteByte(')') + } + } + b.WriteByte('.') + b.WriteString(f.Name.Name) + return b.String() } func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors map[ast.Node]*types.Signature) { @@ -302,6 +330,17 @@ func functionTypeOf(fn ast.Node) *ast.FuncType { } } +func functionRecvOf(fn ast.Node) *ast.FieldList { + switch f := fn.(type) { + case *ast.FuncDecl: + return f.Recv + case *ast.FuncLit: + return nil + default: + panic("node is neither *ast.FuncDecl or *ast.FuncLit") + } +} + func functionBodyOf(fn ast.Node) *ast.BlockStmt { switch f := fn.(type) { case *ast.FuncDecl: diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 92d7218..97fc7fd 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -577,3 +577,30 @@ func IdentityGenericInt(n int) { func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + +type Box struct { + x int +} + +func (b *Box) Closure(y int) func(int) { + // Force compilation of this method and the closure within. + // Remove once #84 is fixed. + coroutine.Yield[int, any](-1) + + return func(z int) { + coroutine.Yield[int, any](b.x) + coroutine.Yield[int, any](y) + coroutine.Yield[int, any](z) + b.x++ + y++ + z++ // mutation is lost + } +} + +func StructClosure(n int) { + box := Box{10} + fn := box.Closure(100) + for i := 0; i < n; i++ { + fn(1000) + } +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index f13ce47..0e4b610 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3252,7 +3252,156 @@ func IdentityGenericInt(n int) { IdentityGeneric[int](n) } //go:noinline func IdentityGenericStructInt(n int) { (&IdentityGenericStruct[int]{n: n}).Run() } + +type Box struct { + x int +} + +//go:noinline +func (_fn0 *Box) Closure(_fn1 int) (_ func(int)) { + _c := coroutine.LoadContext[int, any]() + var _f1 *struct { + IP int + X0 *Box + X1 int + } = coroutine.Push[struct { + IP int + X0 *Box + X1 int + }](&_c.Stack) + if _f1.IP == 0 { + *_f1 = struct { + IP int + X0 *Box + X1 int + }{X0: _fn0, X1: _fn1} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f1.IP < 2: + + coroutine.Yield[int, any](-1) + _f1.IP = 2 + fallthrough + case _f1.IP < 3: + + return func(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + } = coroutine.Push[struct { + IP int + X0 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + coroutine.Yield[int, any](_f1.X0.x) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + coroutine.Yield[int, any](_f1.X1) + _f0.IP = 3 + fallthrough + case _f0.IP < 4: + coroutine.Yield[int, any](_f0.X0) + _f0.IP = 4 + fallthrough + case _f0.IP < 5: + _f1.X0. + x++ + _f0.IP = 5 + fallthrough + case _f0.IP < 6: + _f1.X1++ + _f0.IP = 6 + fallthrough + case _f0.IP < 7: + _f0.X0++ + } + } + } + panic("unreachable") +} + +//go:noinline +func StructClosure(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 Box + X2 func(int) + X3 int + } = coroutine.Push[struct { + IP int + X0 int + X1 Box + X2 func(int) + X3 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 Box + X2 func(int) + X3 int + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = Box{10} + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X2 = _f0.X1.Closure(100) + _f0.IP = 3 + fallthrough + case _f0.IP < 5: + switch { + case _f0.IP < 4: + _f0.X3 = 0 + _f0.IP = 4 + fallthrough + case _f0.IP < 5: + for ; _f0.X3 < _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 { + _f0.X2(1000) + } + } + } +} func init() { + _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure") + _types.RegisterClosure[func(_fn0 int), struct { + F uintptr + X0 *struct { + IP int + X0 *Box + X1 int + } + }]("github.com/stealthrocket/coroutine/compiler/testdata.(*Box).Closure.func2") + _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.FizzBuzzIfGenerator") @@ -3261,7 +3410,6 @@ func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericInt") _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.IdentityGenericStructInt") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.LoopBreakAndContinue") - _types.RegisterFunc[func(_fn1 int)]("github.com/stealthrocket/coroutine/compiler/testdata.MethodGenerator") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.NestedLoops") _types.RegisterFunc[func(_fn0 int, _fn1 func(int))]("github.com/stealthrocket/coroutine/compiler/testdata.Range") _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.Range10ClosureCapturingPointers") @@ -3357,6 +3505,7 @@ func init() { _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwice") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.SquareGeneratorTwiceLoop") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.StructClosure") _types.RegisterFunc[func(_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.TypeSwitchingGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.VarArgs") _types.RegisterFunc[func(_fn0 *int, _fn1, _fn2 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign")