diff --git a/replicator/driver_test.go b/replicator/driver_test.go index 45fcede..b7efbd5 100644 --- a/replicator/driver_test.go +++ b/replicator/driver_test.go @@ -10,6 +10,7 @@ import ( func TestDriverInterface(t *testing.T) { var _ Driver = FileTree("") var _ Driver = Map{} + var _ Driver = Must{} } func testDriver(t *testing.T, d Driver) { diff --git a/replicator/must.go b/replicator/must.go new file mode 100644 index 0000000..a1f2771 --- /dev/null +++ b/replicator/must.go @@ -0,0 +1,40 @@ +package replicator + +import ( + "context" + "time" +) + +type Must struct { + driver Driver +} + +func NewMust(driver Driver) Must { + return Must{driver: driver} +} + +func (must Must) KeysSince(ctx context.Context, t time.Time) (chan KeyVersion, *error) { + return must.driver.KeysSince(ctx, t) +} + +func (must Must) Get(ctx context.Context, k Key) (ValueVersion, error) { + got, err := must.driver.Get(ctx, k) + if err != nil { + panic(err) + } + return got, nil +} + +func (must Must) Set(ctx context.Context, k Key, v Value, ver Version) error { + if err := must.driver.Set(ctx, k, v, ver); err != nil { + panic(err) + } + return nil +} + +func (must Must) Del(ctx context.Context, k Key, ver Version) error { + if err := must.driver.Del(ctx, k, ver); err != nil { + panic(err) + } + return nil +} diff --git a/replicator/replicator_test.go b/replicator/replicator_test.go new file mode 100644 index 0000000..300d502 --- /dev/null +++ b/replicator/replicator_test.go @@ -0,0 +1,60 @@ +package replicator + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestReplicatorStream(t *testing.T) { + key := Key{Namespace: "x/y", Key: "z"} + value := Value("hello world") + version := TimeAsVersion(time.Now()) + + cases := map[string]struct { + before func(Replicator) + during func(Replicator) + after func(Replicator) + }{ + "noop": {}, + "one prior op moves": { + before: func(r Replicator) { + r.Src.Set(nil, key, value, version) + }, + after: func(r Replicator) { + }, + }, + } + + for name, d := range cases { + c := d + t.Run(name, func(t *testing.T) { + replicator := NewReplicator(NewMap(), NewMap()) + ctx, can := context.WithTimeout(context.Background(), time.Second*10) + defer can() + + if c.before != nil { + c.before(replicator) + } + + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + if err := replicator.Stream(ctx); err != nil && ctx.Err() == nil { + t.Fatal(err) + } + }() + if c.during != nil { + c.during(replicator) + } + + can() + wg.Wait() + if c.after != nil { + c.after(replicator) + } + }) + } +}