From 1bd8336f7058bc32fa4dffc2fff318002e6432ec Mon Sep 17 00:00:00 2001 From: Mat Ryer Date: Fri, 7 Jul 2017 17:49:29 +0100 Subject: [PATCH] Added customer importer --- package/moq/importer.go | 145 ++++++++++++++++++++++++++++++++++++++++ package/moq/moq.go | 3 +- 2 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 package/moq/importer.go diff --git a/package/moq/importer.go b/package/moq/importer.go new file mode 100644 index 0000000..66c6109 --- /dev/null +++ b/package/moq/importer.go @@ -0,0 +1,145 @@ +package moq + +// taken from https://github.com/ernesto-jimenez/gogen +// Copyright (c) 2015 Ernesto Jiménez + +import ( + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "io/ioutil" + "os" + "path" + "path/filepath" + "strings" +) + +type customImporter struct { + imported map[string]*types.Package + base types.Importer + skipTestFiles bool +} + +func (i *customImporter) Import(path string) (*types.Package, error) { + var err error + if path == "" || path[0] == '.' { + path, err = filepath.Abs(filepath.Clean(path)) + if err != nil { + return nil, err + } + path = stripGopath(path) + } + if pkg, ok := i.imported[path]; ok { + return pkg, nil + } + pkg, err := i.fsPkg(path) + if err != nil { + return nil, err + } + i.imported[path] = pkg + return pkg, nil +} + +func gopathDir(pkg string) (string, error) { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + absPath, err := filepath.Abs(path.Join(gopath, "src", pkg)) + if err != nil { + return "", err + } + if dir, err := os.Stat(absPath); err == nil && dir.IsDir() { + return absPath, nil + } + } + return "", fmt.Errorf("%s not in $GOPATH", pkg) +} + +func removeGopath(p string) string { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) + } + return p +} + +func (i *customImporter) fsPkg(pkg string) (*types.Package, error) { + dir, err := gopathDir(pkg) + if err != nil { + return importOrErr(i.base, pkg, err) + } + + dirFiles, err := ioutil.ReadDir(dir) + if err != nil { + return importOrErr(i.base, pkg, err) + } + + fset := token.NewFileSet() + var files []*ast.File + for _, fileInfo := range dirFiles { + if fileInfo.IsDir() { + continue + } + n := fileInfo.Name() + if path.Ext(fileInfo.Name()) != ".go" { + continue + } + if i.skipTestFiles && strings.Contains(fileInfo.Name(), "_test.go") { + continue + } + file := path.Join(dir, n) + src, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + f, err := parser.ParseFile(fset, file, src, 0) + if err != nil { + return nil, err + } + files = append(files, f) + } + conf := types.Config{ + Importer: i, + } + p, err := conf.Check(pkg, fset, files, nil) + + if err != nil { + return importOrErr(i.base, pkg, err) + } + return p, nil +} + +func importOrErr(base types.Importer, pkg string, err error) (*types.Package, error) { + p, impErr := base.Import(pkg) + if impErr != nil { + return nil, err + } + return p, nil +} + +// newImporter returns an importer that will try to import code from gopath before using go/importer.Default and skipping test files +func newImporter() types.Importer { + return &customImporter{ + imported: make(map[string]*types.Package), + base: importer.Default(), + skipTestFiles: true, + } +} + +// // DefaultWithTestFiles same as Default but it parses test files too +// func DefaultWithTestFiles() types.Importer { +// return &customImporter{ +// imported: make(map[string]*types.Package), +// base: importer.Default(), +// skipTestFiles: false, +// } +// } + +// stripGopath teks the directory to a package and remove the gopath to get the +// cannonical package name +func stripGopath(p string) string { + for _, gopath := range strings.Split(os.Getenv("GOPATH"), ":") { + p = strings.Replace(p, path.Join(gopath, "src")+"/", "", 1) + } + return p +} diff --git a/package/moq/moq.go b/package/moq/moq.go index 9353abb..be3af3d 100644 --- a/package/moq/moq.go +++ b/package/moq/moq.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "go/ast" - "go/importer" "go/parser" "go/token" "go/types" @@ -77,7 +76,7 @@ func (m *Mocker) Mock(w io.Writer, name ...string) error { files[i] = file i++ } - conf := types.Config{Importer: importer.Default()} + conf := types.Config{Importer: newImporter()} tpkg, err := conf.Check(m.src, m.fset, files, nil) if err != nil { return err