mypyでProtocolを使ってmix-inを利用したクラスに型をつける
たまには、他の人の役に立つ記事も書こうということで書いてみる。
例えば、以下のようなmix-inを使ったコードがあるとする。トリビアルな例で特に必要になりそうなコードではないけれど、まぁ説明のためのコードなので許してほしい。
EnumerableMixinはmap()
を提供していて、このmap()
や他のメソッド(定義されていないけれど)は、each()
に依存するというmix-in。そしてList
はeach()
を実装するクラス。何の変哲もないmix-inのコード。
class EnumerableMixin: def map(self, fn): return [fn(x) for x in self.each()] class List(EnumerableMixin): def __init__(self, xs): self.xs = xs def each(self): return iter(self.xs)
実行結果はこう。
L = List([10, 20, 30]) print(L.map(lambda x: x * x)) # [100, 400, 900]
このようなコードに型を付けたい。
Genericを使って定義してみる
雰囲気でmypyで型を付けてみると以下の様な感じになるのではないか?ただこうしてしまうとeachが困る。
import typing as t A = t.TypeVar("A", covariant=True) B = t.TypeVar("B") class EnumerableMixin(t.Generic[A]): def map(self, fn: t.Callable[[A], B]) -> t.List[B]: return [fn(x) for x in self.each()] class List(EnumerableMixin[A]): def __init__(self, xs: t.List[A]) -> None: self.xs = xs def each(self) -> t.Iterator[A]: return iter(self.xs)
存在していないメソッドに依存してしまっているわけなので、当然といえば当然。
$ mypy --strict 01map.py 01map.py:9: error: "EnumerableMixin[A]" has no attribute "each"
このエラーをどうやって潰そうか?というのが今回の主題。
Protocolを使う
duck typingで詰まったらProtocolというのがmypyをいじってきて感じる経験知かもしれない。今回もProtocolを使うことにする(python3.8からはtyping_extensionsは不要)。
import typing as t import typing_extensions as tx A = t.TypeVar("A", covariant=True) B = t.TypeVar("B") class HasEach(tx.Protocol[A]): def each(self) -> t.Iterator[A]: ... class EnumerableMixin(t.Generic[A]): def map(self: HasEach[A], fn: t.Callable[[A], B]) -> t.List[B]: return [fn(x) for x in self.each()] class List(EnumerableMixin[A]): def __init__(self, xs: t.List[A]) -> None: self.xs = xs def each(self) -> t.Iterator[A]: return iter(self.xs) L = List([10, 20, 30]) if t.TYPE_CHECKING: reveal_type(L) result = L.map(lambda x: x * x) if t.TYPE_CHECKING: reveal_type(result) print(result)
実行結果。問題なさそう。HasEach
というProtocolを定義したのが肝。これをmapの定義時のselfの型として利用する。
$ mypy --strict 02map.py 02map.py:28: note: Revealed type is '02map.List[builtins.int*]' 02map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'
selfにProtocolを与えた場合の制限
ただしこの方法にも1つ制限があって、自分自身の持つメソッドや属性にアクセスできなくなる。
import typing_extensions as tx class P(tx.Protocol): def foo(self) -> str: ... class M: def bar(self: P) -> None: print(self.foo(), self.foo()) # error: "P" has no attribute "boo" print(self.boo()) def boo(self) -> str: return "boo"
自分自身がProtocolを継承する
先程の部分的にself
にProtocolを指定する方法は悪くはないのだけれど。制限がある。それを解消するために自分自身がProtocolとなっても良い。以下の例ではEnumerableMixin
自身がHasEach
を継承している。
class HasEach(tx.Protocol[A]): def each(self) -> t.Iterator[A]: ... class EnumerableMixin(HasEach[A]): def map(self, fn: t.Callable[[A], B]) -> t.List[B]: return [fn(x) for x in self.each()] class List(EnumerableMixin[A]): def __init__(self, xs: t.List[A]) -> None: self.xs = xs def each(self) -> t.Iterator[A]: return iter(self.xs) L = List([10, 20, 30]) if t.TYPE_CHECKING: reveal_type(L) result = L.map(lambda x: x * x) if t.TYPE_CHECKING: reveal_type(result) print(result)
このようにしても型チェックは通る。
$ mypy --strict 03map.py 03map.py:28: note: Revealed type is '03map.List[builtins.int*]' 03map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'
Protocolを継承したときの弱点
ただし、この方法にも一つだけ弱点がある。それはListがeach()
を実装していなくても型チェックが通ってしまうこと。例えばこういうコードでも型チェックが通ってしまう。
class List(EnumerableMixin[A]): def __init__(self, xs: t.List[A]) -> None: self.xs = xs # def each(self) -> t.Iterator[A]: # return iter(self.xs)
実行すると以下の様なエラーが出る。そしてこれはちょっと分かりづらいかもしれない。
$ python 03map.py Traceback (most recent call last): File "03map.py", line 29, in <module> result = L.map(lambda x: x * x) File "03map.py", line 15, in map return [fn(x) for x in self.each()] TypeError: 'NoneType' object is not iterable
というわけで以下のようにprotocolを定義した方が良いかもしれない。
class HasEach(tx.Protocol[A]): def each(self) -> t.Iterator[A]: raise NotImplementedError("each")
補足
(ちなみに先程のselfだけにProtocolを指定したコードは、mixin-in自体はProtocolを継承していないのでeach()
の実装にたどり着けない。なのでmap()
を利用した時点でエラーが出る)。
02map.py:29: error: Invalid self argument "List[int]" to attribute function "map" with type "
gist
いつものgist