go言語でASTの解析にgo/typesの機能を使うことの威力について

goでASTの解析が手軽なのは便利。ところでこれにgo/typesの機能を使うととても便利になる。その威力を一番わかりやすく体験できそうな例を思いついたので紹介する。 (実際にはgo/typesの機能の使いかたはこれいがいにも色々ある)

go/types?

go/typesというのは概ねgoで型情報を扱う時に使うことになるパッケージ。(ちなみにgolang.org/x/tools/go/loader経由でソースコードを読むと自動で使われるので個人的にはloader経由で使うことが多い)

importしたパッケージのpathを変える

importしたパッケージのpathを変えてみたい。例えばの例なのでテキトウだけれど。fmtパッケージのimportをfmt2パッケージのimportに変える例を紹介する。例えば以下の様なコードが

fmt.Println("hello")

以下の様になってくれれば良いということ。

fmt2.Println("hello")

(あるいはgomvpkgの処理のsubsetと考えてみても良いかもしれない)

fmt.Printlnをfmt2.Println

AST上ではそんなに大変でもなくて。fmtやfmt2の部分は以下のようなgo/ast上の値に格納されている。

&ast.SelectorExpr{
  X: &ast.Ident{
    NamePos: 93,
    Name:    "fmt",
  },
  Sel: &ast.Ident{
    NamePos: 97,
    Name:    "Println",
  },
}

雰囲気で呼んで欲しいのだけれど。ast.IdentName を書き換えれば。fmt.Printlnfmt2.Println などにすることができる。

ASTからgoのコードを生成することも訳なくて。以下のようにgo/printerの関数が用意されている。

pp := &printer.Config{Tabwidth: 8, Mode: printer.UseSpaces | printer.TabIndent}
pp.Fprint(os.Stdout, prog.Fset, f) // fは*ast.File

めんどくさいのは同名の値が存在する場合

めんどくさいのは同名の値が存在する場合。例えば以下の様な少し恣意的なコードがあるとする。

type s struct{}

func (s s)Println(x string) {}

func main(){
    fmt.Println("xxx")

    {
        fmt := s{}
        fmt.Println("yyy") // ここだけfmt2に変わってはだめ。
    }

    fmt.Println("xxx")
}

ast上ではどちらもast.IdentのNameに"fmt"が入るのだけれど。オブジェクトの名前がfmtのsから作ったコードの部分は書き換えてほしくない。こういうようなコードに対応するときにASTベースで考えて対応しようとするとフロー解析っぽいことをしないとだめなのでひどくだるい。つらい。

そんなときにgo/types経由で扱うとすごい簡単に対応できて良い。

x/toolsのloaderで関数内のコードの型チェックを有効にする状態で読み込んだとき(defaultではその状態)に、info.Uses,info.Defsという値に、使われたときの型の値定義されたときの型の値がそれぞれ手に入る。(実際のinfoの方はgo/types.Info)

// go/typesのapi.go

type Info struct {
    Defs map[*ast.Ident]Object
    Uses map[*ast.Ident]Object

    Types map[ast.Expr]TypeAndValue
    Implicits map[ast.Node]Object
    Selections map[*ast.SelectorExpr]*Selection
    Scopes map[ast.Node]*Scope
    InitOrder []*Initializer
}

またこれらのmapにアクセスするよりもObjectOf()を使うのが便利。これでast上の各表記に対応するオブジェクトが取れる(似たようなTypesOf()というメソッドもある)。

func (info *Info) ObjectOf(id *ast.Ident) Object {
    if obj := info.Defs[id]; obj != nil {
        return obj
    }
    return info.Uses[id]
}

各オブジェクトはそれが定義されているパッケージが存在するはずで。それは簡単に取得できる。

ob := info.ObjectOf(id)
ob.Pkg() // *types.Package

このパッケージを比較して書き換えたかったimportのものだったら書き換えるということができれば良いというわけ。

実際以下の様なコードのfmtをfmt2に書き換えることができる

package main

import "fmt"

type s struct{}

func (s s)Println(x string) {}

func main(){
    fmt.Println("xxx")

    {
        fmt := s{}
        fmt.Println("yyy")
    }

    fmt.Println("xxx")
}

こんな感じで。中央のfmtという値を作った部分は書き換えられていない(すばらしい)。

--- before.go    2018-04-08 20:45:34.311077503 +0900
+++ after.go  2018-04-08 20:45:46.785757895 +0900
@@ -1,18 +1,18 @@
 package main
 
-import "fmt"
+import "fmt2"
 
 type s struct{}
 
-func (s s)Println(x string) {}
+func (s s) Println(x string) {}
 
-func main(){
-  fmt.Println("xxx")
+func main() {
+   fmt2.Println("xxx")
 
    {
        fmt := s{}
        fmt.Println("yyy")
    }
 
-  fmt.Println("xxx")
+   fmt2.Println("xxx")
 }

(実際にはembeddedのあたりの対応がこれだけでは不足していたりする)

code

実際に動作を確認できるコード。

package main

import (
    "go/ast"
    "go/parser"
    "go/printer"
    "log"
    "os"

    "golang.org/x/tools/go/ast/astutil"
    "golang.org/x/tools/go/loader"
)

func main() {
    if err := run(); err != nil {
        log.Fatalf("%+v", err)
    }
}

func run() error {
    source := `
package main

import "fmt"

type s struct{}

func (s s)Println(x string) {}

func main(){
  fmt.Println("xxx")

  {
      fmt := s{}
      fmt.Println("yyy")
  }

  fmt.Println("xxx")
}
`
    loader := loader.Config{ParserMode: parser.ParseComments}
    astf, err := loader.ParseFile("main.go", source)
    if err != nil {
        return err
    }
    loader.CreateFromFiles("main", astf)

    prog, err := loader.Load()
    if err != nil {
        return err
    }

    main := prog.Package("main")
    fmtpkg := prog.Package("fmt").Pkg
    for _, f := range main.Files {
        ast.Inspect(f, func(node ast.Node) bool {
            if t, _ := node.(*ast.SelectorExpr); t != nil {
                if main.Info.ObjectOf(t.Sel).Pkg() == fmtpkg {
                    ast.Inspect(t.X, func(node ast.Node) bool {
                        if t, _ := node.(*ast.Ident); t != nil {
                            if t.Name == "fmt" && t.Obj == nil {
                                t.Name = "fmt2"
                            }
                            return false
                        }
                        return true
                    })
                }
                return false
            }
            return true
        })

        astutil.RewriteImport(prog.Fset, f, "fmt", "fmt2")

        pp := &printer.Config{Tabwidth: 8, Mode: printer.UseSpaces | printer.TabIndent}
        pp.Fprint(os.Stdout, prog.Fset, f)
    }
    return nil
}