diff --git a/compiler/compile.go b/compiler/compile.go index 8efe8e1..2c7c756 100644 --- a/compiler/compile.go +++ b/compiler/compile.go @@ -214,6 +214,25 @@ func (c *compiler) compile(path string) error { pkgColors[fn] = color } + // Add all packages from the module. Although these packages don't contain + // yield points, they may return closures that need to be serialized. For + // this to work, certain functions need to be marked as noinline and function + // literal types need to be registered. + // + // TODO: improve this by scanning dependencies to see if they need to be included + packages.Visit(pkgs, func(p *packages.Package) bool { + if p.Module == nil || p.Module.Dir != moduleDir { + return true + } + if p.PkgPath == coroutinePackage { + return true + } + if _, ok := colorsByPkg[p]; !ok { + colorsByPkg[p] = functionColors{} + } + return true + }, nil) + if c.onlyListFiles { cwd, _ := os.Getwd() for pkg := range colorsByPkg { @@ -356,16 +375,43 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er case *ast.FuncDecl: color := colorsByFunc[decl] - if color == nil && !containsColoredFuncLit(decl, colorsByFunc) { - gen.Decls = append(gen.Decls, decl) - continue + + compiled := false + if color != nil || containsColoredFuncLit(decl, colorsByFunc) { + // Reject certain language features for now. + if err := unsupported(decl, p.TypesInfo); err != nil { + return err + } + scope := &scope{compiler: c, colors: colorsByFunc} + decl = scope.compileFuncDecl(p, decl, color) + compiled = true } - // Reject certain language features for now. - if err := unsupported(decl, p.TypesInfo); err != nil { - return err + + if compiled || containsFuncLit(decl) { + // If the function declaration contains function literals, we have to + // add the //go:noinline copmiler directive to prevent inlining or the + // resulting symbol name generated by the linker wouldn't match the + // predictions made in generateFunctypes. + // + // When functions are inlined, the linker creates a unique name + // combining the symbol name of the calling function and the symbol name + // of the closure. Knowing which functions will be inlined is difficult + // considering the score-base mechansim that Go uses and alterations + // like PGO, therefore we take the simple approach of disabling inlining + // instead. + // + // Note that we only need to do this for single-expression functions as + // otherwise the presence of a defer statement to unwind the coroutine + // already prevents inlining, however, it's simpler to always add the + // compiler directive. + if decl.Doc == nil { + decl.Doc = &ast.CommentGroup{} + } + decl.Doc.List = appendCommentGroup(decl.Doc.List, decl.Doc) + decl.Doc.List = appendComment(decl.Doc.List, "//go:noinline\n") } - scope := &scope{compiler: c, colors: colorsByFunc} - gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color)) + + gen.Decls = append(gen.Decls, decl) } } @@ -400,6 +446,17 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types return } +func containsFuncLit(decl *ast.FuncDecl) (yes bool) { + ast.Inspect(decl, func(n ast.Node) bool { + if _, ok := n.(*ast.FuncLit); ok { + yes = true + return false + } + return true + }) + return +} + func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File { imports := map[string]string{} @@ -488,7 +545,7 @@ type scope struct { } func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl { - log.Printf("compiling function %s %s", p.Name, fn.Name) + log.Printf("compiling function %s.%s", p.Name, fn.Name) // Generate the coroutine function. At this stage, use the same name // as the source function (and require that the caller use build tags @@ -502,25 +559,6 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color Body: scope.compileFuncBody(p, fnType, fn.Body, fn.Recv, color), } - // If the function declaration contains function literals, we have to - // add the //go:noinline copmiler directive to prevent inlining or the - // resulting symbol name generated by the linker wouldn't match the - // predictions made in generateFunctypes. - // - // When functions are inlined, the linker creates a unique name - // combining the symbol name of the calling function and the symbol name - // of the closure. Knowing which functions will be inlined is difficult - // considering the score-base mechansim that Go uses and alterations - // like PGO, therefore we take the simple approach of disabling inlining - // instead. - // - // Note that we only need to do this for single-expression functions as - // otherwise the presence of a defer statement to unwind the coroutine - // already prevents inlining, however, it's simpler to always add the - // compiler directive. - gen.Doc.List = appendCommentGroup(gen.Doc.List, fn.Doc) - gen.Doc.List = appendComment(gen.Doc.List, "//go:noinline\n") - if color != nil && !isExpr(gen.Body) { scope.colors[gen] = color } @@ -528,8 +566,6 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color } func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit { - log.Printf("compiling function literal %s", p.Name) - gen := &ast.FuncLit{ Type: funcTypeWithNamedResults(p, fn), Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color), diff --git a/compiler/coroutine_test.go b/compiler/coroutine_test.go index a0bd7ee..e515f9d 100644 --- a/compiler/coroutine_test.go +++ b/compiler/coroutine_test.go @@ -282,6 +282,18 @@ func TestCoroutineYield(t *testing.T) { coro: func() { InterfaceEmbedded() }, yields: []int{1, 1, 1}, }, + + { + name: "closure in separate package", + coro: func() { ClosureInSeparatePackage(3) }, + yields: []int{3, 4, 5}, + }, + + { + name: "closure via generic with struct type param", + coro: func() { GenericStructClosure(3) }, + yields: []int{3, 5, 7}, + }, } // This emulates the installation of function type information by the diff --git a/compiler/function.go b/compiler/function.go index f0cdbfe..08518b9 100644 --- a/compiler/function.go +++ b/compiler/function.go @@ -5,7 +5,6 @@ import ( "go/ast" "go/token" "go/types" - "log" "maps" "slices" "strconv" @@ -283,7 +282,6 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma if len(instances) == 0 { // This can occur when a generic function is never instantiated/used, // or when it's instantiated in a package not known to the compiler. - log.Printf("warning: cannot register runtime type information for generic function %s", fn) continue } for _, instance := range instances { @@ -489,20 +487,7 @@ func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type { } func (g *genericInstance) partial() bool { - sig := g.instance.Signature - params := sig.Params() - for i := 0; i < params.Len(); i++ { - if _, ok := params.At(i).Type().(*types.TypeParam); ok { - return true - } - } - results := sig.Results() - for i := 0; i < results.Len(); i++ { - if _, ok := results.At(i).Type().(*types.TypeParam); ok { - return true - } - } - return false + return containsTypeParam(g.instance.Signature) } func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) { @@ -585,6 +570,24 @@ func writeGoShape(b *strings.Builder, tt types.Type) { } else { panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) } + case *types.Struct: + b.WriteString("struct { ") + for i := 0; i < t.NumFields(); i++ { + if i > 0 { + b.WriteString("; ") + } + f := t.Field(i) + + if f.Embedded() { + panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) + } + b.WriteString(f.Pkg().Path()) + b.WriteByte('.') + b.WriteString(f.Name()) + b.WriteByte(' ') + b.WriteString(f.Type().String()) + } + b.WriteString(" }") default: panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t)) } diff --git a/compiler/testdata/coroutine.go b/compiler/testdata/coroutine.go index 7f8c092..e89f427 100644 --- a/compiler/testdata/coroutine.go +++ b/compiler/testdata/coroutine.go @@ -9,6 +9,7 @@ import ( "unsafe" "github.com/dispatchrun/coroutine" + "github.com/dispatchrun/coroutine/compiler/testdata/subpkg" ) //go:generate coroc @@ -741,3 +742,38 @@ func InterfaceEmbedded() { coroutine.Yield[int, any](x.Value()) coroutine.Yield[int, any](x.Value()) } + +func ClosureInSeparatePackage(n int) { + adder := subpkg.Adder(n) + for i := 0; i < n; i++ { + coroutine.Yield[int, any](adder(i)) + } +} + +func GenericStructClosure(n int) { + impl := AdderImpl{base: n, mul: 2} + + boxed := &GenericAdder[AdderImpl]{adder: impl} + for i := 0; i < n; i++ { + coroutine.Yield[int, any](boxed.Add(i)) + } +} + +type adder interface { + Add(int) int +} + +type AdderImpl struct { + base int + mul int +} + +func (a AdderImpl) Add(n int) int { return a.base + n*a.mul } + +var _ adder = AdderImpl{} + +type GenericAdder[A adder] struct{ adder A } + +func (b *GenericAdder[A]) Add(n int) int { + return b.adder.Add(n) +} diff --git a/compiler/testdata/coroutine_durable.go b/compiler/testdata/coroutine_durable.go index 35d157b..6b15b8a 100644 --- a/compiler/testdata/coroutine_durable.go +++ b/compiler/testdata/coroutine_durable.go @@ -4,6 +4,7 @@ package testdata import ( coroutine "github.com/dispatchrun/coroutine" + subpkg "github.com/dispatchrun/coroutine/compiler/testdata/subpkg" math "math" reflect "reflect" time "time" @@ -4123,6 +4124,145 @@ func InterfaceEmbedded() { coroutine.Yield[int, any](_f0.X3) } } + +//go:noinline +func ClosureInSeparatePackage(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 func(int) int + X2 int + X3 int + } = coroutine.Push[struct { + IP int + X0 int + X1 func(int) int + X2 int + X3 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 func(int) int + X2 int + X3 int + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = subpkg.Adder(_f0.X0) + _f0.IP = 2 + fallthrough + case _f0.IP < 5: + switch { + case _f0.IP < 3: + _f0.X2 = 0 + _f0.IP = 3 + fallthrough + case _f0.IP < 5: + for ; _f0.X2 < _f0.X0; _f0.X2, _f0.IP = _f0.X2+1, 3 { + switch { + case _f0.IP < 4: + _f0.X3 = _f0.X1(_f0.X2) + _f0.IP = 4 + fallthrough + case _f0.IP < 5: + coroutine.Yield[int, any](_f0.X3) + } + } + } + } +} + +//go:noinline +func GenericStructClosure(_fn0 int) { + _c := coroutine.LoadContext[int, any]() + var _f0 *struct { + IP int + X0 int + X1 AdderImpl + X2 *GenericAdder[AdderImpl] + X3 int + X4 int + } = coroutine.Push[struct { + IP int + X0 int + X1 AdderImpl + X2 *GenericAdder[AdderImpl] + X3 int + X4 int + }](&_c.Stack) + if _f0.IP == 0 { + *_f0 = struct { + IP int + X0 int + X1 AdderImpl + X2 *GenericAdder[AdderImpl] + X3 int + X4 int + }{X0: _fn0} + } + defer func() { + if !_c.Unwinding() { + coroutine.Pop(&_c.Stack) + } + }() + switch { + case _f0.IP < 2: + _f0.X1 = AdderImpl{base: _f0.X0, mul: 2} + _f0.IP = 2 + fallthrough + case _f0.IP < 3: + _f0.X2 = &GenericAdder[AdderImpl]{adder: _f0.X1} + _f0.IP = 3 + fallthrough + case _f0.IP < 6: + switch { + case _f0.IP < 4: + _f0.X3 = 0 + _f0.IP = 4 + fallthrough + case _f0.IP < 6: + for ; _f0.X3 < _f0.X0; _f0.X3, _f0.IP = _f0.X3+1, 4 { + switch { + case _f0.IP < 5: + _f0.X4 = _f0.X2. + Add(_f0.X3) + _f0.IP = 5 + fallthrough + case _f0.IP < 6: + coroutine.Yield[int, any](_f0.X4) + } + } + } + } +} + +type adder interface { + Add(int) int +} + +type AdderImpl struct { + base int + mul int +} + +func (a AdderImpl) Add(n int) int { return a.base + n*a.mul } + +var _ adder = AdderImpl{} + +type GenericAdder[A adder] struct{ adder A } + +func (b *GenericAdder[A]) Add(n int) int { + return b.adder.Add(n) +} func init() { _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/dispatchrun/coroutine/compiler/testdata.(*Box).Closure") _types.RegisterClosure[func(_fn0 int), struct { @@ -4134,6 +4274,7 @@ func init() { } }]("github.com/dispatchrun/coroutine/compiler/testdata.(*Box).Closure.func1") _types.RegisterFunc[func()]("github.com/dispatchrun/coroutine/compiler/testdata.(*Box).YieldAndInc") + _types.RegisterFunc[func(n int) (_ int)]("github.com/dispatchrun/coroutine/compiler/testdata.(*GenericAdder[go.shape.struct { github.com/dispatchrun/coroutine/compiler/testdata.base int; github.com/dispatchrun/coroutine/compiler/testdata.mul int }]).Add") _types.RegisterFunc[func(_fn1 int) (_ func(int))]("github.com/dispatchrun/coroutine/compiler/testdata.(*GenericBox[go.shape.int]).Closure") _types.RegisterClosure[func(_fn0 int), struct { F uintptr @@ -4156,11 +4297,14 @@ func init() { }]("github.com/dispatchrun/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Closure.func1") _types.RegisterFunc[func()]("github.com/dispatchrun/coroutine/compiler/testdata.(*IdentityGenericStruct[go.shape.int]).Run") _types.RegisterFunc[func(_fn1 int)]("github.com/dispatchrun/coroutine/compiler/testdata.(*MethodGeneratorState).MethodGenerator") + _types.RegisterFunc[func(n int) (_ int)]("github.com/dispatchrun/coroutine/compiler/testdata.AdderImpl.Add") + _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.ClosureInSeparatePackage") _types.RegisterFunc[func(n int)]("github.com/dispatchrun/coroutine/compiler/testdata.Double") _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.EllipsisClosure") _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.EvenSquareGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.FizzBuzzIfGenerator") _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.FizzBuzzSwitchGenerator") + _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.GenericStructClosure") _types.RegisterFunc[func(n int)]("github.com/dispatchrun/coroutine/compiler/testdata.Identity") _types.RegisterFunc[func(n int)]("github.com/dispatchrun/coroutine/compiler/testdata.IdentityGenericClosureInt") _types.RegisterFunc[func(_fn0 int)]("github.com/dispatchrun/coroutine/compiler/testdata.IdentityGenericClosure[go.shape.int]") diff --git a/compiler/testdata/subpkg/adder.go b/compiler/testdata/subpkg/adder.go new file mode 100644 index 0000000..072db73 --- /dev/null +++ b/compiler/testdata/subpkg/adder.go @@ -0,0 +1,9 @@ +//go:build !durable + +package subpkg + +func Adder(n int) func(int) int { + return func(x int) int { + return x + n + } +} diff --git a/compiler/testdata/subpkg/adder_durable.go b/compiler/testdata/subpkg/adder_durable.go new file mode 100644 index 0000000..1416f82 --- /dev/null +++ b/compiler/testdata/subpkg/adder_durable.go @@ -0,0 +1,18 @@ +//go:build durable + +package subpkg + +import _types "github.com/dispatchrun/coroutine/types" +//go:noinline +func Adder(n int) func(int) int { + return func(x int) int { + return x + n + } +} +func init() { + _types.RegisterFunc[func(n int) (_ func(int) int)]("github.com/dispatchrun/coroutine/compiler/testdata/subpkg.Adder") + _types.RegisterClosure[func(x int) (_ int), struct { + F uintptr + X0 int + }]("github.com/dispatchrun/coroutine/compiler/testdata/subpkg.Adder.func1") +} diff --git a/compiler/types.go b/compiler/types.go index d0f31b3..bbdaa63 100644 --- a/compiler/types.go +++ b/compiler/types.go @@ -117,7 +117,9 @@ func typeExpr(p *packages.Package, typ types.Type, typeArg func(*types.TypeParam case *types.TypeParam: if typeArg != nil { - return typeExpr(p, typeArg(t), typeArg) + if known := typeArg(t); known != nil { + return typeExpr(p, known, typeArg) + } } obj := t.Obj() ident := ast.NewIdent(obj.Name()) @@ -239,3 +241,55 @@ func substituteFieldList(p *packages.Package, f *ast.FieldList, typeArg func(*ty } return &ast.FieldList{List: fields} } + +func containsTypeParam(typ types.Type) bool { + if typ == nil { + return false + } + switch t := typ.(type) { + case *types.Basic: + case *types.Slice: + return containsTypeParam(t.Elem()) + case *types.Array: + return containsTypeParam(t.Elem()) + case *types.Pointer: + return containsTypeParam(t.Elem()) + case *types.Map: + return containsTypeParam(t.Elem()) || containsTypeParam(t.Key()) + case *types.Chan: + return containsTypeParam(t.Elem()) + case *types.Named: + if args := t.TypeArgs(); args != nil { + for i := 0; i < args.Len(); i++ { + if containsTypeParam(args.At(i)) { + return true + } + } + } + case *types.Tuple: + for i := 0; i < t.Len(); i++ { + if containsTypeParam(t.At(i).Type()) { + return true + } + } + case *types.Signature: + if recv := t.Recv(); recv != nil { + if containsTypeParam(recv.Type()) { + return true + } + } + if containsTypeParam(t.Params()) { + return true + } + if containsTypeParam(t.Results()) { + return true + } + case *types.TypeParam: + return true + case *types.Interface: + case *types.Struct: + default: + panic(fmt.Sprintf("not implemented: %T", typ)) + } + return false +}