functools.partialの元を辿る
時折、functools.partialで作られた値の元の値(関数)を辿りたくなることがある。そのようなことをしたくなった場合のメモ。
単純に辿りたい場合
単純に辿りたい場合は.funcを見れば良い。ついでにクラスである場合もサポートしてあげると親切。
def find_original(fn): if isinstance(fn, partial): fn = fn.func if inspect.isclass(fn): fn = fn.__init__ return fn
functools.partialの元を辿ることができる。
def f(x, y, *, z): return (x, y, z) g = partial(partial(f, 10), z=20) assert f == find_original(g)
あるいはクラスでも。
class Ob: def __init__(self, name, age=0): self.name = name self.age = age make = partial(partial(Ob, "foo"), age=20) assert Ob.__init__ == find_original(make)
nestしている場合
nestしている形も考えてwhileにする必要があるかと思うかもしれない。ifで大丈夫。partialが行われている時点で引数などはまとめられてしまうので。
assert f == partial(partial(f, 10, z=30), 20).func
test
テストはこんな感じに書けば良い(テスト書くの大事)。
import unittest class Tests(unittest.TestCase): def _callFut(self, *args, **kwargs): return find_original(*args, **kwargs) def test_it(self): def f(x, y, *, z): return (x, y, z) f0 = partial(f, 10) f1 = partial(f, z=10) g = partial(partial(f, 10), z=20) class Ob: def __init__(self, name, age=0): self.name = name self.age = age x0 = partial(Ob, "foo") x1 = partial(Ob, age=10) y = partial(partial(Ob, "foo"), age=20) candidates = [ (f, f), (f, f0), (f, f1), (f, g), (Ob.__init__, Ob), (Ob.__init__, x0), (Ob.__init__, x1), (Ob.__init__, y), ] for expected, target in candidates: with self.subTest(target=target): got = self._callFut(target) self.assertEqual(got, expected)
デフォルト値も一緒に取る
args,keywordsに入っている。
def find_original_with_arguments(fn): args = () kwargs = {} if isinstance(fn, partial): args = fn.args kwargs = fn.keywords fn = fn.func if inspect.isclass(fn): fn = fn.__init__ return fn, args, kwargs
型情報が欲しい場合
inspect.getfullargspecするとannotationsという属性に入っている。
import typing as t import inspect def f( x: int, y: int, *, z: int, i: t.Optional[int] = None, j: t.Optional[int] = None, ): return (x, y, z, i, j) spec = inspect.getfullargspec(f) print(spec.annotations) # {'x': <class 'int'>, 'y': <class 'int'>, 'z': <class 'int'>, 'i': typing.Union[int, NoneType], 'j': typing.Union[int, NoneType]}