pytestのfixtureとcontextlib.contextmanagerでの例外の取り扱い方の違い

前回の記事でpytestのfixtureでもteardownが実行されることを確実にするにはtry-finallyで囲む必要があるという風に書いてしまっていた。

特にtestでの利用を想定して作ったわけではないけれど。test時のsetup/teardownのことを考えると、途中のコードが失敗した時にも必ずteardown部分が実行されて欲しいということがあるかもしれない。この場合にはtry-finallyで括ってあげるのが無難(たぶんpytestの方もそうなはずだけれど確認していない)

これが嘘という話

pytestでは正しくteardownが実行できている

例えば以下のようなコードをpytestで実行した時に、after部分のprintがちゃんと呼ばれる。

import pytest


@pytest.fixture
def ob():
    ob = object()
    print("")
    print("<<< before", id(ob))
    yield ob
    print(">>> after", id(ob))


def test_it(ob):
    print("**test", id(ob), "**")
    1 / 0

>>> after 4375605680 という出力が行われている。

pytest -s .
================================================= test session starts =================================================
platform darwin -- Python 3.5.2, pytest-3.0.2, py-1.4.31, pluggy-0.3.1
rootdir:$HOME/vboxshare/venvs/my3/sandbox/daily/20170722/example_pytest, inifile:
collected 1 items

02exception/test_it.py
<<< before 4375605680
**test 4375605680 **
F>>> after 4375605680


====================================================== FAILURES =======================================================
_______________________________________________________ test_it _______________________________________________________

ob = <object object at 0x104ce71b0>

    def test_it(ob):
        print("**test", id(ob), "**")
>       1 / 0
E       ZeroDivisionError: division by zero

02exception/test_it.py:15: ZeroDivisionError
============================================== 1 failed in 0.05 seconds ===============================================

contextlibの実装の詳細

contextlib.contextmanagerは内部で_GeneratorContextManagerというクラスが使われているのだけれど。このクラスの __exit__() の定義が期待していたものとは違っていた。以下のような定義(長いし難しいので全部読む必要は無い)。

class _GeneratorContextManager(ContextDecorator):
    def __exit__(self, type, value, traceback):
        if type is None:
            try:
                next(self.gen)
            except StopIteration:
                return
            else:
                raise RuntimeError("generator didn't stop")
        else:
            if value is None:
                # Need to force instantiation so we can reliably
                # tell if we get the same exception back
                value = type()
            try:
                self.gen.throw(type, value, traceback)
                raise RuntimeError("generator didn't stop after throw()")
            except StopIteration as exc:
                # Suppress StopIteration *unless* it's the same exception that
                # was passed to throw().  This prevents a StopIteration
                # raised inside the "with" statement from being suppressed.
                return exc is not value
            except RuntimeError as exc:
                # Likewise, avoid suppressing if a StopIteration exception
                # was passed to throw() and later wrapped into a RuntimeError
                # (see PEP 479).
                if exc.__cause__ is value:
                    return False
                raise
            except:
                # only re-raise if it's *not* the exception that was
                # passed to throw(), because __exit__() must not raise
                # an exception unless __exit__() itself failed.  But throw()
                # has to raise the exception to signal propagation, so this
                # fixes the impedance mismatch between the throw() protocol
                # and the __exit__() protocol.
                #
                if sys.exc_info()[1] is not value:
                    raise

context managerについて

実際のコードの内容に触れる前におさらいが必要かも。

context managerの機能についておさらいしておくと(詳しくはここ)、

  • context managerとして機能するオブジェクトは __enter__()__exit__(exc_type, exc_value, traceback) のmethodを持ったもの
  • 内部で例外が発生しない場合には、__exit__()の引数は全部None
  • 内部で例外が発生した場合に、__exit__() の戻り値がtruthyな値の場合には発生した例外が無視される

実際に読むべき重要なところ

読むのは例外が発生した時なので、最初のifはelse部分、valueも通常は入っているはず。まぁそんなわけで真面目に読むべきはここの部分。

                self.gen.throw(type, value, traceback)
                raise RuntimeError("generator didn't stop after throw()")

ここでself.genは渡された関数の生成するgeneratorが束縛されている(generator関数は呼び出す際に戻り値としてgenerator objectを返す)。 そしてgenerator.throw()は、ほとんど __exit__()と同じ引数をを取り、そのgeneratorの現在の位置情報を利用して例外を送出する。

例えば雑な例をあげてみる。以下の様なコードは以下のような出力を返す。

def gen():
    print("hai")
    yield 1
    print("hoi")


it = gen()
print(next(it))
it.throw(Exception, Exception("hmm"))

nextでyield 1までは実行された後に、throw()を呼ぶ(tracebackは渡していないけれど。たぶん良い感じのネストしたerror reportに使われるかexceptionの発生位置として使われる)。

hai
1
Traceback (most recent call last):
  File "qr_8838MIf.py", line 9, in <module>
    it.throw(Exception, Exception("hmm"))
  File "qr_8838MIf.py", line 3, in gen
    yield 1
Exception: hmm

pytestの方の実装

pytestの方はというとfixtures.pyとrunner.pyのあたりをみれば良いのだけれど。雑にnext()を呼んでいる。この辺(紹介は雑)

fixtures.py

def call_fixture_func(fixturefunc, request, kwargs):
    yieldctx = is_generator(fixturefunc)
    if yieldctx:
        it = fixturefunc(**kwargs)
        res = next(it)

        def teardown():
            try:
                next(it)
            except StopIteration:
                pass
            else:
                fail_fixturefunc(fixturefunc,
                    "yield_fixture function has more than one 'yield'")

        request.addfinalizer(teardown)
    else:
        res = fixturefunc(**kwargs)
    return res

runner.py

    def _callfinalizers(self, colitem):
        finalizers = self._finalizers.pop(colitem, None)
        exc = None
        while finalizers:
            fin = finalizers.pop()
            try:
                fin()
            except Exception:
                # XXX Only first exception will be seen by user,
                #     ideally all should be reported.
                if exc is None:
                    exc = sys.exc_info()
        if exc:
            py.builtin._reraise(*exc)

簡略化した話

詳細に立ち寄りすぎたので簡略化すると例えば以下の様なコードがあったとする(再掲)。

def ob():
    ob = object()
    print("")
    print("<<< before", id(ob))

    yield ob  # ここで例外が発生した場合の話

    print(">>> after", id(ob))

ここで generator.throw() が呼ばれれば当然yieldの位置までしか実行はされずに終わる。一方でnext()が呼ばれれば通常通りiteratorを進めるだけなので以降の処理が実行される(最終的にStopIterationが送出される)。

雑な対応方法

上のことから考えて雑な対応方法を考えると以下の様な感じになる。

  • なるべく元のcontextlibの実装を利用する
  • 最後まで動かし切るために next() は呼ぶ
class SafeContextManager(_GeneratorContextManager):
    def __exit__(self, type, value, traceback):
        try:
            next(self.gen)
        except StopIteration:
            if type is None:
                return
        else:
            if type is None:
                raise RuntimeError("generator didn't stop")
        return super().__exit__(type, value, traceback)

上の対応方法に従った形で自分用のcontxtlib.contextmanagerを作ってあげれば良い(safe?という修飾は良くない気もするけれど)。

  • どんなときでもnext()は呼ぶ。
  • type is None の場合には、元の実装と同じ形で処理をする
  • 最も末尾で元の__exit__()を呼ぶ。

(元の__exit__()が呼ばれたとき、type is Noneになることはありえない=next()が2度呼ばれるということはありえない)

from functools import wraps

def safecontextmanager(func):
    @wraps(func)
    def helper(*args, **kwds):
        return SafeContextManager(func, args, kwds)

    return helper

おしまい。