functools.partialの元を辿る

時折、functools.partialで作られた値の元の値(関数)を辿りたくなることがある。そのようなことをしたくなった場合のメモ。

単純に辿りたい場合

単純に辿りたい場合は.funcを見れば良い。ついでにクラスである場合もサポートしてあげると親切。

def find_original(fn):
    if isinstance(fn, partial):
        fn = fn.func
    if inspect.isclass(fn):
        fn = fn.__init__
    return fn

functools.partialの元を辿ることができる。

def f(x, y, *, z):
    return (x, y, z)

g = partial(partial(f, 10), z=20)
assert f == find_original(g)

あるいはクラスでも。

class Ob:
    def __init__(self, name, age=0):
        self.name = name
        self.age = age

make = partial(partial(Ob, "foo"), age=20)
assert Ob.__init__ == find_original(make)

nestしている場合

nestしている形も考えてwhileにする必要があるかと思うかもしれない。ifで大丈夫。partialが行われている時点で引数などはまとめられてしまうので。

assert f == partial(partial(f, 10, z=30), 20).func

test

テストはこんな感じに書けば良い(テスト書くの大事)。

import unittest


class Tests(unittest.TestCase):
    def _callFut(self, *args, **kwargs):
        return find_original(*args, **kwargs)

    def test_it(self):
        def f(x, y, *, z):
            return (x, y, z)

        f0 = partial(f, 10)
        f1 = partial(f, z=10)
        g = partial(partial(f, 10), z=20)

        class Ob:
            def __init__(self, name, age=0):
                self.name = name
                self.age = age

        x0 = partial(Ob, "foo")
        x1 = partial(Ob, age=10)
        y = partial(partial(Ob, "foo"), age=20)

        candidates = [
            (f, f),
            (f, f0),
            (f, f1),
            (f, g),
            (Ob.__init__, Ob),
            (Ob.__init__, x0),
            (Ob.__init__, x1),
            (Ob.__init__, y),
        ]
        for expected, target in candidates:
            with self.subTest(target=target):
                got = self._callFut(target)
                self.assertEqual(got, expected)

デフォルト値も一緒に取る

args,keywordsに入っている。

def find_original_with_arguments(fn):
    args = ()
    kwargs = {}
    if isinstance(fn, partial):
        args = fn.args
        kwargs = fn.keywords
        fn = fn.func
    if inspect.isclass(fn):
        fn = fn.__init__
    return fn, args, kwargs

型情報が欲しい場合

inspect.getfullargspecするとannotationsという属性に入っている。

import typing as t
import inspect


def f(
    x: int,
    y: int,
    *,
    z: int,
    i: t.Optional[int] = None,
    j: t.Optional[int] = None,
):
    return (x, y, z, i, j)


spec = inspect.getfullargspec(f)
print(spec.annotations)
# {'x': <class 'int'>, 'y': <class 'int'>, 'z': <class 'int'>, 'i': typing.Union[int, NoneType], 'j': typing.Union[int, NoneType]}