171 lines
4.4 KiB
Go
171 lines
4.4 KiB
Go
//go:build integration
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestAPIV0ChatBot(t *testing.T) {
|
|
defer goTestMain(t)()
|
|
|
|
body := func() url.Values {
|
|
result := url.Values{}
|
|
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'. Middle manager always replies with non-zero verbal responses.\n\nMIDDLE_MANAGER: What a lovely day.")
|
|
result.Set(`ReversePrompt`, "AN_EMPLOYEE: ")
|
|
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
|
|
return result
|
|
}
|
|
|
|
t.Run("put over post", func(t *testing.T) {
|
|
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
|
got, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("(%d) %s", resp.StatusCode, got)
|
|
if len(got) < 20 {
|
|
t.Error(string(got))
|
|
}
|
|
|
|
body2 := body()
|
|
body2.Set("Message", "I don't know, boss....")
|
|
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body2.Encode())
|
|
got2, err := io.ReadAll(resp2.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("(%d) %s", resp2.StatusCode, got2)
|
|
if len(got2) < 20 {
|
|
t.Error(string(got2))
|
|
}
|
|
|
|
if bytes.Equal(got, got2) {
|
|
t.Errorf("dupe generation: %s", stderrStash.Bytes())
|
|
stderrStash.Reset()
|
|
}
|
|
|
|
resp3 := httpDo(t, http.MethodGet, "/api/v0/chatbot", "")
|
|
var result struct {
|
|
Messages string
|
|
ReversePrompt string
|
|
Prompt string
|
|
}
|
|
if err := json.NewDecoder(resp3.Body).Decode(&result); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Contains([]byte(result.Messages), got) {
|
|
t.Errorf("forgot got: %s does not contain %s", result.Messages, got)
|
|
}
|
|
if !bytes.Contains([]byte(result.Messages), got2) {
|
|
t.Errorf("forgot got2: %s does not contain %s", result.Messages, got2)
|
|
}
|
|
if result.Prompt != body().Get("Prompt") {
|
|
t.Error(result.Prompt)
|
|
}
|
|
if result.ReversePrompt != body().Get("ReversePrompt") {
|
|
t.Error(result.ReversePrompt)
|
|
}
|
|
})
|
|
|
|
t.Run("post over post", func(t *testing.T) {
|
|
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
|
got, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("(%d) %s", resp.StatusCode, got)
|
|
if len(got) < 20 {
|
|
t.Error(string(got))
|
|
}
|
|
|
|
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
|
got2, err := io.ReadAll(resp2.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("(%d) %s", resp2.StatusCode, got2)
|
|
if len(got2) < 20 {
|
|
t.Error(string(got))
|
|
}
|
|
|
|
if bytes.Equal(got, got2) {
|
|
t.Errorf("dupe generation: %s", stderrStash.Bytes())
|
|
stderrStash.Reset()
|
|
}
|
|
})
|
|
|
|
t.Run("put over zero", func(t *testing.T) {
|
|
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
|
|
got, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Logf("(%d) %s", resp.StatusCode, got)
|
|
if len(got) > 20 {
|
|
t.Error(string(got))
|
|
}
|
|
})
|
|
|
|
}
|
|
|
|
var stderrStash = bytes.NewBuffer(nil)
|
|
|
|
func goTestMain(t *testing.T) func() {
|
|
//log.SetOutput(stderrStash)
|
|
|
|
ctx, can := context.WithCancel(context.Background())
|
|
ctx, cleanup := contextWithCleanup(ctx)
|
|
config(ctx)
|
|
Config.Port += 10
|
|
Config.ChatBot.N = 32
|
|
wg := &sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
listenAndServe(ctx)
|
|
}()
|
|
httpDo(t, http.MethodGet, "/", "")
|
|
return func() {
|
|
cleanup()
|
|
can()
|
|
wg.Wait()
|
|
}
|
|
}
|
|
|
|
func httpDo(t *testing.T, method, path, body string) *http.Response {
|
|
req, _ := http.NewRequest(
|
|
method,
|
|
fmt.Sprintf("http://localhost:%d/%s", Config.Port, strings.TrimLeft(path, "/")),
|
|
nil,
|
|
)
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
cookie := Cookie{Name: t.Name()}
|
|
req.Header.Set("Cookie", "root="+cookie.Serialize())
|
|
for {
|
|
req.Body = io.NopCloser(strings.NewReader(body))
|
|
resp, err := http.DefaultClient.Do(req.Clone(context.Background()))
|
|
if err == nil {
|
|
defer resp.Body.Close()
|
|
b, _ := io.ReadAll(resp.Body)
|
|
resp.Body = io.NopCloser(bytes.NewReader(b))
|
|
return resp
|
|
}
|
|
time.Sleep(time.Millisecond * 25)
|
|
t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes())
|
|
stderrStash.Reset()
|
|
}
|
|
t.Fatalf("failed to ever %s %s", method, path)
|
|
panic(nil)
|
|
}
|