読者です 読者をやめる 読者になる 読者になる

pyramid-swagger-routerと一緒にtoyboxのswaggerのvalidationを使ってみる

相も変わらずtoyboxというリポジトリで作業していた。昨日試しに作ってみたswaggerのinput,outputのvalidationを行うものを昔作っていた pyramid-swagger-router と一緒に使ってみるようにしてみた。ココまで来るとそろそろswaggerに対するnormalizer + iteratorみたいなものが欲しくなってくる。

動くサンプルは ここ

使い方

だいたい以下のようなMakefileを使うという感じで察して欲しい(yapfのコメントアウトはASTのmergeに使うredbaronというpackageのバグを踏んでしまっていたので。yapfはgofmtのpython版みたいなもの)。

gen: spec router schema

# fmt:
#  yapf -r -i --style='{based_on_style: chromium, indent_width: 4}' app

spec:
  cp ../swagger/swagger.yml .
  gsed -i 's/operationId: /operationId: views./g' swagger.yml

router:
  pyramid-swagger-router --driver=./driver.py:Driver --logging=DEBUG ./swagger.yml app

schema:
  swagger-marshmallow-codegen --full --logging=DEBUG ./swagger.yml > app/schema.py

run:
  PYTHONPATH=. python app/__init__.py

このようにするとだいたい以下のようなコードが生成される。 使うswagger specは 昨日のものと同じもの

routes.py

def includeme_swagger_router(config):
    config.add_route('views', '/')
    config.add_route('views1', '/add')
    config.add_route('views2', '/dateadd')
    config.scan('.views')


def includeme(config):
    config.include(includeme_swagger_router)

views.py

from pyramid.view import (
    view_config
)
from . import (
    schema
)
from toybox.swagger import (
    withswagger
)


@view_config(decorator=withswagger(schema.Input, schema.Output), renderer='vjson', request_method='GET', route_name='views')
def hello(context, request):
    """

    request.GET:

        * 'name'  -  `{"type": "string", "example": "Ada", "default": "Friend"}`
    """
    return {}


@view_config(decorator=withswagger(schema.AddInput, schema.AddOutput), renderer='vjson', request_method='POST', route_name='views1')
def add(context, request):
    """


    request.json_body:

    ```
        {
          "type": "object",
          "properties": {
            "x": {
              "type": "integer"
            },
            "y": {
              "type": "integer"
            }
          },
          "required": [
            "x",
            "y"
          ]
        }
    ```
    """
    return {}


@view_config(decorator=withswagger(schema.DateaddInput, schema.DateaddOutput), renderer='vjson', request_method='POST', route_name='views2')
def dateadd(context, request):
    """


    request.json_body:

    ```
        {
          "type": "object",
          "properties": {
            "value": {
              "type": "string",
              "format": "date"
            },
            "addend": {
              "minimum": 1,
              "type": "integer"
            },
            "unit": {
              "type": "string",
              "default": "days",
              "enum": [
                "days",
                "minutes"
              ]
            }
          },
          "required": [
            "addend"
          ]
        }
    ```
    """
    return {}

init.pyは自分で書かなければダメ。

# -*- coding:utf-8 -*-
from toybox.simpleapi import run


def includeme(config):
    config.include("toybox.swagger")
    config.include("app.routes")
    config.scan("app.views")


if __name__ == "__main__":
    run.include(includeme)
    run(port=5001)

以下の事が自動で行われる

  • routingの定義
  • viewの内部のformat validation

もちろんroutingができるだけなので内部のviewは自分で実装してあげないとダメ。

diff --git a/examples/swagger2/app/views.py b/examples/swagger2/app/views.py
index a531b8c..85b5a96 100644
--- a/examples/swagger2/app/views.py
+++ b/examples/swagger2/app/views.py
@@ -1,3 +1,4 @@
+from datetime import datetime, timedelta
 from pyramid.view import (
     view_config
 )
@@ -17,7 +18,7 @@ def hello(context, request):
 
         * 'name'  -  `{"type": "string", "example": "Ada", "default": "Friend"}`
     """
-    return {}
+    return {'message': 'Welcome, {}!'.format(request.GET["name"])}
 
 
 @view_config(decorator=withswagger(schema.AddInput, schema.AddOutput), renderer='vjson', request_method='POST', route_name='views1')
@@ -45,7 +46,9 @@ def add(context, request):
         }
     ```
     """
-    return {}
+    x = request.json["x"]
+    y = request.json["y"]
+    return {"result": x + y}
 
 
 @view_config(decorator=withswagger(schema.DateaddInput, schema.DateaddOutput), renderer='vjson', request_method='POST', route_name='views2')
@@ -82,4 +85,13 @@ def dateadd(context, request):
         }
     ```
     """
-    return {}
\ No newline at end of file
+    value = request.json["value"]
+    addend = request.json["addend"]
+    unit = request.json["unit"]
+    value = value or datetime.utcnow()
+    if unit == 'minutes':
+        delta = timedelta(minutes=addend)
+    else:
+        delta = timedelta(days=addend)
+    result = value + delta
+    return {'result': result}

pyramid-swagger-routerについて

昔記事書いていた

http://pod.hatenablog.com/entry/2017/01/03/202551

生成したmarshmallowのschemaをwrapしてpyramidから使えるようにしてみた

swaggerからmarshmallowのschemaを生成する機能は昔から作っていて、デフォルトでは definitions 部分だけしか見ないのだけれど。--full というオプションをつけると paths 以下の parametersresponses も見るようになっている。ここで生成したschemaをpyramidから使えるようにした。まだpypiにはあげていない。

だいたい以下の様な手順で使う。

$ pip install swagger-marshmallow-codegen
$ swagger-marshmallow-codegen swagger.yml --full > schema.py

あとは以下の様な感じでコードを書けばOK。以下が重要。

  • config.include("toybox.swagger")
  • withswaggerで生成されたschemaを渡したdecoratorをつける
from datetime import datetime, timedelta
from toybox.simpleapi import simple_view, run
from toybox.swagger import withswagger
import schema  # ./schema.py


@simple_view("/", decorator=withswagger(input=schema.Input, output=schema.Output))
def hello(request):
    return {'message': 'Welcome, {}!'.format(request.GET["name"])}


@simple_view("/add", request_method="POST", decorator=withswagger(input=schema.AddInput, output=schema.AddOutput))
def add(request):
    x = request.json["x"]
    y = request.json["y"]
    return {"result": x + y}


@simple_view("/dateadd", request_method="POST", decorator=withswagger(input=schema.DateaddInput, output=schema.DateaddOutput))
def dateadd(request):
    value = request.json["value"]
    addend = request.json["addend"]
    unit = request.json["unit"]
    value = value or datetime.utcnow()
    if unit == 'minutes':
        delta = timedelta(minutes=addend)
    else:
        delta = timedelta(days=addend)
    result = value + delta
    return {'result': result}


if __name__ == "__main__":
    import logging
    logging.basicConfig(level=logging.DEBUG)
    run.include("toybox.swagger")
    run(port=5001)

viewの内部では全てswaggerのspecを満たしたrequsestであることが保証されている。ダメだった場合にはHTTPBadRequestが返る。

例えば以下の様な感じ

$ http POST :5001/add x=10 y=20
HTTP/1.0 200 OK
Content-Length: 14
Content-Type: application/json
Date: Sat, 18 Mar 2017 18:24:52 GMT
Server: WSGIServer/0.2 CPython/3.5.2

{
    "result": 30
}

$ http POST :5001/add x=10
HTTP/1.0 400 Bad Request
Content-Length: 107
Content-Type: application/json
Date: Sat, 18 Mar 2017 18:24:55 GMT
Server: WSGIServer/0.2 CPython/3.5.2

{
    "code": "400 Bad Request",
    "message": {
        "y": [
            "Missing data for required field."
        ]
    },
    "title": "Bad Request"
}

$ http POST :5001/add
HTTP/1.0 400 Bad Request
Content-Length: 214
Content-Type: application/json
Date: Sat, 18 Mar 2017 18:24:57 GMT
Server: WSGIServer/0.2 CPython/3.5.2

{
    "code": "400 Bad Request",
    "message": "The server could not comply with the request since it is either malformed or otherwise incorrect.\n\n\nExpecting value: line 1 column 1 (char 0)\n\n",
    "title": "Bad Request"
}

viewへのinputとして渡せるのはswaggerと同様に以下。

  • path – /foo/{name} みたいなやつ
  • query – /?foo=bar みたいなやつ
  • header – request header
  • body – REST APIなどでjsonをpostしたときのやつ
  • form – 普通のPOST

example全体は以下のリンク先。

https://github.com/podhmo/toybox/tree/master/examples/swagger

どうでも良いこと

ところでmarshmallowのvalidationがload(deserialization)時にだけ掛かるということに気づき衝撃を受けた。しょうがないのでserialize後deserializeするみたいなコードになっている。

1ファイルでapi serverを作る用の環境を整えていた

個人用のメモです。

はじめに

手元で色々弄る用に1ファイルでweb serverを作る用の環境を整えていた。1ファイルが良い理由はいろいろな試行錯誤をするための実験をしたいからです。

pythonで使うwebフレームワークとしてはpyramidが好きなのですが、ところどころ1ファイルだけでアプリを作るにはあんまりうれしくない感じの状態で少しだけ調整が必要になったりします(恣意的な評価)。あと、用途としてはAPIサーバーを作る事が多いのですが、どちらかと言うとpyramidのデフォルトはサーバー側でHTMLを出力するアプリ向けの構成になっています。

そして色々な用途用のscaffoldが用意されてはいるのですが。そのまま使うということもなかったり。一方で、結局フルのpyramidの機能を使う分には1ファイルでのあぷりでは限界があります。もちろん真面目に開発するときには1ファイルで作るということなどまずないので、不要といえば不要なのですが。詳しい話をすると色々な機能のあれこれが昔存在したPasteDeployというパッケージの機能から作られる事が前提となっており、それ用の設定ファイル(.iniファイル)が必要とする形になっています。

そんな感じで自分用のsnipetを集めるよりは、ちょっとしたフレームワーク地味た何かにまとめておこうと思い始めたのでした。

つくりたいもの

つくりたいものはざっくりいうと以下のような感じです。

  • 1ファイルで作られたアプリケーションに注力
  • 主にjson responseを返すAPIサーバー用の機能をデフォルトにする
  • pyramidの機能は潰さずに使えるようにする(後で真面目に作る時には移行が手軽にできるようにする)

作っている最中のものは toybox というリポジトリにおいてあります。

hello world

hello worldはbottleやflask並みに短いです。つまり色々な部分を覆い隠したショートカットが存在するということです。

from toybox.simpleapi import simple_view, run


@simple_view("/hello/{name}")
def hello(request):
    return {"message": "hello {}".format(request.matchdict["name"])}

if __name__ == "__main__":
    run(port=8080)

大変短い。

サーバーの実行はそのままpythonで実行するだけです。defaultではwsgirefのサーバーが立ち上がります(python3ならそれなりに早い(本番で使えるとは言っていない))。

$ python app.py
scanning __main__
running host='0.0.0.0', port=8080

以下の様な結果を返します。

$ http GET :8080/hello/world
{
    "message": "hello world"
}
$ http GET :8080/404
{
    "code": "404 Not Found",
    "message": "The resource could not be found.\n\n\ndebug_notfound of url http://localhost:8080/; path_info: '/', context: <pyramid.traversal.DefaultRootFactory object at 0x109e3b8d0>, view_name: '', subpath: (), traversed: (), root: <pyramid.traversal.DefaultRootFactory object at 0x109e3b8d0>, vroot: <pyramid.traversal.DefaultRootFactory object at 0x109e3b8d0>, vroot_path: ()\n\n",
    "title": "Not Found"
}

デフォルトの404エラーが text/html ではなく application/json なのが嬉しいところです。

これは元々、pyramidでもdefaultの設定で、Accept: application/json のヘッダーがついているrequestに関しては、 application/json のresponseを返す様になっていました。これを常に有効にしています。

もう1つ、500のInternel Server Errorの時にもJSONでかえってきます。例えば以下の様なview callableをテキトウに書いてあげると。

@simple_view("/500")
def error(request):
    return 10 / 0

/500 GETAPIを定義しました。これはruntime errorになるはずのviewです。実際にrequestしてみると以下の様なresponseが返ってきます。

$ http GET :8080/500
{
    "code": "500 Internal Server Error",
    "message": "division by zero",
    "title": "Internal Server Error",
    "traceback": [
        "Traceback (most recent call last):",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/tweens.py\", line 22, in excview_tween",
        "    response = handler(request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/router.py\", line 155, in handle_request",
        "    view_name",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/view.py\", line 612, in _call_view",
        "    response = view_callable(context, request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/viewderivers.py\", line 351, in authdebug_view",
        "    return view(context, request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/viewderivers.py\", line 438, in rendered_view",
        "    result = view(context, request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/viewderivers.py\", line 147, in _requestonly_view",
        "    response = view(request)",
        "  File \"app.py\", line 11, in error",
        "    return 10 / 0",
        "ZeroDivisionError: division by zero"
    ]
}

tracebackなどは debug=True のときだけ返すようにしたいですが。今のところ本番環境で使うようなことは考えていないので常にでてきます。

simple_view

pyramidにはsimple_viewというデコレータが存在しません。その代わりに、config.add_route(), config.add_view() というディレクティブと view_config() というデコレータがあります。すごく雑に言えば、simple_view というのはrouteとviewの定義を同時にやってしまっているものです。

さて、pyramidの開発で結構ハマってしまうのはviewの定義が上手く行っているかどうかなのですが。それを確認するために proutes というコマンドが使われます。 例えば以下の様な形で使います。

$ proutes development.ini

同様に、pcreate, pserve, pshell, proutes, pviews, ptweens, prequest, pdistrepor などがありますがたいていのコマンドはPasteDeploy由来の設定ファイルが必要になります。

routeの定義だけでも確認したいということで以下の様な形で確認できるようにしました。ただし実際のpyramidの proutes と全く同じものではありません。あくまで簡易版です。

if __name__ == "__main__":
    # run(port=8080)
    run.proutes()

今現在定義されているrouteは以下の様な感じです。

$ python app.py 
scanning __main__
500 /500 __main__.error *
helloname /hello/{name} __main__.hello *

pyramidの機能の利用

ところで現状のコードの場合には、datetimeをresponseとして出力しようとするとエラーになります。

@simple_view("/now")
def now(request):
    from datetime import datetime
    return {"now": datetime.now()}

エラーになります。

$ http GET :8080/now

{
    "code": "500 Internal Server Error",
    "message": "datetime.datetime(2017, 2, 19, 21, 59, 59, 309713) is not JSON serializable",
    "title": "Internal Server Error",
    "traceback": [
        "Traceback (most recent call last):",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/tweens.py\", line 22, in excview_tween",
        "    response = handler(request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/router.py\", line 155, in handle_request",
        "    view_name",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/view.py\", line 612, in _call_view",
        "    response = view_callable(context, request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/viewderivers.py\", line 351, in authdebug_view",
        "    return view(context, request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/viewderivers.py\", line 461, in rendered_view",
        "    request, result, view_inst, context)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/renderers.py\", line 432, in render_view",
        "    return self.render_to_response(response, system, request=request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/renderers.py\", line 455, in render_to_response",
        "    result = self.render(value, system_values, request=request)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/renderers.py\", line 451, in render",
        "    result = renderer(value, system_values)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/renderers.py\", line 275, in _render",
        "    return self.serializer(value, default=default, **self.kw)",
        "  File \"/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/json/__init__.py\", line 237, in dumps",
        "    **kw).encode(obj)",
        "  File \"/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/json/encoder.py\", line 198, in encode",
        "    chunks = self.iterencode(o, _one_shot=True)",
        "  File \"/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/json/encoder.py\", line 256, in iterencode",
        "    return _iterencode(o, 0)",
        "  File \"/me/venvs/my3/lib/python3.5/site-packages/pyramid/renderers.py\", line 288, in default",
        "    raise TypeError('%r is not JSON serializable' % (obj,))",
        "TypeError: datetime.datetime(2017, 2, 19, 21, 59, 59, 309713) is not JSON serializable"
    ]
}

これはpyramidでもそうで。jsonのrendererの設定を追加してあげる必要があります。

def support_datetime_response(config):
    from pyramid.renderers import JSON
    from datetime import datetime

    # override: json renderer
    json_renderer = JSON()

    def datetime_adapter(obj, request):
        return obj.isoformat()
    json_renderer.add_adapter(datetime, datetime_adapter)
    config.add_renderer('json', json_renderer)


if __name__ == "__main__":
    run.add_modify(support_datetime_response)
    run(port=8080)

今度は大丈夫。

$bash http GET :8080/
{
    "now": "2017-02-19T22:04:15.534517"
}

add_modifyという名前はやめて、pyramidと同様にinclude()という名前にして文字列も受け取れるようにするかどうかは考え中。

1ファイルのアプリでview_configを使う。

誰も特をしないpyramidの話。

pyramidのconfiguration

pyramidにはすごく雑にいうと以下の2つのconfigurationの方法がある。

  • declarative configuration
  • imperative configuration

すごく雑に言えば、declarative configurationはデコレーターを使った設定(内部でvenusianが使われる)。imperative configurationはConfiguratorオブジェクトのdirectiveを直接使って設定する方。viewの設定に限って言えば、前者が view_config の利用。後者が add_view の利用。

通常1ファイルではadd_viewを使う

ドキュメントにもある通り通常は、1ファイルのアプリを作るときにはadd_viewを使うことが多い。

from wsgiref.simple_server import make_server
from pyramid.config import Configurator
from pyramid.response import Response


def hello_world(request):
    return Response('Hello %(name)s!' % request.matchdict)

if __name__ == '__main__':
    config = Configurator()
    config.add_route('hello', '/hello/{name}')
    config.add_view(hello_world, route_name='hello')
    app = config.make_wsgi_app()
    server = make_server('0.0.0.0', 8080, app)
    server.serve_forever()

1ファイルでも view_config を使いたい

view_configを使いたい。やっぱりpathの設定とview関数が遠い状態は面倒くさい。どうにかできないかと試行錯誤した結果出来るようになった。以下の様にすると1ファイルでview_configが使える

from wsgiref.simple_server import make_server
from pyramid.view import view_config
from pyramid.config import Configurator
from pyramid.response import Response


@view_config(route_name="hello")
def hello_world(request):
    return Response('Hello %(name)s!' % request.matchdict)


if __name__ == '__main__':
    config = Configurator()
    config.add_route('hello', '/hello/{name}')
    config.scan(__name__)
    app = config.make_wsgi_app()
    server = make_server('0.0.0.0', 8080, app)
    server.serve_forever()

詳しく説明すると、それこそpyramidの内部のことまでふれなければいけないので省略するけれど。declarative configurationはconfig.scan()の部分により、Configuratorのactionの実行に代わる。venusianというライブラリが使われていて、moduleの開始地点から登録されたclosure(たしか。ココは記憶で書いている)を実行する。そしてこの登録時の名前空間には呼び出されたタイミングでのモジュール名が使われる。

ここでpythonの話になる。例えば上のファイルの名前が app.py だった時に、 python app.py として実行された場合には、つまりエントリーポイントのモジュールの名前は"__main__" として実行される。モジュール名自体は __name__ に格納されている。そんなわけで、__name__ をconfig.scanに渡している。

おまけ

ちょっと試して終了みたいなアプリの時に、一度だけrequestを受け取れば十分みたいなときがある。そのような場合にはserve_foreverを使うよりhandle_requestを使ったほうが楽。

server = make_server('0.0.0.0', 8080, app)
server.handle_request()

すると、1回だけrequestを捌いたら終了してくれる

おまけ2

こういう1ファイルの時にrouteとpathの指定が面倒と言うことがあるその場合には以下の様なdirectiveないしはdecoratorを作ってあげると良いかもしれない。

import venusian
from pyramid.config import PHASE1_CONFIG


def add_simple_view(config, view, path, *args, **kwargs):
    def callback():
        route_name = view.__qualname__
        config.add_route(route_name, path)
        config.add_view(view, route_name=route_name, *args, **kwargs)
    discriminator = ('add_simple_view', path)
    config.action(discriminator, callback, order=PHASE1_CONFIG)


# venusian対応
class simple_view(object):
    def __init__(self, path, *args, **kwargs):
        self.path = path
        self.args = args
        self.kwargs = kwargs

    def register(self, scanner, name, wrapped):
        scanner.config.add_simple_view(wrapped, self.path, *self.args, **self.kwargs)

    def __call__(self, wrapped):
        venusian.attach(wrapped, self.register)
        return wrapped

def includeme(config):
    config.add_directive("add_simple_view", add_simple_view)

これを使うといよいよflaskっぽくなる。

from wsgiref.simple_server import make_server
from pyramid.config import Configurator
from pyramid.response import Response
from my import simple_view


@simple_view("/hello/{name}")
def hello_world(request):
    return Response('Hello %(name)s!' % request.matchdict)


if __name__ == '__main__':
    config = Configurator()
    config.include("my")  # ここで上のadd_simple_viewなどが使えるようになる
    config.scan(__name__)
    app = config.make_wsgi_app()
    server = make_server('0.0.0.0', 8080, app)
    server.serve_forever()

参考

最近pythonでcliのコマンドを作る時にやっていること

最近cliのコマンドを作る時にやっていることをまとめてみる。ここでのコマンドは特にパッケージとして提供されるシェルなどから実行されるコマンドのことを指している。

何が問題?

特にパッケージの提供者とパッケージのユーザーの望みが全く乖離せず一致している場合は問題がない。ユーザーが必要としている機能をパッケージの作者が提供すれば良い。 問題はところどころカスタマイズしたくなるような場合。このようなケースは自分がパッケージの作者でありユーザーである時によく発生するので面白い。パッケージの機能としては含めたくないものの現在のプロジェクトの範疇では必要となる、ただし新たなサブパッケージの様な何かを作る程汎用性があるとは思えないというような場合など。このような時にどうすれば良いのかということについてある程度回答ができるようになったのでまとめてみる。

おさらい

上の問題についてとりあえずpythonでの話しに限定して書いてみることにする。その前にpythonについてのおさらいの様な説明を書く。例えば、簡単なhelloというコマンドを作ってみる。

$ hello
hello world
$ hello --target someone
hello someone

実行したら hello world というメッセージを出力して終了する(実際に作成するコードではまともな何らかの処理になるイメージ)。--targetオプションで指定した文字列をworldの代わりに表示する。

def run(target):
    print("hello {target}".format(target=target))


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--target", default="world")
    args = parser.parse_args()
    run(args.target)


if __name__ == "__main__":
    main()

pythonの場合は以下のようなsetup.pyを書いてあげるとパッケージとしてインストールできるようになる。

from setuptools import setup

setup(name='hello',
      version='0.0',
      description='hello',
      packages=['.'],
      entry_points="""
      [console_scripts]
hello=hello:main
"""
)

現在は以下のような状況。pip install -e . などでインストールしてみる。

$ tree
├── hello.py
└── setup.py
$ pip install -e .

パッケージとしてインストールされていれば。他のパッケージからimportすることもできるし-mオプション経由でpythonコマンドから呼び出す事もできる。

$ python
Python 3.5.2 (default, Sep 19 2016, 02:49:52) 
[GCC 4.2.1 Compatible Apple LLVM 7.3.0 (clang-703.0.31)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import hello
>>> hello.run("world")
hello world
>>>
$ python -m hello
hello world
$ python -m hello --target someone
hello someone

また、上のsetup.pyでは console_scripts の設定も書いたので、helloで呼び出す事もできる。

$ hello
hello world

ここまでがおさらい。パッケージ(ここではhello パッケージ)の提供者が何らかの機能(ここでは hello worldと表示するだけ)を提供しているという状態になった。

本題

ここからが本題。さてこのhelloパッケージを便利に使っているとする。ところでちょっとした機能の変更を加えたいとする。それはほんの1行に過ぎない変更かもしれない。あるいはその変更が良いものとして恒久的に残りうるものとも限らない。そんなある意味独善的だったり個人的な変更を少しだけ加えたい。このような場合にどうするかという話。

ユーザーが自分で独自のコマンドを作っている場合

ユーザーが自分の手で元のパッケージのコードをライブラリレベルで使っていてそれをラップしたようなコマンドを作っている場合は特に何も気にしなくて良い。通常コードを書くときと同様に対応すれば良い。ここではあまり問題にならない。

ユーザーがコマンドを単に利用者として使っている場合

ユーザーが提供されているコマンドをそのまま使っている場合。こちらの場合に問題が起きる。例えばhelloの代わりにgoobyeを表示するように変えたいとする。このようなときには、わざわざラッパー用のコマンドを作るだったりパッケージを作り直さ無くてはいけない(ここが面倒くさい)。つまりユーザーが自分で独自のコマンドを作らないといけない。

最近やっていること

そんなわけでちょっとした変更を加えたい時にちょっとした変更が加えられるコマンドをどのように作るべきかみたいなことを色々考えた結果、以下の様な形にするというのが良いという結論になった。

$ hello
hello world
$ hello --driver=./my.py:MyDriver
goodbye world

やっていることは単純で --driver というオプションを渡せるようにするということ。--driverに渡す文字列は利用したいdriverのパス。 インストールされているパッケージを利用するなら以下の様に渡す。

--driver foo.bar.boo:OurDriver

とは言え、このままであれば別途パッケージを作ってインストールしたり環境変数のPYTHONPATHにわざわざ入れてあげたりしなければ使えないので不便。ということで物理的なファイルのパスも受け取れるようにする。

--driver ./my.py:MyDriver

コードは以下の様になる。

import magicalimport


class Driver:
    def run(self, target):
        print("hello {target}".format(target=target))


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--target", default="world")
    parser.add_argument("--driver", default="{}:Driver".format(__name__))
    args = parser.parse_args()

    driver_class = magicalimport.import_symbol(args.driver, sep=":")
    driver = driver_class()
    driver.run(args.target)


if __name__ == "__main__":
    main()

magicalimportは個人的に作ったライブラリで物理的なファイルパスを指定してのimportをサポートするもの。そんなに大きなライブラリというわけでもないので依存したくなければ中のコードを除いて自分で似たような機能のものを作っても良い。

このようにdriverというオプションで渡した文字列からコマンドの実行用のインスタンスを作成するという仕組みにしておくと後で捗ると言うことが分かった。 例えば以下の様にして挙動を変えられる。

$ hello
hello world
$ hello --target=world --driver=./my.py:MyDriver
goodbye world

このときmy.pyは以下のようなもの。

class MyDriver:
    def run(self, target):
        print("goodbye {target}".format(target=target))

これでちょっとした思いつきで拡張したコード片をテキトウに置いておき、それを --driverオプションに渡すというような形でちょっとした挙動の変更ができるようになる。これがちょっとした試行錯誤に都合が良いと最近は思っている。

応用例

例えば、最近作った swagger-marshmallow-codegenなどでもこの方法は使われている。これはswaggerの定義ファイル(APIの仕様をjsonschemaに似た形式で書いたファイル)からmarshmallow(schemaライブラリ)のコードを生成するコマンドを提供している。

そして最近のおしごとではmongodbを使っているので、bson.ObjectIdをサポートしたschemaが生成したいという要求があった。ところが個人的な信条として独自にmongodb用のコードをここには含めたくない。一方でmongodbに対応した別のパッケージ(リポジトリ)を作る気も起きなかった。このような時に先程の様にdriverをオプションとして渡してあげられるようになったので便利になった。パッケージ自体の作成者も自分ではあるけれど。先にdriver経由のところで実装してみて良さそうと思ったら元のパッケージに反映させるみたいなことをやったりしている。

$ swagger-marshmallow-codegen --driver=./me:CustomDriver swagger.yaml > app/schema.py

swagger-marshmallow-codegen でpaths以下も見るようにした

swagger-marshmallow-codegen でpaths以下も見るようにした。あまりきれいとはいえない感じかもしれないけれど。

paths以下を見るということ

今まではdefinitions以下しか見なかったのだけれど。通常swaggerでapiの定義をするときにはpaths以下にも色々書く。というよりrequestとresponseがどのようになっているかはpaths以下の定義を見る事が多い。例えば以下のようなAPI定義(apiaryのtuotialからもらってきた)はどうなっているかというと。

paths:
  /message/{name}:
    x-summary: Message operations
    x-description: Operation description in Markdown
    get:
      summary: Get a message of the day
      description: |
       Description of the operation in Markdown
      operationId: getMessage
      parameters:
        - name: name
          in: path
          description: name to include in the message
          type: string
          x-example: 'Hello, Adam!'
      responses:
        default:
          description: Bad request
        200:
          description: Successful response
          schema:
            $ref: '#/definitions/Message'
          examples:
            'application/json':
              message: 'Hello, Adam!'
definitions:
  Message:
    required:
      - message
    properties:
      message:
        type: string
        default: 'Hello, Adam!'

これは GET /message/{name} というようなAPIが存在していて、そのrequestの形式で許可するものはpathのみ(apiaryのsampleはqueryになっていたけれどそれは間違い。こちらでは修正している)。また、outputとしてstatus=200のresponseは Message のschemaになっている。

Input, Output

先程のAPI定義からswagger-marshmallo-codegenを使ってschemaのコードを生成してみる。今度からInput,Outputというclassも生成されるようになった。具体的には以下の様なもの。

# -*- coding:utf-8 -*-
from marshmallow import (
    Schema,
    fields
)


class Message(Schema):
    message = fields.String(required=True, missing=lambda: 'Hello, Adam!')


class MessageNameInput(object):
    class Get(object):
        """Get a message of the day"""

        class Path(Schema):
            name = fields.String(description='name to include in the message')


class MessageNameOutput(object):
    class Get200(Message):
        """Successful response"""
        pass

APIのrequestとresponse毎にInput,Outputが存在していてそのバリエーション毎にクラスが別れている。外側のクラスはnamespaceみたいなもの。

今回のAPIに関して言えば、以下のようなrequestとresponseになる。

GET /message/adam
{"message": "Hello, Adam!"}

それぞれ MessageNameInputMessageNameOutput が対応している。

x-marshmallow-name

Input,Outputの名前はpathのpatternからすごく雑に変換して決めている。

/message/{name} -> /message/name -> message, name -> MessageName

気に入らない場合もあるだろうから、 x-marshmallow-name で名前を決められるようにした。

--- 00schema.yaml    2017-01-17 06:29:00.000000000 +0900
+++ 01schema.yaml 2017-01-17 06:33:33.000000000 +0900
@@ -2,6 +2,7 @@
   /message/{name}:
     x-summary: Message operations
     x-description: Operation description in Markdown
+    x-marshmallow-name: Message
     get:
       summary: Get a message of the day
       description: |

以下のような修正を加えると出力結果は以下の様に変わる。

# -*- coding:utf-8 -*-
from marshmallow import (
    Schema,
    fields
)


class Message(Schema):
    message = fields.String(required=True, missing=lambda: 'Hello, Adam!')


class MessageInput(object):
    class Get(object):
        """Get a message of the day"""

        class Path(Schema):
            name = fields.String(description='name to include in the message')


class MessageOutput(object):
    class Get200(Message):
        """Successful response"""
        pass

もう少し複雑なもの

Path以外にもparametersは色々ある。query,formData,body,header (詳しくはここ)。それらも見る。

例えばGithubAPIの一部のAPI定義をすこしだけ弄った以下のようなyamを渡すと以下のようなコードを生成する。

definitions:
  emailsPost:
    items:
      type: string
      pattern: ".+@.+"
    type: array
  label:
    properties:
      color:
        maxLength: 6
        minLength: 6
        type: string
      name:
        type: string
      url:
        type: string
    type: object
  labels:
    items:
      $ref: '#/definitions/label'
    type: array
  labelsBody:
    items:
      type: string
    type: array

parameters:
  owner:
    description: Name of repository owner.
    in: path
    name: owner
    required: true
    type: string
  repo:
    description: Name of repository.
    in: path
    name: repo
    required: true
    type: string
  number:
    description: Number of issue.
    in: path
    name: number
    required: true
    type: integer
  X-Github-Media-Type:
    description: |
      You can check the current version of media type in responses.
    in: header
    name: X-GitHub-Media-Type
    type: string
  Accept:
    description: Is used to set specified media type.
    in: header
    name: Accept
    type: string
  X-RateLimit-Limit:
    in: header
    name: X-RateLimit-Limit
    type: integer
  X-RateLimit-Remaining:
    in: header
    name: X-RateLimit-Remaining
    type: integer
  X-RateLimit-Reset:
    in: header
    name: X-RateLimit-Reset
    type: integer
  X-GitHub-Request-Id:
    in: header
    name: X-GitHub-Request-Id
    type: integer


responses:
  labels:
    description: OK
    schema:
      $ref: '#/definitions/labels'
  label-created:
    description: Created
    schema:
      $ref: '#/definitions/label'


paths:
  '/repos/{owner}/{repo}/issues/{number}/labels':
    delete:
      description: Remove all labels from an issue.
      parameters:
        - $ref: "#/parameters/owner"
        - $ref: "#/parameters/repo"
        - $ref: "#/parameters/number"
        - $ref: "#/parameters/X-Github-Media-Type"
        - $ref: "#/parameters/Accept"
        - $ref: "#/parameters/X-RateLimit-Limit"
        - $ref: "#/parameters/X-RateLimit-Remaining"
        - $ref: "#/parameters/X-RateLimit-Reset"
        - $ref: "#/parameters/X-GitHub-Request-Id"
      responses:
        '204':
          description: |
            No content.
        '403':
          description: |
            API rate limit exceeded. See http://developer.github.com/v3/#rate-limiting
            for details.
    get:
      description: List labels on an issue.
      parameters:
        - $ref: "#/parameters/owner"
        - $ref: "#/parameters/repo"
        - $ref: "#/parameters/number"
        - $ref: "#/parameters/X-Github-Media-Type"
        - $ref: "#/parameters/Accept"
        - $ref: "#/parameters/X-RateLimit-Limit"
        - $ref: "#/parameters/X-RateLimit-Remaining"
        - $ref: "#/parameters/X-RateLimit-Reset"
        - $ref: "#/parameters/X-GitHub-Request-Id"
      responses:
        '200':
          $ref: "#/responses/labels"
        '403':
          description: |
            API rate limit exceeded. See http://developer.github.com/v3/#rate-limiting
            for details.
    x-marshmallow-name: IssuedLabels
    post:
      description: Add labels to an issue.
      parameters:
        - $ref: "#/parameters/owner"
        - $ref: "#/parameters/repo"
        - $ref: "#/parameters/number"
        - $ref: "#/parameters/X-Github-Media-Type"
        - $ref: "#/parameters/Accept"
        - $ref: "#/parameters/X-RateLimit-Limit"
        - $ref: "#/parameters/X-RateLimit-Remaining"
        - $ref: "#/parameters/X-RateLimit-Reset"
        - $ref: "#/parameters/X-GitHub-Request-Id"
        - in: body
          name: body
          required: true
          schema:
            $ref: '#/definitions/emailsPost'
      responses:
        '201':
          $ref: "#/responses/label-created"
        '403':
          description: |
            API rate limit exceeded. See http://developer.github.com/v3/#rate-limiting
            for details.
    put:
      description: |
        Replace all labels for an issue.
        Sending an empty array ([]) will remove all Labels from the Issue.
      parameters:
        - $ref: "#/parameters/owner"
        - $ref: "#/parameters/repo"
        - $ref: "#/parameters/number"
        - $ref: "#/parameters/X-Github-Media-Type"
        - $ref: "#/parameters/Accept"
        - $ref: "#/parameters/X-RateLimit-Limit"
        - $ref: "#/parameters/X-RateLimit-Remaining"
        - $ref: "#/parameters/X-RateLimit-Reset"
        - $ref: "#/parameters/X-GitHub-Request-Id"
        - in: body
          name: body
          required: true
          schema:
            $ref: '#/definitions/emailsPost'
      responses:
        '201':
          $ref: "#/responses/label-created"
        '403':
          description: |
            API rate limit exceeded. See http://developer.github.com/v3/#rate-limiting
            for details.

こういう感じ。

# -*- coding:utf-8 -*-
from marshmallow import (
    Schema,
    fields
)
from marshmallow.validate import (
    Length,
    Regexp
)
from swagger_marshmallow_codegen.schema import (
    PrimitiveValueSchema
)
import re


class Label(Schema):
    color = fields.String(validate=[Length(min=6, max=6, equal=None)])
    name = fields.String()
    url = fields.String()


class IssuedLabelsInput(object):
    class Delete(object):
        class Header(Schema):
            X_GitHub_Media_Type = fields.String(description='You can check the current version of media type in responses.\n', dump_to='X-GitHub-Media-Type', load_from='X-GitHub-Media-Type')
            Accept = fields.String(description='Is used to set specified media type.')
            X_RateLimit_Limit = fields.Integer(dump_to='X-RateLimit-Limit', load_from='X-RateLimit-Limit')
            X_RateLimit_Remaining = fields.Integer(dump_to='X-RateLimit-Remaining', load_from='X-RateLimit-Remaining')
            X_RateLimit_Reset = fields.Integer(dump_to='X-RateLimit-Reset', load_from='X-RateLimit-Reset')
            X_GitHub_Request_Id = fields.Integer(dump_to='X-GitHub-Request-Id', load_from='X-GitHub-Request-Id')

        class Path(Schema):
            owner = fields.String(description='Name of repository owner.')
            repo = fields.String(description='Name of repository.')
            number = fields.Integer(description='Number of issue.')


    class Get(object):
        class Header(Schema):
            X_GitHub_Media_Type = fields.String(description='You can check the current version of media type in responses.\n', dump_to='X-GitHub-Media-Type', load_from='X-GitHub-Media-Type')
            Accept = fields.String(description='Is used to set specified media type.')
            X_RateLimit_Limit = fields.Integer(dump_to='X-RateLimit-Limit', load_from='X-RateLimit-Limit')
            X_RateLimit_Remaining = fields.Integer(dump_to='X-RateLimit-Remaining', load_from='X-RateLimit-Remaining')
            X_RateLimit_Reset = fields.Integer(dump_to='X-RateLimit-Reset', load_from='X-RateLimit-Reset')
            X_GitHub_Request_Id = fields.Integer(dump_to='X-GitHub-Request-Id', load_from='X-GitHub-Request-Id')

        class Path(Schema):
            owner = fields.String(description='Name of repository owner.')
            repo = fields.String(description='Name of repository.')
            number = fields.Integer(description='Number of issue.')


    class Post(object):
        class Body(PrimitiveValueSchema):
            v = fields.String(validate=[Regexp(regex=re.compile('.+@.+'))])

        class Header(Schema):
            X_GitHub_Media_Type = fields.String(description='You can check the current version of media type in responses.\n', dump_to='X-GitHub-Media-Type', load_from='X-GitHub-Media-Type')
            Accept = fields.String(description='Is used to set specified media type.')
            X_RateLimit_Limit = fields.Integer(dump_to='X-RateLimit-Limit', load_from='X-RateLimit-Limit')
            X_RateLimit_Remaining = fields.Integer(dump_to='X-RateLimit-Remaining', load_from='X-RateLimit-Remaining')
            X_RateLimit_Reset = fields.Integer(dump_to='X-RateLimit-Reset', load_from='X-RateLimit-Reset')
            X_GitHub_Request_Id = fields.Integer(dump_to='X-GitHub-Request-Id', load_from='X-GitHub-Request-Id')

        class Path(Schema):
            owner = fields.String(description='Name of repository owner.')
            repo = fields.String(description='Name of repository.')
            number = fields.Integer(description='Number of issue.')


    class Put(object):
        class Body(PrimitiveValueSchema):
            v = fields.String(validate=[Regexp(regex=re.compile('.+@.+'))])

        class Header(Schema):
            X_GitHub_Media_Type = fields.String(description='You can check the current version of media type in responses.\n', dump_to='X-GitHub-Media-Type', load_from='X-GitHub-Media-Type')
            Accept = fields.String(description='Is used to set specified media type.')
            X_RateLimit_Limit = fields.Integer(dump_to='X-RateLimit-Limit', load_from='X-RateLimit-Limit')
            X_RateLimit_Remaining = fields.Integer(dump_to='X-RateLimit-Remaining', load_from='X-RateLimit-Remaining')
            X_RateLimit_Reset = fields.Integer(dump_to='X-RateLimit-Reset', load_from='X-RateLimit-Reset')
            X_GitHub_Request_Id = fields.Integer(dump_to='X-GitHub-Request-Id', load_from='X-GitHub-Request-Id')

        class Path(Schema):
            owner = fields.String(description='Name of repository owner.')
            repo = fields.String(description='Name of repository.')
            number = fields.Integer(description='Number of issue.')




class IssuedLabelsOutput(object):
    class Get200(Label):
        """OK"""
        def __init__(self, *args, **kwargs):
            kwargs['many'] = True
            super().__init__(*args, **kwargs)


    class Post201(Label):
        """Created"""
        pass

    class Put201(Label):
        """Created"""
        pass

signal handleするコードのテスト

はじめに

signalをhandleするコード自体は手軽に書ける。

import signal
import sys


def on_sigint(signum, frame):
    print("hmm")
    sys.exit(1)

signal.signal(signal.SIGINT, on_sigint)

しかしこれが確実にtrapされたことを確認するテストを書くのはだるい

面倒くさい理由

面倒くさい理由はいくつかあって、まず、signalをtrapするというのはプログラム全体に影響を及ぼす。 そしてmain threadでしか受け取れないので気軽にthreadingでごまかすということも出来ない。

試行錯誤した結果

しょうがないのでmultiprocessingで頑張る。

import sys
import unittest
import signal


class Ob(object):
    def __str__(self):
        return hex(id(self))


def do_something(calculate, _on_trap=None):
    ob = Ob()

    def on_trap(signum, frame):
        if _on_trap is not None:  # for test
            _on_trap(ob)
        print("cleanup with ", ob)
        sys.exit(1)

    signal.signal(signal.SIGHUP, on_trap)
    signal.signal(signal.SIGINT, on_trap)
    signal.signal(signal.SIGTERM, on_trap)

    # fetch anything?

    calculate(ob)  # do something

    # save db?


class Tests(unittest.TestCase):
    def test_it(self):
        from multiprocessing import Process, Queue
        import time

        q = Queue()
        init = 1
        called = 10
        q.put(init)

        def calculate(ob):
            print("before calculate", ob)
            time.sleep(1)  # waiting for killed
            print("after calculate", ob)

        def _on_trap(ob):
            self.assertEqual(q.get(), init)
            q.put(called)

        p = Process(target=lambda: do_something(calculate, _on_trap=_on_trap))
        p.start()
        time.sleep(0.1)
        p.terminate()  # SIGTERM
        p.join()
        self.assertEqual(q.get(), called)

if __name__ == "__main__":
    unittest.main()
    # before calculate 0x1030b66d8
    # cleanup with  0x1030b66d8

afterは呼ばれていないので途中で中断されている(process.terminate()によるSIGTERM)。 そしてqueueの値はcalledになっている。

before calculate 0x103f6f7b8
cleanup with  0x103f6f7b8
.
----------------------------------------------------------------------
Ran 1 test in 0.182s

OK