impl storage GetEventThreads

main
Bel LaPointe 2024-04-19 12:25:40 -06:00
parent 79de56e236
commit 81fe8070ca
2 changed files with 81 additions and 1 deletions

View File

@ -67,6 +67,10 @@ func (s Storage) GetThread(ctx context.Context, ID string) (model.Thread, error)
return v, err
}
func (s Storage) GetEventThreads(ctx context.Context, ID string) ([]model.Thread, error) {
return s.selectThreadsWhere(ctx, "EventID = $1", ID)
}
func (s Storage) GetThreadMessages(ctx context.Context, ID string) ([]model.Message, error) {
return s.selectMessagesWhere(ctx, "ThreadID = $1", ID)
}
@ -75,6 +79,52 @@ func (s Storage) UpsertThread(ctx context.Context, thread model.Thread) error {
return s.upsert(ctx, "threads", thread)
}
func (s Storage) selectThreadsWhere(ctx context.Context, clause string, args ...any) ([]model.Thread, error) {
keys, _, _, _, err := keysArgsKeyargsValues(model.Thread{})
if err != nil {
return nil, err
}
args2 := make([]any, len(args))
for i := range args {
args2[i], _ = json.Marshal(args[i])
}
scanTargets := make([]any, len(keys))
q := fmt.Sprintf(`
SELECT %s FROM threads WHERE %s
`, strings.Join(keys, ", "), clause)
rows, err := s.driver.QueryContext(ctx, q, args2...)
if err != nil {
return nil, err
}
defer rows.Close()
var result []model.Thread
for rows.Next() {
for i := range scanTargets {
scanTargets[i] = &[]byte{}
}
if err := rows.Scan(scanTargets...); err != nil {
return nil, err
}
m := map[string]json.RawMessage{}
for i, k := range keys {
m[k] = *scanTargets[i].(*[]byte)
}
b, _ := json.Marshal(m)
var one model.Thread
if err := json.Unmarshal(b, &one); err != nil {
return nil, err
}
result = append(result, one)
}
return result, rows.Err()
}
func (s Storage) selectMessagesWhere(ctx context.Context, clause string, args ...any) ([]model.Message, error) {
keys, _, _, _, err := keysArgsKeyargsValues(model.Message{})
if err != nil {

View File

@ -88,7 +88,7 @@ func TestStorage(t *testing.T) {
}
})
t.Run("thread messages", func(t *testing.T) {
t.Run("get thread messages", func(t *testing.T) {
thread := fmt.Sprintf("thread-%d", rand.Int())
m := model.NewMessage(
"ID",
@ -117,4 +117,34 @@ func TestStorage(t *testing.T) {
t.Fatalf("wanted msgs like %+v but got %+v", m, msgs[0])
}
})
t.Run("get event threads", func(t *testing.T) {
event := fmt.Sprintf("event-%d", rand.Int())
m := model.NewThread(
"ID",
"URL",
1,
"Channel",
event,
)
if err := s.UpsertThread(ctx, m); err != nil {
t.Fatal("unexpected error on insert:", err)
} else if m2, err := s.GetThread(ctx, m.ID); err != nil {
t.Fatal("unexpected error on upsert-get:", err)
} else if m2 != m {
t.Errorf("expected %+v but got %+v", m, m2)
}
msgs, err := s.GetEventThreads(ctx, event)
if err != nil {
t.Fatal(err)
} else if len(msgs) != 1 {
t.Fatal(msgs)
} else if msgs[0].EventID != m.EventID {
t.Fatal(msgs[0].EventID)
} else if msgs[0] != m {
t.Fatalf("wanted msgs like %+v but got %+v", m, msgs[0])
}
})
}