diff --git a/vicuna-tools.d/main.go b/vicuna-tools.d/main.go index 0fc16c3..763829a 100644 --- a/vicuna-tools.d/main.go +++ b/vicuna-tools.d/main.go @@ -16,6 +16,7 @@ import ( "os/exec" "os/signal" "path" + "regexp" "strconv" "strings" "sync" @@ -27,11 +28,12 @@ var ( Port int Debug bool ChatBot struct { - SessionD string - semaphore sync.Mutex - WD string - Command string - N int + SessionD string + PromptDelimiter string + semaphore sync.Mutex + WD string + Command string + N int } } @@ -88,6 +90,7 @@ func config(ctx context.Context) { flag.IntVar(&Config.Port, "p", 37070, "port to listen on") flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions") + flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU") flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot") flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix") flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen") @@ -117,9 +120,9 @@ func listenAndServe(ctx context.Context) { func handle(w http.ResponseWriter, r *http.Request) { cookie, _ := ParseCookie(r) if err := _handle(w, r); err != nil { - log.Printf("%s: %s %s: %v", cookie.Name, r.Method, r.URL.Path, err) + log.Printf("%s: %s %s: %v", cookie.MyName(), r.Method, r.URL.Path, err) } else { - log.Printf("%s: %s %s", cookie.Name, r.Method, r.URL.Path) + log.Printf("%s: %s %s", cookie.MyName(), r.Method, r.URL.Path) } } @@ -208,8 +211,15 @@ func parseCookieFromCookie(r *http.Request) (Cookie, error) { return result, result.Verify() } +func (cookie Cookie) MyName() string { + return string(bytes.Join( + regexp.MustCompile(`[a-zA-Z]`).FindAll([]byte(cookie.Name), -1), + []byte(""), + )) +} + func (cookie Cookie) Verify() error { - if cookie.Name == "" { + if cookie.MyName() == "" { return fmt.Errorf("incomplete cookie") } return nil @@ -265,7 +275,7 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error { func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error { cookie, _ := ParseCookie(r) - sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) + sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName()) if _, err := os.Stat(path.Join(sessionD)); err == nil { if err := os.RemoveAll(path.Join(sessionD)); err != nil { return err @@ -293,56 +303,81 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { } cookie, _ := ParseCookie(r) - sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) + sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName()) promptF := path.Join(sessionD, "prompt.txt") inputF := path.Join(sessionD, "input.txt") cacheF := path.Join(sessionD, "cache.bin") - reversePrompt := cookie.Name + "///" + reversePrompt := cookie.MyName() + if len(reversePrompt) > 8 { + reversePrompt = reversePrompt[:8] + } + reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter + forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter if err := copyFile(inputF, promptF); err != nil { return err } - if err := appendFile(inputF, reversePrompt+message); err != nil { + if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil { return err } - if _, err := os.Stat(cacheF); os.IsNotExist(err) { - if err := func() error { - commands := strings.Fields(Config.ChatBot.Command) - commands = append(commands, - "--batch-size", "8", - "--prompt-cache", cacheF, - "-f", inputF, - "--n_predict", "1", - ) - command := exec.CommandContext( - r.Context(), - commands[0], - commands[1:]..., - ) - command.Dir = Config.ChatBot.WD - - Config.ChatBot.semaphore.Lock() - defer Config.ChatBot.semaphore.Unlock() - - if b, err := command.CombinedOutput(); err != nil { - return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b) - } - return nil - }(); err != nil { - return err - } + if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil { + return err } + justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt) + if err != nil { + return err + } + if err := os.Rename(inputF, promptF); err != nil { + return err + } + w.Write(bytes.TrimSuffix(justNew, []byte(reversePrompt))) + + return nil +} + +func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error { + if _, err := os.Stat(cacheF); !os.IsNotExist(err) { + return nil + } + commands := strings.Fields(Config.ChatBot.Command) + commands = append(commands, + "--batch-size", "8", + "--prompt-cache", cacheF, + "-f", inputF, + "--n_predict", "1", + ) + command := exec.CommandContext( + ctx, + commands[0], + commands[1:]..., + ) + command.Dir = Config.ChatBot.WD + + Config.ChatBot.semaphore.Lock() + defer Config.ChatBot.semaphore.Unlock() + + if b, err := command.CombinedOutput(); err != nil { + return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b) + } + return nil +} + +func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) { commands := strings.Fields(Config.ChatBot.Command) commands = append(commands, "-f", inputF, "--prompt-cache-all", "--prompt-cache", cacheF, "-n", strconv.Itoa(Config.ChatBot.N), - "--reverse-prompt", reversePrompt, ) + if len(reversePrompt) > 0 { + commands = append(commands, + "--reverse-prompt", reversePrompt, + ) + } command := exec.CommandContext( - r.Context(), + ctx, commands[0], commands[1:]..., ) @@ -351,33 +386,52 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { stdout, err := command.StdoutPipe() if err != nil { - return err + return nil, err } buff := bytes.NewBuffer(nil) go func() { - stdout.Read(make([]byte, 1)) + stdout.Read(make([]byte, 1)) //1 BOS byte io.Copy(buff, stdout) }() - Config.ChatBot.semaphore.Lock() - defer Config.ChatBot.semaphore.Unlock() + if err := func() error { + Config.ChatBot.semaphore.Lock() + defer Config.ChatBot.semaphore.Unlock() + + return command.Run() + }(); err != nil { + return nil, err - if err := command.Run(); err != nil { - return err } - result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt)) + oldAndNew := buff.Bytes() + log.Printf("generated: [%s]", oldAndNew) priorContent, err := os.ReadFile(inputF) if err != nil { - return err + return nil, err } - shortResult := result[len(priorContent):] - if err := os.WriteFile(promptF, result, os.ModePerm); err != nil { - return err + justNew := oldAndNew[len(priorContent):] + idx := bytes.LastIndex( + append(priorContent, justNew...), + []byte(reversePrompt), + ) + log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx) + if idx+len(reversePrompt) > len(priorContent) { + justNew = justNew[:idx+len(reversePrompt)-len(priorContent)] } - w.Write(shortResult) + if err := _appendFile(inputF, string(justNew)); err != nil { + return nil, err + } + log.Printf("newly generated: [%s]", justNew) - return nil + if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) { + more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt) + if err != nil { + return nil, err + } + justNew = append(justNew, more...) + } + return justNew, nil } func copyFile(toF, fromF string) error { @@ -398,16 +452,18 @@ func copyFile(toF, fromF string) error { } func appendFile(toF, msg string) error { + return _appendFile(toF, "\n"+msg+"\n") +} + +func _appendFile(toF, msg string) error { f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) if err != nil { return err } defer f.Close() - f.Write([]byte("\n")) if _, err := f.WriteString(msg); err != nil { return err } - f.Write([]byte("\n")) return nil } diff --git a/vicuna-tools.d/main_integration_test.go b/vicuna-tools.d/main_integration_test.go index f0693b3..3f3deac 100644 --- a/vicuna-tools.d/main_integration_test.go +++ b/vicuna-tools.d/main_integration_test.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -19,78 +18,85 @@ import ( func TestAPIV0ChatBot(t *testing.T) { defer goTestMain(t)() - body := url.Values{} - body.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") - body.Set(`Message`, `I lost keys to the company car in my couch, boss.`) + body := func() url.Values { + result := url.Values{} + result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") + result.Set(`Message`, `I lost keys to the company car in my couch, boss.`) + return result + } t.Run("put over post", func(t *testing.T) { - resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) < 100 { + if len(got) < 20 { t.Error(string(got)) } - resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) + resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode()) got2, err := io.ReadAll(resp2.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) - if len(got2) < 100 { + if len(got2) < 20 { t.Error(string(got2)) } if bytes.Equal(got, got2) { - t.Error("dupe generation") + t.Errorf("dupe generation: %s", stderrStash.Bytes()) + stderrStash.Reset() } }) t.Run("post over post", func(t *testing.T) { - resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) < 100 { + if len(got) < 20 { t.Error(string(got)) } - resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got2, err := io.ReadAll(resp2.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) - if len(got2) < 100 { + if len(got2) < 20 { t.Error(string(got)) } if bytes.Equal(got, got2) { - t.Error("dupe generation") + t.Errorf("dupe generation: %s", stderrStash.Bytes()) + stderrStash.Reset() } }) t.Run("put over zero", func(t *testing.T) { - resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) > 100 { + if len(got) > 20 { t.Error(string(got)) } }) } +var stderrStash = bytes.NewBuffer(nil) + func goTestMain(t *testing.T) func() { - log.SetOutput(io.Discard) + //log.SetOutput(stderrStash) ctx, can := context.WithCancel(context.Background()) ctx, cleanup := contextWithCleanup(ctx) @@ -122,14 +128,16 @@ func httpDo(t *testing.T, method, path, body string) *http.Response { req.Header.Set("Cookie", "root="+cookie.Serialize()) for { req.Body = io.NopCloser(strings.NewReader(body)) - if resp, err := http.DefaultClient.Do(req.Clone(context.Background())); err == nil { + resp, err := http.DefaultClient.Do(req.Clone(context.Background())) + if err == nil { defer resp.Body.Close() b, _ := io.ReadAll(resp.Body) resp.Body = io.NopCloser(bytes.NewReader(b)) return resp } time.Sleep(time.Millisecond * 25) - t.Logf("retrying %s", req.URL.String()) + t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes()) + stderrStash.Reset() } t.Fatalf("failed to ever %s %s", method, path) panic(nil)