djangoでヘテロなリストのprefetch (generic foreign keyのprefetch_related)

djangoヘテロなリストのprefetch

djangoヘテロなリストのprefetch。できないと思ったら普通に出来たのでわりとびっくりしたので記事にしてみた。

ヘテロなリスト?

ここでのヘテロなリストと言うのは以下のようなリストを指している。

xs = [A(), B(), A(), A(), B(), C()]

つまるところ1つの型のオブジェクトのリストではなく複数の型が混在したリストのようなものを指している。このようなリストは複数のことなるアイテムを要素として持つタイムラインのようなページを作ろうとした時に必要になる。例えば、このようなものをgeneric foreignkeyを使って以下の様に実装するということがあるかもしれない。

from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models


class A(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class B(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class C(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class Feed(models.Model):
    object_id = models.PositiveIntegerField()
    content_type = models.ForeignKey(ContentType)
    content = GenericForeignKey('content_type', 'object_id')

    class Meta:
        unique_together = ("content_type", "object_id")

上の例ではfeedというモデルに複数の種類のモデルA,B,Cのいずれかが紐付いているというような定義。色々制約や不都合はあるもののgeneric foreign keyを使うことでテーブルの定義はできる。

データの生成

実データに対するqueryを行うためにテキトウにデータを投入しておく。Feedの生成部分は1つの種類毎に分けても良いが、一緒くたにまとめて渡し、複数の種類のオブジェクトが混在した状態でも上手く動いた事に驚いたので例として載せている。

A.objects.bulk_create([
    A(name="a0", id=1),
    A(name="a1", id=2),
    A(name="a2", id=3)
])
B.objects.bulk_create([
    B(name="b0", id=10),
    B(name="b1", id=20),
    B(name="b2", id=30)
])
C.objects.bulk_create([
    C(name="c0", id=100),
    C(name="c1", id=200),
    C(name="c2", id=300)
])
Feed.objects.bulk_create(
    [Feed(content=a) for a in A.objects.all()]
    + [Feed(content=b) for b in B.objects.all()]
    + [Feed(content=c) for c in C.objects.all()]
)

ちなみにsqliteでは以下のようなSQLがfeedのデータの生成時に実行された。 (a,b,cの登録は省いている)

INSERT INTO "feed" ("object_id", "content_type_id") SELECT 1, 1 UNION ALL SELECT 2, 1 UNION ALL SELECT 3, 1 UNION ALL SELECT 10, 2 UNION ALL SELECT 20, 2 UNION ALL SELECT 30, 2 UNION ALL SELECT 100, 3 UNION ALL SELECT 200, 3 UNION ALL SELECT 300, 3; args=(1, 1, 2, 1, 3, 1, 10, 2, 20, 2, 30, 2, 100, 3, 200, 3, 300, 3)

N+1 query

もちろん、feedの一覧を取って、そのfeedから結びついているobjectを雑に取ろうとした場合にはN+1 のqueryが発生する。

# query数を計測するためのshorthandの関数を用意しておく
@contextlib.contextmanager
def with_clear_connection(c, message):
    print("\n========================================")
    print(message)
    print("========================================")
    c.queries_log.clear()
    yield

例えば以下のように単純なループでfeedに紐づくA,B,Cの各モデルを取得する。

c = connections["default"]
with with_clear_connection(c, "n + 1"):
    print(len(c.queries))
    content_list = []
    for feed in Feed.objects.all():
        content_list.append(feed.content)
    print(len(c.queries))  # => 1 + 9 * 1 = 10
    print([(o.__class__.__name__, o.id) for o in content_list])

先程のfeedデータの生成で9件のデータを登録したのでN+1 queryが発生し全部で10件のsqlが発行された。

0
(0.000) SELECT "feed"."id", "feed"."object_id", "feed"."content_type_id" FROM "feed"; args=()
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 1; args=(1,)
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 2; args=(2,)
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" = 3; args=(3,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 10; args=(10,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 20; args=(20,)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" = 30; args=(30,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 100; args=(100,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 200; args=(200,)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" = 300; args=(300,)
10
[('A', 1), ('A', 2), ('A', 3), ('B', 10), ('B', 20), ('B', 30), ('C', 100), ('C', 200), ('C', 300)]

prefetch_related

どうせdjangoさんのことだから、そのままではprefetch_relatedは動かないだろうと、少し斜に構えた目で見ながら以下のようなコードを書いた。

with with_clear_connection(c, "prefetch"):
    print(len(c.queries))
    content_list = []
    for feed in Feed.objects.all().prefetch_related("content"):
        content_list.append(feed.content)
    print(len(c.queries))  # => 1 + 3 * 1 = 4
    print([(o.__class__.__name__, o.id) for o in content_list])

すると、普通に何事もなかったかのように動いたのでびっくりした。もちろん実行されたSQLはfeedの分の1回と3種類のmodelの分の3回のみ。

========================================
prefetch
========================================
0
(0.000) SELECT "feed"."id", "feed"."object_id", "feed"."content_type_id" FROM "feed"; args=()
(0.000) SELECT "a"."id", "a"."name" FROM "a" WHERE "a"."id" IN (1, 2, 3); args=(1, 2, 3)
(0.000) SELECT "b"."id", "b"."name" FROM "b" WHERE "b"."id" IN (10, 20, 30); args=(10, 20, 30)
(0.000) SELECT "c"."id", "c"."name" FROM "c" WHERE "c"."id" IN (200, 100, 300); args=(200, 100, 300)
4
[('A', 1), ('A', 2), ('A', 3), ('B', 10), ('B', 20), ('B', 30), ('C', 100), ('C', 200), ('C', 300)]

内部の詳細

さて、実装がどうなっているのか気になったので、少し覗いてみよう。大まかな概要は以下の様になっている。

内部(django.db.models.query)では、prefetcherとでも呼ぶようなオブジェクトが対象となるクラスから取得できることが期待されている。どうやって取り出すかというと、prefetch_relatedで指定した文字列部分を名前としてアクセス。

つまり、上の例では、以下の様にしてprefetcherオブジェクトが取れることが期待されている。

Feed.content

そしてこのprefetcherオブジェクトと言うのは、 get_prefetch_queryset() というメソッドを持つことが期待されている(is_cached() というメソッドも必要になるが今回はこれは無視することにする。雑に説明するとこちらは select_related() 等で既に取得済みだったら使い回せれば良いね。みたいなもの)。まじめにgeneric foreign keyでのprefetcherの実装を知りたいな django.contrib.contenttypes.fields の辺りを見れば良い。

色々と長々と書かれているが重要な部分だけを取り出すと以下の通り。

class Prefetcher(object):
    def get_prefetch_queryset(self, objs, qs):
        # objsはprefetch_relatedを貼ったqueryのobjectのコレクション(ここではFeedのコレクション)
        # qsはgeneric foreign keyのprefetchでは使われない

        # fk_dictはcontent_id(上の例ではfeed)のidとobject_id(上の例ではA,B,Cのid)の辞書
        # ret_valはlist
        for ct_id, fkeys in fk_dict.items():
            instance = instance_dict[ct_id]
            ct = self.get_content_type(id=ct_id, using=instance._state.db)
            ret_val.extend(ct.get_all_objects_for_this_type(pk__in=fkeys))

        # ...    
        # prefetcherのget_prefetch_querysetは5つの戻り値を返す
        return (ret_val,
            lambda obj: (obj._get_pk_val(), obj.__class__),
            gfk_key,
            True,
            self.cache_attr)

真ん中のループは雑に説明すると以下の様なことが行われている。

ret_val.extend(A.objects.filter(pk__in=[1,2,3]))
ret_val.extend(B.objects.filter(pk__in=[10,20,30]))
ret_val.extend(B.objects.filter(pk__in=[100,200,300]))

関連オブジェクトを一気に取り出してリストに詰め込んでいる。

そして、get_prefetch_queryset() の戻り値のそれぞれの意味は以下の様になっている。

  • 1つ目: prefetchしときたい関連オブジェクトのコレクション
  • 2つ目: 親オブジェクトの関連オブジェクト(上の例ではA,B,Cのインスタンス)からprefetchの際に使われるkeyを生成
  • 3つ目: 親オブジェクト(上の例ではFeedのインスタンス)からprefetchの際に使われるkeyの生成
  • 4つ目: 値がcollectionだった場合にはFalseそうでない場合にはTrue(feed.contentは1つなのでTrue)。
  • 5つ目: cacheを保持するため場所の名前。

ここで重要なのは1つ目と2つ目3つ目。prefetch_relatedの実装はどうなっているかというと、キャッシ ュ用に作った辞書を通してeager loadingされたobjectをキャッシュしているに過ぎない。

例えば今回の例だと以下の様な形でcacheが保存されている。

# (<object_id>, <model class>) => <Related Object>
# 例えばA(id=1,name="a0")は以下の様な形
{
    (1, A): a0
}

そして 関連オブジェクトからkeyの生成は以下で済み(こちらでcacheの辞書にprefetchしたオブジェクトを登録)。

# a0はAクラスのオブジェクト
(a0.id, a0.__class__) # => (1, A)

親オブジェクトからkeyの生成は以下で済む(こちらでprefetchされたオブジェクトを辞書から取得)。

# f0はFeedクラスのオブジェクト(Aクラスと結びついているとする)
# djangoのcontent_type tableを経由して該当の関連オブジェクトのmodelのクラスを取得する
model = self.get_content_type(id=ct_id, using=obj._state.db).model_class()
(feed0.object_id, model) # => (1, A)

反対側の関連の例もあるけれど。今回はこれでおしまい。

djangoでprefetch_relatedで使えるようになる 独自のディスクリプタを作ってみる

prefetch_relatedで使えるようになる、独自のディスクリプタの作り方

はじめに

前回の記事 で prefetch_relatedの実装にprefetcher オブジェクトとでも呼ぶようなオブジェクトが必要になるということを説明した。

今度は、prefetcherのインターフェイスを実装した独自のディスクリプタを作ってみようとしてみる。

ディスクリプタ?

ディスクリプタと言うのはこういうオブジェクトのこと

class MyDiscriptor:
    def __get__(self, ob, type_=None):
        if ob is None:
            return "C" # class から呼び出される
        else:
            return "I" # object から呼び出される

例えば以下の様な表示になる

class A:
    x = MyDiscriptor()
    
A.x  # =>  C
A().x # => I

今回の例: 集計した値を集めた関連を持てるようにしてみる

モデル定義

例えば以下の様なモデルがあるとする。

  • Post は 記事
  • Comment は 記事に対するコメント
class Post(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    content = models.TextField(default="", blank=False)


class Comment(models.Model):
    post = models.ForeignKey(Post)
    content = models.CharField(max_length=255, default="", blank=False)

ここでPostとCommentは1:Nの関係。

集計

例えばテキトウにデータを投入した後、書く記事(Post)に対するコメント数(Comment)を集計したい。

Post.objects.bulk_create([
    Post(name="a0"),
    Post(name="a1"),
    Post(name="a2")
])
posts = list(Post.objects.all())
Comment.objects.bulk_create([
    Comment(content="foo", post=posts[0]),
    Comment(content="bar", post=posts[0]),
    Comment(content="boo", post=posts[0]),
    Comment(content="xxx", post=posts[1]),
    Comment(content="yyy", post=posts[1]),
    Comment(content="@@@", post=posts[2]),
])

これは、以下の様な形で取り出せる。

qs = Post.objects.values("id").annotate(c=Count('comment__post_id'))
# (0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "c" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") GROUP BY "post"."id" LIMIT 21; args=()
print(qs)
[{'c': 3, 'id': 1}, {'c': 2, 'id': 2}, {'c': 1, 'id': 3}]

これをprefetch_related経由で行えるようにしてみようと言うのが今回の課題

とりあえず以下の様な要件を設けることにした。

  • <Post instance>.comment_count -- ある記事に対するコメント数を返す
  • <Aricle class>.comment_count -- コメント数に対するprefetcher objectを返す

prefetcher object

prefetcher objectを実装してみる。 get_prefetch_queryset() が重要。ちょっとめんどうなのは名前(name)とキャッシュ名(cache_name)を別にとっている点。理由は後に説明する。基本的には集計用のqueryはidと集計値を保持する辞書を返すことを期待している。

# {id: number, <name>: number} という形式の辞書のコレクションを集計用のqueryに期待している

class AggregatedPrefetcher(object):
    def __init__(self, name, cache_name, gen_query):
        self.name = name
        self.cache_name = cache_name
        self.gen_query = gen_query

    def is_cached(self, instance):
        return False

    def get_prefetch_queryset(self, objs, qs):
        if qs is not None:
            raise ValueError("Aggregated queryset can't be used for this lookup.")

        id_list = [o.id for o in objs]
        result = list(self.gen_query(objs, self.name).filter(id__in=id_list))
        single = True
        return (
            result,
            self.key_from_rel_obj,
            self.key_from_instance,
            single,
            self.cache_name
        )

    def key_from_rel_obj(self, relobj):
        return relobj["id"]

    def key_from_instance(self, obj):
        return obj.id

例えば今回の例だと、AggregatedPrefetcherは以下の様な形で作られることになる。

AggregatedPrefetcher("comment_count", "_comment_count_dict", lambda objs: name, Post.objects.values("id").annotate(**{name: Count('comment__post_id')}))

とは言え、元の要望の通りに、インスタンスから post.comment_count と呼ばれた時には違った処理を行いたい。ディスクリプタを作ることにする。

descriptor for prefetch

ディスクリプタ部分の実装

class AggregatedPrefetchDescriptor(object):
    def __init__(self, name, gen_from_query, gen_from_one):
        cache_name = "_{}_dict".format(name)
        self.prefetcher = AggregatedPrefetcher(name, cache_name, gen_from_query)
        self.gen_from_one = gen_from_one

    def __get__(self, ob, type_=None):
        if ob is None:
            return self.prefetcher
        elif hasattr(ob, self.prefetcher.cache_name):
            return getattr(ob, self.prefetcher.cache_name)[self.prefetcher.name]
        else:
            d = self.gen_from_one(ob, self.prefetcher.name)
            setattr(ob, self.prefetcher.cache_name, d)
            return d[self.prefetcher.name]

ここでようやく、cache_nameとnameを別に取る必要について説明するが、post.comment_count で取得したいのは単なる数値なのだけれど、queryで取れるのは {id: number, comment_count: number} という形の辞書。そしてprefetch_relatedでは取得した値とインスタンスの属性(name)に束縛される値は同じものになってしまう。

このため、 prefetch_relatedでのeager loadingでは先程の形式の辞書を取り、ディスクリプタでのアクセスの過程で辞書からコメント数を取り出すという実装になっている。

例えば以下の様にして使う。

class Post(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    content = models.TextField(default="", blank=False)

    comment_count = AggregatedPrefetchDescriptor(
        "comment_count",
        lambda objs, name: Post.objects.values("id").annotate(**{name: Count('comment__post_id')}),
        lambda ob, name: {"id": ob.id, name: Comment.objects.filter(post=ob).count()}
    )

インスタンスからコメント数を取得しようとする処理と、prefetcherとしての取得の処理が異なるのがちょっとめんどうではある。

データ生成

テキトウに登録する

Post.objects.bulk_create([
    Post(name="a0"),
    Post(name="a1"),
    Post(name="a2")
])
posts = list(Post.objects.all())
Comment.objects.bulk_create([
    Comment(content="foo", post=posts[0]),
    Comment(content="bar", post=posts[0]),
    Comment(content="boo", post=posts[0]),
    Comment(content="xxx", post=posts[1]),
    Comment(content="yyy", post=posts[1]),
    Comment(content="@@@", post=posts[2]),
])

記事a0に3件、記事a1に2件、記事a2に1件。

N+1 query

インスタンスからコメント数を取得することはもちろん可能だがN+1 queryが発生する。 (with_clear_conection() については前回の記事参照)

with with_clear_connection(c, "n + 1"):
    print(len(c.queries))
    comment_count_list = []
    for post in Post.objects.all():
        comment_count_list.append((post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 3 * 1 = 4
    print(comment_count_list)

記事用に1件、各記事に対してコメント数を取りに行こうとするので3件の4件

========================================
n + 1
========================================
0
(0.000) SELECT "post"."id", "post"."name", "post"."content" FROM "post"; args=()
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 1; args=(1,)
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 2; args=(2,)
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 3; args=(3,)
4
[(1, 3), (2, 2), (3, 1)]

prefetch

prefetch_related用に機能を作ったのでもちろん実行できるようになる。

with with_clear_connection(c, "prefetch"):
    print(len(c.queries))
    comment_count_list = []
    qs = Post.objects.all().prefetch_related("comment_count")
    for post in qs:
        comment_count_list.append((post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 1 = 2
    print(comment_count_list)

まとめて取ってきているので2件だけ。

========================================
prefetch
========================================
0
(0.000) SELECT "post"."id", "post"."name", "post"."content" FROM "post"; args=()
(0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "comment_count" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") WHERE "post"."id" IN (1, 2, 3) GROUP BY "post"."id"; args=(1, 2, 3)
2
[(1, 3), (2, 2), (3, 1)]

更にネスト

嬉しいかどうかわからないが分からないがprefetch_relatedで利用できるようになった利点としては、以下の様なコードもOKなこと。

記事(Post)に対してMagazinen(マガジン)モデルが以下の様な関係を持っていたとする。

# magazine : post = 1 : N
# Magazine *- Post *- Comment

class Magazine(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class Post(models.Model):
    magazine = models.ForeignKey(Magazine, null=True)
    # .. 以下略

この時、さらにネストした実行も可能になっている。

# データ生成
magazine = Magazine(name="foo")
magazine.save()
magazine.refresh_from_db()
for post in Post.objects.all():
    magazine.post_set.add(post)

with with_clear_connection(c, "prefetch nested 3"):
    print(len(c.queries))
    comment_count_list = []
    qs = Magazine.objects.all().prefetch_related("post_set", "post_set__comment_count")
    for magazine in qs:
        for post in magazine.post_set.all():
            comment_count_list.append((magazine.id, post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 1 = 2
    print(comment_count_list)

ネストが深くなっても呼ばれるqueryは3件(magazine + post + comment)。

========================================
prefetch nested 3
========================================
0
(0.000) SELECT "magazine"."id", "magazine"."name" FROM "magazine"; args=()
(0.000) SELECT "post"."id", "post"."magazine_id", "post"."name", "post"."content" FROM "post" WHERE "post"."magazine_id" IN (1); args=(1,)
(0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "comment_count" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") WHERE "post"."id" IN (1, 2, 3) GROUP BY "post"."id"; args=(1, 2, 3)
3
[(1, 1, 3), (1, 2, 2), (1, 3, 1)]

magazine毎にコメント数を集計したいとなったらloopが必要かまた別のqueryを書く必要が出てくるのだけれど。。