mypyでProtocolを使ってmix-inを利用したクラスに型をつける

たまには、他の人の役に立つ記事も書こうということで書いてみる。

例えば、以下のようなmix-inを使ったコードがあるとする。トリビアルな例で特に必要になりそうなコードではないけれど、まぁ説明のためのコードなので許してほしい。

EnumerableMixinはmap()を提供していて、このmap()や他のメソッド(定義されていないけれど)は、each()に依存するというmix-in。そしてListeach()を実装するクラス。何の変哲もないmix-inのコード。

class EnumerableMixin:
    def map(self, fn):
        return [fn(x) for x in self.each()]


class List(EnumerableMixin):
    def __init__(self, xs):
        self.xs = xs

    def each(self):
        return iter(self.xs)

実行結果はこう。

L = List([10, 20, 30])
print(L.map(lambda x: x * x))
# [100, 400, 900]

このようなコードに型を付けたい。

Genericを使って定義してみる

雰囲気でmypyで型を付けてみると以下の様な感じになるのではないか?ただこうしてしまうとeachが困る。

import typing as t

A = t.TypeVar("A", covariant=True)
B = t.TypeVar("B")


class EnumerableMixin(t.Generic[A]):
    def map(self, fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)

存在していないメソッドに依存してしまっているわけなので、当然といえば当然。

$ mypy --strict 01map.py
01map.py:9: error: "EnumerableMixin[A]" has no attribute "each"

このエラーをどうやって潰そうか?というのが今回の主題。

Protocolを使う

duck typingで詰まったらProtocolというのがmypyをいじってきて感じる経験知かもしれない。今回もProtocolを使うことにする(python3.8からはtyping_extensionsは不要)。

import typing as t
import typing_extensions as tx

A = t.TypeVar("A", covariant=True)
B = t.TypeVar("B")


class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        ...


class EnumerableMixin(t.Generic[A]):
    def map(self: HasEach[A], fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)


L = List([10, 20, 30])
if t.TYPE_CHECKING:
    reveal_type(L)
result = L.map(lambda x: x * x)
if t.TYPE_CHECKING:
    reveal_type(result)
print(result)

実行結果。問題なさそう。HasEachというProtocolを定義したのが肝。これをmapの定義時のselfの型として利用する。

$ mypy --strict 02map.py
02map.py:28: note: Revealed type is '02map.List[builtins.int*]'
02map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'

selfにProtocolを与えた場合の制限

ただしこの方法にも1つ制限があって、自分自身の持つメソッドや属性にアクセスできなくなる。

import typing_extensions as tx


class P(tx.Protocol):
    def foo(self) -> str:
        ...


class M:
    def bar(self: P) -> None:
        print(self.foo(), self.foo())

        # error: "P" has no attribute "boo"
        print(self.boo())

    def boo(self) -> str:
        return "boo"

自分自身がProtocolを継承する

先程の部分的にselfにProtocolを指定する方法は悪くはないのだけれど。制限がある。それを解消するために自分自身がProtocolとなっても良い。以下の例ではEnumerableMixin自身がHasEachを継承している。

class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        ...


class EnumerableMixin(HasEach[A]):
    def map(self, fn: t.Callable[[A], B]) -> t.List[B]:
        return [fn(x) for x in self.each()]


class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    def each(self) -> t.Iterator[A]:
        return iter(self.xs)


L = List([10, 20, 30])
if t.TYPE_CHECKING:
    reveal_type(L)
result = L.map(lambda x: x * x)
if t.TYPE_CHECKING:
    reveal_type(result)
print(result)

このようにしても型チェックは通る。

$ mypy --strict 03map.py
03map.py:28: note: Revealed type is '03map.List[builtins.int*]'
03map.py:31: note: Revealed type is 'builtins.list[builtins.int*]'

Protocolを継承したときの弱点

ただし、この方法にも一つだけ弱点がある。それはListがeach()を実装していなくても型チェックが通ってしまうこと。例えばこういうコードでも型チェックが通ってしまう。

class List(EnumerableMixin[A]):
    def __init__(self, xs: t.List[A]) -> None:
        self.xs = xs

    # def each(self) -> t.Iterator[A]:
    #     return iter(self.xs)

実行すると以下の様なエラーが出る。そしてこれはちょっと分かりづらいかもしれない。

$ python 03map.py
Traceback (most recent call last):
  File "03map.py", line 29, in <module>
    result = L.map(lambda x: x * x)
  File "03map.py", line 15, in map
    return [fn(x) for x in self.each()]
TypeError: 'NoneType' object is not iterable

というわけで以下のようにprotocolを定義した方が良いかもしれない。

class HasEach(tx.Protocol[A]):
    def each(self) -> t.Iterator[A]:
        raise NotImplementedError("each")

補足

(ちなみに先程のselfだけにProtocolを指定したコードは、mixin-in自体はProtocolを継承していないのでeach()の実装にたどり着けない。なのでmap()を利用した時点でエラーが出る)。

02map.py:29: error: Invalid self argument "List[int]" to attribute function "map" with type "

gist

いつものgist