signal handleするコードのテスト

はじめに

signalをhandleするコード自体は手軽に書ける。

import signal
import sys


def on_sigint(signum, frame):
    print("hmm")
    sys.exit(1)

signal.signal(signal.SIGINT, on_sigint)

しかしこれが確実にtrapされたことを確認するテストを書くのはだるい

面倒くさい理由

面倒くさい理由はいくつかあって、まず、signalをtrapするというのはプログラム全体に影響を及ぼす。 そしてmain threadでしか受け取れないので気軽にthreadingでごまかすということも出来ない。

試行錯誤した結果

しょうがないのでmultiprocessingで頑張る。

import sys
import unittest
import signal


class Ob(object):
    def __str__(self):
        return hex(id(self))


def do_something(calculate, _on_trap=None):
    ob = Ob()

    def on_trap(signum, frame):
        if _on_trap is not None:  # for test
            _on_trap(ob)
        print("cleanup with ", ob)
        sys.exit(1)

    signal.signal(signal.SIGHUP, on_trap)
    signal.signal(signal.SIGINT, on_trap)
    signal.signal(signal.SIGTERM, on_trap)

    # fetch anything?

    calculate(ob)  # do something

    # save db?


class Tests(unittest.TestCase):
    def test_it(self):
        from multiprocessing import Process, Queue
        import time

        q = Queue()
        init = 1
        called = 10
        q.put(init)

        def calculate(ob):
            print("before calculate", ob)
            time.sleep(1)  # waiting for killed
            print("after calculate", ob)

        def _on_trap(ob):
            self.assertEqual(q.get(), init)
            q.put(called)

        p = Process(target=lambda: do_something(calculate, _on_trap=_on_trap))
        p.start()
        time.sleep(0.1)
        p.terminate()  # SIGTERM
        p.join()
        self.assertEqual(q.get(), called)

if __name__ == "__main__":
    unittest.main()
    # before calculate 0x1030b66d8
    # cleanup with  0x1030b66d8

afterは呼ばれていないので途中で中断されている(process.terminate()によるSIGTERM)。 そしてqueueの値はcalledになっている。

before calculate 0x103f6f7b8
cleanup with  0x103f6f7b8
.
----------------------------------------------------------------------
Ran 1 test in 0.182s

OK