diff --git a/compiler/compile.go b/compiler/compile.go index 552db89..d47e00f 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -291,8 +291,29 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er } case *ast.FuncDecl: - color, ok := colorsByFunc[decl] - if !ok { + declcolor := colorsByFunc[decl] + + // Check for colored function literals in the body. If found, color the + // function decl so that it's compiled. + var err error + ast.Inspect(decl, func(n ast.Node) bool { + if f, ok := n.(*ast.FuncLit); ok { + if litcolor, ok := colorsByFunc[f]; ok { + if declcolor == nil { + declcolor = litcolor + colorsByFunc[decl] = declcolor + } else if !types.Identical(declcolor, litcolor) { + err = fmt.Errorf("function %s.%s has more than one color: %v vs. %v", p.Name, decl.Name, declcolor, litcolor) + } + } + } + return true + }) + if err != nil { + return err + } + + if declcolor == nil { gen.Decls = append(gen.Decls, decl) continue } @@ -302,7 +323,7 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er } scope := &scope{compiler: c, colors: colorsByFunc} - gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) + gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, declcolor)) } } diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 412244a..112d149 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -3174,6 +3174,64 @@ func varArgs(_fn0 ...int) { } } } + +//go:noinline +func YieldFromClosure(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 func() + } = coroutine.Push[struct { + IP int + X0 int + X1 func() + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 func() + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = closure(_f0.X0) + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X1() + } +} + +//go:noinline +func closure(_fn0 int) (_ func()) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + } = coroutine.Push[struct { + IP int + X0 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + return func() { coroutine.Yield[int, any](_f0.X0) } +} func init() { _types.RegisterFunc[func(n int)]("github.com/stealthrocket/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.EvenSquareGenerator") @@ -3289,6 +3347,7 @@ func init() { X3 []func() } }]("github.com/stealthrocket/coroutine/compiler/testdata.YieldAndDeferAssign.func2") + _types.RegisterFunc[func(_fn0 int)]("github.com/stealthrocket/coroutine/compiler/testdata.YieldFromClosure") _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.YieldingDurations") _types.RegisterClosure[func(), struct { F uintptr @@ -3303,5 +3362,13 @@ func init() { _types.RegisterFunc[func()]("github.com/stealthrocket/coroutine/compiler/testdata.YieldingExpressionDesugaring") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.a") _types.RegisterFunc[func(_fn0 int) (_ int)]("github.com/stealthrocket/coroutine/compiler/testdata.b") + _types.RegisterFunc[func(_fn0 int) (_ func())]("github.com/stealthrocket/coroutine/compiler/testdata.closure") + _types.RegisterClosure[func(), struct { + F uintptr + X0 *struct { + IP int + X0 int + } + }]("github.com/stealthrocket/coroutine/compiler/testdata.closure.func2") _types.RegisterFunc[func(_fn0 ...int)]("github.com/stealthrocket/coroutine/compiler/testdata.varArgs") }