integration tests k

master
bel 2023-06-17 15:29:56 -06:00
parent 242bf0a746
commit a29101296b
5 changed files with 54 additions and 32 deletions

Binary file not shown.

View File

@ -29,7 +29,6 @@ var (
Debug bool Debug bool
ChatBot struct { ChatBot struct {
SessionD string SessionD string
PromptDelimiter string
semaphore sync.Mutex semaphore sync.Mutex
WD string WD string
Command string Command string
@ -90,7 +89,6 @@ func config(ctx context.Context) {
flag.IntVar(&Config.Port, "p", 37070, "port to listen on") flag.IntVar(&Config.Port, "p", 37070, "port to listen on")
flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions") flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions")
flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME")
flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot") flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot")
flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix") flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix")
flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen") flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen")
@ -293,6 +291,10 @@ func handleAPIChatBotGet(w http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
return err return err
} }
reversePrompt, err := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt"))
if err != nil {
return err
}
prompt, err := os.ReadFile(path.Join(sessionD, "initial-prompt.txt")) prompt, err := os.ReadFile(path.Join(sessionD, "initial-prompt.txt"))
if err != nil { if err != nil {
return err return err
@ -301,9 +303,11 @@ func handleAPIChatBotGet(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(struct { return json.NewEncoder(w).Encode(struct {
Messages string Messages string
Prompt string Prompt string
ReversePrompt string
}{ }{
Messages: string(bytes.TrimPrefix(messages, prompt)), Messages: string(bytes.TrimPrefix(messages, prompt)),
Prompt: string(prompt), Prompt: string(prompt),
ReversePrompt: string(reversePrompt),
}) })
} }
@ -323,6 +327,10 @@ func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
if len(prompt) == 0 { if len(prompt) == 0 {
return errors.New("no prompt") return errors.New("no prompt")
} }
reverseprompt := r.PostForm.Get("ReversePrompt")
if err := os.WriteFile(path.Join(sessionD, "reverse-prompt.txt"), []byte(reverseprompt), os.ModePerm); err != nil {
return err
}
if err := os.WriteFile(path.Join(sessionD, "initial-prompt.txt"), []byte(prompt), os.ModePerm); err != nil { if err := os.WriteFile(path.Join(sessionD, "initial-prompt.txt"), []byte(prompt), os.ModePerm); err != nil {
return err return err
} }
@ -345,11 +353,10 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
promptF := path.Join(sessionD, "prompt.txt") promptF := path.Join(sessionD, "prompt.txt")
inputF := path.Join(sessionD, "input.txt") inputF := path.Join(sessionD, "input.txt")
cacheF := path.Join(sessionD, "cache.bin") cacheF := path.Join(sessionD, "cache.bin")
reversePrompt := cookie.MyName() reversePrompt := func() string {
if len(reversePrompt) > 8 { b, _ := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt"))
reversePrompt = reversePrompt[:8] return string(b)
} }()
reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter
if err := copyFile(inputF, promptF); err != nil { if err := copyFile(inputF, promptF); err != nil {
return err return err
} }
@ -436,7 +443,7 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
} }
buff := bytes.NewBuffer(nil) buff := bytes.NewBuffer(nil)
go func() { go func() {
stdout.Read(make([]byte, 1)) //1 BOS byte //stdout.Read(make([]byte, 1)) //1 BOS byte
io.Copy(buff, stdout) io.Copy(buff, stdout)
}() }()
@ -458,6 +465,7 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
justNew := oldAndNew[len(priorContent):] justNew := oldAndNew[len(priorContent):]
trimmedReversePrompt := strings.TrimSpace(reversePrompt) trimmedReversePrompt := strings.TrimSpace(reversePrompt)
if len(trimmedReversePrompt) > 0 {
if idx := bytes.Index(justNew, []byte(trimmedReversePrompt)); idx > -1 { if idx := bytes.Index(justNew, []byte(trimmedReversePrompt)); idx > -1 {
justNew = justNew[:idx+len(trimmedReversePrompt)] justNew = justNew[:idx+len(trimmedReversePrompt)]
} else if idx := bytes.LastIndex( } else if idx := bytes.LastIndex(
@ -466,6 +474,7 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
); idx+len(trimmedReversePrompt) > len(priorContent) { ); idx+len(trimmedReversePrompt) > len(priorContent) {
justNew = justNew[:idx+len(trimmedReversePrompt)-len(priorContent)] justNew = justNew[:idx+len(trimmedReversePrompt)-len(priorContent)]
} }
}
if err := _appendFile(inputF, string(justNew)); err != nil { if err := _appendFile(inputF, string(justNew)); err != nil {
return nil, err return nil, err
} }

View File

@ -5,6 +5,7 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -20,7 +21,8 @@ func TestAPIV0ChatBot(t *testing.T) {
body := func() url.Values { body := func() url.Values {
result := url.Values{} result := url.Values{}
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'. Middle manager always replies with non-zero verbal responses.\n\nMIDDLE_MANAGER: What a lovely day.")
result.Set(`ReversePrompt`, "AN_EMPLOYEE: ")
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`) result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
return result return result
} }
@ -54,16 +56,25 @@ func TestAPIV0ChatBot(t *testing.T) {
} }
resp3 := httpDo(t, http.MethodGet, "/api/v0/chatbot", "") resp3 := httpDo(t, http.MethodGet, "/api/v0/chatbot", "")
got3, err := io.ReadAll(resp3.Body) var result struct {
if err != nil { Messages string
ReversePrompt string
Prompt string
}
if err := json.NewDecoder(resp3.Body).Decode(&result); err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Contains([]byte(result.Messages), got) {
if !bytes.Contains(got3, got) { t.Errorf("forgot got: %s does not contain %s", result.Messages, got)
t.Error("forgot got")
} }
if !bytes.Contains(got3, got2) { if !bytes.Contains([]byte(result.Messages), got2) {
t.Error("forgot got2") t.Errorf("forgot got2: %s does not contain %s", result.Messages, got2)
}
if result.Prompt != body().Get("Prompt") {
t.Error(result.Prompt)
}
if result.ReversePrompt != body().Get("ReversePrompt") {
t.Error(result.ReversePrompt)
} }
}) })

View File

@ -7,6 +7,7 @@
var data = JSON.parse(body) var data = JSON.parse(body)
document.getElementById("stream-log").innerHTML = data["Messages"] document.getElementById("stream-log").innerHTML = data["Messages"]
document.getElementById("stream-prompt").innerHTML = data["Prompt"] document.getElementById("stream-prompt").innerHTML = data["Prompt"]
document.getElementById("stream-reverse-prompt").innerHTML = data["ReversePrompt"]
}, null) }, null)
} }
@ -68,6 +69,7 @@
<summary>Set up a new session</summary> <summary>Set up a new session</summary>
<form id="prompt" onsubmit="startStream(this); return false;"> <form id="prompt" onsubmit="startStream(this); return false;">
<textarea id="stream-prompt" name="Prompt"></textarea> <textarea id="stream-prompt" name="Prompt"></textarea>
<input type="text" id="stream-reverse-prompt" name="ReversePrompt"/>
<button type="submit">Start with prompt</button> <button type="submit">Start with prompt</button>
</form> </form>
</details> </details>

BIN
vicuna-tools.d/vicuna-tools.d Executable file

Binary file not shown.