もう少しだけdjangoのprefetch_relatedについて考えてみる(条件付加したrelationのeager loading)

はじめに

あるモデルに対してあるコンテキスト(文脈)に従った条件を加味した関係の元に値を取り出したい場合がある。そのような条件を付加した値を仮想的なフィールドとして扱うことができないかという話。

例えば以下の事がしたい

X,Yというテーブルが存在。これらはMany to Manyの関係になっている。

  • join時の条件を付加しておきたい(e.g. is_valid=Trueの条件の元queryしたい)
  • defaultのorder byを指定しておきたい(e.g. 生成日時で降順に撮りたい)

instanceだけを対象で考える場合

@property で雑にプロパティにしてしまって良いという話かもしれない。そのような条件を満たしたproperty valid_xs を定義してみる。

# x: y = M : N

class Y(models.Model):
    name = models.CharField(max_length=32, null=False, default="")
    ctime = models.DateTimeField(auto_now_add=True, null=False)
    is_valid = models.BooleanField(default=True, null=False)

    xs = ManyToMany(X, related_name="xs")

    @property
    def valid_xs(self):
        return self.xs.all().filter(is_valid=True).order_by("-ctime")

もちろんこのままではN + 1クエリが発生する可能性は残る。

N + 1の問題について

以前からN+1の問題については幾つか記事を書いてきた。基本的には、select_relatedでのtableのjoinもしくはprefetch_relatedでeager loadingをすれば良い。例えば今回の要件では元のquerysetに対して付加的な条件を加えたいだけだったりする。

付加的な条件の追加をその場で行うのはprefetch_relatedにPrefetch objectを渡すことで可能ではある。

prefetch = Prefetch(queryset=Y.objects.filter(is_valid=True).order_by("-ctime"), to_attr="valid_ys")
qs = X.objects.all().prefetch_related("ys", prefetch)

とは言えこのままだと以下の点が面倒に感じる。

  • 先程のpropertyの定義と名前が衝突する
  • 同じようなqueryを二度書く必要が出てくる
  • 実行時の条件やprefetchする際の名前を間違ってしまう場合が存在する可能性がある

propertyの定義と名前が衝突する

prefetch_relatedが付加されない場合でも上手く動いて欲しいため、上で定義したpropertyと同じ名前でprefetch_relatedを使おうとすると以下の様なエラーが発生する。

prefetch = Prefetch("xs", X.objects.all().filter(is_valid=True).order_by("-ctime"), "valid_xs")
for y in Y.objects.all().prefetch_related(prefetch):
    print(y.id, y.name, [x.name for x in y.valid_xs])
# AttributeError: can't set attribute

これはpropertyにsetterを付けることで解決できる。cacheとして使われる値が束縛される時に使われる名前がto_attrで指定した文字列なので、以前に定義したproperty名と重複してしまっているということなので。したがってモデルのpropertyの定義を以下の様に変えれば良い。

class Y(models.Model):
    # snip..

    # # 以下の様に書いてしまうと、bool(qs) or get_queryset() という呼び出しになってしまい、bool(qs)でqsが評価されてしまうので注意
    # @property
    # def valid_xs(self):
    #     return getattr(self, "_valid_xs", None) or self.xs.all().filter(is_valid=True).order_by("-ctime")

    @property
    def valid_xs(self):
        result = getattr(self, "_valid_xs", None)
        if result is None:
            result = self._valid_xs = self.xs.all().filter(is_valid=True).order_by("-ctime")
        return result

    @valid_xs.setter
    def valid_xs(self, value):
        self._valid_xs = value

同じようなqueryを2度書いてしまっている

同じqueryを2度書くというのも嫌かもしれない。今回のクエリーの条件自体は、Xオブジェクト側だけの情報で付加できるものではあるのでXのモデル定義に含めておくと良いのかもしれない。

条件を付加したquerysetをモデルの定義に含めておく方法としては以下の4つくらいが考えられる。

  • filteringする関数の定義
  • modelのclass methodに追加
  • 独自のquerysetの定義
  • 独自のmanagerの定義

まずmanagerの定義は論外。これはquerysetを取得する際の開始時の処理しか定義する事ができないので。managerに定義する位ならquerysetに定義した方が良い。class methodにするかquerysetに定義するかは好みで良いと思う。関数として独立して定義するのとmodelのclass methodに追加するのは実質同じことではあり単に利用する際にmoduleのimportが不要かどうかという話でしかない。

この中で最も自然なのは独自のquerysetを定義することではあるけれど、個人的には単にmodelにclass methodを追加するだけで十分なのではないかと思っている。とりあえず同じようなqueryを2度書く必要があるというのはこれにより解決できる。

filteringする関数の定義

from y.models import Y

def get_valid_ys_set(qs=Y.objects.all()):
    return qs.filter(is_valid=True).order_by("-ctime")

実際に利用する時にはimportして使う必要がある。

class methodに追加する方法

class X(models.Model):
    # snip...
    @classmethod
    def valid_set(cls, qs=None):
        if qs is None:
            qs = cls.objects.all()
        # 以下の様に書いてしまうとbool(qs) で querysetが評価されてしまうので注意
        # qs = qs or cls.objects.all()
        return qs.filter(is_valid=True).order_by("-ctime")

使う時は以下の様になる

class Y(models.Model):
    # snip...
    @property
    def valid_xs(self):
        return getattr(self, "_valid_xs", None) or X.valid_set(self.xs.all())

# prefetchとして
prefetch = Prefetch("xs", X.valid_set(X.objects.all()), "valid_xs")

querysetとして追加

自分でQuerySetクラスを定義し、自分のmanagerを定義したQuerySetを返すように変更しておく。

class XQuerySet(models.QuerySet):
    def valid_set(self):
        return self.filter(is_valid=True).order_by("-ctime")

class X(models.Model):
    objects = XQuerySet.as_manager()
    # or objects = models.Manager().from_queryset(XQuerySet)()
    # snip...

querysetとして追加しておくと、関連するmodelのimportが不要になるというメリットはあるものの、返されるmodelが何であったのか分かりづらくなる気がするので個人的にはclassmethodでの追加の方が好みではある。

# classmethodでの追加の場合、Xモデルを経由しなければ絞り込みの条件を適用できない
from x.models import X
X.valid_set(y.xs.all())

# querysetにmethodが追加されていればimportは不要
y.xs.all().valid_set()

prefetchする際の名前を間違ってしまう場合が存在する可能性がある

これに関してはprefetch objectを生成する処理をメソッド・関数化すれば良いというだけの話かもしれない。 どこに追加するかというのは先程のqueryの条件追加の部分のものと同様。

class Y(models.Model):
    # snip ...
    @classmethod
    def prefetch_valid_xs(cls):
        return Prefetch("ys", queryset=X.valid_set(), to_attr="valid_xs")    

prefetch_relatedでの文字列指定は諦めた方が良い

今まではかたくなに以下の形式でprefetchの設定を追加しようとしていた。

# prefetch_related(<prefetch object>)
Y.objects.all().prefetch_related(Y.prefetch_valid_xs())

以下の様な形式で設定できるようにすることは可能だろうか?

# prefetch_related(<string>)
Y.objects.all().prefetch_related("valid_ys")

結論から言うと辛いので止めておいたほうが良い。1.9の現在のdjangoのこの辺りの処理のコードは決め打ちで内部構造を利用するコードが多いので文字列で指定可能にするための労力に対して得られるメリットが少ない。

(ちなみに、過去にprefetcherオブジェクト(is_cached(),get_prefetch_queryset()を持つオブジェクト)を自前で定義してprefetch_relatedに指定可能にするということを記事にしていたが、あれは最後の手段なので日常的に使うものではない。)

prefetch_relatedのうれしさ

prefetch_relatedの機能は、ある種の関連queryのeager loadingを行うということだけれど、この機能のうれしさについてもう少し詳しく話すとすると、prefetch_relatedに纏わるeager loadingが半ばフレームワークになっていて、prefetch可能なインターフェイスとして定義を揃えておけば、N段ネストした場合についても上手く動作するということがうれしい。

例えば、以下の2つの関連があるとすると

  • A から B の関連
  • B から C の関連

AからBへの関連のN+1,BからCへの関連のN+1の2つを抑制することができたら、この2つの組み合わせたAからCへのN+1も抑制することができるというのがうれしい点。

実際、上で定義した prefetch_valid_xs() について、さらにXにtagのようなモデルがくっついていた場合にも適切にprefetchしてくれる。

例えば、Y -> X -> XTag という関連がある時、

class XTag(models.Model):
    xs = models.ManyToManyField(X, related_name="tags")
    name = models.CharField(max_length=32, null=False, default="", unique=True)

以下のコードで発行されるqueryは3件ですむ。

qs = Y.objects.all().prefetch_related(Y.prefetch_valid_xs(), "valid_xs__tags")
for y in qs:
    for x in y.valid_xs:
        for tag in x.tags.all():
            print(y.name, x.name, tag.name)

追記

setter,getterを陽に定義するのが面倒であれば以下のような関数を作っても良いかもしれない。

def custom_relation_property(getter):
    name = getter.__name__
    cache_name = "_{}".format(name)

    def _getter(self):
        result = getattr(self, cache_name, None)
        if result is None:
            result = getter(self)
            setattr(self, cache_name, result)
        return result

    def _setter(self, value):
        setattr(self, cache_name, value)

    prop = property(_getter, _setter, doc=_getter.__doc__)
    return prop

こんな感じで済む。

class Y(models.Model):
    # snip ...

    @custom_relation_property
    def valid_xs(self):
        return X.valid_set(self.xs.all())