add storage.Threads, storage.ThreadsSince, storage.Thread
parent
5817dce70e
commit
88f6746c85
50
storage.go
50
storage.go
|
|
@ -3,6 +3,8 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
@ -17,12 +19,56 @@ func NewStorage(driver Driver) Storage {
|
||||||
return Storage{driver: driver}
|
return Storage{driver: driver}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s Storage) Threads(ctx context.Context) ([]string, error) {
|
||||||
|
return s.ThreadsSince(ctx, time.Unix(0, 0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Storage) ThreadsSince(ctx context.Context, t time.Time) ([]string, error) {
|
||||||
|
messages, err := s.messagesWhere(ctx, func(m Message) bool {
|
||||||
|
return !t.After(m.Time())
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
threads := map[string]struct{}{}
|
||||||
|
for _, m := range messages {
|
||||||
|
threads[m.Thread] = struct{}{}
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(threads))
|
||||||
|
for k := range threads {
|
||||||
|
result = append(result, k)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Storage) Thread(ctx context.Context, thread string) ([]Message, error) {
|
||||||
|
return s.messagesWhere(ctx, func(m Message) bool {
|
||||||
|
return m.Thread == thread
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Storage) messagesWhere(ctx context.Context, where func(Message) bool) ([]Message, error) {
|
||||||
|
result := make([]Message, 0)
|
||||||
|
err := s.driver.ForEach(ctx, "m", func(_ string, v []byte) error {
|
||||||
|
m := MustDeserialize(v)
|
||||||
|
if !where(m) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result = append(result, m)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
sort.Slice(result, func(i, j int) bool {
|
||||||
|
return result[i].TS < result[j].TS
|
||||||
|
})
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s Storage) Upsert(ctx context.Context, m Message) error {
|
func (s Storage) Upsert(ctx context.Context, m Message) error {
|
||||||
return s.driver.Set(ctx, "storage", m.ID, m.Serialize())
|
return s.driver.Set(ctx, "m", m.ID, m.Serialize())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s Storage) Get(ctx context.Context, id string) (Message, error) {
|
func (s Storage) Get(ctx context.Context, id string) (Message, error) {
|
||||||
b, err := s.driver.Get(ctx, "storage", id)
|
b, err := s.driver.Get(ctx, "m", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Message{}, err
|
return Message{}, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,18 +2,51 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
//func newStorageFromTestdata(t *testing.T) {
|
||||||
|
|
||||||
func TestStorage(t *testing.T) {
|
func TestStorage(t *testing.T) {
|
||||||
ctx, can := context.WithTimeout(context.Background(), time.Second)
|
ctx, can := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer can()
|
defer can()
|
||||||
|
|
||||||
db := NewTestDBIn(t.TempDir())
|
t.Run("Threads", func(t *testing.T) {
|
||||||
defer db.Close()
|
s := NewStorage(NewRAM())
|
||||||
|
|
||||||
s := NewStorage(db)
|
mX1 := Message{ID: "1", Thread: "X", TS: 1}
|
||||||
|
mX2 := Message{ID: "2", Thread: "X", TS: 2}
|
||||||
|
mY1 := Message{ID: "1", Thread: "Y", TS: 3}
|
||||||
|
|
||||||
|
for _, m := range []Message{mX1, mX2, mY1} {
|
||||||
|
if err := s.Upsert(ctx, m); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if threads, err := s.Threads(ctx); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if len(threads) != 2 {
|
||||||
|
t.Error(threads)
|
||||||
|
} else if !slices.Contains(threads, "X") {
|
||||||
|
t.Error(threads, "X")
|
||||||
|
} else if !slices.Contains(threads, "Y") {
|
||||||
|
t.Error(threads, "Y")
|
||||||
|
}
|
||||||
|
|
||||||
|
if threads, err := s.ThreadsSince(ctx, time.Unix(3, 0)); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else if len(threads) != 1 {
|
||||||
|
t.Error(threads)
|
||||||
|
} else if threads[0] != "Y" {
|
||||||
|
t.Error(threads[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get Upsert", func(t *testing.T) {
|
||||||
|
s := NewStorage(NewRAM())
|
||||||
|
|
||||||
if _, err := s.Get(ctx, "id"); err != ErrNotFound {
|
if _, err := s.Get(ctx, "id"); err != ErrNotFound {
|
||||||
t.Error("failed to get 404", err)
|
t.Error("failed to get 404", err)
|
||||||
|
|
@ -33,4 +66,5 @@ func TestStorage(t *testing.T) {
|
||||||
} else if m != m2 {
|
} else if m != m2 {
|
||||||
t.Error(m2)
|
t.Error(m2)
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue