読者です 読者をやめる 読者になる 読者になる

そう言えば、selfishというツール作っていました

golang

これは何?

gistのuploadを手軽にするやつです。goの勉強のための習作でした。

経緯

以前から結構gistを頻繁に利用していて、特に複数ファイルをuploadしたい場合には、web画面からポチポチとファイルを指定していくのではつらすぎる感じがしてました。なので、今まではgistyというツールを使っていたのですが。おそらく利用するユーザー層と自分とは完全にはマッチしていない感じでちょっと不便だなとは思っていました。

gistにuploadする際に、複数のファイルを指定してuploadしたくはなると思います。これはgistyでもサポートしていて頻繁に利用していました。

$ gisty post x y z 

ところで、このuploadしたxに少し修正を加えた結果をuploadしたいとなった時に便利に取り扱う方法がなかったというのが問題でした。gistyはgistへのuploadと同時に ~/.gisty/ 以下などにuploadした内容をcloneしてくれるので、その中のファイルを変更しpushしてあげるとupdateできた(記憶が)あるのですが。そもそも、自分が変更したかったファイルは別の場所に隔離されたコピー(~/gisty/<gist id>/x などのことを言っている)ではなく、コピー対象となったファイル自体であることがほとんどなので上手くいきません。

また、新規のアップロードか既存のgistの更新なのか判断するためにgistのidを指定するのはどうでしょう?そもそもgistのidを覚えておいたりコピペしたりするのが嫌でした。

このため、今までの利用方法としては、特定のファイルの更新であっても新規にgistをuploadし直してました。(時間当たりのファイルの内容の遷移を追うことはできなくなりますが、そのあたりは利便性との兼ね合いで目をつむってました)

selfish

そんなわけでgoの勉強も兼ねて丁度良いということでgistのuploadをする自分用のツールを作ろうと思ったのでした。基本的には新規作成・更新・削除以上のことをしないのでそれしかサポートしていません。

インストール方法

インストール自体は以下でできます。

$ go get github.com/podhmo/selfish/cmd/selfish

色々調べた結果、 cmd 以下に実行コマンドのコードを置くという方法があるらしいということを知りなるほどと思いました。 (調べたメモ)

使い方

gistのidを管理するのが面倒だったので以下のようにaliasを指定できるようにしました。

$ selfish -alias mytest x y z
create success. (id="5639abca377b5c92061248666d38e6aa")
opening.. "https://gist.github.com/5639abca377b5c92061248666d38e6aa"

上の例では、mytestという名前で管理することになります。これで新規のgistが生成されます。gistyに倣ってgist作成後に作成したgistのページをブラウザで開きます。

また、さらにxに変更があった場合には再度以下の以下のコマンドを実行してください。

$ selfish -alias mytest -silent x y z
update success. (id="5639abca377b5c92061248666d38e6aa")

gist post後にブラウザで開かれるのが邪魔な場合には、 -silent を付けると抑制出来ます。

upload済みのgistを消したい場合には -delete を付けると消せます。

$ selfish -alias mytest -delete
deleted. (id="5639abca377b5c92061248666d38e6aa")

ちなみに -alias を指定しない場合には、 head というaliasで新規作成されます。(ただし -alias head と明示的にaliasを指定しない限り更新はされません)

結果

gistのrevisionが機能するようになった。嬉しい。

細かいこと

sqliteなど使うのはおおげさかなと思い使わない選択をしたのですが、ところで、LIFOみたいなものを雑に保持するのに何が良いんだろう?みたいなことを思ったりしました。

使い捨てのコードのエラー処理について

golang memo

tl;dr

  • panic時ではなくerror時にもfullのstack traceが欲しい
  • pkg/errors が便利

はじめに

しばらくgoを書いていて、使い捨てのコードのエラー処理についてどうすれば良いのか考えたりしていた。ここで言う使い捨てのコードというのは1ファイル位で作れそうな小さなコマンドラインのコマンドのようなものを指している。

まともなアプリケーションコードでは考えることが色々ある気がするけれど。使い捨てのコードなら以下を満たしていれば十分だと思った。

  • 終了ステータスが0以外になる
  • エラーの発生箇所が正確に分かる(stack trace)

前者はテキトーに書いても自然に満たす気がする。ここでは後者をどうするかについて書く。

panic時は問題なし。ただしerror時には問題がある

ここでいうエラー処理は以下の2つを含んでいる。

  • panic時の処理
  • error時の処理

テキトーに書いた場合にどういう状況か整理。

panic時の処理

panic()により中断された場合にはstack traceを出力してくれる。なのでpanic時は何もしなくても期待通りのstack traceが出力される。

package main

func foo() {
    panic("hmm")
}

func main(){
    foo()
}

main() の中で foo() を呼び foo() の内部でpanicが起きたということが分かる。

panic: hmm

goroutine 1 [running]:
panic(0x56d40, 0xc82000a0c0)
    /opt/local/lib/go/src/runtime/panic.go:481 +0x3e6
main.foo()
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/00panic.go:4 +0x65
main.main()
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/00panic.go:8 +0x14
exit status 2

error時の処理

errorの時はどうなるか整理してみる。通常は、内部の関数はerror値を含んだ値を返し、トップレベルでerror値をまじめに取り扱うという感じになると思う。

package main

import (
    "fmt"
)

func foo() error {
    return fmt.Errorf("hmm")
}


func main() {
    err := foo()
    if err != nil {
        panic(err)
    }
}

foo()の処理自体はerror値を返すという形で正常に行われている。そのため、当然ではあるけれど、error message自体は元のエラーのものが表示されるものの、stack traceは元のエラーの発生箇所ではなくトップレベルのものになってしまう。

panic: hmm

goroutine 1 [running]:
panic(0xd5f00, 0xc82000a120)
    /opt/local/lib/go/src/runtime/panic.go:481 +0x3e6
main.main()
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/01error.go:15 +0x59
exit status 2

エラー発生箇所のstack traceを保持したままにしたい

エラー発生箇所のstack traceを保持したままerror値を伝搬させていきたい。 これは pkg/errors を使うとできそう。

+フラグ付きで出力するとstack traceも含めてくれる。

package main

import (
    "fmt"
    "github.com/pkg/errors"
    "os"
)

func foo() error {
    return errors.Errorf("hmm")
}

func main() {
    err := foo()
    if err != nil {
        fmt.Printf("error: %+v\n", err)
        os.Exit(1)
    }
}

以下の様な感じ。今回は foo() の内部でエラーが発生していることが分かる。

error: hmm
main.foo
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/02error.go:10
main.main
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/02error.go:14
runtime.main
    /opt/local/lib/go/src/runtime/proc.go:188
runtime.goexit
    /opt/local/lib/go/src/runtime/asm_amd64.s:1998
exit status 1

appendix

もう少しだけ pkg/errors のことを詳しく。基本的には以下の様にすれば良い。

  • 自分でエラーを発生させる -> fmt.Errorf() のかわりに errors.Errorf() を使う
  • 内部の関数で発生したエラーを伝搬させる -> 直接error値を返すより、 errors.Wrapf() でwrapした値を返す
package main

import (
    "github.com/pkg/errors"
    "fmt"
    "os"
)

func f0() error{
    err := f1()
    if err != nil {
        return errors.Wrapf(err, "f0")
    }
    return err
}
func f1() error{
    err := f2()
    if err != nil {
        return errors.Wrapf(err, "f1")
    }
    return err
}
func f2() error{
    err := f3()
    if err != nil {
        return errors.Wrapf(err, "f2")
    }
    return err
}

// 外部のパッケージでのエラー
func f3() error{
    return fmt.Errorf("*error on a external package*")
}

func main() {
    err := f0()
    if err != nil {
        fmt.Printf("err %+v\n", err)
        os.Exit(1)
    }
}

ちょっと出力が冗長ではあるけれど。完全なstack traceが取れる。

err *error on a external package*
f2
main.f2
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:26
main.f1
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:17
main.f0
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:10
main.main
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:35
runtime.main
    /opt/local/lib/go/src/runtime/proc.go:188
runtime.goexit
    /opt/local/lib/go/src/runtime/asm_amd64.s:1998
f1
main.f1
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:19
main.f0
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:10
main.main
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:35
runtime.main
    /opt/local/lib/go/src/runtime/proc.go:188
runtime.goexit
    /opt/local/lib/go/src/runtime/asm_amd64.s:1998
f0
main.f0
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:12
main.main
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/03nested.go:35
runtime.main
    /opt/local/lib/go/src/runtime/proc.go:188
runtime.goexit
    /opt/local/lib/go/src/runtime/asm_amd64.s:1998
exit status 1

直接 stack trace的な情報を取り出す

あと、stack trace的な情報だけ欲しい場合には以下の様にすれば良さそう。

func main() {
    type causer interface {
        Cause() error
    }
    type stackTracer interface {
        StackTrace() errors.StackTrace
    }

    err := f0()
    if err != nil {
        errs := []stackTracer{}
        for err != nil {
            if err, ok := err.(stackTracer); ok {
                errs = append(errs, err)
            }

            if cause, ok := err.(causer); !ok {
                break
            }
            err = cause.Cause()
        }
        fmt.Println("stack trace")
        for _, frame := range errs[len(errs)-1].StackTrace() {
            fmt.Printf("\t %+v\n", frame)
        }
        os.Exit(1)
    }
}

出力結果

stack trace
     main.f2
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/04stacktracer.go:26
     main.f1
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/04stacktracer.go:17
     main.f0
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/04stacktracer.go:10
     main.main
    /home/podhmo/go-sandbox/examples-errors/example_stacktrace/04stacktracer.go:42
     runtime.main
    /opt/local/lib/go/src/runtime/proc.go:188
     runtime.goexit
    /opt/local/lib/go/src/runtime/asm_amd64.s:1998
exit status 1

golangのfmt系のformatの機能のメモ

golang memo

まじめに調べるなら以下を見たほうが良い。

https://golang.org/pkg/fmt/

reflection使った便利な出力

  • %T 値の方を表示
  • %v 値を良い感じに表示
  • %+v +フラグ付きで冗長出力表示。
  • %#v 値を型名やフィールド名も含めて出力

利用例

type Person struct {
    Name string
    Age int
}

func main(){
    person := Person{Name: "foo", Age: 20}
    fmt.Printf("%%T %T\n", person)
    fmt.Printf("%%v %v\n", person)
    fmt.Printf("%%v %#v\n", person)
}

/*
%T main.Person
%v {foo 20}
%v main.Person{Name:"foo", Age:20}
*/

同一の値を添え字で参照

1-originなことに注意

fmt.Printf("type=%[1]T, value=%[1]v, verbose=%#[1]v\n", person)

/*
type=main.Person, value={foo 20}, verbose=main.Person{Name:"foo", Age:20}
*/

quoteされた文字列の表示

%q が用意されている。

fmt.Printf("string = %q¥n", "foo")

/*
string = "foo"
*/

0-padding

0-paddingだけできる?

fmt.Printf("long=%06d, short=%04[1]d\n", 100)
/*
long=000100, short=0100
*/

golangのequalityの評価について

memo golang

はじめに

他の言語のつもりで比較演算子を使うと想定外の挙動をするということもあったりする。最初テキトウにコードを書いて挙動を確認しようとしていたが後に言語仕様を読めば良いだけだということに気づいた。学び始めの最中に思ったことをメモしておくというのも良いと思ったのでメモしておく。

幾つか驚いたこと

golangは比較をわりと頑張るタイプの言語のようだった。

structの比較は等値ではなく等価

初見の先入観としてstructのような値のコンテナの比較は特に何か特別なこと(e.g. 比較時に呼ばれるメソッドのオーバーライド)をしないかぎり、等値で比較されると思ったが等価だった。

例えば以下の様なこと。

type Point struct {
    x,y int
}

// 等値
pt := Point{x: 10, y: 20}
fmt.Printf("%v¥n", pt == pt) // => true

// 等価
// (先入観でこれはfalseだと思っていた)
fmt.Printf("%v¥n", Point{x: 10, y: 20}, Point{x: 10, y: 20}) // => true

値オブジェクト的な物を作った時に比較が自然に定義されるのでこのような挙動はわりと嬉しい。

また、mapのkeyの評価も同様に行われるので以下の様なコードの実行後のmapの保持する値の数は2。

m := map[Point]int{}
m[Point{X: 10, Y: 20}]++
m[Point{X: 10, Y: 20}]++
m[Point{X: 10, Y: 10}]++
fmt.Printf("%v\n", m) // => map[{10 20}:2 {10 10}:1]

mapにアクセスし値が存在しない場合にはzero値が返されるので単にcounterとして使いたい場合に便利。pythonではnamedtupleを利用した時と同様ということを考えると便利。

from collections import defaultdict

d = defaultdict(int)
d[(10, 20)] += 1
d[(10, 20)] += 1
d[(10, 10)] += 1
print(dict(d))  # => {(10, 20): 2, (10, 10): 1}

from collections import namedtuple

Point = namedtuple("Point", "x y")
d = defaultdict(int)
d[Point(x=10, y=20)] += 1
d[Point(x=10, y=20)] += 1
d[Point(x=10, y=10)] += 1
print(dict(d))  # => {Point(x=10, y=20): 2, Point(x=10, y=10): 1}


# but normal class
class Point2:
    def __init__(self, x, y):
        self.x = x
        self.y = y

d = defaultdict(int)
d[Point2(x=10, y=20)] += 1
d[Point2(x=10, y=20)] += 1
d[Point2(x=10, y=10)] += 1
print(dict(d))  # => {<__main__.Point2 object at 0x109f11828>: 1, <__main__.Point2 object at 0x109ed7550>: 1, <__main__.Point2 object at 0x109f117f0>: 1}

もっとも、namedtupleの罠として同じ形状の定義は同じになってしまうという問題があるので同じ挙動というわけではない。

Point = namedtuple("Point", "x y")
Point2 = namedtuple("Point2", "a b")
# golangではfalse
Point(10, 20) == Point2(10, 20)  # => True

slicesとmapで = を利用しようとするとコンパイルエラー

sliceとmapは = で値を比較しようとするとコンパイルエラーになる。

// slice
xs := [3]int{1, 2, 3}
i, j := xs[:], xs[:]
fmt.Printf("%v == %v, %v\n", i, j, i == j) // compile error

// map
i, j := map[string]int{"x": 1}, map[string]int{"x": 1}
fmt.Printf("%v == %v, %v\n", i, j, i == j) // compile error

後で調べてみると、mapとsliceは比較不能ただしnilとだけ比較可能という感じだった。常にfalseを返す位ならコンパイルエラーにしてしまうというのも、それはそれとして割り切りとしてありなような気がする。

以下はOK。

fmt.Printf("%v¥n", m == nil) // mに値が入っていたらfalse

map,sliceの保持する値を全部比較したければ、reflect.DeepEqual() が使える。

i, j := map[string]int{"x": 1}, map[string]int{"x": 1}
fmt.Printf("deep %v == %v, %v\n", i, j, reflect.DeepEqual(i, j))

nil同士の比較

ところで nil = nil と書いた時に値を見るだけなのか型を意識してチェックするかも気になったので以下の様なコードを書いた。これは期待通り。

type MyInterface interface{}

i, j, k := interface{}(nil), interface{}(nil), MyInterface(nil)
fmt.Printf("%v == %v, %v\n", i, j, i == j) // => true
fmt.Printf("%v == %v, %v\n", i, k, i == k) // => true

また隠した型も調べてくれるみたい。わりと便利

type Point struct {
    x, y int
}
type Point2 struct {
    x, y int
}

type MyInterface interface {
}

i, j, k := Point{x: 10, y: 20}, Point{x: 10, y: 20}, Point2{x: 10, y: 20}

// 同じ形状の値なのでtrue
fmt.Printf("%v == %v, %v\n", i, j, i == j) // i = j true

// 型が違うので比較不能
// fmt.Printf("%v == %v, %v\n", i, k, i == k) // compile error

// 型を変換すれば同じ形状の値なのでtrue
fmt.Printf("%v == %v, %v\n", i, k, i == Point(k)) // i == k true

// 同じ型にあわせて比較しても元の型が違うのでfalse
fmt.Printf("%v == %v, %v\n", i, k, MyInterface(i) == MyInterface(k)) // i == k false
fmt.Printf("%v == %v, %v\n", i, k, interface {}(i) == interface {}(k)) // i == k false

よく考えて見れば

よく考えて見れば、言語仕様を直接読めばよかった気がしないでもない。

まず、x == y についてxがyに対して代入可能(assignable)かどうか調べて、代入可能ならなんか自然な感じで型毎に比較方法が列挙されている。slices,map以外に関数も比較不能。

django-returnfieldsというパッケージを作っていました

django python

django-returnfields というパッケージを作っていました。

これは何?

はじめはapi responseのfilteringをするライブラリとして作っていましたが、いろいろな変更の結果あるAPIのresponseに対してそのsubsetを返すためのoptimizerのようなものになりました。

具体的には以下の機能を持っています。

  • skip_fields, return_fields optionによるresponseのfiltering
  • aggressive optionによるDB queryのoptimize

すごく高速に動作するというよりは、遅くなっている状態を避けようという感じのものなので最適な結果を保証するものではなかったりします。またrichなresponseを返すREST API以外ではあまり意味が無いかもしれません。

responseのfiltering

以下のようなresponseを返すAPIがあるとします。

// /api/users
{
  "id": 1,
  "username": "foo",
  "skills": [
    {
      "id": 1,
      "user": 1,
      "name": "magic"
    },
    {
      "id": 2,
      "user": 1,
      "name": "magik"
    }
  ]
}

return_fields

これに return_fields optionを使って nameだけを取り出す事ができます。

// /api/users/?return_fields="username, skills__name"
{
  "username": "foo",
  "skills": [
    {
      "name": "magic"
    },
    {
      "name": "magik"
    }
  ]
}

skip_fields

かわりに skip_fields optionを使っても同様のことができます。ただしこちらはresponseに含めたくないフィールドを指定します。

// /api/users/?skip_fields=skills__id,skills__user
{
  "username": "foo",
  "skills": [
    {
      "name": "magic"
    },
    {
      "name": "magik"
    }
  ]
}

why return_fields and skip_fields?

通常のdjangoの作法によるとこのような何らかのコレクションに対する絞り込みのoption名には include, exclude のペアを使うことが多いです。ですが、このパッケージでは return_fieldsskip_fields という別の名前を使っています。

理由は、include, exclude の場合には相互排他的な意味を持っているためです。例えば、include=a,b,cのときexclude=aとした時には条件の指定がconflictしてエラーになります。

一方、 return_fields=a,b,c かつ skip_fields=a の意味は全体集合として {a, b, c} を取り、そこから {c}との差集合を取るという意味になるので {b,c} として検索される事になります。また、return_fields の指定がなかった場合には暗黙に可能なかぎり全部のfieldsを出力対象にするということになります。

DB queryのoptimize

aggressive=1 というoptionを付けるとoptimizerとしても機能します。具体的には以下の事をします。

  • modelの定義に基づく prefetch, joinを付加する
  • 選択されたfieldだけをonly,deferで取り出す

例えば上のAPI/users/?format=json&return_fields=username,skills&skip_fields=skills__id,skills__user によるアクセスは以下のようなqueryが実行されますが。

(0.000) SELECT "user"."id", "user"."password", "user"."last_login", "user"."is_superuser", "user"."username", "user"."first_name", "user"."last_name", "user"."email", "user"."is_staff", "user"."is_active", "user"."date_joined" FROM "user"; args=()
(0.000) SELECT "skill"."id", "skill"."name", "skill"."user_id" FROM "skill" WHERE "skill"."user_id" = 1; args=(1,)
(0.000) SELECT "skill"."id", "skill"."name", "skill"."user_id" FROM "skill" WHERE "skill"."user_id" = 2; args=(2,)

以下のようなアクセスの場合には /users/?format=json&aggressive=1&return_fields=username,skills&skip_fields=skills__id,skills__user eager loadingが効きます。そして不要なフィールドがselect句に含まれません。

(0.000) SELECT "user"."id", "user"."username" FROM "user"; args=()
(0.000) SELECT "skill"."id", "skill"."name", "skill"."user_id" FROM "skill" WHERE "skill"."user_id" IN (2, 1); args=(2, 1)

またこの例ではネストが1段だけですがN段のネストにも対応しています。そして、skip_fieldsreturn_fields によりeager loadingの必要性がなくなった場合にはeager loadingが行われません。

あとで詳しく

optimization関係や実際のdjango restframeworkとの組み合わせ(特にpaginationとの組み合わせ)などではもう少し説明する必要があるところがありますがとりあえず今回はここまで。

generic foreignkeyのsub relationをprefetchする方法

django python

はじめに

generic foreignkey 自体のprefetchはできる。しかし、その更に先のrelationをprefetchすることができない。これをどうにかしようと言う苦肉の策を考えてみたという話。

言い訳

djangoのgeneric foreignkey関連のコードを読んでみたところ綺麗にできる方法は無さそうだった。実行時のprefetchの条件を上手く受け渡す方法が存在しなさそうだったので。仕方がないので thread localなcontext objectを作りそこでprefetchの条件を指定できるようにする。

概要

以下の様な形のモデルになっているとする。

Feed -- generic foreign key --> cotent = {A,B,C}

A -- 1:N --> xs = {X}
B -- 1:N --> ys = {Y}
C
X
Y

Feedというモデルが有りこれがgenericなrelationを持っており、A,B,Cのいずれかを保持する。また、AはXモデルとBはYモデルと1:Nの関係になっている。今回はFeedのqueryを取ってくる際にA,B,Cだけでなく、Aに結びつくX,Bに結びつくYも一緒に取ってくるようにしたい。

例えば以下のようなqueryを実行するとする。xs,ysの取得に関してはN+1になってしまう。

def use(content):
    if hasattr(content, "xs"):
        return [content, list(content.xs.all())]
    elif hasattr(content, "ys"):
        return [content, list(content.ys.all())]
    else:
        return [content]

content_list = []
for feed in Feed.objects.all().prefetch_related("content"):
    # content :: {A, B, C} はprefetchされるがそのsub relationであるxs,ysがprefetchされない
    content_list.append(use(feed.content))

この時以下のようにprefetchを指定するとエラーになる。

Feed.objects.all().prefetch_related("content", "content__xs", "content__ys")

generic foreignkeyのsub relationをprefetchする方法

試行錯誤を行なったdjangoのversionは1.9.5だった。

素朴な方法

基本的な方針としては以下のようになる。

GenericForeignKeyのget_content_typeが各relation(ここではA,B,C)を取ってくる際の始端となるオブジェクトになっている。これを書き換え、各relationを取ってくるquerysetを生成するところにprefetch_relatedを追加する処理を加えてあげる。

なので、django.contrib.contenttypes.fields.GenericForeignKey と django.contrib.contenttypes.models.ContentType を自分で定義したサブクラスに置き換える。

class MyContentType(ContentType):
    class Meta:
        proxy = True

    def get_all_objects_for_this_type(self, **kwargs):
        qs = super().get_all_objects_for_this_type(**kwargs)
        return self.attach_prefetch(qs)

    def attach_prefetch(self, qs):
        # ここでprefetch_relatedの設定をする
        model = qs.model
        if issubclass(model, A):
            return qs.prefetch_related("xs")
        elif issubclass(model, B):
            return qs.prefetch_related("ys")
        else:
            return qs

class MyGenericForeignKey(GenericForeignKey):
    def get_content_type(self, *args, **kwargs):
        ct = super().get_content_type(*args, **kwargs)
        ct.__class__ = MyContentType  # これはmethodの挙動を書き換えるための雑な方法。真面目な方法ではない。
        return ct

各relationの取得に、MyContentTypeのget_all_objects_for_this_type()が使われる。これをフックするために、GenericForeignKeyからはMyContentTypeが使われるように get_content_type() を書き換える。内部の実装的にcontent typeのインスタンスの取得はキャッシュされているのでまじめに使う場合には注意が必要。

このようにすると以下のようなquery中でもX,Yを含めてprefetchしてくれるようになる。

content_list = []
for feed in Feed.objects.all().prefetch_related("content"):
    # sub relationであるxs, ysもprefetchされる
    content_list.append(use(feed.content))

ただし、上のようにした場合には、xs,ysのprefetchがデフォルトの動作になってしまう点が問題になる。できれば実行時にprefetchの条件を指定したい。

実行時にprefetchの条件をしていするための苦肉の策

冒頭の方にも書いたがdjangoの現状のコードセットではつらい。なので苦肉の策としてthread local objectにprefetchの条件を格納させることにする。まじめに実装するなら、異なる条件が二重に重なる場合なども考えなければいけなそうではあるけれど。そこまではやっていない。(本来は上手くqueryのhintを指定するカタチで情報を付加できたら良いのだけれど)

import threading
import contextlib

class GFKPrefetchContext:
    def __init__(self):
        self._context = threading.local()
        self._context.attach_prefetch = lambda qs: qs

    @contextlib.contextmanager
    def activate_prefetch(self, fn):
        oldvalue = self._context.attach_prefetch
        self._context.attach_prefetch = fn
        yield
        self._context.attach_prefetch = oldvalue

    def attach_prefetch(self, qs):
        return self._context.attach_prefetch(qs)

gfk_prefetch_context = GFKPrefetchContext()


class ContentTypeWithPrefetch(ContentType):
    class Meta:
        proxy = True

    def get_all_objects_for_this_type(self, **kwargs):
        qs = super().get_all_objects_for_this_type(**kwargs)
        return gfk_prefetch_context.attach_prefetch(qs)


class MyGenericForeignKey(GenericForeignKey):
    def get_content_type(self, *args, **kwargs):
        ct = super().get_content_type(*args, **kwargs)
        ct.__class__ = ContentTypeWithPrefetch
        return ct

以下の様にして使う。

def attach_prefetch(qs):
    model = qs.model
    if issubclass(model, A):
        return qs.prefetch_related("xs")
    elif issubclass(model, B):
        return qs.prefetch_related("ys")
    else:
        return qs

content_list = []
with gfk_prefetch_context.activate_prefetch(attach_prefetch):
    for feed in Feed.objects.all().prefetch_related("content"):
        content_list.append(use(feed.content))

# もちろんsub relationのprefetchを効かせたくなければcontextを指定しなければ良い
for feed in Feed.objects.all().prefetch_related("content"):
    content_list.append(use(feed.content))

補足

想定していたモデルの定義は以下のようなものだった。

class A(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)

    class Meta:
        db_table = "a"

class B(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)

    class Meta:
        db_table = "b"

class C(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)

    class Meta:
        db_table = "c"

class X(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    a = models.ForeignKey(A, related_name="xs")

    class Meta:
        db_table = "x"

class Y(models.Model):
    name = models.CharField(max_length=32, default="", blank=False)
    b = models.ForeignKey(B, related_name="ys")

    class Meta:
        db_table = "y"

class Feed(models.Model):
    class Meta:
        db_table = "feed"
        unique_together = ("content_type", "object_id")

    object_id = models.PositiveIntegerField()
    content_type = models.ForeignKey(ContentType)
    content = MyGenericForeignKey('content_type', 'object_id')

もう少しだけdjangoのprefetch_relatedについて考えてみる(条件付加したrelationのeager loading)

django python

はじめに

あるモデルに対してあるコンテキスト(文脈)に従った条件を加味した関係の元に値を取り出したい場合がある。そのような条件を付加した値を仮想的なフィールドとして扱うことができないかという話。

例えば以下の事がしたい

X,Yというテーブルが存在。これらはMany to Manyの関係になっている。

  • join時の条件を付加しておきたい(e.g. is_valid=Trueの条件の元queryしたい)
  • defaultのorder byを指定しておきたい(e.g. 生成日時で降順に撮りたい)

instanceだけを対象で考える場合

@property で雑にプロパティにしてしまって良いという話かもしれない。そのような条件を満たしたproperty valid_xs を定義してみる。

# x: y = M : N

class Y(models.Model):
    name = models.CharField(max_length=32, null=False, default="")
    ctime = models.DateTimeField(auto_now_add=True, null=False)
    is_valid = models.BooleanField(default=True, null=False)

    xs = ManyToMany(X, related_name="xs")

    @property
    def valid_xs(self):
        return self.xs.all().filter(is_valid=True).order_by("-ctime")

もちろんこのままではN + 1クエリが発生する可能性は残る。

N + 1の問題について

以前からN+1の問題については幾つか記事を書いてきた。基本的には、select_relatedでのtableのjoinもしくはprefetch_relatedでeager loadingをすれば良い。例えば今回の要件では元のquerysetに対して付加的な条件を加えたいだけだったりする。

付加的な条件の追加をその場で行うのはprefetch_relatedにPrefetch objectを渡すことで可能ではある。

prefetch = Prefetch(queryset=Y.objects.filter(is_valid=True).order_by("-ctime"), to_attr="valid_ys")
qs = X.objects.all().prefetch_related("ys", prefetch)

とは言えこのままだと以下の点が面倒に感じる。

  • 先程のpropertyの定義と名前が衝突する
  • 同じようなqueryを二度書く必要が出てくる
  • 実行時の条件やprefetchする際の名前を間違ってしまう場合が存在する可能性がある

propertyの定義と名前が衝突する

prefetch_relatedが付加されない場合でも上手く動いて欲しいため、上で定義したpropertyと同じ名前でprefetch_relatedを使おうとすると以下の様なエラーが発生する。

prefetch = Prefetch("xs", X.objects.all().filter(is_valid=True).order_by("-ctime"), "valid_xs")
for y in Y.objects.all().prefetch_related(prefetch):
    print(y.id, y.name, [x.name for x in y.valid_xs])
# AttributeError: can't set attribute

これはpropertyにsetterを付けることで解決できる。cacheとして使われる値が束縛される時に使われる名前がto_attrで指定した文字列なので、以前に定義したproperty名と重複してしまっているということなので。したがってモデルのpropertyの定義を以下の様に変えれば良い。

class Y(models.Model):
    # snip..

    # # 以下の様に書いてしまうと、bool(qs) or get_queryset() という呼び出しになってしまい、bool(qs)でqsが評価されてしまうので注意
    # @property
    # def valid_xs(self):
    #     return getattr(self, "_valid_xs", None) or self.xs.all().filter(is_valid=True).order_by("-ctime")

    @property
    def valid_xs(self):
        result = getattr(self, "_valid_xs", None)
        if result is None:
            result = self._valid_xs = self.xs.all().filter(is_valid=True).order_by("-ctime")
        return result

    @valid_xs.setter
    def valid_xs(self, value):
        self._valid_xs = value

同じようなqueryを2度書いてしまっている

同じqueryを2度書くというのも嫌かもしれない。今回のクエリーの条件自体は、Xオブジェクト側だけの情報で付加できるものではあるのでXのモデル定義に含めておくと良いのかもしれない。

条件を付加したquerysetをモデルの定義に含めておく方法としては以下の4つくらいが考えられる。

  • filteringする関数の定義
  • modelのclass methodに追加
  • 独自のquerysetの定義
  • 独自のmanagerの定義

まずmanagerの定義は論外。これはquerysetを取得する際の開始時の処理しか定義する事ができないので。managerに定義する位ならquerysetに定義した方が良い。class methodにするかquerysetに定義するかは好みで良いと思う。関数として独立して定義するのとmodelのclass methodに追加するのは実質同じことではあり単に利用する際にmoduleのimportが不要かどうかという話でしかない。

この中で最も自然なのは独自のquerysetを定義することではあるけれど、個人的には単にmodelにclass methodを追加するだけで十分なのではないかと思っている。とりあえず同じようなqueryを2度書く必要があるというのはこれにより解決できる。

filteringする関数の定義

from y.models import Y

def get_valid_ys_set(qs=Y.objects.all()):
    return qs.filter(is_valid=True).order_by("-ctime")

実際に利用する時にはimportして使う必要がある。

class methodに追加する方法

class X(models.Model):
    # snip...
    @classmethod
    def valid_set(cls, qs=None):
        if qs is None:
            qs = cls.objects.all()
        # 以下の様に書いてしまうとbool(qs) で querysetが評価されてしまうので注意
        # qs = qs or cls.objects.all()
        return qs.filter(is_valid=True).order_by("-ctime")

使う時は以下の様になる

class Y(models.Model):
    # snip...
    @property
    def valid_xs(self):
        return getattr(self, "_valid_xs", None) or X.valid_set(self.xs.all())

# prefetchとして
prefetch = Prefetch("xs", X.valid_set(X.objects.all()), "valid_xs")

querysetとして追加

自分でQuerySetクラスを定義し、自分のmanagerを定義したQuerySetを返すように変更しておく。

class XQuerySet(models.QuerySet):
    def valid_set(self):
        return self.filter(is_valid=True).order_by("-ctime")

class X(models.Model):
    objects = XQuerySet.as_manager()
    # or objects = models.Manager().from_queryset(XQuerySet)()
    # snip...

querysetとして追加しておくと、関連するmodelのimportが不要になるというメリットはあるものの、返されるmodelが何であったのか分かりづらくなる気がするので個人的にはclassmethodでの追加の方が好みではある。

# classmethodでの追加の場合、Xモデルを経由しなければ絞り込みの条件を適用できない
from x.models import X
X.valid_set(y.xs.all())

# querysetにmethodが追加されていればimportは不要
y.xs.all().valid_set()

prefetchする際の名前を間違ってしまう場合が存在する可能性がある

これに関してはprefetch objectを生成する処理をメソッド・関数化すれば良いというだけの話かもしれない。 どこに追加するかというのは先程のqueryの条件追加の部分のものと同様。

class Y(models.Model):
    # snip ...
    @classmethod
    def prefetch_valid_xs(cls):
        return Prefetch("ys", queryset=X.valid_set(), to_attr="valid_xs")    

prefetch_relatedでの文字列指定は諦めた方が良い

今まではかたくなに以下の形式でprefetchの設定を追加しようとしていた。

# prefetch_related(<prefetch object>)
Y.objects.all().prefetch_related(Y.prefetch_valid_xs())

以下の様な形式で設定できるようにすることは可能だろうか?

# prefetch_related(<string>)
Y.objects.all().prefetch_related("valid_ys")

結論から言うと辛いので止めておいたほうが良い。1.9の現在のdjangoのこの辺りの処理のコードは決め打ちで内部構造を利用するコードが多いので文字列で指定可能にするための労力に対して得られるメリットが少ない。

(ちなみに、過去にprefetcherオブジェクト(is_cached(),get_prefetch_queryset()を持つオブジェクト)を自前で定義してprefetch_relatedに指定可能にするということを記事にしていたが、あれは最後の手段なので日常的に使うものではない。)

prefetch_relatedのうれしさ

prefetch_relatedの機能は、ある種の関連queryのeager loadingを行うということだけれど、この機能のうれしさについてもう少し詳しく話すとすると、prefetch_relatedに纏わるeager loadingが半ばフレームワークになっていて、prefetch可能なインターフェイスとして定義を揃えておけば、N段ネストした場合についても上手く動作するということがうれしい。

例えば、以下の2つの関連があるとすると

  • A から B の関連
  • B から C の関連

AからBへの関連のN+1,BからCへの関連のN+1の2つを抑制することができたら、この2つの組み合わせたAからCへのN+1も抑制することができるというのがうれしい点。

実際、上で定義した prefetch_valid_xs() について、さらにXにtagのようなモデルがくっついていた場合にも適切にprefetchしてくれる。

例えば、Y -> X -> XTag という関連がある時、

class XTag(models.Model):
    xs = models.ManyToManyField(X, related_name="tags")
    name = models.CharField(max_length=32, null=False, default="", unique=True)

以下のコードで発行されるqueryは3件ですむ。

qs = Y.objects.all().prefetch_related(Y.prefetch_valid_xs(), "valid_xs__tags")
for y in qs:
    for x in y.valid_xs:
        for tag in x.tags.all():
            print(y.name, x.name, tag.name)

追記

setter,getterを陽に定義するのが面倒であれば以下のような関数を作っても良いかもしれない。

def custom_relation_property(getter):
    name = getter.__name__
    cache_name = "_{}".format(name)

    def _getter(self):
        result = getattr(self, cache_name, None)
        if result is None:
            result = getter(self)
            setattr(self, cache_name, result)
        return result

    def _setter(self, value):
        setattr(self, cache_name, value)

    prop = property(_getter, _setter, doc=_getter.__doc__)
    return prop

こんな感じで済む。

class Y(models.Model):
    # snip ...

    @custom_relation_property
    def valid_xs(self):
        return X.valid_set(self.xs.all())