djangoでヘテロなリストのprefetch (generic foreign keyのprefetch_related)

djangoヘテロなリストのprefetch

djangoヘテロなリストのprefetch。できないと思ったら普通に出来たのでわりとびっくりしたので記事にしてみた。

ヘテロなリスト?

ここでのヘテロなリストと言うのは以下のようなリストを指している。

xs = [A(), B(), A(), A(), B(), C()]

つまるところ1つの型のオブジェクトのリストではなく複数の型が混在したリストのようなものを指している。このようなリストは複数のことなるアイテムを要素として持つタイムラインのようなページを作ろうとした時に必要になる。例えば、このようなものをgeneric foreignkeyを使って以下の様に実装するということがあるかもしれない。

from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models


class A(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class B(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class C(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class Feed(models.Model):
    object_id = models.PositiveIntegerField()
    content_type = models.ForeignKey(ContentType)
    content = GenericForeignKey('content_type', 'object_id')

    class Meta:
        unique_together = ("content_type", "object_id")

上の例ではfeedというモデルに複数の種類のモデルA,B,Cのいずれかが紐付いているというような定義。色々制約や不都合はあるもののgeneric foreign keyを使うことでテーブルの定義はできる。

データの生成

実データに対するqueryを行うためにテキトウにデータを投入しておく。Feedの生成部分は1つの種類毎に分けても良いが、一緒くたにまとめて渡し、複数の種類のオブジェクトが混在した状態でも上手く動いた事に驚いたので例として載せている。

A.objects.bulk_create([
    A(name="a0", id=1),
    A(name="a1", id=2),
    A(name="a2", id=3)
])
B.objects.bulk_create([
    B(name="b0", id=10),
    B(name="b1", id=20),
    B(name="b2", id=30)
])
C.objects.bulk_create([
    C(name="c0", id=100),
    C(name="c1", id=200),
    C(name="c2", id=300)
])
Feed.objects.bulk_create(
    [Feed(content=a) for a in A.objects.all()]
    + [Feed(content=b) for b in B.objects.all()]
    + [Feed(content=c) for c in C.objects.all()]
)

ちなみにsqliteでは以下のようなSQLがfeedのデータの生成時に実行された。 (a,b,cの登録は省いている)

INSERT INTO "feed" ("object_id", "content_type_id") SELECT 1, 1 UNION ALL SELECT 2, 1 UNION ALL SELECT 3, 1 UNION ALL SELECT 10, 2 UNION ALL SELECT 20, 2 UNION ALL SELECT 30, 2 UNION ALL SELECT 100, 3 UNION ALL SELECT 200, 3 UNION ALL SELECT 300, 3; args=(1, 1, 2, 1, 3, 1, 10, 2, 20, 2, 30, 2, 100, 3, 200, 3, 300, 3)

N+1 query

もちろん、feedの一覧を取って、そのfeedから結びついているobjectを雑に取ろうとした場合にはN+1 のqueryが発生する。

# query数を計測するためのshorthandの関数を用意しておく
@contextlib.contextmanager
def with_clear_connection(c, message):
    print("\n========================================")
    print(message)
    print("========================================")
    c.queries_log.clear()
    yield

例えば以下のように単純なループでfeedに紐づくA,B,Cの各モデルを取得する。

c = connections["default"]
with with_clear_connection(c, "n + 1"):
    print(len(c.queries))
    content_list = []
    for feed in Feed.objects.all():
        content_list.append(feed.content)
    print(len(c.queries))  # => 1 + 9 * 1 = 10
    print([(o.__class__.__name__, o.id) for o in content_list])

先程のfeedデータの生成で9件のデータを登録したのでN+1 queryが発生し全部で10件のsqlが発行された。

0
(0.000) SELECT "feed"."id", "feed"."object_id", "feed"."content_type_id" FROM "feed"; args=()
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 1; args=(1,)
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 2; args=(2,)
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 3; args=(3,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 10; args=(10,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 20; args=(20,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 30; args=(30,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 100; args=(100,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 200; args=(200,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 300; args=(300,)
10
[('A', 1), ('A', 2), ('A', 3), ('B', 10), ('B', 20), ('B', 30), ('C', 100), ('C', 200), ('C', 300)]

prefetch_related

どうせdjangoさんのことだから、そのままではprefetch_relatedは動かないだろうと、少し斜に構えた目で見ながら以下のようなコードを書いた。

with with_clear_connection(c, "prefetch"):
    print(len(c.queries))
    content_list = []
    for feed in Feed.objects.all().prefetch_related("content"):
        content_list.append(feed.content)
    print(len(c.queries))  # => 1 + 3 * 1 = 4
    print([(o.__class__.__name__, o.id) for o in content_list])

すると、普通に何事もなかったかのように動いたのでびっくりした。もちろん実行されたSQLはfeedの分の1回と3種類のmodelの分の3回のみ。

========================================
prefetch
========================================
0
(0.000) SELECT "feed"."id", "feed"."object_id", "feed"."content_type_id" FROM "feed"; args=()
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" IN (1, 2, 3); args=(1, 2, 3)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" IN (10, 20, 30); args=(10, 20, 30)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" IN (200, 100, 300); args=(200, 100, 300)
4
[('A', 1), ('A', 2), ('A', 3), ('B', 10), ('B', 20), ('B', 30), ('C', 100), ('C', 200), ('C', 300)]

内部の詳細

さて、実装がどうなっているのか気になったので、少し覗いてみよう。大まかな概要は以下の様になっている。

内部(django.db.models.query)では、prefetcherとでも呼ぶようなオブジェクトが対象となるクラスから取得できることが期待されている。どうやって取り出すかというと、prefetch_relatedで指定した文字列部分を名前としてアクセス。

つまり、上の例では、以下の様にしてprefetcherオブジェクトが取れることが期待されている。

Feed.content

そしてこのprefetcherオブジェクトと言うのは、 get_prefetch_queryset() というメソッドを持つことが期待されている(is_cached() というメソッドも必要になるが今回はこれは無視することにする。雑に説明するとこちらは select_related() 等で既に取得済みだったら使い回せれば良いね。みたいなもの)。まじめにgeneric foreign keyでのprefetcherの実装を知りたいな django.contrib.contenttypes.fields の辺りを見れば良い。

色々と長々と書かれているが重要な部分だけを取り出すと以下の通り。

class Prefetcher(object):
    def get_prefetch_queryset(self, objs, qs):
        # objsはprefetch_relatedを貼ったqueryのobjectのコレクション(ここではFeedのコレクション)
        # qsはgeneric foreign keyのprefetchでは使われない

        # fk_dictはcontent_id(上の例ではfeed)のidとobject_id(上の例ではA,B,Cのid)の辞書
        # ret_valはlist
        for ct_id, fkeys in fk_dict.items():
            instance = instance_dict[ct_id]
            ct = self.get_content_type(id=ct_id, using=instance._state.db)
            ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))

        # ...    
        # prefetcherのget_prefetch_querysetは5つの戻り値を返す
        return (ret_val,
            lambda obj: (obj._get_pk_val(), obj.__class__),
            gfk_key,
            True,
            self.cache_attr)

真ん中のループは雑に説明すると以下の様なことが行われている。

ret_val.extend(A.objects.filter(pk__in=[1,2,3]))
ret_val.extend(B.objects.filter(pk__in=[10,20,30]))
ret_val.extend(B.objects.filter(pk__in=[100,200,300]))

関連オブジェクトを一気に取り出してリストに詰め込んでいる。

そして、get_prefetch_queryset() の戻り値のそれぞれの意味は以下の様になっている。

  • 1つ目: prefetchしときたい関連オブジェクトのコレクション
  • 2つ目: 親オブジェクトの関連オブジェクト(上の例ではA,B,Cのインスタンス)からprefetchの際に使われるkeyを生成
  • 3つ目: 親オブジェクト(上の例ではFeedのインスタンス)からprefetchの際に使われるkeyの生成
  • 4つ目: 値がcollectionだった場合にはFalseそうでない場合にはTrue(feed.contentは1つなのでTrue)。
  • 5つ目: cacheを保持するため場所の名前。

ここで重要なのは1つ目と2つ目3つ目。prefetch_relatedの実装はどうなっているかというと、キャッシ ュ用に作った辞書を通してeager loadingされたobjectをキャッシュしているに過ぎない。

例えば今回の例だと以下の様な形でcacheが保存されている。

# (<object_id>, <model class>) => <Related Object>
# 例えばA(id=1,name="a0")は以下の様な形
{
    (1, A): a0
}

そして 関連オブジェクトからkeyの生成は以下で済み(こちらでcacheの辞書にprefetchしたオブジェクトを登録)。

# a0はAクラスのオブジェクト
(a0.id, a0.__class__) # => (1, A)

親オブジェクトからkeyの生成は以下で済む(こちらでprefetchされたオブジェクトを辞書から取得)。

# f0はFeedクラスのオブジェクト(Aクラスと結びついているとする)
# djangoのcontent_type tableを経由して該当の関連オブジェクトのmodelのクラスを取得する
model = self.get_content_type(id=ct_id, using=obj._state.db).model_class()
(feed0.object_id, model) # => (1, A)

反対側の関連の例もあるけれど。今回はこれでおしまい。