mypyでProtocolを使ってmix-inを利用したクラスに型をつける

たまには、他の人の役に立つ記事も書こうということで書いてみる。

例えば、以下のようなmix-inを使ったコードがあるとする。トリビアルな例で特に必要になりそうなコードではないけれど、まぁ説明のためのコードなので許してほしい。

EnumerableMixinはmap()を提供していて、このmap()や他のメソッド(定義されていないけれど)は、each()に依存するというmix-in。そしてListeach()を実装するクラス。何の変哲もないmix-inのコード。

class EnumerableMixin:
    def map(self, fn):
        return [fn(x) for x in self.each()]


class List(EnumerableMixin):
    def __init__(self, xs):
        self.xs = xs

    def each(self):
        return iter(self.xs)

実行結果はこう。

L = List([10, 20, 30])
print(L.map(lambda x: x * x))
# [100, 400, 900]

このようなコードに型を付けたい。

Genericを使って定義してみる

雰囲気でmypyで型を付けてみると以下の様な感じになるのではないか?ただこうしてしまうとeachが困る。

import typing as t

A = t.TypeVar("A", covariant=True)
B = t.TypeVar("B")


class EnumerableMixin(t.Generic[A]):
    def map(self, fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)

存在していないメソッドに依存してしまっているわけなので、当然といえば当然。

$ mypy --strict 01map.py
01map.py:9: error: "EnumerableMixin[A]" has no attribute "each"

このエラーをどうやって潰そうか?というのが今回の主題。

Protocolを使う

duck typingで詰まったらProtocolというのがmypyをいじってきて感じる経験知かもしれない。今回もProtocolを使うことにする(python3.8からはtyping_extensionsは不要)。

import typing as t
import typing_extensions as tx

A = t.TypeVar("A", covariant=True)
B = t.TypeVar("B")


class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        ...


class EnumerableMixin(t.Generic[A]):
    def map(self: HasEach[A], fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)


L = List([10, 20, 30])
if t.TYPE_CHECKING:
    reveal_type(L)
result = L.map(lambda x: x * x)
if t.TYPE_CHECKING:
    reveal_type(result)
print(result)

実行結果。問題なさそう。HasEachというProtocolを定義したのが肝。これをmapの定義時のselfの型として利用する。

$ mypy --strict 02map.py
02map.py:28: note: Revealed type is '02map.List[builtins.int*]'
02map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'

selfにProtocolを与えた場合の制限

ただしこの方法にも1つ制限があって、自分自身の持つメソッドや属性にアクセスできなくなる。

import typing_extensions as tx


class P(tx.Protocol):
    def foo(self) -> str:
        ...


class M:
    def bar(self: P) -> None:
        print(self.foo(), self.foo())

        # error: "P" has no attribute "boo"
        print(self.boo())

    def boo(self) -> str:
        return "boo"

自分自身がProtocolを継承する

先程の部分的にselfにProtocolを指定する方法は悪くはないのだけれど。制限がある。それを解消するために自分自身がProtocolとなっても良い。以下の例ではEnumerableMixin自身がHasEachを継承している。

class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        ...


class EnumerableMixin(HasEach[A]):
    def map(self, fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)


L = List([10, 20, 30])
if t.TYPE_CHECKING:
    reveal_type(L)
result = L.map(lambda x: x * x)
if t.TYPE_CHECKING:
    reveal_type(result)
print(result)

このようにしても型チェックは通る。

$ mypy --strict 03map.py
03map.py:28: note: Revealed type is '03map.List[builtins.int*]'
03map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'

Protocolを継承したときの弱点

ただし、この方法にも一つだけ弱点がある。それはListがeach()を実装していなくても型チェックが通ってしまうこと。例えばこういうコードでも型チェックが通ってしまう。

class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    # def each(self) -> t.Iterator[A]:
    #     return iter(self.xs)

実行すると以下の様なエラーが出る。そしてこれはちょっと分かりづらいかもしれない。

$ python 03map.py
Traceback (most recent call last):
  File "03map.py", line 29, in <module>
    result = L.map(lambda x: x * x)
  File "03map.py", line 15, in map
    return [fn(x) for x in self.each()]
TypeError: 'NoneType' object is not iterable

というわけで以下のようにprotocolを定義した方が良いかもしれない。

class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        raise NotImplementedError("each")

補足

(ちなみに先程のselfだけにProtocolを指定したコードは、mixin-in自体はProtocolを継承していないのでeach()の実装にたどり着けない。なのでmap()を利用した時点でエラーが出る)。

02map.py:29: error: Invalid self argument "List[int]" to attribute function "map" with type "

gist

いつものgist

最近はprestringにcodeobjectというものを組み込もうとしている

https://github.com/podhmo/prestring:cite:emebed

(この機能は妄想の類でまだ未完成なものなので注意)

最近はprestringにcodeobjectというものを組み込もうとしている。実はこのcodeobjectというものの雛形は既に作って使われていたりする。例えばhandofcatsmonogusaの中で。

なぜcodeobjectが欲しくなったのか?

なぜcodeobjectが欲しくなったのかを頭の中の整理も兼ねて文章にしてみることにする。prestringの出力するコード上の名前の管理が難しいと感じたのが発端だった。

prestringは文字列を生成するコードなのだけれど。通常のpythonオブジェクトを触るのと同じ感覚で取り扱おうとしたときに困った点が幾つか出てきた。これはwith構文を使っていることによる欠点かもしれない。これを解消する術がないかと言うようなことを考えていた。

定義はできるが利用ができない

通常プログラミングでは定義とその利用が存在する。一方でprestring上のコードでは定義を気軽に書くことはできてもその利用を記述することが難しい。

定義はできる

例えばfoo()という文字列を返す関数をprestringで記述してみることにする。

from prestring.python import Module

m = Module()
with m.def_("foo"):
    m.return_("'foo'")
print(m)

これは何の変哲もない以下の様なコードを出力する。

def foo():
    return 'foo'

このコードに引数を追加して関数でwrapしたりなどすることで、このfooの定義部分にバリエーションを持たせる事はできる。例えば以下の様な形で。関数名を変えてみたり関数の内部の挙動を一部変更したりすることを受け入れる事はできる。定義部分を可変にするのは気軽な行い。

def emit_foo(m: Module, name: str, *, sep: str):
    with m.def_(name, "message: str"):
        m.return_(f"f'foo{sep}{{message}}'")
    return m


m = Module()
m = emit_foo(m, "do_foo", sep=":")

# def do_foo(message: str):
#     return f'foo:{message}'

利用ができない (やり辛い)

一方でこの関数を利用しようとした場合にはどうか?

"do_foo"というemit_foo()関数に渡した文字列に頼ることになる。あるいは1つ前の例であれば、foo()という名前の関数が出力されるであろうことを知っている必要があった。つまり利用する際にはこの関数の名前などを明確に認識する必要があった。

直近の例で言えばこの定義した関数であるdo_foo()を利用する側のコードは以下の様になる。

with m.for_("i", "range(5)"):
    m.stmt("do_foo(str(i))")

# for i in range(5):
#    do_foo(str(i))

せめて、m.stmt(do_foo(i))のような形で扱えないものだろうか?

利用のしづらさは変数の名前にも

同様のことは出力されるであろうコード中の変数間でも起きる。例えばテキトーにRMSEを計算するような関数を書いてみる。これにしたってaccという変数名への参照くらいは文字列としてのそれを意識せず使いたい。

with m.def_('rmse', 'xs', 'ys'):
    m.stmt('acc = 0')
    m.stmt('assert len(xs) == len(ys)')
    with m.for_('x, y', 'zip(xs, ys)'):
        m.stmt('acc += (x - y) ** 2')
    m.return_('math.sqrt(acc / len(xs))')

# def rmse(xs, ys):
#     acc = 0
#     assert len(xs) == len(ys)
#     for x, y in zip(xs, ys):
#         acc += (x - y) ** 2
#     return math.sqrt(acc / len(xs))

既存のモジュールをimportするときにも

同様のことは既存のモジュールをimportするときにも起こる。例えば現状のprestringでは特定のモジュールをimportする文とimportされたものを使う文には何の結びつきもなかった。

例えば、ちょっとした標準入力をparseする以下の様なコードがあるとする。

import re
import sys


pattern = re.compile('^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\\s*(?P<message>\\S+)', re.IGNORECASE)
for line in sys.stdin:
    m = pattern.search(line)
    if m is not None:
        print(m.groupdict())

このコードのimport部分などを含めて既存のprestringの場合には以下の様に書く必要がある。importされるreモジュールやsysモジュールの文とそれを利用する文には繋がりがない。従って存在するであろう名前のリストを頭の中に入れた上で書く必要があった。

from prestring.python import Module

m = Module()

m.import_("re")
m.import_("sys")
m.sep()
m.stmt(
    "pattern = re.compile({!r}, re.IGNORECASE)",
    r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
)

with m.for_("line", "sys.stdin"):
    m.stmt("m = pattern.search(line)")
    with m.if_("m is not None"):
        m.stmt("print(m.groupdict())")
print(m)

きびしい。

総括

まとめると以下のような形。

  • prestringは定義は記述できるが利用が記述し辛い

    • 利用で煩雑になるのは名前の管理 (symbolの管理)
    • 利用とは

      • 同一モジュール内で定義したものへの参照 (e.g. 定義した関数の利用)
      • 同一スコープ内で定義したものへの参照 (e.g. 関数内での変数の利用)
      • 別のモジュールからimportしてきたものへの参照 (e.g. 既存のモジュールからimportしてきた関数の利用)

codeobject

そういうわけで利用のしやすさを気にした試みをcodeobjectと呼ぶことにしている。codeがオブジェクトになっているので実行しておしまいではなく後から参照(利用)できるという形になれば良い。

importでSymbolオブジェクトが返る

(これは既に現在のprestringに組み込まれている(pythonのみ))

例えばimportの例を考えてみる。今までは何も返さなかったm.import_()がSymbolオブジェクトを返すという形にしたらどうだろう?

ここでSymbolオブジェクトとは文字列化したときにそれそのものの表現を返すもののこと。そこからgetattrしたり関数感覚で呼び出した場合にもそれそのものの表現を返す。

foo = Symbol("foo")

print(str(foo)) # -> foo
print(str(foo.bar("name", 1))) # -> foo.bar("name", 1)

これがimportで返ってくると先ほどのimportされるコードの例は以下の様に書ける。

from prestring.python import Module

m = Module()

re = m.import_("re")
sys = m.import_("sys")

m.sep()
m.stmt(
    "pattern = {}",
    re.compile(
        r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
        re.IGNORECASE,
    ),
)

with m.for_("line", sys.stdin):
    m.stmt("m = pattern.search(line)")
    with m.if_("m is not None"):
        m.stmt("print(m.groupdict())")
print(m)

違いがわかりにくいかもしれないのでdiffを貼っておく。m.import_()がsymbolオブジェクトを返しそれを通常のpythonオブジェクトと同じ感覚で利用できるようになった。

--- 03use-import.py  2020-02-23 13:37:24.000000000 +0900
+++ 04use-import-symbol.py    2020-02-23 13:36:11.000000000 +0900
@@ -2,15 +2,19 @@
 
 m = Module()
 
-m.import_("re")
-m.import_("sys")
+re = m.import_("re")
+sys = m.import_("sys")
+
 m.sep()
 m.stmt(
-    "pattern = re.compile({!r}, re.IGNORECASE)",
-    r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
+    "pattern = {}",
+    re.compile(
+        r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
+        re.IGNORECASE,
+    ),
 )
 
-with m.for_("line", "sys.stdin"):
+with m.for_("line", sys.stdin):
     m.stmt("m = pattern.search(line)")
     with m.if_("m is not None"):
         m.stmt("print(m.groupdict())")

これは一歩前進といえるかもしれない(おそらくこの感覚に共感があるのは自分だけだろうけれど。presetringのユーザーはたぶん自分ひとりなので)。

一方でこのコードで浮かび上がってくるのは変数の参照部分(利用部分)の煩雑さ。

Symbolオブジェクトが返る代入文

変数間の参照も同様に手軽に行いたい。Symbolオブジェクトを手にする事ができれば同様の成功体験を変数間の参照に対しても得られる様になる。はず。また、print()などの組み込みのsymbolに対しても同様に扱いたい。

ここで let() というメソッドを用意することにした。そして既存のprestringのモジュールのメソッドとして組み込むのは憚られるということで別途CodeObjectModuleというオブジェクトから参照することにしてみる(この辺りのAPI設計はまだ固まっていない)。

co = CodeobjectModule(m)

co.import_("re")  # m.import_()でも良い
pattern = co.let("pattern", re.compile(".*"))

print(pattern) # -> pattern

print(m)
# import re
# pattern = re.compile(".*")

また、直接Symbolオブジェクトを返すsymbol()というメソッドも用意している。これらを使うと以下の様に書ける様になる。

from prestring.python import Module
from prestring.codeobject import CodeObjectModule

m = Module()
co = CodeObjectModule(m)

re = co.import_("re")
sys = co.import_("sys")

m.sep()
pattern = co.let(
    "pattern",
    re.compile(
        r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
        re.IGNORECASE,
    ),
)


with m.for_("line", sys.stdin):
    matched = co.let("matched", pattern.search(co.symbol("line")))
    with m.if_(f"{matched} is not None"):
        print_ = co.symbol("print")
        m.stmt(print_(matched.groupdict()))
print(m)

こちらも先程のimportだけからなるコードに対するdiffを貼っておく。

--- 04use-import-symbol.py   2020-02-23 13:36:11.000000000 +0900
+++ 06use-import-code-object.py   2020-02-23 13:49:22.000000000 +0900
@@ -1,21 +1,25 @@
 from prestring.python import Module
+from prestring.codeobject import CodeObjectModule
 
 m = Module()
+co = CodeObjectModule(m)
 
-re = m.import_("re")
-sys = m.import_("sys")
+re = co.import_("re")
+sys = co.import_("sys")
 
 m.sep()
-m.stmt(
-    "pattern = {}",
+pattern = co.let(
+    "pattern",
     re.compile(
         r"^(?P<label>DEBUG|INFO|WARNING|ERROR|CRITICAL):\s*(?P<message>\S+)",
         re.IGNORECASE,
     ),
 )
 
+
 with m.for_("line", sys.stdin):
-    m.stmt("m = pattern.search(line)")
-    with m.if_("m is not None"):
-        m.stmt("print(m.groupdict())")
+    matched = co.let("matched", pattern.search(co.symbol("line")))
+    with m.if_(f"{matched} is not None"):
+        print_ = co.symbol("print")
+        m.stmt(print_(matched.groupdict()))
 print(m)

(ちなみにインターフェイスがまだ固まりきってはいないが、CodeObjectModuleMixinというMixinクラスも用意していて、これを継承したprestring.python.PythonModuleを作ってあげれば、m.let()などとすべてをmだけで書けるようにはなる)

定義した関数への参照

(これはまだ未実装)

例えばこのSymbolオブジェクトを返す試みをm.def_()などで定義を書いたときに返してみてはどうだろう? 幸い(?)pythonのwith構文はスコープを持たないのでas開いた参照はずっと利用できる。

with m.def_("hello", "name") as hello:
    m.stmt("print(f'hello: {name}')")

# use hello
m.stmt(hello("foo"))  # -> hello("foo")
m.stmt(hello("bar"))  # -> hello("bar")

あるいは記述したコードの実行を遅らせつつ参照だけをしたいという場合にはcodeobjectというデコレータを付加して扱うということも考えられる(実はmonogusaで一部これを利用している)。

@codeobject
def hello(m: Module, name: str) -> Module:
    with m.def_(name, "name") as hello:
        m.stmt("print(f'hello: {name}')")
    return m

何かしらの操作(call,getattr)を行うとsymbolオブジェクトとして振る舞う。利用ができる。

m.stmt(hello("foo"))  # -> hello("foo")

単体でstmt()を利用すると定義が出力される。

m.stmt(hello)

# def hello(name):
#     print(f'hello {name}')

goへの利用

ここからはおまけでこれらの機能をgoでも扱えないかと言うようなことを考えたりしていた。

例えばこんな感じの記述ができるようになる。

from prestring.codeobject import CodeObjectModule
from go import Module


m = Module()
co = CodeObjectModule(m, assign_op=":=")

conf = co.import_("github.com/podhmo/appkit/conf")
foo = co.import_("github.com/podhmo/appkit/foo")

with m.func("run", return_="error"):
    filename = co.let("filename", "*config")
    use = co.symbol("use")

    c, err = co.letN(("c", "err"), conf.LoadConfig(filename))
    with m.if_(f"{err} != nil"):
        m.return_(err)

    fooOb = co.let("fooOb", foo.FromConfig(c))
    m.return_(use(fooOb))
print(m)

以前よりかは文字列味が減ったように感じる(ダブルクォートがうるさくさない)。 ちなみに以下の様なコードを出力する。

func run() error {
    filename := *config
    c, err := conf.LoadConfig(filename)
    if err != nil  {
        return err
    }
    fooOb := foo.FromConfig(c)
    return use(fooOb)
}

ただし幾つか課題がある

  • :== の使い分けをどうするか?
  • ポインターの扱いをどうするか?
  • goでは同一パッケージで異なるファイルがありうる (pythonは1ファイル1モジュール)
  • (func(){\n ... \n}() という構文をどう扱うか? (with構文だと厳しい))

そして仮にpythonからgoを生成するならば以下の様な状況になってほしい。

  • python側のモジュールの階層構造をそのままgoに転写したい

    • うまくオブジェクトの __module__ を使い回す事はできないだろうか?

gist

いつもの