swagger-marshmallow-codegenでカスタマイズ出来るようにした

swagger-marshmallow-codegenで簡単なカスタマイズ出来るようにした。

例えば以下の様なことができるようになった

  • defaultで使うschema classをMySchemaに変える
  • 特定の条件を満たした値のときには自分で作った独自のfieldを使うように変える

ただこれらはすごくwork-aroundっぽい方針で作っているのであんまり綺麗ではないかもしれない。自分でDriverというクラスを作りそのクラスを --driver に渡す感じで使う。例えば以下の様な形。

$ swagger-marshmallow-codegen --driver=_custom.py:MyDriver --logging=DEBUG person.yaml > person.py

defaultで使うschema classをMySchemaに変える

defaultで使うschema classを変えるには codegen_factory を変える。myschema モジュールのMySchemaが使いたい場合には以下の様にする。

from swagger_marshmallow_codegen.driver import Driver

class MyDriver(Driver):
    codegen_factory = Driver.codegen_factory.override(schema_class_path="myschema:MySchema")

特定の条件を満たした値のときには自分で作った独自のfieldを使うように変える

こちらも同様に dispatcher_factory を変える。

例えば format=objectId のものは自分で定義した myschema の ObjectIdを使うように変えるときには以下の様にする。default値を気にせずmappingを変更する場合には、以下だけで良い。

type_map = {
    Pair(type="string", format="objectId"): "myschema:ObjectId",
    **TYPE_MAP,
}


class MyDriver(Driver):
    codegen_factory = Driver.codegen_factory.override(schema_class_path="myschema:MySchema")
    dispatcher_factory = Driver.dispatcher_factory.override(type_map=type_map)

とは言えdefault値の扱いを考えるとこちらは少し頑張らないとだめ。

from swagger_marshmallow_codegen.driver import Driver
from swagger_marshmallow_codegen.dispatcher import TYPE_MAP, Pair, FormatDispatcher, ReprWrapString


class MyDispatcher(FormatDispatcher):
    type_map = {
        Pair(type="string", format="objectId"): "myschema:ObjectId",
        **TYPE_MAP,
    }

    def dispatch_default(self, c, value, field):
        if isinstance(value, bson.ObjectId) or field.get("format") == "objectId":
            c.import_("bson")
            return ReprWrapString("bson.{!r}".format(bson.ObjectId(value)))
        return super().dispatch_default(c, value, field)


class MyDriver(Driver):
    codegen_factory = Driver.codegen_factory.override(schema_class_path="myschema:MySchema")
    dispatcher_factory = MyDispatcher

実行結果

例えば上で定義したものを使うと。以下のようなyaml

definitions:
  person:
    type: object
    properties:
      id:
        type: string
        format: objectId
        default: 5872bad4c54d2d4e78b34c9d
      name:
        type: string
      age:
        type: integer
    required:
      - name

このようなpythonのコードになる。

# -*- coding:utf-8 -*-
from myschema import (
    MySchema,
    ObjectId
)
from marshmallow import fields
import bson


class Person(MySchema):
    id = ObjectId(missing=lambda: bson.ObjectId('5872bad4c54d2d4e78b34c9d'))
    name = fields.String(required=True)
    age = fields.Integer()

参考

一応、参考にするための example も作った。

補足

ちなみにmyschemaのコードは例えば以下のようなもの

import bson
from marshmallow import Schema, fields


class MySchema(Schema):
    class Meta:
        ordered = True
        strict = True


class ObjectId(fields.String):
    default_error_messages = {
        'invalid_object_id': 'Not a valid bson.ObjectId.',
    }

    def _validated(self, value):
        """Format the value or raise a :exc:`ValidationError` if an error occurs."""
        if value is None:
            return None
        if isinstance(value, bson.ObjectId):
            return value
        try:
            return bson.ObjectId(value)
        except (ValueError, AttributeError):
            self.fail('invalid_object_id')

    def _deserialize(self, value, attr, data):
        return self._validated(value)

    def _serialize(self, value, attr, data):
        if not value:
            return value
        return str(value)