djangoでの集計は辛いという話 -- ORMは用法・用量を守って正しく使いましょう

djangoでの集計は辛いという話 -- ORMは用法・用量を守って正しく使いましょう

djangoのORMの機能の不足にぶち当たり辛いという話。別の言い方をすると、ORMは用法・容量守って正しく使いましょうという感じになるかもしれない。

はじめに

以下のような情報を年齢で丸めた値で集計してヒストグラムのようなものを作りたい。

名前 年齢
foo 10
bar 15
boo 20

結果

rank c
1 2
2 1

SQLでは頑張ればどうにかなる

集計をしたい時など何らかの演算の結果で GROUP BY したい時など結構ある。おそらくきっとある。 例えばヒストグラム的なものを作成したい時など。SQLであれば CASEWHENを書き連ねることを気にしなければどうにかなる。

sqlite> create table person(name string primary key, age int);
sqlite> insert into person values ('foo', 10);
sqlite> insert into person values ('bar', 15);
sqlite> insert into person values ('boo', 20);
sqlite> select case when age < 10 then 0 when 10 <= age and age < 20 then 1 when 20 <= age and age < 30 then 2 else -1 end as rank, count(*) from person group by case when age < 10 then 0 when 10 <= age and age < 20 then 1 when 20 <= age and age < 30 then 2 else -1 end;
    rank = 1
count(*) = 2

    rank = 2
count(*) = 1

djangoでどうするのという話

結論から言うとdjangoのORMで書くのは辛い。 以下のようなmodelがあったとして。

from django.db import models


class Person(models.Model):
    name = models.CharField(max_length=32, default="a", blank=False)
    age = models.PositiveIntegerField(null=False)

COUNT(*) の部分が無ければある程度機能としては揃っていると思い、はじめは、楽観視していた。

というわけで頑張ればそれなりにすぐにできるだろうと思っていた。が、意外と大変だった。以下のところまではそれなりにすぐにたどり着ける。

from django.db.models import Count, Case, When, Value, Q


# case/when
case = Case(
    When(age__lt=10, then=Value(0)),
    When(Q(age__gte=10, age__lt=20), then=Value(1)),
    When(Q(age__gte=20, age__lt=30), then=Value(2)),
    default=Value(-1),
    output_field=models.IntegerField()
)
qs = (
    Person.objects.all()
    .annotate(rank=case)
    .values("rank")
)

# group by rank
qs.query.group_by = ["rank"]  #  実は qs.query.group_by = True でも qs.query.set_group_by() でも良い

すると結果として以下のような結果が返るところまではくる。しかしここから先が辛かった。

[{"rank": 1}, {"rank": 2}]

GROUP BY も辛いという話

ところで GROUP BY に関してわざわざ query objectのqueryを触っているのは理由があり、通常は values() のあとに annotate() を書いてあげれば values() で指定したフィールドで GROUP BY されるのだけれど、この values() で設定されるものに関しては modelで定義されたフィールドであることを暗黙の前提としてコードが書かれている。

なので以下の様には書けない。"rank"というフィールドが存在しないと言われてしまう。

qs.values("rank").annotate(rank=case)

COUNT(*) を含めるのが辛いという話 (これが辛い)

そして、そもそも集計結果の値が存在しなければ、つまり COUNT(*) が付加されていなければ何の意味も無いのだけれど、ここから先は結構辛くて、原因は、djangoのORMが暗黙に SELECT句 に来るフィールドと GROUP BY句 に来るフィールドが同じという仮定を要求してくるため(詳しいことが知りたかったら、 django.db.models.query, django.db.models.sql.query, django.db.models.sql.compiler のあたりを行ったり来たりしながら読んでみて下さい)。

右往左往の結果、一応、期待した通りに COUNT(*) を追加するコードを書くことはできた。 バッドノウハウっぽいのでどこかで共有しようと思いこの記事を書いている。

qs.query.values_select.append("c")
qs.query.add_select(Count("*"))

これは、django.db.models.ValuesIterable 辺りを見ると良い (djangoのORMのqueryは呼び出すメソッドによって、queryが抱える _iterable_class が代わりSQLの結果はこのクラスに転写される)。

class ValuesIterable(BaseIterable):
    """
    Iterable returned by QuerySet.values() that yields a dict
    for each row.
    """

    def __iter__(self):
        queryset = self.queryset
        query = queryset.query
        compiler = query.get_compiler(queryset.db)

        field_names = list(query.values_select)
        extra_names = list(query.extra_select)
        annotation_names = list(query.annotation_select)

        # extra(select=...) cols are always at the start of the row.
        names = extra_names + field_names + annotation_names

        for row in compiler.results_iter():
            yield dict(zip(names, row))

見ての通り SELECT句 に値を追加しようと思ったら、以下のどれかに値を追加できれば良い。

  • field_names
  • extra_names
  • annotation_names

通常のQueryオブジェクトに用意されているメソッドを利用しての追加を考えると、 annotation_namesextra_names に値を追加しようということになるのだけれど、ここに追加しようとした場合にはGROUP BY句 にも付加されるようなSQLが生成されてしまう。

結果として FROM person GROUP BY <caseを使った式>, COUNT(*) というような謎のGROUP BY を作ろうとして失敗する。(また、djangoのORMは定義の指定に失敗すると、GROUP BYid を含めたがるような問題もあり注意が必要)

そんなわけで、生成するSQLのSELECT句に追加する処理転写されるIterableクラスの名前に追加する処理 を無理矢理追加してあげると言うことが必要になる。

全体を繋げたコードは以下の様になる。

from django.db.models import Count, Case, When, Value, Q

def extra_select(qs, **kwargs):
    qs = qs.all()
    for name, col in kwargs.items():
        qs.query.values_select.append(name)
        qs.query.add_select(col)
    return qs


case = Case(
    When(age__lt=10, then=Value(0)),
    When(Q(age__gte=10, age__lt=20), then=Value(1)),
    When(Q(age__gte=20, age__lt=30), then=Value(2)),
    default=Value(-1),
    output_field=models.IntegerField()
)
qs = (
    Person.objects.all()
    .annotate(rank=case)
    .values("rank")
)
qs = extra_select(qs, c=Count("*"))
qs.query.group_by = ["rank"]

print(qs) # => [{"c": 2, "rank": 1}, {"c": 1, "rank": 2}]

これは以下のような期待通りSQLを生成してくれる。

SELECT
  COUNT(*),
  CASE
    WHEN "person"."age" < 10 THEN 0
    WHEN ("person"."age" < 20 AND "person"."age" >= 10) THEN 1
    WHEN ("person"."age" < 30 AND "person"."age" >= 20) THEN 2
    ELSE -1
  END AS "rank"
FROM "person"
GROUP BY
  CASE
    WHEN "person"."age" < 10 THEN 0
    WHEN ("person"."age" < 20 AND "person"."age" >= 10) THEN 1
    WHEN ("person"."age" < 30 AND "person"."age" >= 20) THEN 2
    ELSE -1
  END

djangoのORMは難しいという印象は消えたことが無いですね。

ところで sqlalchemy であれば...

以下の様に書けます。

import sqlalchemy as sa
# Baseとsessionは各自で作成

class Person(Base):
    __tablename__ = "person"

    name = sa.Column(sa.String(255), default="", nullable=False, primary_key=True)
    age = sa.Column(sa.Integer)

# query
case = sa.case(
    [
        (Person.age < 10, 1),
        ((10 <= Person.age) & (Person.age < 20), 2),
        ((20 <= Person.age) & (Person.age < 30), 3)
    ],
    else_=-1)
qs = session.query(sa.func.count("*"), case).group_by(case)
print(qs.all()) # => [(2, 2), (1, 3)]

dictを返したければ

qs = session.query(sa.func.count("*").label("c"), case.label("rank")).group_by(case)
print([row._asdict() for row in qs.all()])  # => [{'c': 2, 'rank': 2}, {'c': 1, 'rank': 3}]

ちなみに、djangoのORMとsqlalchemyとを交互に使っているときには qs.all() の意味が両者の間でほとんど真逆なあたりが一番つらい。

参考