golangのsort packageを見ながらcollection系の構造をどうやって管理するか把握しようとした

golangのsort packageを見ながらcollection系の構造をどうやって管理するか把握しようとした

はじめに

golanggenericsがない。genericsがない状況でcollection系のデータ構造をどうするのかという話。 結論から言うとinterfaceを定義して頑張る。

sort packageを見る

以下を見る

sort - The Go Programming Language

Sortableのようなpackageを用意してあげれば良い。

以下のようなinterfaceを定義しておけば良い。

type Sortable interface {
    Len() int
    Less(i, j int) bool
    Swap(i, j int)
}

実際の内部構造

内部でintrosortしている

// quickSort, heapSort, insertionSortが内部で定義されていて使い分けられている。
func Sort(data Interface) {
    // Switch to heapsort if depth of 2*ceil(lg(n+1)) is reached.
    n := data.Len()
    maxDepth := 0
    for i := n; i > 0; i >>= 1 {
        maxDepth++
    }
    maxDepth *= 2
    quickSort(data, 0, n, maxDepth)
}

インデックスアクセスなどはinterfaceで要求していないし Less() でチェックしつつ Swap() で順序を変えつつ再帰でSortしているっぽい。

使い方

  • 必要なinterfaceを満たしたstructを定義
  • sort条件毎にtype alias的なものを作成
$ go run /tmp/sort-example.go 
[foo(20) bar(10) boo(15)]
[bar(10) boo(15) foo(20)]

sort-example.go

package main

// sort - The Go Programming Language
// https://golang.org/pkg/sort/

import (
    "fmt"
    "sort" // 後で "github.com/podhmo/mysort" に変える
)

type Person struct {
    Name string
    Age  int
}

func (p Person) String() string {
    return fmt.Sprintf("%s(%d)", p.Name, p.Age)
}

type ByAge []Person

// for sort function
func (a ByAge) Len() int {
    return len(a)
}
func (a ByAge) Swap(i int, j int) {
    a[i], a[j] = a[j], a[i]
}
func (a ByAge) Less(i int, j int) bool {
    return a[i].Age < a[j].Age
}

func main() {
    people := []Person{
        Person{Name: "foo", Age: 20},
        Person{Name: "bar", Age: 10},
        Person{Name: "boo", Age: 15},
    }
    fmt.Println(people)
    mysort.Sort(ByAge(people))
    fmt.Println(people)
}

自分でsort package相当のことをやってみる。

同じように Sortable interfaceを要求するpackageを書いてあげれば。同様のことはできるはず。 アルゴリズム自体は気にしないのでテキトウなので良い。

gopath
mkdir github.com/podhmo/mysort
editor github.com/podhmo/mysort/mysort.go
go build github.com/podhmo/mysort
go test github.com/podhmo/mysort
ok      github.com/podhmo/mysort    0.005s

mysort.go

package mysort

type Sortable interface {
    Len() int
    Less(i, j int) bool
    Swap(i, j int)
}

func Sort(data Sortable) {
    n := data.Len();
    for i := 0; i < n; i++ {
        for j := i + 1; j < n; j++ {
            if ! data.Less(i, j) {
                data.Swap(i, j)
            }
        }
    }
}

test

mysort_test.go

package mysort

import (
    "fmt"
    "testing"
)

type Person struct {
    Name string
    Age  int
}

func (p Person) String() string {
    return fmt.Sprintf("%s(%d)", p.Name, p.Age)
}

type ByAge []Person

// for sort function
func (a ByAge) Len() int {
    return len(a)
}
func (a ByAge) Swap(i int, j int) {
    a[i], a[j] = a[j], a[i]
}
func (a ByAge) Less(i int, j int) bool {
    return a[i].Age < a[j].Age
}

type ByName []Person

// for sort function
func (a ByName) Len() int {
    return len(a)
}
func (a ByName) Swap(i int, j int) {
    a[i], a[j] = a[j], a[i]
}
func (a ByName) Less(i int, j int) bool {
    return a[i].Name < a[j].Name
}

func TestSortByAge(t *testing.T) {
    foo := Person{Name: "foo", Age: 20}
    bar := Person{Name: "bar", Age: 10}
    boo := Person{Name: "boo", Age: 15}
    cases := []struct {
        in, want []Person
    }{
        {[]Person{foo, bar, boo}, []Person{bar, boo, foo}},
    }
    for _, c := range cases {
        Sort(ByAge(c.in))
        for i := range c.in {
            if c.in[i].Name != c.want[i].Name || c.in[i].Age != c.want[i].Age {
                t.Errorf("Sort: (%q) != %q", c.in, c.want)
            }
        }
    }
}

func TestSortByName(t *testing.T) {
    foo := Person{Name: "foo", Age: 20}
    bar := Person{Name: "bar", Age: 10}
    boo := Person{Name: "boo", Age: 15}
    cases := []struct {
        in, want []Person
    }{
        {[]Person{foo, bar, boo}, []Person{bar, boo, foo}},
    }
    for _, c := range cases {
        Sort(ByName(c.in))
        for i := range c.in {
            if c.in[i].Name != c.want[i].Name || c.in[i].Name != c.want[i].Name {
                t.Errorf("Sort: (%q) != %q", c.in, c.want)
            }
        }
    }
}