djangoでprefetch_relatedで使えるようになる 独自のディスクリプタを作ってみる

prefetch_relatedで使えるようになる、独自のディスクリプタの作り方

はじめに

前回の記事 で prefetch_relatedの実装にprefetcher オブジェクトとでも呼ぶようなオブジェクトが必要になるということを説明した。

今度は、prefetcherのインターフェイスを実装した独自のディスクリプタを作ってみようとしてみる。

ディスクリプタ?

ディスクリプタと言うのはこういうオブジェクトのこと

class MyDiscriptor:
    def __get__(self, ob, type_=None):
        if ob is None:
            return "C" # class から呼び出される
        else:
            return "I" # object から呼び出される

例えば以下の様な表示になる

class A:
    x = MyDiscriptor()
    
A.x  # =>  C
A().x # => I

今回の例: 集計した値を集めた関連を持てるようにしてみる

モデル定義

例えば以下の様なモデルがあるとする。

  • Post は 記事
  • Comment は 記事に対するコメント
class Post(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    content = models.TextField(default="", blank=False)


class Comment(models.Model):
    post = models.ForeignKey(Post)
    content = models.CharField(max_length=255, default="", blank=False)

ここでPostとCommentは1:Nの関係。

集計

例えばテキトウにデータを投入した後、書く記事(Post)に対するコメント数(Comment)を集計したい。

Post.objects.bulk_create([
    Post(name="a0"),
    Post(name="a1"),
    Post(name="a2")
])
posts = list(Post.objects.all())
Comment.objects.bulk_create([
    Comment(content="foo", post=posts[0]),
    Comment(content="bar", post=posts[0]),
    Comment(content="boo", post=posts[0]),
    Comment(content="xxx", post=posts[1]),
    Comment(content="yyy", post=posts[1]),
    Comment(content="@@@", post=posts[2]),
])

これは、以下の様な形で取り出せる。

qs = Post.objects.values("id").annotate(c=Count('comment__post_id'))
# (0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "c" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") GROUP BY "post"."id" LIMIT 21; args=()
print(qs)
[{'c': 3, 'id': 1}, {'c': 2, 'id': 2}, {'c': 1, 'id': 3}]

これをprefetch_related経由で行えるようにしてみようと言うのが今回の課題

とりあえず以下の様な要件を設けることにした。

  • <Post instance>.comment_count -- ある記事に対するコメント数を返す
  • <Aricle class>.comment_count -- コメント数に対するprefetcher objectを返す

prefetcher object

prefetcher objectを実装してみる。 get_prefetch_queryset() が重要。ちょっとめんどうなのは名前(name)とキャッシュ名(cache_name)を別にとっている点。理由は後に説明する。基本的には集計用のqueryはidと集計値を保持する辞書を返すことを期待している。

# {id: number, <name>: number} という形式の辞書のコレクションを集計用のqueryに期待している

class AggregatedPrefetcher(object):
    def __init__(self, name, cache_name, gen_query):
        self.name = name
        self.cache_name = cache_name
        self.gen_query = gen_query

    def is_cached(self, instance):
        return False

    def get_prefetch_queryset(self, objs, qs):
        if qs is not None:
            raise ValueError("Aggregated queryset can't be used for this lookup.")

        id_list = [o.id for o in objs]
        result = list(self.gen_query(objs, self.name).filter(id__in=id_list))
        single = True
        return (
            result,
            self.key_from_rel_obj,
            self.key_from_instance,
            single,
            self.cache_name
        )

    def key_from_rel_obj(self, relobj):
        return relobj["id"]

    def key_from_instance(self, obj):
        return obj.id

例えば今回の例だと、AggregatedPrefetcherは以下の様な形で作られることになる。

AggregatedPrefetcher("comment_count", "_comment_count_dict", lambda objs: name, Post.objects.values("id").annotate(**{name: Count('comment__post_id')}))

とは言え、元の要望の通りに、インスタンスから post.comment_count と呼ばれた時には違った処理を行いたい。ディスクリプタを作ることにする。

descriptor for prefetch

ディスクリプタ部分の実装

class AggregatedPrefetchDescriptor(object):
    def __init__(self, name, gen_from_query, gen_from_one):
        cache_name = "_{}_dict".format(name)
        self.prefetcher = AggregatedPrefetcher(name, cache_name, gen_from_query)
        self.gen_from_one = gen_from_one

    def __get__(self, ob, type_=None):
        if ob is None:
            return self.prefetcher
        elif hasattr(ob, self.prefetcher.cache_name):
            return getattr(ob, self.prefetcher.cache_name)[self.prefetcher.name]
        else:
            d = self.gen_from_one(ob, self.prefetcher.name)
            setattr(ob, self.prefetcher.cache_name, d)
            return d[self.prefetcher.name]

ここでようやく、cache_nameとnameを別に取る必要について説明するが、post.comment_count で取得したいのは単なる数値なのだけれど、queryで取れるのは {id: number, comment_count: number} という形の辞書。そしてprefetch_relatedでは取得した値とインスタンスの属性(name)に束縛される値は同じものになってしまう。

このため、 prefetch_relatedでのeager loadingでは先程の形式の辞書を取り、ディスクリプタでのアクセスの過程で辞書からコメント数を取り出すという実装になっている。

例えば以下の様にして使う。

class Post(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    content = models.TextField(default="", blank=False)

    comment_count = AggregatedPrefetchDescriptor(
        "comment_count",
        lambda objs, name: Post.objects.values("id").annotate(**{name: Count('comment__post_id')}),
        lambda ob, name: {"id": ob.id, name: Comment.objects.filter(post=ob).count()}
    )

インスタンスからコメント数を取得しようとする処理と、prefetcherとしての取得の処理が異なるのがちょっとめんどうではある。

データ生成

テキトウに登録する

Post.objects.bulk_create([
    Post(name="a0"),
    Post(name="a1"),
    Post(name="a2")
])
posts = list(Post.objects.all())
Comment.objects.bulk_create([
    Comment(content="foo", post=posts[0]),
    Comment(content="bar", post=posts[0]),
    Comment(content="boo", post=posts[0]),
    Comment(content="xxx", post=posts[1]),
    Comment(content="yyy", post=posts[1]),
    Comment(content="@@@", post=posts[2]),
])

記事a0に3件、記事a1に2件、記事a2に1件。

N+1 query

インスタンスからコメント数を取得することはもちろん可能だがN+1 queryが発生する。 (with_clear_conection() については前回の記事参照)

with with_clear_connection(c, "n + 1"):
    print(len(c.queries))
    comment_count_list = []
    for post in Post.objects.all():
        comment_count_list.append((post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 3 * 1 = 4
    print(comment_count_list)

記事用に1件、各記事に対してコメント数を取りに行こうとするので3件の4件

========================================
n + 1
========================================
0
(0.000) SELECT "post"."id", "post"."name", "post"."content" FROM "post"; args=()
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 1; args=(1,)
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 2; args=(2,)
(0.000) SELECT COUNT(*) AS "__count" FROM "comment" WHERE "comment"."post_id" = 3; args=(3,)
4
[(1, 3), (2, 2), (3, 1)]

prefetch

prefetch_related用に機能を作ったのでもちろん実行できるようになる。

with with_clear_connection(c, "prefetch"):
    print(len(c.queries))
    comment_count_list = []
    qs = Post.objects.all().prefetch_related("comment_count")
    for post in qs:
        comment_count_list.append((post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 1 = 2
    print(comment_count_list)

まとめて取ってきているので2件だけ。

========================================
prefetch
========================================
0
(0.000) SELECT "post"."id", "post"."name", "post"."content" FROM "post"; args=()
(0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "comment_count" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") WHERE "post"."id" IN (1, 2, 3) GROUP BY "post"."id"; args=(1, 2, 3)
2
[(1, 3), (2, 2), (3, 1)]

更にネスト

嬉しいかどうかわからないが分からないがprefetch_relatedで利用できるようになった利点としては、以下の様なコードもOKなこと。

記事(Post)に対してMagazinen(マガジン)モデルが以下の様な関係を持っていたとする。

# magazine : post = 1 : N
# Magazine *- Post *- Comment

class Magazine(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)


class Post(models.Model):
    magazine = models.ForeignKey(Magazine, null=True)
    # .. 以下略

この時、さらにネストした実行も可能になっている。

# データ生成
magazine = Magazine(name="foo")
magazine.save()
magazine.refresh_from_db()
for post in Post.objects.all():
    magazine.post_set.add(post)

with with_clear_connection(c, "prefetch nested 3"):
    print(len(c.queries))
    comment_count_list = []
    qs = Magazine.objects.all().prefetch_related("post_set", "post_set__comment_count")
    for magazine in qs:
        for post in magazine.post_set.all():
            comment_count_list.append((magazine.id, post.id, post.comment_count))
    print(len(c.queries))  # => 1 + 1 = 2
    print(comment_count_list)

ネストが深くなっても呼ばれるqueryは3件(magazine + post + comment)。

========================================
prefetch nested 3
========================================
0
(0.000) SELECT "magazine"."id", "magazine"."name" FROM "magazine"; args=()
(0.000) SELECT "post"."id", "post"."magazine_id", "post"."name", "post"."content" FROM "post" WHERE "post"."magazine_id" IN (1); args=(1,)
(0.000) SELECT "post"."id", COUNT("comment"."post_id") AS "comment_count" FROM "post" LEFT OUTER JOIN "comment" ON ("post"."id" = "comment"."post_id") WHERE "post"."id" IN (1, 2, 3) GROUP BY "post"."id"; args=(1, 2, 3)
3
[(1, 1, 3), (1, 2, 2), (1, 3, 1)]

magazine毎にコメント数を集計したいとなったらloopが必要かまた別のqueryを書く必要が出てくるのだけれど。。