diff --git a/storage.go b/storage.go index 32ab077..96bec5e 100644 --- a/storage.go +++ b/storage.go @@ -3,6 +3,8 @@ package main import ( "context" "errors" + "sort" + "time" ) var ( @@ -17,12 +19,56 @@ func NewStorage(driver Driver) Storage { return Storage{driver: driver} } +func (s Storage) Threads(ctx context.Context) ([]string, error) { + return s.ThreadsSince(ctx, time.Unix(0, 0)) +} + +func (s Storage) ThreadsSince(ctx context.Context, t time.Time) ([]string, error) { + messages, err := s.messagesWhere(ctx, func(m Message) bool { + return !t.After(m.Time()) + }) + if err != nil { + return nil, err + } + threads := map[string]struct{}{} + for _, m := range messages { + threads[m.Thread] = struct{}{} + } + result := make([]string, 0, len(threads)) + for k := range threads { + result = append(result, k) + } + return result, nil +} + +func (s Storage) Thread(ctx context.Context, thread string) ([]Message, error) { + return s.messagesWhere(ctx, func(m Message) bool { + return m.Thread == thread + }) +} + +func (s Storage) messagesWhere(ctx context.Context, where func(Message) bool) ([]Message, error) { + result := make([]Message, 0) + err := s.driver.ForEach(ctx, "m", func(_ string, v []byte) error { + m := MustDeserialize(v) + if !where(m) { + return nil + } + result = append(result, m) + return nil + }) + sort.Slice(result, func(i, j int) bool { + return result[i].TS < result[j].TS + }) + return result, err +} + func (s Storage) Upsert(ctx context.Context, m Message) error { - return s.driver.Set(ctx, "storage", m.ID, m.Serialize()) + return s.driver.Set(ctx, "m", m.ID, m.Serialize()) } func (s Storage) Get(ctx context.Context, id string) (Message, error) { - b, err := s.driver.Get(ctx, "storage", id) + b, err := s.driver.Get(ctx, "m", id) if err != nil { return Message{}, err } diff --git a/storage_test.go b/storage_test.go index a77523a..304a872 100644 --- a/storage_test.go +++ b/storage_test.go @@ -2,35 +2,69 @@ package main import ( "context" + "slices" "testing" "time" ) +//func newStorageFromTestdata(t *testing.T) { + func TestStorage(t *testing.T) { ctx, can := context.WithTimeout(context.Background(), time.Second) defer can() - db := NewTestDBIn(t.TempDir()) - defer db.Close() + t.Run("Threads", func(t *testing.T) { + s := NewStorage(NewRAM()) - s := NewStorage(db) + mX1 := Message{ID: "1", Thread: "X", TS: 1} + mX2 := Message{ID: "2", Thread: "X", TS: 2} + mY1 := Message{ID: "1", Thread: "Y", TS: 3} - if _, err := s.Get(ctx, "id"); err != ErrNotFound { - t.Error("failed to get 404", err) - } + for _, m := range []Message{mX1, mX2, mY1} { + if err := s.Upsert(ctx, m); err != nil { + t.Fatal(err) + } + } - m := Message{ - ID: "id", - TS: 1, - } + if threads, err := s.Threads(ctx); err != nil { + t.Error(err) + } else if len(threads) != 2 { + t.Error(threads) + } else if !slices.Contains(threads, "X") { + t.Error(threads, "X") + } else if !slices.Contains(threads, "Y") { + t.Error(threads, "Y") + } - if err := s.Upsert(ctx, m); err != nil { - t.Error("failed to upsert", err) - } + if threads, err := s.ThreadsSince(ctx, time.Unix(3, 0)); err != nil { + t.Error(err) + } else if len(threads) != 1 { + t.Error(threads) + } else if threads[0] != "Y" { + t.Error(threads[0]) + } + }) - if m2, err := s.Get(ctx, "id"); err != nil { - t.Error("failed to get", err) - } else if m != m2 { - t.Error(m2) - } + t.Run("Get Upsert", func(t *testing.T) { + s := NewStorage(NewRAM()) + + if _, err := s.Get(ctx, "id"); err != ErrNotFound { + t.Error("failed to get 404", err) + } + + m := Message{ + ID: "id", + TS: 1, + } + + if err := s.Upsert(ctx, m); err != nil { + t.Error("failed to upsert", err) + } + + if m2, err := s.Get(ctx, "id"); err != nil { + t.Error("failed to get", err) + } else if m != m2 { + t.Error(m2) + } + }) }