Skip to content

Commit

Permalink
Fixes (part 2) (#149)
Browse files Browse the repository at this point in the history
Following on from #148,
this PR contains various fixes:
* we now support generic type parameters that are structs
* we now conservatively register type info for all function decls/lits
in the module, and mark functions/methods in the module that contain
function literals as `noinline`, in case they create a closure that
needs to be serialized
* tweaks the logging so it generates more useful output
  • Loading branch information
chriso authored Jun 24, 2024
2 parents 892d764 + 0be6d66 commit 9fefde4
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 47 deletions.
96 changes: 66 additions & 30 deletions compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,25 @@ func (c *compiler) compile(path string) error {
pkgColors[fn] = color
}

// Add all packages from the module. Although these packages don't contain
// yield points, they may return closures that need to be serialized. For
// this to work, certain functions need to be marked as noinline and function
// literal types need to be registered.
//
// TODO: improve this by scanning dependencies to see if they need to be included
packages.Visit(pkgs, func(p *packages.Package) bool {
if p.Module == nil || p.Module.Dir != moduleDir {
return true
}
if p.PkgPath == coroutinePackage {
return true
}
if _, ok := colorsByPkg[p]; !ok {
colorsByPkg[p] = functionColors{}
}
return true
}, nil)

if c.onlyListFiles {
cwd, _ := os.Getwd()
for pkg := range colorsByPkg {
Expand Down Expand Up @@ -356,16 +375,43 @@ func (c *compiler) compilePackage(p *packages.Package, colors functionColors) er

case *ast.FuncDecl:
color := colorsByFunc[decl]
if color == nil && !containsColoredFuncLit(decl, colorsByFunc) {
gen.Decls = append(gen.Decls, decl)
continue

compiled := false
if color != nil || containsColoredFuncLit(decl, colorsByFunc) {
// Reject certain language features for now.
if err := unsupported(decl, p.TypesInfo); err != nil {
return err
}
scope := &scope{compiler: c, colors: colorsByFunc}
decl = scope.compileFuncDecl(p, decl, color)
compiled = true
}
// Reject certain language features for now.
if err := unsupported(decl, p.TypesInfo); err != nil {
return err

if compiled || containsFuncLit(decl) {
// If the function declaration contains function literals, we have to
// add the //go:noinline copmiler directive to prevent inlining or the
// resulting symbol name generated by the linker wouldn't match the
// predictions made in generateFunctypes.
//
// When functions are inlined, the linker creates a unique name
// combining the symbol name of the calling function and the symbol name
// of the closure. Knowing which functions will be inlined is difficult
// considering the score-base mechansim that Go uses and alterations
// like PGO, therefore we take the simple approach of disabling inlining
// instead.
//
// Note that we only need to do this for single-expression functions as
// otherwise the presence of a defer statement to unwind the coroutine
// already prevents inlining, however, it's simpler to always add the
// compiler directive.
if decl.Doc == nil {
decl.Doc = &ast.CommentGroup{}
}
decl.Doc.List = appendCommentGroup(decl.Doc.List, decl.Doc)
decl.Doc.List = appendComment(decl.Doc.List, "//go:noinline\n")
}
scope := &scope{compiler: c, colors: colorsByFunc}
gen.Decls = append(gen.Decls, scope.compileFuncDecl(p, decl, color))

gen.Decls = append(gen.Decls, decl)
}
}

Expand Down Expand Up @@ -400,6 +446,17 @@ func containsColoredFuncLit(decl *ast.FuncDecl, colorsByFunc map[ast.Node]*types
return
}

func containsFuncLit(decl *ast.FuncDecl) (yes bool) {
ast.Inspect(decl, func(n ast.Node) bool {
if _, ok := n.(*ast.FuncLit); ok {
yes = true
return false
}
return true
})
return
}

func addImports(p *packages.Package, f *ast.File, gen *ast.File) *ast.File {
imports := map[string]string{}

Expand Down Expand Up @@ -488,7 +545,7 @@ type scope struct {
}

func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color *types.Signature) *ast.FuncDecl {
log.Printf("compiling function %s %s", p.Name, fn.Name)
log.Printf("compiling function %s.%s", p.Name, fn.Name)

// Generate the coroutine function. At this stage, use the same name
// as the source function (and require that the caller use build tags
Expand All @@ -502,34 +559,13 @@ func (scope *scope) compileFuncDecl(p *packages.Package, fn *ast.FuncDecl, color
Body: scope.compileFuncBody(p, fnType, fn.Body, fn.Recv, color),
}

// If the function declaration contains function literals, we have to
// add the //go:noinline copmiler directive to prevent inlining or the
// resulting symbol name generated by the linker wouldn't match the
// predictions made in generateFunctypes.
//
// When functions are inlined, the linker creates a unique name
// combining the symbol name of the calling function and the symbol name
// of the closure. Knowing which functions will be inlined is difficult
// considering the score-base mechansim that Go uses and alterations
// like PGO, therefore we take the simple approach of disabling inlining
// instead.
//
// Note that we only need to do this for single-expression functions as
// otherwise the presence of a defer statement to unwind the coroutine
// already prevents inlining, however, it's simpler to always add the
// compiler directive.
gen.Doc.List = appendCommentGroup(gen.Doc.List, fn.Doc)
gen.Doc.List = appendComment(gen.Doc.List, "//go:noinline\n")

if color != nil && !isExpr(gen.Body) {
scope.colors[gen] = color
}
return gen
}

func (scope *scope) compileFuncLit(p *packages.Package, fn *ast.FuncLit, color *types.Signature) *ast.FuncLit {
log.Printf("compiling function literal %s", p.Name)

gen := &ast.FuncLit{
Type: funcTypeWithNamedResults(p, fn),
Body: scope.compileFuncBody(p, fn.Type, fn.Body, nil, color),
Expand Down
12 changes: 12 additions & 0 deletions compiler/coroutine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ func TestCoroutineYield(t *testing.T) {
coro: func() { InterfaceEmbedded() },
yields: []int{1, 1, 1},
},

{
name: "closure in separate package",
coro: func() { ClosureInSeparatePackage(3) },
yields: []int{3, 4, 5},
},

{
name: "closure via generic with struct type param",
coro: func() { GenericStructClosure(3) },
yields: []int{3, 5, 7},
},
}

// This emulates the installation of function type information by the
Expand Down
35 changes: 19 additions & 16 deletions compiler/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"go/ast"
"go/token"
"go/types"
"log"
"maps"
"slices"
"strconv"
Expand Down Expand Up @@ -283,7 +282,6 @@ func (c *compiler) generateFunctypes(p *packages.Package, f *ast.File, colors ma
if len(instances) == 0 {
// This can occur when a generic function is never instantiated/used,
// or when it's instantiated in a package not known to the compiler.
log.Printf("warning: cannot register runtime type information for generic function %s", fn)
continue
}
for _, instance := range instances {
Expand Down Expand Up @@ -489,20 +487,7 @@ func (g *genericInstance) typeArgOf(param *types.TypeParam) types.Type {
}

func (g *genericInstance) partial() bool {
sig := g.instance.Signature
params := sig.Params()
for i := 0; i < params.Len(); i++ {
if _, ok := params.At(i).Type().(*types.TypeParam); ok {
return true
}
}
results := sig.Results()
for i := 0; i < results.Len(); i++ {
if _, ok := results.At(i).Type().(*types.TypeParam); ok {
return true
}
}
return false
return containsTypeParam(g.instance.Signature)
}

func (g *genericInstance) scanRecvTypeArgs(fn func(*types.TypeParam, int, types.Type)) {
Expand Down Expand Up @@ -585,6 +570,24 @@ func writeGoShape(b *strings.Builder, tt types.Type) {
} else {
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
case *types.Struct:
b.WriteString("struct { ")
for i := 0; i < t.NumFields(); i++ {
if i > 0 {
b.WriteString("; ")
}
f := t.Field(i)

if f.Embedded() {
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
b.WriteString(f.Pkg().Path())
b.WriteByte('.')
b.WriteString(f.Name())
b.WriteByte(' ')
b.WriteString(f.Type().String())
}
b.WriteString(" }")
default:
panic(fmt.Sprintf("not implemented: %#v (%T)", tt, t))
}
Expand Down
36 changes: 36 additions & 0 deletions compiler/testdata/coroutine.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"unsafe"

"github.com/dispatchrun/coroutine"
"github.com/dispatchrun/coroutine/compiler/testdata/subpkg"
)

//go:generate coroc
Expand Down Expand Up @@ -741,3 +742,38 @@ func InterfaceEmbedded() {
coroutine.Yield[int, any](x.Value())
coroutine.Yield[int, any](x.Value())
}

func ClosureInSeparatePackage(n int) {
adder := subpkg.Adder(n)
for i := 0; i < n; i++ {
coroutine.Yield[int, any](adder(i))
}
}

func GenericStructClosure(n int) {
impl := AdderImpl{base: n, mul: 2}

boxed := &GenericAdder[AdderImpl]{adder: impl}
for i := 0; i < n; i++ {
coroutine.Yield[int, any](boxed.Add(i))
}
}

type adder interface {
Add(int) int
}

type AdderImpl struct {
base int
mul int
}

func (a AdderImpl) Add(n int) int { return a.base + n*a.mul }

var _ adder = AdderImpl{}

type GenericAdder[A adder] struct{ adder A }

func (b *GenericAdder[A]) Add(n int) int {
return b.adder.Add(n)
}
Loading

0 comments on commit 9fefde4

Please sign in to comment.