wsgirefでhttp用のreverse proxyを書いてみる

goだとnet/http/httputilに便利なコードが置いてあるのだけれど。pythonだとそういうものがない。 なのでとりあえずwsgirefでどうするかを考えてみた(hip by hopなどは対応していない)。

挙動としては、requestされたら、テキトウなurlにrequestsで通信する。返ってきたresponseに対して加工を加えるようなもの。

以下のような形。返す時にちょっとした加工を加える。

<client> -> <proxy>
                     ->  <server>
<client> <-

serverは全部JSON APIを想定。

request/response

なんだかんだでrequest/response形式が楽。そういう形のオブジェクトを定義してみる。requests.models.Responseをそのまま使っても良いのだけれど。何が行えるかぱっと見で分かりたかったのでprotocolにした(Protocolを使っている時点でtyping_extensionsが必要になるじゃんという話はある。現状では)。

response

import typing as t
import typing_extensions as tx

class Response(tx.Protocol):
    status_code: int
    reason: str
    headers: t.Dict[str, str]

    @property
    def content(self) -> bytes:
        ...

    def json(self) -> dict:
        ...

request

class Request:
    def __init__(self, environ):
        self.environ = environ

    @property
    def wsgi_input(self):
        return self.environ["wsgi.input"]

    @property
    def path(self):
        return self.environ["PATH_INFO"]

    @property
    def method(self):
        return self.environ["REQUEST_METHOD"]

    @property
    def query_string(self):
        return self.environ.get("QUERY_STRING")

    @property
    def headers(self):
        environ = self.environ
        return {k[5:].replace("_", "-"): environ[k] for k in environ if k.startswith("HTTP_")}

    @property
    def content_type(self):
        return self.environ.get("CONTENT_TYPE")

    @property
    def content_length(self):
        v = self.environ.get("CONTENT_LENGTH") or None
        if v is None:
            return None
        return int(v)

    @property
    def data(self):
        if not self.content_length:
            return None
        return self.wsgi_input.read(self.content_length)

wsgi app

それっぽいwsgi appを書く。wsgi appについてはwsgirefのドキュメントを見ると良いかも?

class Proxy:
    def __init__(
        self,
        request: t.Callable[[Request], Response],
        response: t.Callable[[Response], bytes],
    ):
        self.request = request
        self.response = response

    def __call__(
        self,
        environ: dict,
        start_response: t.Callable[[str, t.List[t.Tuple[str, str]]], None],
    ) -> None:
        response = self.request(Request(environ))
        if response.status_code == 200:
            content = self.response(response)
        else:
            content = response.content
        start_response(f'{response.status_code} {response.reason}', list(response.headers.items()))
        return [content]

main。ここを書き換えて使う想定。一点だけ、responseを書き換えるならContent-Lengthをそのままにするとまずい(元のサイズより大きくなった時、responseを全部見ないで閉じてしまう)。まじめに計算しても良いけれど。怠惰なので取り除く。

def main(port=4444):
    def request(req: Request) -> Response:
        url = f"http://localhost:5000{req.path}"
        if req.query_string:
            url = f"{url}?{req.query_string}"
        return requests.request(req.method, url, data=req.data, headers=req.headers)

    def response(res: Response) -> bytes:
        if not res.headers.get("Content-Type", "").lstrip().startswith("application/json"):
            return res.content
        res.headers.pop("Content-Length", None) # ココ重要

        # ここで加工
        body = res.json()
        for k in list(body.keys()):
            body[k] = f"**{body[k]}**"  # とりあえず**を付けてみる

        return json.dumps(body).encode("utf-8")

    proxy = Proxy(request, response)
    with make_server('', port, proxy) as httpd
        httpd.serve_forever()

つなげて動かす

以下の様な雑なserverを立てる

import json
from wsgiref.simple_server import make_server


def app(environ, start_response):
    status = '200 OK'
    headers = [('Content-type', 'application/json; charset=utf-8')]
    start_response(status, headers)
    data = {
        "name": "foo",
        "age": "20",
    }
    return [json.dumps(data).encode("utf-8")]


def main(port=4445):
    with make_server('', port, app) as httpd
        httpd.serve_forever()
    httpd = make_server('', port, app)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=4444)
    args = parser.parse_args()
    main(port=args.port)

serverとproxyを立ち上げておく

$ python proxy.py --port=5001 &
$ python server.py --port=5000 &

確認

proxy無し

$ http :5000/api/foo
127.0.0.1 - - [14/Aug/2018 00:00:39] "GET /api/foo HTTP/1.1" 200 28
HTTP/1.0 200 OK
Content-Length: 28
Content-type: application/json; charset=utf-8
Date: Mon, 13 Aug 2018 15:00:39 GMT
Server: WSGIServer/0.2 CPython/3.7.0

{
    "age": "20",
    "name": "foo"
}

proxy有り

$ http :5001/api/foo
127.0.0.1 - - [13/Aug/2018 23:58:28] "GET /api/foo HTTP/1.1" 200 28
127.0.0.1 - - [13/Aug/2018 23:58:28] "GET /api/foo HTTP/1.1" 200 36
HTTP/1.0 200 OK
Content-Length: 36
Content-type: application/json; charset=utf-8
Date: Mon, 13 Aug 2018 14:58:28 GMT
Server: WSGIServer/0.2 CPython/3.7.0

{
    "age": "**20**",
    "name": "**foo**"
}

POSTなども上手くいっている?

昔作ったreqtrace を使うと手軽そう。これは内部で行われている通信をdumpしてくれるもの(proxyが行っているrequestをdumpしたい)。

確認してみた。

proxyを立ち上げる時に以下の様にしておく。

# git clone git@github.com:podhmo/reqtrace
# pip install -e ./reqtrace
$ python -m reqtrace proxy.py -- --port=5001 &

POST(application/json)

$ echo '{"foo": "boo"}' | http -b --json POST :5001 headerX:headerV qsK==qsV
127.0.0.1 - - [14/Aug/2018 00:05:23] "POST /?qsK=qsV HTTP/1.1" 200 28
INFO:reqtrace.tracelib.hooks:traced http://localhost:5000/?qsK=qsV
127.0.0.1 - - [14/Aug/2018 00:05:23] "POST /?qsK=qsV HTTP/1.1" 200 36
{
    "age": "**20**",
    "name": "**foo**"
}

このときのrequest以下のようなもの

  • headerも送られている
  • query stringも送られている
  • json bodyも送られている。
{
  "request": {
    "body": "{\"foo\": \"boo\"}\n",
    "headers": {
      "ACCEPT": "application/json, */*",
      "ACCEPT-ENCODING": "gzip, deflate",
      "CONNECTION": "keep-alive",
      "Content-Length": "15",
      "HEADERX": "headerV",
      "HOST": "localhost:5001",
      "USER-AGENT": "HTTPie/0.9.8"
    },
    "host": "localhost:5000",
    "method": "POST",
    "path": "/",
    "queries": [
      [
        "qsK",
        "qsV"
      ]
    ],
    "url": "http://localhost:5000/?qsK=qsV"
  },
  "response": {
    "body": {
      "age": "20",
      "name": "foo"
    },
    "headers": {
      "Content-Length": "28",
      "Content-type": "application/json; charset=utf-8",
      "Date": "Mon, 13 Aug 2018 15:05:23 GMT",
      "Server": "WSGIServer/0.2 CPython/3.7.0"
    },
    "status_code": 200
  }
}

今度は通常のPOSTの場合

$ http -b --form POST :5001 headerX:headerV qsK==qsV foo=boo
127.0.0.1 - - [14/Aug/2018 00:08:41] "POST /?qsK=qsV HTTP/1.1" 200 28
INFO:reqtrace.tracelib.hooks:traced http://localhost:5000/?qsK=qsV
127.0.0.1 - - [14/Aug/2018 00:08:41] "POST /?qsK=qsV HTTP/1.1" 200 36
{
    "age": "**20**",
    "name": "**foo**"
}

この時のrequestと先ほどのdiff

--- roundtrips/0000post_http:__localhost:5000_?qsK=qsV.json  2018-08-14 00:05:23.024385238 +0900
+++ roundtrips/0001post_http:__localhost:5000_?qsK=qsV.json   2018-08-14 00:08:41.018282933 +0900
@@ -1,11 +1,11 @@
 {
   "request": {
-    "body": "{\"foo\": \"boo\"}\n",
+    "body": "foo=boo",
     "headers": {
-      "ACCEPT": "application/json, */*",
+      "ACCEPT": "*/*",
       "ACCEPT-ENCODING": "gzip, deflate",
       "CONNECTION": "keep-alive",
-      "Content-Length": "15",
+      "Content-Length": "7",
       "HEADERX": "headerV",
       "HOST": "localhost:5001",
       "USER-AGENT": "HTTPie/0.9.8"
@@ -29,7 +29,7 @@
     "headers": {
       "Content-Length": "28",
       "Content-type": "application/json; charset=utf-8",
-      "Date": "Mon, 13 Aug 2018 15:05:23 GMT",
+      "Date": "Mon, 13 Aug 2018 15:08:41 GMT",
       "Server": "WSGIServer/0.2 CPython/3.7.0"
     },
     "status_code": 200

gist

gist