context managerの`__enter__()`の呼び忘れを防ぎたい

はじめに

context managerの__enter__()の呼び忘れを防ぎたい

# こちらは正しい __enter__()が呼ばれている。
def ok():
    with f():
        do_something()

# こちらはだめ
def ng():
    f()  # 警告を出したい

方針

これはおそらくあんまり方法がなくって。gcに回収されるまでにオブジェクトが__enter__を使ったかカウントする(実質1/0なのでboolで良いけれど)みたいな感じにすることが限界そう。

もちろんチェック用の処理を挿入するということだったら他の方法があるけれど。それを忘れない人はwithを忘れないので。

やったこと

shouldenterというデコレーターを作った。これでwrapしてあげたものを使えば良い。

import contextlib


def f():
    g(10)


def g(x):
    with h() as message:
        print(message)
    h()  # だめ


@shouldenter
@contextlib.contextmanager
def h():
    print("before")
    yield "hai"
    print("after")

hがcontext manager。

ちゃんと怒られる。

qr_2101oSF.py:74: UserWarning: should use as context manager, (__enter__() is not called)

ただ常にstack frameを見る実装になっているのですごい微妙な気がしている。

実装

こんな感じ。wrapperで包んであげてる。

import warnings
import sys
import functools


_debug = True


def setdebug(value):
    global _debug
    _debug = value


def getdebug():
    global _debug
    return _debug


def shouldenter(fn, level=2):
    @functools.wraps(fn)
    def _shouldenter(*args, **kwargs):
        return ShouldEnter(fn(*args, **kwargs), level=level)

    return _shouldenter


class ShouldEnter:
    def __init__(self, internal=None, level=1, debug=False, message_class=UserWarning):
        self.internal = internal
        self.message_class = message_class
        self.used = False

        self.lineno = 0
        self.filename = None

        if debug or getdebug():
            # get context information for warning message
            f = sys._getframe(level)
            self.lineno = f.f_lineno
            self.filename = f.f_code.co_filename

    def __enter__(self):
        self.used = True
        if hasattr(self.internal, "__enter__"):
            return self.internal.__enter__()
        else:
            return self.internal

    def __exit__(self, exc_type, exc_value, traceback):
        if hasattr(self.internal, "__exit__"):
            return self.internal.__exit__(exc_type, exc_value, traceback)
        else:
            return None

    def __del__(self):
        if not self.used:
            self.warn()

    def warn(self):
        msg = "should use as context manager, (__enter__() is not called)"
        if self.filename is None:
            warnings.warn(msg, self.message_class)
        else:
            warnings.warn_explicit(msg, self.message_class, self.filename, self.lineno)

メソッドを置き換えたmockをもう少しstrictにしてみたい

メソッドを置き換えたmockをもう少しstrictにしてみたい。mockのpatchなどでobjectを置き換える時に属性の存在まではspecやspec_setで対応できるのだけれど。メソッドのsignatureまで含めて置き換え前のものと同じかどうか確認したい。

例えば、存在しない属性へのアクセスはエラーになる(これは期待通り)

class Ob:
    def hello(self, name):
        return "hello:{}".format(name)


class Tests(unittest.TestCase):
    def test_attr_missing(self):
        # 属性無しはOK
        m = mock.Mock(spec_set=Ob)
        with self.assertRaises(AttributeError):
            m.bye()

一方でこれはダメ、元のメソッドから見たら引数不足なものの、mockはそれを関知しない(これは期待通りではない)

class Tests(unittest.TestCase):
    def test_mismatch_signature(self):
        m = mock.Mock(spec_set=Ob)
        m.hello.side_effect = lambda: "*replaced*"

        # Ob.hello()から見たら不正な呼び出しなのだけれど。置き換えたmockとは合っているので通ってしまう
        got = m.hello()
        self.assertEqual(got, "*replaced*")

本来であればTypeErrorなどが発生してほしい。

Ob().hello()
# TypeError: hello() missing 1 required positional argument: 'name'

現状のwork-around

雑に replace_method_with_signature_check() という名前の関数を定義している。

これを使うと以下の様にAssertionErrorが出るようになる。

class Tests(unittest.TestCase):
    def test_mismatch_signature(self):
        m = mock.Mock(spec_set=Ob)

        # Ob.hello()に対して引数が不足した定義
        def hello():
            return "*replaced*"

        replace_method_with_signature_check(m, hello)

        got = m.hello()
        self.assertEqual(got, "*replaced*")

# AssertionError: expected hello()'s signature: (name), but ()

ちゃんとsignatureを考慮して見てくれる。もちろん、まともなsignatureの合った定義に書き換えたら呼び出し側の引数の不一致がわかりTypeErrorになる。

一応mockじゃないものに利用してしまった場合の事も考慮して type(m)m.__class__ を比較している(mockとmock以外を見分けるイディオム)。

実装

実装は以下の様な感じ。

import inspect

def replace_method_with_signature_check(m, fn, name=None):
    """mock中のmethodをsignatureを考慮して書き換えるもの"""
    spec = m.__class__
    typ = type(m)
    name = name or fn.__name__

    assert typ != spec, "{} == {}, maybe spec is not set?".format(typ, spec)

    sig_repr = str(inspect.signature(getattr(spec, name)))
    sig_repr = sig_repr.replace('(self, ', '(')  # xxx work-around
    fn_sig_repr = str(inspect.signature(fn))
    assert sig_repr == fn_sig_repr, "expected {}()'s signature: {}, but {}".format(name, sig_repr, fn_sig_repr)
    attr = getattr(m, name)
    attr.side_effect = fn

微妙な点も残っていて、メソッドの置き換えを考慮するのに、self部分をカットする部分がすごく雑。これは isnpect.signature() で取れる値の引数部分が変更不可能なせいでもあるのだけれど。本当に真面目に頑張るのならinspect.getfullargspec()の方を利用した方が良いかもしれない。

置き換えをオブジェクトで

もうちょっと不格好じゃない形で置き換えをしたい場合にはオブジェクトにしたほうが良いのかもしれない。

class MethodReplacer:
    def __init__(self, m):
        self.m = m

    def __getattr__(self, name):
        return partial(replace_method_with_signature_check, self.m, name=name)


m = mock.Mock(spec_set=Ob)
rep = MethodReplacer(m)
rep.hello(lambda name: "*replaced*")