diff --git a/vicuna-tools.d/.main_integration_test.go.swp b/vicuna-tools.d/.main_integration_test.go.swp new file mode 100644 index 0000000..609fb64 Binary files /dev/null and b/vicuna-tools.d/.main_integration_test.go.swp differ diff --git a/vicuna-tools.d/main.go b/vicuna-tools.d/main.go index 0475230..c2b34e5 100644 --- a/vicuna-tools.d/main.go +++ b/vicuna-tools.d/main.go @@ -28,12 +28,11 @@ var ( Port int Debug bool ChatBot struct { - SessionD string - PromptDelimiter string - semaphore sync.Mutex - WD string - Command string - N int + SessionD string + semaphore sync.Mutex + WD string + Command string + N int } } @@ -90,7 +89,6 @@ func config(ctx context.Context) { flag.IntVar(&Config.Port, "p", 37070, "port to listen on") flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions") - flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME") flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot") flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix") flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen") @@ -293,17 +291,23 @@ func handleAPIChatBotGet(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + reversePrompt, err := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt")) + if err != nil { + return err + } prompt, err := os.ReadFile(path.Join(sessionD, "initial-prompt.txt")) if err != nil { return err } return json.NewEncoder(w).Encode(struct { - Messages string - Prompt string + Messages string + Prompt string + ReversePrompt string }{ - Messages: string(bytes.TrimPrefix(messages, prompt)), - Prompt: string(prompt), + Messages: string(bytes.TrimPrefix(messages, prompt)), + Prompt: string(prompt), + ReversePrompt: string(reversePrompt), }) } @@ -323,6 +327,10 @@ func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error { if len(prompt) == 0 { return errors.New("no prompt") } + reverseprompt := r.PostForm.Get("ReversePrompt") + if err := os.WriteFile(path.Join(sessionD, "reverse-prompt.txt"), []byte(reverseprompt), os.ModePerm); err != nil { + return err + } if err := os.WriteFile(path.Join(sessionD, "initial-prompt.txt"), []byte(prompt), os.ModePerm); err != nil { return err } @@ -345,11 +353,10 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { promptF := path.Join(sessionD, "prompt.txt") inputF := path.Join(sessionD, "input.txt") cacheF := path.Join(sessionD, "cache.bin") - reversePrompt := cookie.MyName() - if len(reversePrompt) > 8 { - reversePrompt = reversePrompt[:8] - } - reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter + reversePrompt := func() string { + b, _ := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt")) + return string(b) + }() if err := copyFile(inputF, promptF); err != nil { return err } @@ -436,7 +443,7 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr } buff := bytes.NewBuffer(nil) go func() { - stdout.Read(make([]byte, 1)) //1 BOS byte + //stdout.Read(make([]byte, 1)) //1 BOS byte io.Copy(buff, stdout) }() @@ -458,13 +465,15 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr justNew := oldAndNew[len(priorContent):] trimmedReversePrompt := strings.TrimSpace(reversePrompt) - if idx := bytes.Index(justNew, []byte(trimmedReversePrompt)); idx > -1 { - justNew = justNew[:idx+len(trimmedReversePrompt)] - } else if idx := bytes.LastIndex( - append(priorContent, justNew...), - []byte(trimmedReversePrompt), - ); idx+len(trimmedReversePrompt) > len(priorContent) { - justNew = justNew[:idx+len(trimmedReversePrompt)-len(priorContent)] + if len(trimmedReversePrompt) > 0 { + if idx := bytes.Index(justNew, []byte(trimmedReversePrompt)); idx > -1 { + justNew = justNew[:idx+len(trimmedReversePrompt)] + } else if idx := bytes.LastIndex( + append(priorContent, justNew...), + []byte(trimmedReversePrompt), + ); idx+len(trimmedReversePrompt) > len(priorContent) { + justNew = justNew[:idx+len(trimmedReversePrompt)-len(priorContent)] + } } if err := _appendFile(inputF, string(justNew)); err != nil { return nil, err diff --git a/vicuna-tools.d/main_integration_test.go b/vicuna-tools.d/main_integration_test.go index 8814103..3f0d3f8 100644 --- a/vicuna-tools.d/main_integration_test.go +++ b/vicuna-tools.d/main_integration_test.go @@ -5,6 +5,7 @@ package main import ( "bytes" "context" + "encoding/json" "fmt" "io" "net/http" @@ -20,7 +21,8 @@ func TestAPIV0ChatBot(t *testing.T) { body := func() url.Values { result := url.Values{} - result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") + result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'. Middle manager always replies with non-zero verbal responses.\n\nMIDDLE_MANAGER: What a lovely day.") + result.Set(`ReversePrompt`, "AN_EMPLOYEE: ") result.Set(`Message`, `I lost keys to the company car in my couch, boss.`) return result } @@ -54,16 +56,25 @@ func TestAPIV0ChatBot(t *testing.T) { } resp3 := httpDo(t, http.MethodGet, "/api/v0/chatbot", "") - got3, err := io.ReadAll(resp3.Body) - if err != nil { + var result struct { + Messages string + ReversePrompt string + Prompt string + } + if err := json.NewDecoder(resp3.Body).Decode(&result); err != nil { t.Fatal(err) } - - if !bytes.Contains(got3, got) { - t.Error("forgot got") + if !bytes.Contains([]byte(result.Messages), got) { + t.Errorf("forgot got: %s does not contain %s", result.Messages, got) } - if !bytes.Contains(got3, got2) { - t.Error("forgot got2") + if !bytes.Contains([]byte(result.Messages), got2) { + t.Errorf("forgot got2: %s does not contain %s", result.Messages, got2) + } + if result.Prompt != body().Get("Prompt") { + t.Error(result.Prompt) + } + if result.ReversePrompt != body().Get("ReversePrompt") { + t.Error(result.ReversePrompt) } }) diff --git a/vicuna-tools.d/template.d/chatbot.html b/vicuna-tools.d/template.d/chatbot.html index c814e0e..b44adf2 100644 --- a/vicuna-tools.d/template.d/chatbot.html +++ b/vicuna-tools.d/template.d/chatbot.html @@ -7,6 +7,7 @@ var data = JSON.parse(body) document.getElementById("stream-log").innerHTML = data["Messages"] document.getElementById("stream-prompt").innerHTML = data["Prompt"] + document.getElementById("stream-reverse-prompt").innerHTML = data["ReversePrompt"] }, null) } @@ -68,6 +69,7 @@ Set up a new session
+
diff --git a/vicuna-tools.d/vicuna-tools.d b/vicuna-tools.d/vicuna-tools.d new file mode 100755 index 0000000..ec3347e Binary files /dev/null and b/vicuna-tools.d/vicuna-tools.d differ