functools.partialの元を辿る

時折、functools.partialで作られた値の元の値(関数)を辿りたくなることがある。そのようなことをしたくなった場合のメモ。

単純に辿りたい場合

単純に辿りたい場合は.funcを見れば良い。ついでにクラスである場合もサポートしてあげると親切。

def find_original(fn):
    if isinstance(fn, partial):
        fn = fn.func
    if inspect.isclass(fn):
        fn = fn.__init__
    return fn

functools.partialの元を辿ることができる。

def f(x, y, *, z):
    return (x, y, z)

g = partial(partial(f, 10), z=20)
assert f == find_original(g)

あるいはクラスでも。

class Ob:
    def __init__(self, name, age=0):
        self.name = name
        self.age = age

make = partial(partial(Ob, "foo"), age=20)
assert Ob.__init__ == find_original(make)

nestしている場合

nestしている形も考えてwhileにする必要があるかと思うかもしれない。ifで大丈夫。partialが行われている時点で引数などはまとめられてしまうので。

assert f == partial(partial(f, 10, z=30), 20).func

test

テストはこんな感じに書けば良い(テスト書くの大事)。

import unittest


class Tests(unittest.TestCase):
    def _callFut(self, *args, **kwargs):
        return find_original(*args, **kwargs)

    def test_it(self):
        def f(x, y, *, z):
            return (x, y, z)

        f0 = partial(f, 10)
        f1 = partial(f, z=10)
        g = partial(partial(f, 10), z=20)

        class Ob:
            def __init__(self, name, age=0):
                self.name = name
                self.age = age

        x0 = partial(Ob, "foo")
        x1 = partial(Ob, age=10)
        y = partial(partial(Ob, "foo"), age=20)

        candidates = [
            (f, f),
            (f, f0),
            (f, f1),
            (f, g),
            (Ob.__init__, Ob),
            (Ob.__init__, x0),
            (Ob.__init__, x1),
            (Ob.__init__, y),
        ]
        for expected, target in candidates:
            with self.subTest(target=target):
                got = self._callFut(target)
                self.assertEqual(got, expected)

デフォルト値も一緒に取る

args,keywordsに入っている。

def find_original_with_arguments(fn):
    args = ()
    kwargs = {}
    if isinstance(fn, partial):
        args = fn.args
        kwargs = fn.keywords
        fn = fn.func
    if inspect.isclass(fn):
        fn = fn.__init__
    return fn, args, kwargs

型情報が欲しい場合

inspect.getfullargspecするとannotationsという属性に入っている。

import typing as t
import inspect


def f(
    x: int,
    y: int,
    *,
    z: int,
    i: t.Optional[int] = None,
    j: t.Optional[int] = None,
):
    return (x, y, z, i, j)


spec = inspect.getfullargspec(f)
print(spec.annotations)
# {'x': <class 'int'>, 'y': <class 'int'>, 'z': <class 'int'>, 'i': typing.Union[int, NoneType], 'j': typing.Union[int, NoneType]}

argparseで--limitと--no-limitのような相互排他なオプションの定義の仕方のメモ

ドキュメントを見れば分かることだけれど、たまに忘れるので。

import argparse


def parse(argv=None):
    parser = argparse.ArgumentParser()
    limit_group = parser.add_mutually_exclusive_group()
    limit_group.add_argument("--limit", dest="limit", type=int)
    limit_group.add_argument("--no-limit", dest="limit", action="store_const", conost=None)
    parser.set_defaults(limit=100)
    return parser.parse_args(argv)

結果

print(parse(["--limit", "100"]))  # Namespace(limit=100)
print(parse(["--limit", "50"]))  # Namespace(limit=50)
print(parse(["--no-limit"]))  # Namespace(limit=None)
print(parse(["--no-limit", "--limit", "50"]))  # error: argument --limit: not allowed with argument --no-limit

詳細

parse_args後の値を決めるのはdest

オプション名と異なる名前で値を挿入したい場合にはdestをつける

まじめに相互排他にするならadd_mutually_exclusive_group()を使う

なくても動くけれど、2つのオプションを同時に使おうとしたときにどのような挙動になるのか不安になるので。

default値はset_defaults()をつけるとより安全

実際のところは、add_argument()にdefaultオプションを付けても良いけれど。複数の値を見る可能性がある場合(今回の場合は--limit--no-limitがlimitを見る)には、set_defaults()が無難。 ちょっとデフォルト値とオプションの定義が離れたりで見にくくなったりするかもだけれど。

参考