diff --git a/db_test.go b/db_test.go index da22de0..1d32f28 100644 --- a/db_test.go +++ b/db_test.go @@ -110,6 +110,12 @@ func TestImplementations(t *testing.T) { cases = append(cases, memcache) } + if memcacheCluster, err := NewMemcacheCluster("localhost:11211"); err != nil { + t.Errorf("cannot make memcacheCluster: %v", err) + } else { + cases = append(cases, memcacheCluster) + } + validKey := "key" validValue := []byte("value") @@ -122,7 +128,7 @@ func TestImplementations(t *testing.T) { } else if !bytes.Equal(v, validValue) { t.Errorf("wrong get %T: %q vs %q", db, v, validValue) } else { - t.Logf("%18T GET: %s", db, v) + t.Logf("%25T GET: %s", db, v) } if err := db.Close(); err != nil { t.Errorf("cannot close %T: %v", db, err) diff --git a/memcached.go b/memcached.go index b23d4f0..4b1752f 100644 --- a/memcached.go +++ b/memcached.go @@ -10,29 +10,6 @@ type Memcache struct { db *memcache.Client } -type serverSelector struct { - addrs []string -} - -func (ss *serverSelector) PickServer(key string) (net.Addr, error) { - return &netAddr{ - network: "tcp", - addr: ss.addrs[0], - }, nil -} - -func (ss *serverSelector) Each(each func(net.Addr) error) error { - for _, addr := range ss.addrs { - if err := each(&netAddr{ - network: "tcp", - addr: addr, - }); err != nil { - return err - } - } - return nil -} - type netAddr struct { network string addr string @@ -47,8 +24,9 @@ func (a *netAddr) String() string { } func NewMemcache(addr string, addrs ...string) (*Memcache, error) { - ss := &serverSelector{ - addrs: append(addrs, addr), + ss := &memcache.ServerList{} + if err := ss.SetServers(append(addrs, addr)...); err != nil { + return nil, err } if err := ss.Each(func(addr net.Addr) error { conn, err := net.Dial("tcp", addr.String()) diff --git a/memcachedcluster.go b/memcachedcluster.go new file mode 100644 index 0000000..09e6731 --- /dev/null +++ b/memcachedcluster.go @@ -0,0 +1,85 @@ +package storage + +import ( + "net" + + "github.com/bradfitz/gomemcache/memcache" + "github.com/buraksezer/consistent" + "github.com/cespare/xxhash" +) + +type MemcacheCluster struct { + db *memcache.Client +} + +type serverSelector struct { + hash *consistent.Consistent +} + +func (ss *serverSelector) PickServer(key string) (net.Addr, error) { + return &netAddr{ + network: "tcp", + addr: ss.hash.LocateKey([]byte(key)).String(), + }, nil +} + +func (ss *serverSelector) Each(each func(net.Addr) error) error { + for _, member := range ss.hash.GetMembers() { + if err := each(&netAddr{ + network: "tcp", + addr: member.String(), + }); err != nil { + return err + } + } + return nil +} + +type hasher struct{} + +func (h hasher) Sum64(data []byte) uint64 { + return xxhash.Sum64(data) +} + +func NewMemcacheCluster(addr string, addrs ...string) (*MemcacheCluster, error) { + cfg := consistent.Config{ + PartitionCount: 71, + ReplicationFactor: 20, + Load: 1.25, + Hasher: hasher{}, + } + hash := consistent.New(nil, cfg) + for _, addr := range append(addrs, addr) { + hash.Add(&netAddr{addr: addr}) + } + ss := &serverSelector{ + hash: hash, + } + if err := ss.Each(func(addr net.Addr) error { + conn, err := net.Dial("tcp", addr.String()) + if err != nil { + return err + } + return conn.Close() + }); err != nil { + return nil, err + } + db := memcache.NewFromSelector(ss) + return &MemcacheCluster{db: db}, nil +} + +func (mc *MemcacheCluster) Get(key string) ([]byte, error) { + v, err := mc.db.Get(key) + return v.Value, err +} + +func (mc *MemcacheCluster) Set(key string, value []byte) error { + return mc.db.Set(&memcache.Item{ + Key: key, + Value: value, + }) +} + +func (mc *MemcacheCluster) Close() error { + return mc.db.FlushAll() +}