-
Notifications
You must be signed in to change notification settings - Fork 128
/
rewrite.go
158 lines (138 loc) · 3.25 KB
/
rewrite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package main
import (
"bytes"
"errors"
"go/ast"
"go/printer"
"go/token"
"io/ioutil"
"path/filepath"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader"
)
var ErrImported = errors.New("trace already imported")
// rewriteSource rewrites current source and saves
// into temporary file, returning it's path.
func rewriteSource(path string) (string, error) {
data, err := addCode(path)
if err == ErrImported {
data, err = ioutil.ReadFile(path)
if err != nil {
return "", err
}
} else if err != nil {
return "", err
}
tmpDir, err := ioutil.TempDir("", "gotracer_package")
if err != nil {
return "", err
}
filename := filepath.Join(tmpDir, filepath.Base(path))
err = ioutil.WriteFile(filename, data, 0666)
if err != nil {
return "", err
}
return tmpDir, nil
}
// addCode searches for main func in data, and updates AST code
// adding tracing functions.
func addCode(path string) ([]byte, error) {
var conf loader.Config
if _, err := conf.FromArgs([]string{path}, false); err != nil {
return nil, err
}
prog, err := conf.Load()
if err != nil {
return nil, err
}
// check if runtime/trace already imported
for i, _ := range prog.Imported {
if i == "runtime/trace" {
return nil, ErrImported
}
}
pkg := prog.Created[0]
// TODO: find file with main func inside
astFile := pkg.Files[0]
// add imports
astutil.AddImport(prog.Fset, astFile, "os")
astutil.AddImport(prog.Fset, astFile, "runtime/trace")
astutil.AddImport(prog.Fset, astFile, "time")
// add start/stop code
ast.Inspect(astFile, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.FuncDecl:
// find 'main' function
if x.Name.Name == "main" && x.Recv == nil {
stmts := createTraceStmts()
stmts = append(stmts, x.Body.List...)
x.Body.List = stmts
return true
}
}
return true
})
var buf bytes.Buffer
err = printer.Fprint(&buf, prog.Fset, astFile)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func createTraceStmts() []ast.Stmt {
ret := make([]ast.Stmt, 2)
// trace.Start(os.Stderr)
ret[0] = &ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "trace"},
Sel: &ast.Ident{Name: "Start"},
},
Args: []ast.Expr{
&ast.SelectorExpr{
X: &ast.Ident{Name: "os"},
Sel: &ast.Ident{Name: "Stderr"},
},
},
},
}
// defer func(){ time.Sleep(50*time.Millisecond; trace.Stop() }()
ret[1] = &ast.DeferStmt{
Call: &ast.CallExpr{
Fun: &ast.FuncLit{
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "time"},
Sel: &ast.Ident{Name: "Sleep"},
},
Args: []ast.Expr{
&ast.BinaryExpr{
X: &ast.BasicLit{Kind: token.INT, Value: "50"},
Op: token.MUL,
Y: &ast.SelectorExpr{
X: &ast.Ident{Name: "time"},
Sel: &ast.Ident{Name: "Millisecond"},
},
},
},
},
},
&ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "trace"},
Sel: &ast.Ident{Name: "Stop"},
},
},
},
},
},
Type: &ast.FuncType{Params: &ast.FieldList{}},
},
},
}
return ret
}