diff --git a/vicuna-tools.d/main.go b/vicuna-tools.d/main.go index c7c975f..0fc16c3 100644 --- a/vicuna-tools.d/main.go +++ b/vicuna-tools.d/main.go @@ -367,10 +367,15 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { } result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt)) + priorContent, err := os.ReadFile(inputF) + if err != nil { + return err + } + shortResult := result[len(priorContent):] if err := os.WriteFile(promptF, result, os.ModePerm); err != nil { return err } - w.Write(result) + w.Write(shortResult) return nil } diff --git a/vicuna-tools.d/main_integration_test.go b/vicuna-tools.d/main_integration_test.go index 100b9b1..f0693b3 100644 --- a/vicuna-tools.d/main_integration_test.go +++ b/vicuna-tools.d/main_integration_test.go @@ -30,6 +30,9 @@ func TestAPIV0ChatBot(t *testing.T) { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) + if len(got) < 100 { + t.Error(string(got)) + } resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) got2, err := io.ReadAll(resp2.Body) @@ -37,11 +40,12 @@ func TestAPIV0ChatBot(t *testing.T) { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) + if len(got2) < 100 { + t.Error(string(got2)) + } - if len(got) == len(got2) { - t.Error("nothing new as of put") - } else if !bytes.HasPrefix(got2, got) { - t.Error("put was not a continuation") + if bytes.Equal(got, got2) { + t.Error("dupe generation") } }) @@ -52,6 +56,9 @@ func TestAPIV0ChatBot(t *testing.T) { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) + if len(got) < 100 { + t.Error(string(got)) + } resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) got2, err := io.ReadAll(resp2.Body) @@ -59,9 +66,12 @@ func TestAPIV0ChatBot(t *testing.T) { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) + if len(got2) < 100 { + t.Error(string(got)) + } - if bytes.HasPrefix(got2, got) { - t.Error("post over post was a continuation") + if bytes.Equal(got, got2) { + t.Error("dupe generation") } }) @@ -72,6 +82,9 @@ func TestAPIV0ChatBot(t *testing.T) { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) + if len(got) > 100 { + t.Error(string(got)) + } }) } @@ -83,6 +96,7 @@ func goTestMain(t *testing.T) func() { ctx, cleanup := contextWithCleanup(ctx) config(ctx) Config.Port += 10 + Config.ChatBot.N = 32 wg := &sync.WaitGroup{} wg.Add(1) go func() {