functools.partialの元を辿る
時折、functools.partialで作られた値の元の値(関数)を辿りたくなることがある。そのようなことをしたくなった場合のメモ。
単純に辿りたい場合
単純に辿りたい場合は.funcを見れば良い。ついでにクラスである場合もサポートしてあげると親切。
def find_original(fn): if isinstance(fn, partial): fn = fn.func if inspect.isclass(fn): fn = fn.__init__ return fn
functools.partialの元を辿ることができる。
def f(x, y, *, z): return (x, y, z) g = partial(partial(f, 10), z=20) assert f == find_original(g)
あるいはクラスでも。
class Ob: def __init__(self, name, age=0): self.name = name self.age = age make = partial(partial(Ob, "foo"), age=20) assert Ob.__init__ == find_original(make)
nestしている場合
nestしている形も考えてwhileにする必要があるかと思うかもしれない。ifで大丈夫。partialが行われている時点で引数などはまとめられてしまうので。
assert f == partial(partial(f, 10, z=30), 20).func
test
テストはこんな感じに書けば良い(テスト書くの大事)。
import unittest class Tests(unittest.TestCase): def _callFut(self, *args, **kwargs): return find_original(*args, **kwargs) def test_it(self): def f(x, y, *, z): return (x, y, z) f0 = partial(f, 10) f1 = partial(f, z=10) g = partial(partial(f, 10), z=20) class Ob: def __init__(self, name, age=0): self.name = name self.age = age x0 = partial(Ob, "foo") x1 = partial(Ob, age=10) y = partial(partial(Ob, "foo"), age=20) candidates = [ (f, f), (f, f0), (f, f1), (f, g), (Ob.__init__, Ob), (Ob.__init__, x0), (Ob.__init__, x1), (Ob.__init__, y), ] for expected, target in candidates: with self.subTest(target=target): got = self._callFut(target) self.assertEqual(got, expected)
デフォルト値も一緒に取る
args,keywordsに入っている。
def find_original_with_arguments(fn): args = () kwargs = {} if isinstance(fn, partial): args = fn.args kwargs = fn.keywords fn = fn.func if inspect.isclass(fn): fn = fn.__init__ return fn, args, kwargs
型情報が欲しい場合
inspect.getfullargspecするとannotationsという属性に入っている。
import typing as t import inspect def f( x: int, y: int, *, z: int, i: t.Optional[int] = None, j: t.Optional[int] = None, ): return (x, y, z, i, j) spec = inspect.getfullargspec(f) print(spec.annotations) # {'x': <class 'int'>, 'y': <class 'int'>, 'z': <class 'int'>, 'i': typing.Union[int, NoneType], 'j': typing.Union[int, NoneType]}
argparseで--limitと--no-limitのような相互排他なオプションの定義の仕方のメモ
ドキュメントを見れば分かることだけれど、たまに忘れるので。
import argparse def parse(argv=None): parser = argparse.ArgumentParser() limit_group = parser.add_mutually_exclusive_group() limit_group.add_argument("--limit", dest="limit", type=int) limit_group.add_argument("--no-limit", dest="limit", action="store_const", conost=None) parser.set_defaults(limit=100) return parser.parse_args(argv)
結果
print(parse(["--limit", "100"])) # Namespace(limit=100) print(parse(["--limit", "50"])) # Namespace(limit=50) print(parse(["--no-limit"])) # Namespace(limit=None) print(parse(["--no-limit", "--limit", "50"])) # error: argument --limit: not allowed with argument --no-limit
詳細
parse_args後の値を決めるのはdest
オプション名と異なる名前で値を挿入したい場合にはdestをつける
まじめに相互排他にするならadd_mutually_exclusive_group()を使う
なくても動くけれど、2つのオプションを同時に使おうとしたときにどのような挙動になるのか不安になるので。
default値はset_defaults()をつけるとより安全
実際のところは、add_argument()にdefaultオプションを付けても良いけれど。複数の値を見る可能性がある場合(今回の場合は--limit
と--no-limit
がlimitを見る)には、set_defaults()が無難。
ちょっとデフォルト値とオプションの定義が離れたりで見にくくなったりするかもだけれど。