pythonでclassの階層関係がわかりにくいときに手動で良い感じに把握したいという試み

久しぶりにpythonの初歩的な話?をしたいと思ったので色々書いてみることにします。

すごく雑に言えば、pythonで、あるモジュールが何をやっているのかわからないという所からpydocを使えば良いかもしれないという話をしたり。 そこからpydocをプログラムとして使うこともできたりするよみたいな話をしつつ、出力結果を自分が見たい形にするというようなことができたら便利と言うような話をします。

このモジュールがなにやっているかわからない

例えば、あなたの現在書いているコードでmatplotlib.backends.backend_svgというモジュールが使われているということがわかったととします。このモジュールが提供しているシンボルが何かというのがわからない状況(IDEを使っているならおもむろにモジュール名を入力して補完させてみるというのもありかもしれません)。

このとき、いきなりブラウザを立ち上げてドキュメントを探してみたりするのも良いですが。pydocを使うという方法もあります。個人的には、ネットにつなげるのが面倒。より詳しく言うと検索エンジンから所定のバージョンに対応するドキュメントを見つけて、それを開くのが面倒と感じるような時に使ったりします。

そんなpydocですが。一般的には、pythonがインストールされると、pydocコマンドが使えるようになっているのではないでしょうか?

通常はこれをコマンドとして使って、以下の様にして該当のモジュールのヘルプを見ることができるはずなのですが。見えない場合があります(確かにインストールされているパッケージのはずなのに。。)。

$ pydoc matplotlib.backends.backend_svg
No Python documentation found for 'matplotlib.backends.backend_svg'.
Use help() to get the interactive help utility.
Use help(str) for help on the str class.

実際の所、venvやvirtualenvでインストールされたパッケージの場合には見つからないことがあります。幸い、pydocは、python -m <module>での呼び出しにも対応しているので以下の様にしても良いです。この場合には上手くヘルプを見ることができます(暇な時にどうしてそうなるのかを調べてみるというのも良いかもしれないですね)。

$ python -m pydoc matplotlib.backends.backend_svg
Help on module matplotlib.backends.backend_svg in matplotlib.backends:

NAME
    matplotlib.backends.backend_svg

CLASSES
    builtins.object
        XMLWriter
    matplotlib.backend_bases.FigureCanvasBase(builtins.object)
        FigureCanvasSVG

...

    unicode_literals = _Feature((2, 6, 0, 'alpha', 2), (3, 0, 0, 'alpha', ...
    verbose = <matplotlib.Verbose object>

VERSION
    2.0.2

FILE
    /home/podhmo/my/lib/python3.6/site-packages/matplotlib/backends/backend_svg.py

ちなみに、このpydocはhtmlでのちょっとしたUIも提供してくれていたりします。python -m pydoc -b などとするとブラウザが立ち上がります。ブラウザが良いという人にはそちらが良いかもしれないです。

このクラスがなにやっているかわからない

何かわからないけれど。色々な機能を提供しているモジュールがあり、何か意味がありそうなクラスはありそうなのだけれど全体を把握できない、そんな時にクラスだけの一覧を得たいというときにはpydocから少し離れてみると良いかもしれません。

例えばそのモジュールで定義されているクラスは全部でどんなのがあるのかということが知りたければ、直接モジュールに対応するファイルをgrepしてみても良いかもしれないですし。

$ grep '^class ' `python -c 'import matplotlib.backends.backend_svg as m; print(m.__file__)'`
class XMLWriter(object):
class RendererSVG(RendererBase):
class FigureCanvasSVG(FigureCanvasBase):
class FigureManagerSVG(FigureManagerBase):

もうすこしまじめに考えてインタプリタで調べてみても良いかもしれない。

$ python
Python 3.6.4 (default, Jan  5 2018, 02:35:40)
[GCC 7.2.1 20171224] on linux
Type "help", "copyright", "credits" or "license" for more information.

>>> import matplotlib.backends.backend_svg as m
>>> import inspect
>>> classes = [name for name, val in m.__dict__.items() if inspect.isclass(val) and val.__module__ == m.__name__]
>>> print("\n".join(classes))
XMLWriter
RendererSVG
FigureCanvasSVG
FigureManagerSVG
FigureCanvas
FigureManager

どんなクラスがあるのか絞り込めました。とりあえず、FigureCanvasSVGに焦点を当ててみることにしましょう。

先程pydocを使うことでモジュール単位のヘルプページを見ることができると言いましたが。巨大なモジュールだった場合にモジュール全部のヘルプを見たいわけじゃないんだよと思うことがあると思います。特にコンソールで作業しているときには邪魔になることがあったりしますね。

まぁ何にせよpydocに対象のモジュールのクラスなど一部を与えたら、pydocはそれだけについて表示してくれます。特定の部分についてだけ見たい場合はどんどん"."をつなげていきましょう。

$ python -m pydoc matplotlib.backends.backend_svg.FigureCanvasSVG | head -n 20
Help on class FigureCanvasSVG in matplotlib.backends.backend_svg:

matplotlib.backends.backend_svg.FigureCanvasSVG = class FigureCanvasSVG(matplotlib.backend_bases.FigureCanvasBase)
 |  The canvas the figure renders into.
 |  
 |  Public attributes
 |  
 |      *figure*
 |          A :class:`matplotlib.figure.Figure` instance
 |  
 |  Method resolution order:
 |      FigureCanvasSVG
 |      matplotlib.backend_bases.FigureCanvasBase
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  get_default_filetype(self)
 |      Get the default savefig file format as specified in rcParam
 |      ``savefig.format``. Returned string excludes period. Overridden

いやいや、この内特定のメソッドだけ調べられれば十分なんだよ。ということならメソッド名も付け足してあげれば良いです。どんどん"."をつなげていきましょう(本当はprint_figureが良かったけれど出力が長すぎた)。

$ python -m pydoc matplotlib.backends.backend_svg.FigureCanvasSVG.resize
Help on function resize in matplotlib.backends.backend_svg.FigureCanvasSVG:

matplotlib.backends.backend_svg.FigureCanvasSVG.resize = resize(self, w, h)
    set the canvas size in pixels

メソッドだけを見ることができるます。

継承関係

ところでこのクラスの継承関係はどうなっているのかと言うと、先程の出力の Method resolution order(MRO)の箇所を見てあげれば良いです。

builtins.object
  |
matplotlib.backend_bases.FigureCanvasBase
  |
matplotlib.backends.backend_svg.FigureCanvasSVG

つまるところ

class FigureCanvasBase(object):
    pass

class FigureCanvasSVG(FigureCanvasBase):
    pass

ということですね。

pydocをプログラムから使う

pydocをプログラム上からimportして使うこともできます。ちょっと異なる出力結果が欲しいけれどそれら全部を実装したくないという時に部分的に再利用すると便利だったりします。

実際、pydocの出力を文字列で得るのは簡単で。コンソール上での出力(htmlではない)を得るには以下の様に書くだけです。

import pydoc


ob, name = pydoc.resolve("matplotlib.backends.backend_bases.FigureCanvasSVG")
print(pydoc.plaintext.document(ob))

特定のモジュールやクラスに対するヘルプを手に入ります。resolveは実質値をそのまま渡せば良いので以下で直接importしたものを渡してあげても良いです。実際的にはdocumentというメソッドがマジカルな感じに空気を読んでやってくれます。

import pydoc
from matplotlib.backends.backend_svg import FigureCanvasSVG

print(pydoc.plaintext.document(FigureCanvasSVG.resize))
# 'resize(self, w, h)\n    set the canvas size in pixels\n'

実際、documentは、内部的には以下の様なメソッドになっています。そんなわけで何かを渡したら、ちょうど良い位置でちょうど良い感じのdocumentの文字列が返ってきます。

class Doc:

    def document(self, object, name=None, *args):
        """Generate documentation for an object."""
        args = (object, name) + args
        # 'try' clause is to attempt to handle the possibility that inspect
        # identifies something in a way that pydoc itself has issues handling;
        # think 'super' and how it is a descriptor (which raises the exception
        # by lacking a __name__ attribute) and an instance.
        if inspect.isgetsetdescriptor(object): return self.docdata(*args)
        if inspect.ismemberdescriptor(object): return self.docdata(*args)
        try:
            if inspect.ismodule(object): return self.docmodule(*args)
            if inspect.isclass(object): return self.docclass(*args)
            if inspect.isroutine(object): return self.docroutine(*args)
        except AttributeError:
            pass
        if isinstance(object, property): return self.docproperty(*args)
        return self.docother(*args)

クラスの階層構造を自分の好きな形式で出力してみる

複雑なオブジェクトの場合にはpydocの出力の量が多すぎてわけがわからないという風に感じることがあったりします。そういう時に自分で良い感じに出力してあげてみたりすると良いかもしれません。

例えばクラスの継承関係だけを取りたい場合。これは mro() でどうにかなりますね。

>>> from matplotlib.backends.backend_svg import FigureCanvasSVG
>>> FigureCanvasSVG.mro()
[<class 'matplotlib.backends.backend_svg.FigureCanvasSVG'>,
 <class 'matplotlib.backend_bases.FigureCanvasBase'>,
 <class 'object'>]

例えば、クラスの持つメソッドだけに興味を持つ場合には、メソッドだけをとりだしてみてみると良さそうです。

>>> from matplotlib.backends.backend_svg import FigureCanvasSVG
>>> import inspect
>>> [name for name, attr in FigureCanvasSVG.__dict__.items() if inspect.isroutine(attr)]
['print_svg',
 'print_svgz',
 '_print_svg',
 'get_default_filetype']

もう少しまじめに考えると以下の様な関数を用意して分類すると便利かもしれません。

def get_kind(attr):
    if isinstance(attr, staticmethod):
        return "static_method"
    elif isinstance(attr, classmethod):
        return "class_method"
    elif isinstance(attr, property):
        return "property"
    elif inspect.isroutine(attr):  # xxx
        return "method"
    else:
        return "data"

ちなみに上記のコードは、inspect.classify_class_attrs の一部です。これの改良版に pydoc.classify_class_attrs という関数があったりします。

>>> attrs = [(name, kind) for name, kind, cls, _ in pydoc.classify_class_attrs(FigureCanvasSVG) if cls == FigureCanvasSVG]
## special methodが邪魔かも
>>> [(name, kind) for name, kind in attrs if not (name.startswith("__") and name.endswith("__"))]
[('_print_svg', 'method'),
 ('filetypes', 'data'),
 ('fixed_dpi', 'data'),
 ('get_default_filetype', 'method'),
 ('print_svg', 'method'),
 ('print_svgz', 'method')]
 ('print_svgz', 'method')]

名前だけがわかってもメソッドの引数なども欲しいですね。こういう時に頑張ってコードを書くよりpydocに任せてしまうと便利です。

>>> names = [name for name, kind in attrs if kind == "method"]
>>> print("".join([pydoc.plaintext.document(getattr(FigureCanvasSVG, name)) for name in names]))
_print_svg(self, filename, svgwriter, **kwargs)
get_default_filetype(self)
    Get the default savefig file format as specified in rcParam
    ``savefig.format``. Returned string excludes period. Overridden
    in backends that only support a single file type.
print_svg(self, filename, *args, **kwargs)
print_svgz(self, filename, *args, **kwargs)

ちょっとクラス名つけてインデントしてみたいですね。

>>> text = pydoc.plaintext
>>> content = text.indent("".join([text.document(getattr(FigureCanvasSVG, name)) for name in names]))
>>> print("\n".join([FigureCanvasSVG.__name__, content]))
FigureCanvasSVG
    _print_svg(self, filename, svgwriter, **kwargs)
    get_default_filetype(self)
        Get the default savefig file format as specified in rcParam
        ``savefig.format``. Returned string excludes period. Overridden
        in backends that only support a single file type.
    print_svg(self, filename, *args, **kwargs)
    print_svgz(self, filename, *args, **kwargs)

これで自分のクラスのところでだけ定義されたメソッドの一覧が出ましたが、祖先のクラスについても考えてみたいですよね。というかまさにクラス階層が複雑ということは祖先の定義と自分自身のクラス定義の関係性が見えてこないという話なので。

そろそろまじめに関数にしてみましょう。

def shape_text(this_cls, doc=pydoc.plaintext):
    attrs = [
        (name, kind) for name, kind, cls, _ in pydoc.classify_class_attrs(this_cls)
        if cls == this_cls
    ]
    attrs = [
        (name, kind) for name, kind in attrs if not (name.startswith("__") and name.endswith("__"))
    ]
    method_names = [name for name, kind in attrs if kind == "method"]
    content = doc.indent("".join([doc.document(getattr(this_cls, name)) for name in method_names]))
    mro = " <- ".join([cls.__name__ for cls in this_cls.mro()])
    return "\n".join([mro, content])

ちょっとdocstringが邪魔なので取り除きます。

import re

def filter_by_indent(s, level, rx=re.compile("^\s+")):
    for line in s.split("\n"):
        m = rx.search(line)
        if m is None or len(m.group(0)) <= level:
            yield line

親もみたいので所定のclassのmro()を全部printしてみます。

from matplotlib.backends.backend_svg import FigureCanvasSVG  # noqa


for cls in FigureCanvasSVG.mro():
    if cls == object:
        break
    text = shape_text(cls)
    print("\n".join(filter_by_indent(text, 4)))

戻り値がないのが残念ですが。なんとなく見えてきましたね。

FigureCanvasSVG <- FigureCanvasBase <- object
    _print_svg(self, filename, svgwriter, **kwargs)
    get_default_filetype(self)
    print_svg(self, filename, *args, **kwargs)
    print_svgz(self, filename, *args, **kwargs)

FigureCanvasBase <- object
    _get_output_canvas(self, format)
    _idle_draw_cntx(self)
    blit(self, bbox=None)
    button_press_event(self, x, y, button, dblclick=False, guiEvent=None)
    button_release_event(self, x, y, button, guiEvent=None)
    close_event(self, guiEvent=None)
    draw(self, *args, **kwargs)
    draw_cursor(self, event)
    draw_event(self, renderer)
    draw_idle(self, *args, **kwargs)
    enter_notify_event(self, guiEvent=None, xy=None)
    flush_events(self)
    get_default_filename(self)
    get_width_height(self)
    get_window_title(self)
    grab_mouse(self, ax)
    idle_event(self, guiEvent=None)
    is_saving(self)
    key_press_event(self, key, guiEvent=None)
    key_release_event(self, key, guiEvent=None)
    leave_notify_event(self, guiEvent=None)
    motion_notify_event(self, x, y, guiEvent=None)
    mpl_connect(self, s, func)
    mpl_disconnect(self, cid)
    new_timer(self, *args, **kwargs)
    onHilite(self, ev)
    onRemove(self, ev)
    pick(self, mouseevent)
    pick_event(self, mouseevent, artist, **kwargs)
    print_figure(self, filename, dpi=None, facecolor=None, edgecolor=None, orientation='portrait', format=None, **kwargs)
    release_mouse(self, ax)
    resize(self, w, h)
    resize_event(self)
    scroll_event(self, x, y, step, guiEvent=None)
    set_window_title(self, title)
    start_event_loop(self, timeout)
    start_event_loop_default(self, timeout=0)
    stop_event_loop(self)
    stop_event_loop_default(self)
    switch_backends(self, FigureCanvasClass)

ついでにオーバーライドされているものについても修飾付きで出力させてみたいですね。親クラスが持っているメソッドを再定義するやつです。

--- 03shape.py   2018-06-22 23:31:19.839708975 +0900
+++ 04shape.py    2018-06-22 23:30:58.850685952 +0900
@@ -10,8 +10,18 @@
     attrs = [
         (name, kind) for name, kind in attrs if not (name.startswith("__") and name.endswith("__"))
     ]
+    attrs = [(name, kind) for name, kind in attrs if not name.startswith("_")]
     method_names = [name for name, kind in attrs if kind == "method"]
-    content = doc.indent("".join([doc.document(getattr(this_cls, name)) for name in method_names]))
+    method_annotations = [
+        "@OVERRIDE: " if any(c for c in this_cls.mro()[1:] if hasattr(c, name)) else ""
+        for name in method_names
+    ]
+    method_docs = [
+        prefix + doc.document(getattr(this_cls, name))
+        for prefix, name in zip(method_annotations, method_names)
+    ]
+
+    content = doc.indent("".join(method_docs))
     mro = " <- ".join([cls.__name__ for cls in this_cls.mro()])
     return "\n".join([mro, content])

例えば wsgiref.simple_server.WSGIServer あたりを見てみると面白いかもしれません。

from wsgiref.simple_server import WSGIServer

for cls in WSGIServer.mro():
    if cls == object:
        break
    text = shape_text(cls)
    print("\n".join(filter_by_indent(text, 4)))

こういう感じ。

WSGIServer <- HTTPServer <- TCPServer <- BaseServer <- object
    get_app(self)
    @OVERRIDE: server_bind(self)
    set_app(self, application)
    setup_environ(self)

HTTPServer <- TCPServer <- BaseServer <- object
    @OVERRIDE: server_bind(self)

TCPServer <- BaseServer <- object
    @OVERRIDE: close_request(self, request)
    fileno(self)
    get_request(self)
    @OVERRIDE: server_activate(self)
    server_bind(self)
    @OVERRIDE: server_close(self)
    @OVERRIDE: shutdown_request(self, request)

BaseServer <- object
    close_request(self, request)
    finish_request(self, request, client_address)
    handle_error(self, request, client_address)
    handle_request(self)
    handle_timeout(self)
    process_request(self, request, client_address)
    serve_forever(self, poll_interval=0.5)
    server_activate(self)
    server_close(self)
    service_actions(self)
    shutdown(self)
    shutdown_request(self, request)
    verify_request(self, request, client_address)

最後に

gistです

pythonでCSVを消費する処理を再開可能にしたい

github.com

はじめに

CSVを消費する処理を再開可能にしたいという気持ちになりました。具体的には、1つ1つの処理にそこそこ時間が掛かる(30秒から1分)ものをそこそこ多く(104件くらい)処理しないといけないことがあったのですが。DBとか用意したり使ったりするの面倒だなと思ったときのことです。

CSVを消費したい(再開したい)

例えば以下の様なイメージです(実際の処理とは異なります)。

input.csv

id,x,y
1,10,20
2,100,100

このようなcsvがあって、これらの各行に対して処理を行う(例えば和を求める)必要があるとします。

output.csv

id,v
1,30 // 本当は結構重たい
<-- このあたりで止めたい(再開したい)
2,200

実際、重いと言っても計算的なものではなく主に帯域制限的なものが原因です。なので並列化とかほぼ意味がない状態なのですが。途中で失敗したら辛いという感じの状況です。

随時終わるたびに書き出していき、終わったところまで入力を削るみたいな作業をして手動で入力となるファイルを書き換えていっても良いのですが。だるい。

csvresumable

まぁそんなわけでだるかったので。ちょっとしたライブラリを作ることにしました(まだ開発途中なのでAPIの変更は普通にあると思います)。具体的には以下の様な形で動きます。

  • 通常のCSVのDictReaderと同様に動く
  • どこまで終わったのかを別途記録する(history.csv(実際には別の場所に記録されます))
  • (再開時には、記録していたところまでの入力はスキップする)

状態などを管理するのは面倒だったので、完全にinsertだけで済むようにしました。

例えば上の例で言えば、処理の途中で止めたいということは

id,x,y
1,10,20
<-- ここで止めたい
2,100,100

以下の様な履歴(history.csv)を用意し(csvである必要はない)

id
1

再開(resume)の際はこれとzipしたiterator(概念上)に対して処理を行えば良いということになります。履歴に残ったものはスキップしてしまえば良いということです。

ちなみに、pandasなどのinterfaceを用意しなかった理由は、そもそも計算や集計が目的ではなかったためです。単なる情報をくれるevent streamとしてCSVがあれば良いというだけだったので(つまりCSVである理由も特にない可能性があります。そのあたりも込みでinterfaceが変わりうるという感じです)。

実際の利用例

以下の様な形で書けます。csv.DictReaderのかわりにcsvresumable.DictReader を使います。

import json
import time
import sys
from csvresumable import DictReader

with open("input.csv") as rf:
    r = DictReader(rf)
    for row in r:
        print("start", row["id"], file=sys.stderr)

        # 重たい処理
        time.sleep(2)
        print(json.dumps(row))  # 重たい処理をした結果のつもり
        sys.stdout.flush()

CSVから1行ずつ取得していき、それを入力として何らかの処理を行うという形です。この処理がそこそこ重めの処理(と言っても先程言った通りにほぼほぼ流量制限が原因で高速化できない)になっており、数十秒程度掛かると言ったものだとします。

例えば先程のinput.csv (再掲)に対して実行し、

id,x,y
1,10,20
2,100,100

実行を途中で止めます(id=2の部分は計算が終わらない。あるいはエラー)。

$ python main.py
start 1
{"id": "1", "x": "10", "y": "20"}
start 2
    KeyboardInterrupt

途中で止まったので、途中から再開したいはずです。ここで RESUME=1 という環境変数と一緒に実行すると再開(resume)できます。

$ RESUME=1 python main.py
start 2
{"id": "2", "x": "100", "y": "100"}

# 全部実行し終わった後なら何も出力されない
$ RESUME=1 python main.py

ちなみにRESUMEをつけないとはじめからやり直しです(つまり何も知らない人にとってはただのcsv.DictReaderとして動く)。

$ python main.py
start 1
{"id": "1", "x": "10", "y": "20"}
start 2
{"id": "2", "x": "100", "y": "100"}

環境変数以外の方法

ところで環境変数で設定というのが、設定より規約(CoC)っぽい感じがして嫌という人いると思います。そんな人は真面目にオプションを与えてください。以下の様な形で。

--- 00add/main.py    2018-06-14 17:45:23.000000000 +0900
+++ 01add/main.py 2018-06-14 18:36:18.000000000 +0900
@@ -1,10 +1,15 @@
 import json
 import time
 import sys
+import argparse
 from csvresumable import DictReader
 
+parser = argparse.ArgumentParser()
+parser.add_argument("--resume", action="store_true")
+args = parser.parse_args()
+
 with open("input.csv") as rf:
-    r = DictReader(rf)
+    r = DictReader(rf, resume=args.resume)
     for row in r:
         print("start", row["id"], file=sys.stderr)

--resume を使ってresumeできます

$ python main.py
start 1
{"id": "1", "x": "10", "y": "20"}
start 2
    KeyboardInterrupt

# resume
$ python main.py --resume
start 2
{"id": "2", "x": "100", "y": "100"}

idとして扱う値を変えたい場合

idとして使う値を変えたくなることがあるかもしれません。その場合にはkeyオプションがあります。これはsorted()関数と同様のイメージで考えてもらえれば良いです。渡されるCSVというのはかならずしも常に自分の意図した通りの構造で渡されるということがなかったりしますし。

例えば、以下の様な形かもしれません。

groupId,userId,name,age,cache
1,1,foo,8,1000
1,2,bar,10,200
2,3,boo,2,0
3,4,bar,2,100

デフォルトでは左端をidとして扱いますが、上の例では左端のgroupIdではなくuserIdをidとして消費したくなると思います。このような場合には以下の様に書けば良いです。

import time
import sys
from csvresumable import DictReader

with open("input.csv") as rf:
    r = DictReader(rf, key=lambda row: row["userId"])  # key=を使う
    for row in r:
        print("start", row["userId"], file=sys.stderr)

        # 重たい処理
        time.sleep(2)
        print(row["name"], int(row["age"]) / int(row["cache"]))
        sys.stdout.flush()

テキトウに年齢(勤続年数?)を貯金で割って、1万円?を稼ぐのに何年掛かるのかというような値でも計算してみましょう(これはテキトーな例です)。

python main.py
start 1
foo 0.008
start 2
bar 0.05
start 3
Traceback (most recent call last):
  File "main.py", line 12, in <module>
    print(row["name"], int(row["age"]) / int(row["cache"]))
ZeroDivisionError: division by zero

おや、エラーになってしまいましたね。0除算を気にしてませんでした(まぁこういう感じでたまに考慮漏れのエラーがあったりします)。 テキトウに直したら。

--- 02groupid/main.py    2018-06-15 01:13:54.785744154 +0900
+++ 03groupid/main.py 2018-06-15 01:22:14.473955859 +0900
@@ -9,5 +9,9 @@
 
         # 重たい処理
         time.sleep(2)
-        print(row["name"], int(row["age"]) / int(row["cache"]))
+        if int(row["cache"]) == 0:
+            ans = "-"
+        else:
+            ans = int(row["age"]) / int(row["cache"])
+        print(row["name"], ans)
         sys.stdout.flush()

再開します。

$ RESUME=1 python main.py
start 3
boo -
start 4
bar 0.02

途中から再開できてますね。

複数のCSVを合成した結果を元に消費したい場合

ところで、今までは入力がひとつだけでしたが。複数の入力が必要になることもあると思います。そのような場合にも一応対応はしています。

直列につなぐ場合(concat)

単純に複数に分割されたファイルを入力だとしましょう。

input.csv

groupId,userId,name,age,cache
1,1,foo,8,1000
1,2,bar,10,200

input2.csv

groupId,userId,name,age,cache
2,3,boo,2,0
3,4,bar,2,100

そのような場合はつなげるだけです。

import time
import sys
from csvresumable import DictReader

# 1つではなく2つなのでforループ
for filename in ["input.csv", "input2.csv"]:
    with open(filename) as rf:
        r = DictReader(rf, key=lambda row: row["userId"])
        for row in r:
            print("start", row["userId"], file=sys.stderr)

            # 重たい処理
            time.sleep(2)
            if int(row["cache"]) == 0:
                ans = "-"
            else:
                ans = int(row["age"]) / int(row["cache"])
            print(row["name"], ans)
            sys.stdout.flush()

注意点としてはDictReaderのiteratorをリストなどにして消費しないようにしてください。

$ python main.py
start 1
foo 0.008
start 2
    KeyboardInterrupt
$ RESUME=1 python main.py
start 2
bar 0.05
start 3
boo -
start 4
    KeyboardInterrupt
$ RESUME=1 python main.py
start 4
bar 0.02

ファイルの切れ目など関係なくresumeできています。これは当然と言えば当然なのですが。渡すファイルの順序は常に一定にしてください(あるときは input2.csv input.csv などの順序であるなど順序が不定の場合にはおかしくなります)。

並列につなぐ場合(groupby)

今度は並列につなぐ場合を考えてみます。例えば先程のcsvについてgroupIdでgroupingされた結果に対する何らかの処理をしてみるということにしてみましょう。そのような場合でも考え方は同様です。毎回常に一定の順序でeventが発生するevent streamのようなものが構成されていればそれで十分です(入力がCSVである必要も特にありません)。

このような場合には csvresumable.iterate を使います。

import time
import sys
import csv
import itertools
import csvresumable

# event streamはiteratorであれば良い
def gen(files):
    sources = [csv.DictReader(open(f)) for f in files]
    sorted_sources = sorted(itertools.chain.from_iterable(sources), key=lambda row: row["groupId"])
    return itertools.groupby(sorted_sources, key=lambda row: row["groupId"])


for group_id, rows in csvresumable.iterate(gen(["input.csv", "input2.csv"])):
    print("start group_id", group_id, file=sys.stderr)
    time.sleep(2)
    print("total", sum(int(row["cache"]) for row in rows))

# groupingは以下の様な形
# 1 [{"groupId": "1", "userId": "1", "name": "foo", "age": "8", "cache": "1000"},
#    {"groupId": "1", "userId": "2", "name": "bar", "age": "10", "cache": "200"}
#   ]
# 2 [{"groupId": "2", "userId": "3", "name": "boo", "age": "2", "cache": "0"}]
# 3 [{"groupId": "3", "userId": "4", "name": "bar", "age": "2", "cache": "100"}]

例を見てわかる通り、実は入力がCSVである必要はありません。一定の順序を保った何らかのeventのstreamであれば大丈夫です(pythonで言えばiterator)。 defaultではitertools.groupbyなどと同様にiterateされた行をリストと捉えての最初の要素をidとして扱いますが(key=lambda xs: xs[0])、もちろんkeyオプションがとれます。

chainしてsortしてとやっているので、原理的には与えられたファイルを全部見ているわけですが。そもそも冒頭で触れたように元となる入力の数自体はせいぜい104程度しかありません。なのでそこまでコストというわけでもないです。

途中で止めてRESUMEで再開できます。

$ python main.py
start group_id 1
total 1200
    KeyboardInterrupt

$ RESUME=1 python main.py
start group_id 2
total 0
start group_id 3
total 100

おまけ

ちなみにchainしてsortしてgroupbyというのは結構よくやる処理なのですが。毎回書くのもめんどくさいのでconcat_groupbyという関数を用意しています。

def gen(files):
    source = [csv.DictReader(open(f)) for f in files]
    return csvresumable.concat_groupby(source, key=lambda row: row["groupId"])

ところで先頭N件だけ取りたいという場合にはitertools.isliceが使えます。

def gen(files, *, size=None):
    sources = [csv.DictReader(open(f)) for f in files]
    itr = csvresumable.concat_groupby(sources, key=lambda row: row["groupId"])
    if size is not None:
        itr = itertools.islice(itr, size)
    return itr

対象となるevent streamは消費しないように気をつけてください(消費というのはlist(gen(files))のようなことを指してます)。

再開時に過去の出力を覚えておきたい場合

さて、いままでは処理の中断・再開を扱ってきましたが。出力を全て通して行いつつ、実際の処理自体は中断・再開したいということがあります。例えば、先程のスクリプトが以下のようなMakefileに書かれていたタスクだったとします。

default:
  python main.py | tee output.csv

ここで、処理を再開したときには、過去分も含めた全ての実行結果が渡されて欲しいはずです(もちろん、呼び出すスクリプト側でファイル入出力を行い、追記でやるという案もあります)。

このようなときにちょっと一手間を加えると良い感じにできます。

import time
import sys
import csv
import csvresumable


def gen(files):
    source = itertools.chain.from_iterable([csv.DictReader(open(f)) for f in files])
    return csvresumable.concat_groupby(source, key=lambda row: row["groupId"])

# captueで包んだ
with csvresumable.capture():
    for group_id, rows in csvresumable.iterate(gen(["input.csv", "input2.csv"])):
        print("start group_id", group_id, file=sys.stderr)
        time.sleep(2)
        print("total", sum(int(row["cache"]) for row in rows))
        sys.stdout.flush()  # 呼び出し方によってはbufferingされてしまう場合もある

captureという名前が良いかはまだ微妙ですが。このコンテキストマネージャでくるんであげるとその間の出力を記録しておけます(ちなみに引数で記録したいstreamは変更できます。デフォルトが標準出力)。 再開時には記録していた出力を再度出力してくれるため、再開(resume)時にも過去も含めた全てを出力をしてれるようになります。

$ make
python main.py | tee output.txt
start group_id 1
total 1200
start group_id 2
    KeyboardInterrupt
make: *** [Makefile:2: default] Error 130

$ RESUME=1 make
python main.py | tee output.txt
total 1200
start group_id 2
total 0
start group_id 3
total 100

そんなわけでteeを使っていても、再開後のファイル中に全ての結果が残ります。

$ cat output.txt
total 1200
total 0
total 100

最後に

裏側の話はまた今度。