diff --git a/helper_test.go b/helper_test.go new file mode 100644 index 0000000..5148d77 --- /dev/null +++ b/helper_test.go @@ -0,0 +1,17 @@ +package analysisutil_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func WriteFiles(t *testing.T, filemap map[string]string) string { + t.Helper() + dir, clean, err := analysistest.WriteFiles(filemap) + if err != nil { + t.Fatal("unexpected error:", err) + } + t.Cleanup(clean) + return dir +} diff --git a/testdata/src/objectof/main.go b/testdata/src/objectof/main.go deleted file mode 100644 index 408aa3d..0000000 --- a/testdata/src/objectof/main.go +++ /dev/null @@ -1,15 +0,0 @@ -package objectof - -import ( - "fmt" - "io" - "vendored" -) - -type A int -var EOF = io.EOF -var _ = vendored.EOF - -func main() { - fmt.Println(EOF) -} diff --git a/testdata/src/objectof/vendor/vendored/vendored.go b/testdata/src/objectof/vendor/vendored/vendored.go deleted file mode 100644 index c965bf3..0000000 --- a/testdata/src/objectof/vendor/vendored/vendored.go +++ /dev/null @@ -1,5 +0,0 @@ -package vendored - -import "io" - -var EOF = io.EOF diff --git a/types.go b/types.go index 4773ac4..46b9706 100644 --- a/types.go +++ b/types.go @@ -196,3 +196,13 @@ func mergeTypesInfo(i1, i2 *types.Info) { // InitOrder i1.InitOrder = append(i1.InitOrder, i2.InitOrder...) } + +// Under returns the most bottom underlying type. +func Under(t types.Type) types.Type { + switch t := t.(type) { + case *types.Named: + return Under(t.Underlying()) + default: + return t + } +} diff --git a/types_test.go b/types_test.go index 6b44660..a539c33 100644 --- a/types_test.go +++ b/types_test.go @@ -1,57 +1,117 @@ package analysisutil_test import ( + "errors" + "fmt" "go/token" + "go/types" + "path/filepath" "testing" "github.com/gostaticanalysis/analysisutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/analysistest" - "golang.org/x/tools/go/analysis/passes/buildssa" ) -const pkg = "objectof" +func TestObjectOf(t *testing.T) { + t.Parallel() -var typesAnalyzer = &analysis.Analyzer{ - Name: "test_types", - Run: run_test_types, - Requires: []*analysis.Analyzer{ - buildssa.Analyzer, - }, -} - -func Test_Types(t *testing.T) { - testdata := analysistest.TestData() - analysistest.Run(t, testdata, typesAnalyzer, pkg) -} - -func run_test_types(pass *analysis.Pass) (interface{}, error) { - tests := []struct { - path, name string - found bool + cases := map[string]struct { + src string + pkg string // blank means same as the map key + name string + found bool }{ - {"fmt", "Println", true}, - {pkg, "A", true}, - {pkg, "EOF", true}, - {"io", "EOF", true}, - {"reflect", "Kind", false}, - {"a", "ok", false}, - {"vendored", "EOF", true}, - {"c", "EOF", false}, - {"database/sql", "*DB", false}, + "standard": {`import _ "fmt"`, "fmt", "Println", true}, + "unimport": {"", "fmt", "Println", false}, + "notexiststd": {`import _ "fmt"`, "fmt", "NOTEXIST", false}, + "typename": {"type A int", "", "A", true}, + "unexportvar": {"var n int", "", "n", true}, + "exportvar": {"var N int", "", "N", true}, + "notexist": {"", "", "NOTEXIST", false}, + "vendored": {`import _ "fmt"`, "vendor/fmt", "Println", true}, + "pointer": {"type A int", "", "*A", false}, } - for _, tt := range tests { - tt := tt - obj := analysisutil.ObjectOf(pass, tt.path, tt.name) + for name, tt := range cases { + name, tt := name, tt + t.Run(name, func(t *testing.T) { + t.Parallel() + a := &analysis.Analyzer{ + Name: name + "Analyzer", + Run: func(pass *analysis.Pass) (interface{}, error) { + pkg := name + if tt.pkg != "" { + pkg = tt.pkg + } + obj := analysisutil.ObjectOf(pass, pkg, tt.name) + switch { + case tt.found && obj == nil: + return nil, errors.New("expect found but not found") + case !tt.found && obj != nil: + return nil, fmt.Errorf("unexpected return value: %v", obj) + } + return nil, nil + }, + } + path := filepath.Join(name, name+".go") + dir := WriteFiles(t, map[string]string{ + path: fmt.Sprintf("package %s\n%s", name, tt.src), + }) + analysistest.Run(t, dir, a, name) + }) + } - if obj == nil && tt.found { - pass.Reportf(token.NoPos, "objectof could not find %s.%s", tt.path, tt.name) - } - if obj != nil && !tt.found { - pass.Reportf(token.NoPos, "objectof found %s.%s, which does not exist", tt.path, tt.name) +} + +func TestUnder(t *testing.T) { + t.Parallel() + + lookup := func(pass *analysis.Pass, n string) (types.Type, error) { + _, obj := pass.Pkg.Scope().LookupParent(n, token.NoPos) + if obj == nil { + return nil, fmt.Errorf("does not find: %s", n) } + return obj.Type(), nil + } + + cases := map[string]struct { + src string + typ string + want string + }{ + "nonamed": {"", "int", "int"}, + "named": {"type A int", "A", "int"}, + "twonamed": {"type A int; type B A", "B", "int"}, } - return nil, nil + for name, tt := range cases { + name, tt := name, tt + t.Run(name, func(t *testing.T) { + t.Parallel() + a := &analysis.Analyzer{ + Name: name + "Analyzer", + Run: func(pass *analysis.Pass) (interface{}, error) { + typ, err := lookup(pass, tt.typ) + if err != nil { + return nil, err + } + want, err := lookup(pass, tt.want) + if err != nil { + return nil, err + } + got := analysisutil.Under(typ) + if !types.Identical(want, got) { + return nil, fmt.Errorf("want %v but got %v", want, got) + } + return nil, nil + }, + } + path := filepath.Join(name, name+".go") + dir := WriteFiles(t, map[string]string{ + path: fmt.Sprintf("package %s\n%s", name, tt.src), + }) + analysistest.Run(t, dir, a, name) + }) + } }