pythonで**kwargsにもう少し細かく型を付けたい

例えば以下の様な関数helloがあるとする。可変長引数を使って定義されている。とてもtrivialな例ではあるけれど説明用なので。。

from typing import Any


def greet(prefix: str, *, name: str) -> None:
    print(f"{prefix}, {name}")


def hello(**params: Any) -> None:
    greet("hello", **params)

ここでhelloを呼び出す際に以下のようにtypoしたとする。これをmypyなどの静的解析で検知したい。

# TypeError: greet() got an unexpected keyword argument 'nam'
hello(nam="foo")

TypedDict

仮に可変長引数の部分全体をDictとして見るような定義だったら上手くいく1。辞書に限って考えれば、typing.TypedDictを利用することで取りうる値の範囲に制限を加えられる。

from typing import TypedDict

class ParamsDict(TypedDict):
    name: str

# error: Extra key 'nam' for TypedDict "ParamsDict"
params : ParamsDict = {"nam": "foo"}

このTypedDictを上手く流用することで可変長引数も同様の形で制限できないか?というようなissueがmypyにも存在した。

まだできるようにはなっていないが要望はあるらしい。Extendというgeneric typeを用意して以下のように書けるようにするという方針らしい。

from typing import TypedDict
from mypy_extensions import Expand

# https://github.com/python/mypy/issues/4441


class ParamsDict(TypedDict):
    name: str


def greet(prefix: str, *, name: str) -> None:
    print(f"{prefix}, {name}")


def hello(**params: Expand[ParamsDict]) -> None:
    greet("hello", **params)


hello(nam="foo")

もちろん、現状では動かない。

work-around

それでは現在利用できる範囲でもう少し正確に型を付けたいときにはどうすれば良いかというと、リンク先のissueでも言及されていたがtyping.overloadを乱用することでごまかせるかもしれない。単に既存の型定義を上書きするだけといえばだけだけれど。定義が複数ないと怒られてしまうのでダミー的な定義も追加しておく。

from typing import overload, Any, TYPE_CHECKING


def greet(prefix: str, *, name: str) -> None:
    print(f"{prefix}, {name}")


@overload
def hello(*, name: str) -> None:
    ...


# suppress "Single overload definition, multiple required"
@overload
def hello(*, _: object = ...) -> None:
    ...


def hello(**params: Any) -> None:
    greet("hello", **params)


hello(name="foo")
hello(nam="x")

一応動く。

$ mypy --strict --pretty 02overload.py
02overload.py:24: error: No overload variant of "hello" matches argument type "str"
    hello(nam="x")
    ^
02overload.py:24: note: Possible overload variants:
02overload.py:24: note:     def hello(*, name: str) -> None
02overload.py:24: note:     def hello(*, _: object = ...) -> None

余談

まぁ、そもそもあんまり可変長引数を乱用するのもどうかと思うし、実装に触れるならこんなまどろっこしい事をせずに素直に全部の引数定義を明示的に書いてしまえば良いと言う話はある。なので、この方法は既存のライブラリに対するstubを作るようなときだけに使いたくなる方法かもしれない。

def greet(prefix: str, *, name: str) -> None:
    print(f"{prefix}, {name}")


def hello(*, name:str) -> None:
    greet("hello", name=name)

あとデコレーター用の関数などにはこの方法は適さない。

gist


  1. まぁご存知の通り、Dictの値部分だけの型を指定する事になっている。 https://mypy.readthedocs.io/en/latest/getting_started.html?highlight=var-args#more-function-signatures

enumとdataclassesを含んだ値をテキトーにJSONとしてseiralize/deserializeしたい

昔に似たようなタイトルの記事を書いていましたが、これとはちょっと違った内容です。

enumやdataclassesを含んだ値をテキトーにJSONとしてserialize/desserializeしたくなった。パフォーマンスは全く気にしなくて良い。 どちらかといえば、schemaの定義なしにJSON側にpythonのオブジェクトの構造に関する情報を持たせたいというニュアンスのほうが強い。 テキストファイルとしてエディタで開いて出力された値を書き換えたりなどしたかったので、pickleではダメだった。

とりあえず、以下の様なオブジェクトが一致することを目指す。

import enum
import dataclasses


class Color(enum.Enum):
    Red = 1
    Green = 2
    Blue = 3
    Yellow = 4


@dataclasses.dataclass
class Person:
    name: str
    initial: str = dataclasses.field(init=False)
    color: Color

    def __post_init__(self):
        self.initial = self.name.upper()[0]


p = Person(name="foo", color=Color.Green)
print(p == loads(dumps(p)))

雑にまとめるとこんな感じ。

  • dumps(),loads()を定義したい
  • init=Falseなフィールドが存在する
  • enumのフィールドが存在する
  • dataclassesを使ったフィールドが存在する(上の例ではフィールドにはdataclassesが使われていないが、実際には使われている)

これに関するserialize/deserializeをでっちあげたかった。

dumps

enum__enum__ というフィールドでwrapして出力することにする。ペアとなる値は定義したEnumの属性名。例えば、Color.Greenは以下の様に出力される。

{"__enum__": "__main__.Color.Green"}

dataclasssesは __dataclass__ というフィールドでwrapして出力することにする。こちらはどのようなデータかという情報も欲しいので、typeとvalueというフィールドを持った辞書にする。どのモジュールに所属しているかの情報もほしいのでフルパスでtypeは出力する

{
  "__dataclass__": {
    "type": "__main__.Person",
    "value": {
      "name": "foo",
      "initial": "F",
      "color": {
        "__enum__": "__main__.Color.Green"
      }
    }
  }
}

init=Falseなフィールドは、pythonコンストラクターを呼び出したときに利用できない。自動で生成される __init__() の引数として扱えない。 dataclasses.fields()を利用して、対象のフィールドを探してpopする。

実装はこんな感じ。手抜きの実装をするときに、キーワード引数を初期値部分にキャッシュ用のdictを入れるのが便利(二度とdeleteはできなくなるけど)。

import enum
import json
import dataclasses


def dumps(ob):
    return json.dumps(ob, indent=2, default=_default)


def _default(ob, *, _cache={}):
    if dataclasses.is_dataclass(ob):
        d = dataclasses.asdict(ob)
        omit_keys = _cache.get(ob.__class__)
        if omit_keys is None:
            omit_keys = _cache[ob.__class__] = [
                f.name for f in dataclasses.fields(ob) if not f.init
            ]
        for k in omit_keys:
            d.pop(k)
        return {
            "__dataclass__": {
                "type": f"{ob.__class__.__module__}.{ob.__class__.__name__}",
                "value": d,
            }
        }
    elif isinstance(ob, enum.Enum):
        return {
            "__enum__": f"{ob.__class__.__module__}.{ob.__class__.__name__}.{ob.name}",
        }
    raise TypeError(f"unexpected type {ob!r}")

loads

dumpsを作ったタイミングでwrapした__enum____dataclass__に対する特別扱いを入れる。JSONのオブジェクト部分のhookとしてobject_pairs_hookが用意されている。ここにdictではなく自作の関数を定義する。あと特筆する事があるとしたら、sys.modulesとimportlib.import_moduleを見て、モジュールをimportすることくらい。

実装は以下の様な感じ。

import sys
import json
from importlib import import_module


def loads(s):
    return json.loads(s, object_pairs_hook=_on_pairs)


def _on_pairs(itr):
    d = {k: v for k, v in itr}
    if "__enum__" in d:
        module, clsname, attr = d["__enum__"].rsplit(".", 3)
        m = sys.modules.get(module)
        if m is None:
            m = import_module(module)
        cls = getattr(m, clsname)
        return getattr(cls, attr)
    elif "__dataclass__" in d:
        v = d["__dataclass__"]
        module, clsname = v["type"].rsplit(".", 2)
        m = sys.modules.get(module)
        if m is None:
            m = import_module(module)
        cls = getattr(m, clsname)
        return cls(**v["value"])
    else:
        return d

実行

実際動く。validationなどは無い。

if __name__ == "__main__":

    class Color(enum.Enum):
        Red = 1
        Green = 2
        Blue = 3
        Yellow = 4

    @dataclasses.dataclass
    class Person:
        name: str
        initial: str = dataclasses.field(init=False)
        color: Color

        def __post_init__(self):
            self.initial = self.name.upper()[0]

    p = Person(name="foo", color=Color.Green)
    print(dumps(p))
    s = """
{
  "__dataclass__": {
    "type": "__main__.Person",
    "value": {
      "name": "foo",
      "color": {
        "__enum__": "__main__.Color.Green"
      }
    }
  }
}
"""
    print(loads(s))
    print(p == loads(dumps(p)))

実行結果

{
  "__dataclass__": {
    "type": "__main__.Person",
    "value": {
      "name": "foo",
      "color": {
        "__enum__": "__main__.Color.Green"
      }
    }
  }
}
Person(name='foo', initial='F', color=<Color.Green: 2>)
True

細々とした話

  • init=Falseにするくらいならpropertyを使ってはどうか? -> 値を更新したくなる。setterを書こうとすると保存する属性をアレコレするのが面倒
  • 本当は更新した値をserializeしたあとにその値を使いたくない? -> 使いたい(頑張ればできるが。。)
  • 結局、init=Falseとか使ってしまうと対称性が壊れる。全部引数として受け入れるようにして、別途クラスメソッドなどを定義したら? -> めんどくさい
  • popダサくない? -> ダサい

(もちろん、frozen=Trueにしてdataclasses.replaceを利用というimmutableっぽい使い方ができるならやっている。何も無いところからはじめるならそれも検討に入れる)

gist

追記

gistだけ修正した。

  • init=Falseに対応
  • dataclasses.asdict()を使うとネストしたdataclassesに対応できない -> 修正