diff --git a/storage.go b/storage.go index a17e81f..1343407 100644 --- a/storage.go +++ b/storage.go @@ -67,6 +67,10 @@ func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) return v, err } +func (s Storage) GetEventThreads(ctx context.Context, ID string) ([]model.Thread, error) { + return s.selectThreadsWhere(ctx, "EventID = $1", ID) +} + func (s Storage) GetThreadMessages(ctx context.Context, ID string) ([]model.Message, error) { return s.selectMessagesWhere(ctx, "ThreadID = $1", ID) } @@ -75,6 +79,52 @@ func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error { return s.upsert(ctx, "threads", thread) } +func (s Storage) selectThreadsWhere(ctx context.Context, clause string, args ...any) ([]model.Thread, error) { + keys, _, _, _, err := keysArgsKeyargsValues(model.Thread{}) + if err != nil { + return nil, err + } + args2 := make([]any, len(args)) + for i := range args { + args2[i], _ = json.Marshal(args[i]) + } + scanTargets := make([]any, len(keys)) + + q := fmt.Sprintf(` + SELECT %s FROM threads WHERE %s + `, strings.Join(keys, ", "), clause) + rows, err := s.driver.QueryContext(ctx, q, args2...) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []model.Thread + for rows.Next() { + for i := range scanTargets { + scanTargets[i] = &[]byte{} + } + + if err := rows.Scan(scanTargets...); err != nil { + return nil, err + } + + m := map[string]json.RawMessage{} + for i, k := range keys { + m[k] = *scanTargets[i].(*[]byte) + } + b, _ := json.Marshal(m) + + var one model.Thread + if err := json.Unmarshal(b, &one); err != nil { + return nil, err + } + result = append(result, one) + } + + return result, rows.Err() +} + func (s Storage) selectMessagesWhere(ctx context.Context, clause string, args ...any) ([]model.Message, error) { keys, _, _, _, err := keysArgsKeyargsValues(model.Message{}) if err != nil { diff --git a/storage_test.go b/storage_test.go index afefd30..180195d 100644 --- a/storage_test.go +++ b/storage_test.go @@ -88,7 +88,7 @@ func TestStorage(t *testing.T) { } }) - t.Run("thread messages", func(t *testing.T) { + t.Run("get thread messages", func(t *testing.T) { thread := fmt.Sprintf("thread-%d", rand.Int()) m := model.NewMessage( "ID", @@ -117,4 +117,34 @@ func TestStorage(t *testing.T) { t.Fatalf("wanted msgs like %+v but got %+v", m, msgs[0]) } }) + + t.Run("get event threads", func(t *testing.T) { + event := fmt.Sprintf("event-%d", rand.Int()) + m := model.NewThread( + "ID", + "URL", + 1, + "Channel", + event, + ) + + if err := s.UpsertThread(ctx, m); err != nil { + t.Fatal("unexpected error on insert:", err) + } else if m2, err := s.GetThread(ctx, m.ID); err != nil { + t.Fatal("unexpected error on upsert-get:", err) + } else if m2 != m { + t.Errorf("expected %+v but got %+v", m, m2) + } + + msgs, err := s.GetEventThreads(ctx, event) + if err != nil { + t.Fatal(err) + } else if len(msgs) != 1 { + t.Fatal(msgs) + } else if msgs[0].EventID != m.EventID { + t.Fatal(msgs[0].EventID) + } else if msgs[0] != m { + t.Fatalf("wanted msgs like %+v but got %+v", m, msgs[0]) + } + }) }