N個のchannelがcloseされるまで読み込む方法について

N個のchannelがcloseされるまで読み込む方法についてのメモ。

1個のchannelがcloseされるまで読み込む

まずはじめにN=1の場合について。これは単にforループで取り出せば良い。

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    ch := make(chan int)
    wg.Add(1)
    go func() {
        defer close(ch)
        defer wg.Done()
        for _, x := range []int{1, 2, 3, 4, 5} {
            ch <- x
            time.Sleep(10 * time.Millisecond)
        }
    }()

    wg.Add(1)
    go func() {
        defer wg.Done()
        var r []int
        for x := range ch {
            r = append(r, x)
        }
        fmt.Println(r)
    }()
    wg.Wait()
    // Output:
    // [1 2 3 4 5]
}

重要なのはここの部分。

for x := range ch {
    r = append(r, x)
}

channelをforループでループするとcloseされるまでの値を取り出すことができる。

forループをselectに変換する

forループで読み込めるのは1つのchannelだけ。以下の様なコードはch0を全部読み込んだ後にch1を読み込むというような挙動になる。このままでは複数のchannelを並行して読み込むことができない。

for x := range ch0 {
    r = append(r, x)
}
for x := range ch1 {
    r = append(r, x)
}

これを防ぐためにはselectを使う。mapの添字アクセスと同様に代入時に2つの値を取るようにした場合に、2番目の値に成功したかどうかの真偽値が入る。この真偽値は対象のchannelがcloseされた時にfalseになる。

なので先程のforループは以下の様なselectに書き換えることができる(returnではなくlabelをつけてbreakでも良いかもしれない)。

func() {
    for {
        select {
        case x, ok := <-ch:
            if !ok {
                return
            }
            r = append(r, x)
        }
    }
}()

2個のchannelがcloseされるまで読み込む

今度はN=2の場合、ようやく意味が出はじめた。ch0とch1の2つのchannelを読み込もうとしている。注意点としてch0は10ミリ秒のsleep,ch1は20ミリ秒のsleepを間に入れている。

先程のselectへの変換の例から似たような形に変換してみる。以下は期待通りには動かないコード。

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup

    ch0 := make(chan int)
    ch1 := make(chan int)
    wg.Add(2)
    go func() {
        defer close(ch0)
        defer wg.Done()
        for _, x := range []int{1, 2, 3, 4, 5} {
            ch0 <- x
            time.Sleep(10 * time.Millisecond)
        }
    }()
    go func() {
        defer close(ch1)
        defer wg.Done()
        for _, x := range []int{-1, -2, -3, -4, -5} {
            ch1 <- x
            time.Sleep(20 * time.Millisecond)
        }
    }()
    wg.Add(1)
    go func() {
        defer wg.Done()
        var r []int
        for n := 2; n > 0; {
            select {
            case x, ok := <-ch0:
                if !ok {
                    n--
                    continue
                }
                r = append(r, x)
            case x, ok := <-ch1:
                if !ok {
                    n--
                    continue
                }
                r = append(r, x)
            }
        }
        fmt.Println(r)
    }()
    wg.Wait()
}

これは上手くいかない。

[1 -1 2 3 -2 4 5 -3]
fatal error: all goroutines are asleep - deadlock!

ch0の方が先に読み込み終わり、一方でch1はまだ読み込み中、にもかかわらずnがch側のcaseを何度も走るためforループの終了条件を抜けてしまう。

例えば以下の様に変更すると動くようにはなる。片側が読み終わったらもう片側を直接forループで読み込む様に変更した形。しかし不格好に見える。

go func() {
    defer wg.Done()
    var r []int
    func() {
        for {
            select {
            case x, ok := <-ch0:
                if !ok {
                    for y := range ch1 {
                        r = append(r, y)
                    }
                    return
                }
                r = append(r, x)
            case x, ok := <-ch1:
                if !ok {
                    for y := range ch0 {
                        r = append(r, y)
                    }
                    return
                }
                r = append(r, x)
            }
        }
    }()
    fmt.Println(r)
}()

動きはする。

[1 -1 2 3 -2 4 5 -3 -4 -5]

nil channel

select対象にnilを代入した場合に対応するcase節(?)では単に無視される。この挙動を使って終了済み(closed)のchannelの処理をスキップさせるコードが書ける。これはnil channelパターンと呼ばれていたりもするらしい。

go func() {
    defer wg.Done()
    var r []int
    for n := 2; n > 0; {
        select {
        case x, ok := <-ch0:
            if !ok {
                ch0 = nil
                n--
                continue
            }
            r = append(r, x)
        case x, ok := <-ch1:
            if !ok {
                ch1 = nil
                n--
                continue
            }
            r = append(r, x)
        }
    }
    fmt.Println(r)
}()

nil channelを使えば片側の終了後にもう片側をforループで読み込むなど凝ったことを考えずに自然に書ける。

[1 -1 2 -2 3 4 -3 5 -4 -5]

N個のchannelがcloseされるまで読み込む(reflect)

nil channelを使ったコードはすごく画一的な形になる。これを上手く使えば読み込むchannelの数がN個の場合にも対応できそうな気がしてきた。

一方で任意個のchannelに対応するためには動的に読み込むchannelの数を増やせる必要がある。ただしselectは構文なので直接手で明示的に書かなければいけない。この悩ましさを解決するためにはもう少し動的な記述が必要になる。つまりreflectパッケージのちからが必要になる。

実際以下の様な形でselectは動的に使うことができる。

package main

import (
    "fmt"
    "reflect"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup

    ch0 := make(chan int)
    ch1 := make(chan int)
    ch2 := make(chan int)
    wg.Add(3)
    go func() {
        defer close(ch0)
        defer wg.Done()
        for _, x := range []int{1, 2, 3, 4, 5} {
            ch0 <- x
            time.Sleep(10 * time.Millisecond)
        }
    }()
    go func() {
        defer close(ch1)
        defer wg.Done()
        for _, x := range []int{-1, -2, -3, -4, -5} {
            ch1 <- x
            time.Sleep(20 * time.Millisecond)
        }
    }()
    go func() {
        defer close(ch2)
        defer wg.Done()
        for _, x := range []int{10, 20, 30} {
            ch1 <- x
            time.Sleep(30 * time.Millisecond)
        }
    }()
    wg.Add(1)
    go func() {
        defer wg.Done()
        var r []int

        cases := []reflect.SelectCase{
            reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch0)},
            reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch1)},
            reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(ch2)},
        }
        for n := 3; n > 0; {
            i, x, ok := reflect.Select(cases)
            if !ok {
                n--
                cases[i].Chan = reflect.ValueOf(nil) // nil channel
                continue
            }
            r = append(r, int(x.Int()))
        }
        fmt.Println(r)
    }()
    wg.Wait()
}

reflect.SelectCase()でCase節(?)に対応する値を作り、reflect.Select()でselectの条件分岐(+channelの待ち受け)を行う。ここでも同様にnil channelを使う事はできる。

ただし、全ての値はreflect.Valueなので変換が必要。

reflectのちからを借りずにどうにかできないものか。

N個のchannelがcloseされるまで読み込む(reflect無し)

reflect無しにN個のchannelを読み込むことはできないか?

固定回数の繰り返しを任意回の繰り返しに変換するときの定石として再帰がある。同じ理屈で再帰的な記述ができればN個のchannelがcloseされるまで読み切ることができるかもしれない。

線形リストの総和

簡単な再帰の練習に線形リストの総和を求めてみる。ここで線形リストをわざわざ作るのは面倒なのでsliceで代用する。

package main

import (
    "fmt"
)

func sum(xs []int) int {
    if len(xs) == 0 {
        return 0
    }
    return xs[0] + sum(xs[1:])
}

func main() {
    fmt.Println(sum([]int{1, 2, 3, 4, 5}))
    // Output:
    // 15
}

これはもう少し丁寧に書くと以下の様な形で計算される。

1 + sum([2,3,4,5])
1 + (2 + sum([3,4,5]))
1 + (2 + (3 + sum([4,5])))
1 + (2 + (3 + (4 + sum([5]))))
1 + (2 + (3 + (4 + (5 + sum([])))))
1 + (2 + (3 + (4 + (5 + 0))))
1 + (2 + (3 + (4 + 5)))
1 + (2 + (3 + 9))
1 + (2 + 12)
1 + 14
15

ここで最初の部分に注目する。これは最初の値と残りの値のsumとの和。

1 + sum([2,3,4,5])

もちろん再帰は終了するために基底条件を持っていなければ行けなくて、今回の場合は長さが0の場合に0というもの。

sum = 0 if len(xs) == 0
sum = xs[0] + sum(xs[1:])

短く雰囲気を書くとこんな感じ。

再帰的な定義をchannelにも。

先程の総和の計算をchannelにも適用できないだろうか?

同じ理屈で考えるなら、あるchannelを合成するmerge()という関数があり、それは再帰的な定義になっている。chanelの数が0個以上のときに最初のchannelと残りのchannelをmergeしたものを組み合わせる。

+ではないので(+)という表記にしてみた。

merge = chs[0] (+) merge(chs[1:])

基底条件も同様に考えるなら、おそらくchannelの数が0のときにはnilを返せば良さそう。nil channelパターンもあるし。

merge = nil if len(chs) == 0
merge = chs[0] (+) merge(chs[1:])

(追記: 0のときnilではダメ。nilはselectをすり抜ける。closeされたchannelを返す)

merge = new closed channel if len(chs) == 0
merge = chs[0] if len(chs) == 1
merge = chs[0] (+) merge(chs[1:])

ただし注意点として残りの計算の部分は別途goroutineを走らせる必要がある。これをコードにすると実際に動く。キモいのだけれど動く。

go func() {
    defer wg.Done()
    var r []int

    var merge func(chs []<-chan int) <-chan int
    merge = func(chs []<-chan int) <-chan int {
        switch len(chs) {
        case 0:
            ch := make(chan int)
            close(ch)
            return ch
        case 1:
            return chs[0]
        default:
            ch := make(chan int)
            go func() {
                defer close(ch)
                restCH := merge(chs[1:])
                for n := 2; n > 0; {
                    select {
                    case x, ok := <-chs[0]:
                        if !ok {
                            n--
                            chs[0] = nil
                            continue
                        }
                        ch <- x
                    case x, ok := <-restCH:
                        if !ok {
                            n--
                            restCH = nil
                            continue
                        }
                        ch <- x
                    }
                }
            }()
            return ch
        }
    }
    for x := range merge([]<-chan int{ch0, ch1, ch2}) {
        r = append(r, x)
    }
    fmt.Println(r)
}()

そういえば、無名関数での再帰やN個のchannelへの対応で再帰を使うのは並行処理のgoの本などでも紹介されていたのを思い出した(or-done channel)。

ところでこのとき作られるgoroutineの数はreflectのときのそれよりも多い。

最終的なコード

最終的なコード

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup

    ch0 := make(chan int)
    ch1 := make(chan int)
    ch2 := make(chan int)
    wg.Add(3)
    go func() {
        defer close(ch0)
        defer wg.Done()
        for _, x := range []int{1, 2, 3, 4, 5} {
            ch0 <- x
            time.Sleep(10 * time.Millisecond)
        }
    }()
    go func() {
        defer close(ch1)
        defer wg.Done()
        for _, x := range []int{-1, -2, -3, -4, -5} {
            ch1 <- x
            time.Sleep(20 * time.Millisecond)
        }
    }()
    go func() {
        defer close(ch2)
        defer wg.Done()
        for _, x := range []int{10, 20, 30} {
            ch1 <- x
            time.Sleep(30 * time.Millisecond)
        }
    }()
    wg.Add(1)
    go func() {
        defer wg.Done()
        var r []int

        var merge func(chs []<-chan int) <-chan int
        merge = func(chs []<-chan int) <-chan int {
            switch len(chs) {
            case 0:
                ch := make(chan int)
                close(ch)
                return ch
            case 1:
                return chs[0]
            default:
                ch := make(chan int)
                go func() {
                    defer close(ch)
                    restCH := merge(chs[1:])
                    for n := 2; n > 0; {
                        select {
                        case x, ok := <-chs[0]:
                            if !ok {
                                n--
                                chs[0] = nil
                                continue
                            }
                            ch <- x
                        case x, ok := <-restCH:
                            if !ok {
                                n--
                                restCH = nil
                                continue
                            }
                            ch <- x
                        }
                    }
                }()
                return ch
            }
        }
        for x := range merge([]<-chan int{ch0, ch1, ch2}) {
            r = append(r, x)
        }
        fmt.Println(r)
    }()
    wg.Wait()
    // Output:
    // [1 -1 10 2 3 -2 20 4 -3 5 30 -4 -5]
}