diff --git a/main.go b/main.go index d8855d8..f5c5f5c 100644 --- a/main.go +++ b/main.go @@ -89,6 +89,7 @@ func newHandler(cfg Config) http.HandlerFunc { mux.Handle("GET /api/v1/version", http.HandlerFunc(newHandlerGetAPIV1Version)) mux.Handle("POST /api/v1/events/slack", http.HandlerFunc(newHandlerPostAPIV1EventsSlack(cfg))) mux.Handle("PUT /api/v1/rpc/scrapeslack", http.HandlerFunc(newHandlerPutAPIV1RPCScrapeSlack(cfg))) + mux.Handle("GET /api/v1/rpc/aievent", http.HandlerFunc(newHandlerGetAPIV1RPCAIEvent(cfg))) return func(w http.ResponseWriter, r *http.Request) { if cfg.Debug { @@ -107,6 +108,16 @@ func newHandlerGetAPIV1Version(w http.ResponseWriter, _ *http.Request) { json.NewEncoder(w).Encode(map[string]any{"version": Version}) } +func newHandlerGetAPIV1RPCAIEvent(cfg Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !basicAuth(cfg, w, r) { + return + } + + http.Error(w, "not impl", http.StatusNotImplemented) + } +} + func newHandlerPutAPIV1RPCScrapeSlack(cfg Config) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if !basicAuth(cfg, w, r) { diff --git a/storage.go b/storage.go index fe166f8..a17e81f 100644 --- a/storage.go +++ b/storage.go @@ -67,10 +67,60 @@ func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error) return v, err } +func (s Storage) GetThreadMessages(ctx context.Context, ID string) ([]model.Message, error) { + return s.selectMessagesWhere(ctx, "ThreadID = $1", ID) +} + func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error { return s.upsert(ctx, "threads", thread) } +func (s Storage) selectMessagesWhere(ctx context.Context, clause string, args ...any) ([]model.Message, error) { + keys, _, _, _, err := keysArgsKeyargsValues(model.Message{}) + if err != nil { + return nil, err + } + args2 := make([]any, len(args)) + for i := range args { + args2[i], _ = json.Marshal(args[i]) + } + scanTargets := make([]any, len(keys)) + + q := fmt.Sprintf(` + SELECT %s FROM messages WHERE %s + `, strings.Join(keys, ", "), clause) + rows, err := s.driver.QueryContext(ctx, q, args2...) + if err != nil { + return nil, err + } + defer rows.Close() + + var result []model.Message + for rows.Next() { + for i := range scanTargets { + scanTargets[i] = &[]byte{} + } + + if err := rows.Scan(scanTargets...); err != nil { + return nil, err + } + + m := map[string]json.RawMessage{} + for i, k := range keys { + m[k] = *scanTargets[i].(*[]byte) + } + b, _ := json.Marshal(m) + + var one model.Message + if err := json.Unmarshal(b, &one); err != nil { + return nil, err + } + result = append(result, one) + } + + return result, rows.Err() +} + func (s Storage) selectOne(ctx context.Context, table string, v any, clause string, args ...any) error { if questions := strings.Count(clause, "$"); questions != len(args) { return fmt.Errorf("expected %v args for clause but found %v", questions, len(args)) diff --git a/storage_test.go b/storage_test.go index 02114bf..afefd30 100644 --- a/storage_test.go +++ b/storage_test.go @@ -2,6 +2,8 @@ package main import ( "context" + "fmt" + "math/rand" "testing" "time" @@ -85,4 +87,34 @@ func TestStorage(t *testing.T) { t.Fatal("unexpected result from get:", got) } }) + + t.Run("thread messages", func(t *testing.T) { + thread := fmt.Sprintf("thread-%d", rand.Int()) + m := model.NewMessage( + "ID", + 1, + "Author", + "Plaintext", + thread, + ) + + if err := s.UpsertMessage(ctx, m); err != nil { + t.Fatal("unexpected error on insert:", err) + } else if m2, err := s.GetMessage(ctx, m.ID); err != nil { + t.Fatal("unexpected error on upsert-get:", err) + } else if m2 != m { + t.Errorf("expected %+v but got %+v", m, m2) + } + + msgs, err := s.GetThreadMessages(ctx, thread) + if err != nil { + t.Fatal(err) + } else if len(msgs) != 1 { + t.Fatal(msgs) + } else if msgs[0].ThreadID != m.ThreadID { + t.Fatal(msgs[0].ThreadID) + } else if msgs[0] != m { + t.Fatalf("wanted msgs like %+v but got %+v", m, msgs[0]) + } + }) }