From 4f08abbb61726e43a8dff6d923bf53a2f31618ce Mon Sep 17 00:00:00 2001 From: bel Date: Sat, 17 Jun 2023 13:52:12 -0600 Subject: [PATCH] ohno --- vicuna-tools.d/main.go | 166 ++++++++++++++++-------- vicuna-tools.d/main_integration_test.go | 46 ++++--- 2 files changed, 138 insertions(+), 74 deletions(-) diff --git a/vicuna-tools.d/main.go b/vicuna-tools.d/main.go index 0fc16c3..763829a 100644 --- a/vicuna-tools.d/main.go +++ b/vicuna-tools.d/main.go @@ -16,6 +16,7 @@ import ( "os/exec" "os/signal" "path" + "regexp" "strconv" "strings" "sync" @@ -27,11 +28,12 @@ var ( Port int Debug bool ChatBot struct { - SessionD string - semaphore sync.Mutex - WD string - Command string - N int + SessionD string + PromptDelimiter string + semaphore sync.Mutex + WD string + Command string + N int } } @@ -88,6 +90,7 @@ func config(ctx context.Context) { flag.IntVar(&Config.Port, "p", 37070, "port to listen on") flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions") + flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU") flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot") flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix") flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen") @@ -117,9 +120,9 @@ func listenAndServe(ctx context.Context) { func handle(w http.ResponseWriter, r *http.Request) { cookie, _ := ParseCookie(r) if err := _handle(w, r); err != nil { - log.Printf("%s: %s %s: %v", cookie.Name, r.Method, r.URL.Path, err) + log.Printf("%s: %s %s: %v", cookie.MyName(), r.Method, r.URL.Path, err) } else { - log.Printf("%s: %s %s", cookie.Name, r.Method, r.URL.Path) + log.Printf("%s: %s %s", cookie.MyName(), r.Method, r.URL.Path) } } @@ -208,8 +211,15 @@ func parseCookieFromCookie(r *http.Request) (Cookie, error) { return result, result.Verify() } +func (cookie Cookie) MyName() string { + return string(bytes.Join( + regexp.MustCompile(`[a-zA-Z]`).FindAll([]byte(cookie.Name), -1), + []byte(""), + )) +} + func (cookie Cookie) Verify() error { - if cookie.Name == "" { + if cookie.MyName() == "" { return fmt.Errorf("incomplete cookie") } return nil @@ -265,7 +275,7 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error { func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error { cookie, _ := ParseCookie(r) - sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) + sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName()) if _, err := os.Stat(path.Join(sessionD)); err == nil { if err := os.RemoveAll(path.Join(sessionD)); err != nil { return err @@ -293,56 +303,81 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { } cookie, _ := ParseCookie(r) - sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) + sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName()) promptF := path.Join(sessionD, "prompt.txt") inputF := path.Join(sessionD, "input.txt") cacheF := path.Join(sessionD, "cache.bin") - reversePrompt := cookie.Name + "///" + reversePrompt := cookie.MyName() + if len(reversePrompt) > 8 { + reversePrompt = reversePrompt[:8] + } + reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter + forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter if err := copyFile(inputF, promptF); err != nil { return err } - if err := appendFile(inputF, reversePrompt+message); err != nil { + if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil { return err } - if _, err := os.Stat(cacheF); os.IsNotExist(err) { - if err := func() error { - commands := strings.Fields(Config.ChatBot.Command) - commands = append(commands, - "--batch-size", "8", - "--prompt-cache", cacheF, - "-f", inputF, - "--n_predict", "1", - ) - command := exec.CommandContext( - r.Context(), - commands[0], - commands[1:]..., - ) - command.Dir = Config.ChatBot.WD - - Config.ChatBot.semaphore.Lock() - defer Config.ChatBot.semaphore.Unlock() - - if b, err := command.CombinedOutput(); err != nil { - return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b) - } - return nil - }(); err != nil { - return err - } + if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil { + return err } + justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt) + if err != nil { + return err + } + if err := os.Rename(inputF, promptF); err != nil { + return err + } + w.Write(bytes.TrimSuffix(justNew, []byte(reversePrompt))) + + return nil +} + +func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error { + if _, err := os.Stat(cacheF); !os.IsNotExist(err) { + return nil + } + commands := strings.Fields(Config.ChatBot.Command) + commands = append(commands, + "--batch-size", "8", + "--prompt-cache", cacheF, + "-f", inputF, + "--n_predict", "1", + ) + command := exec.CommandContext( + ctx, + commands[0], + commands[1:]..., + ) + command.Dir = Config.ChatBot.WD + + Config.ChatBot.semaphore.Lock() + defer Config.ChatBot.semaphore.Unlock() + + if b, err := command.CombinedOutput(); err != nil { + return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b) + } + return nil +} + +func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) { commands := strings.Fields(Config.ChatBot.Command) commands = append(commands, "-f", inputF, "--prompt-cache-all", "--prompt-cache", cacheF, "-n", strconv.Itoa(Config.ChatBot.N), - "--reverse-prompt", reversePrompt, ) + if len(reversePrompt) > 0 { + commands = append(commands, + "--reverse-prompt", reversePrompt, + ) + } command := exec.CommandContext( - r.Context(), + ctx, commands[0], commands[1:]..., ) @@ -351,33 +386,52 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error { stdout, err := command.StdoutPipe() if err != nil { - return err + return nil, err } buff := bytes.NewBuffer(nil) go func() { - stdout.Read(make([]byte, 1)) + stdout.Read(make([]byte, 1)) //1 BOS byte io.Copy(buff, stdout) }() - Config.ChatBot.semaphore.Lock() - defer Config.ChatBot.semaphore.Unlock() + if err := func() error { + Config.ChatBot.semaphore.Lock() + defer Config.ChatBot.semaphore.Unlock() + + return command.Run() + }(); err != nil { + return nil, err - if err := command.Run(); err != nil { - return err } - result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt)) + oldAndNew := buff.Bytes() + log.Printf("generated: [%s]", oldAndNew) priorContent, err := os.ReadFile(inputF) if err != nil { - return err + return nil, err } - shortResult := result[len(priorContent):] - if err := os.WriteFile(promptF, result, os.ModePerm); err != nil { - return err + justNew := oldAndNew[len(priorContent):] + idx := bytes.LastIndex( + append(priorContent, justNew...), + []byte(reversePrompt), + ) + log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx) + if idx+len(reversePrompt) > len(priorContent) { + justNew = justNew[:idx+len(reversePrompt)-len(priorContent)] } - w.Write(shortResult) + if err := _appendFile(inputF, string(justNew)); err != nil { + return nil, err + } + log.Printf("newly generated: [%s]", justNew) - return nil + if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) { + more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt) + if err != nil { + return nil, err + } + justNew = append(justNew, more...) + } + return justNew, nil } func copyFile(toF, fromF string) error { @@ -398,16 +452,18 @@ func copyFile(toF, fromF string) error { } func appendFile(toF, msg string) error { + return _appendFile(toF, "\n"+msg+"\n") +} + +func _appendFile(toF, msg string) error { f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) if err != nil { return err } defer f.Close() - f.Write([]byte("\n")) if _, err := f.WriteString(msg); err != nil { return err } - f.Write([]byte("\n")) return nil } diff --git a/vicuna-tools.d/main_integration_test.go b/vicuna-tools.d/main_integration_test.go index f0693b3..3f3deac 100644 --- a/vicuna-tools.d/main_integration_test.go +++ b/vicuna-tools.d/main_integration_test.go @@ -7,7 +7,6 @@ import ( "context" "fmt" "io" - "log" "net/http" "net/url" "strings" @@ -19,78 +18,85 @@ import ( func TestAPIV0ChatBot(t *testing.T) { defer goTestMain(t)() - body := url.Values{} - body.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") - body.Set(`Message`, `I lost keys to the company car in my couch, boss.`) + body := func() url.Values { + result := url.Values{} + result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") + result.Set(`Message`, `I lost keys to the company car in my couch, boss.`) + return result + } t.Run("put over post", func(t *testing.T) { - resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) < 100 { + if len(got) < 20 { t.Error(string(got)) } - resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) + resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode()) got2, err := io.ReadAll(resp2.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) - if len(got2) < 100 { + if len(got2) < 20 { t.Error(string(got2)) } if bytes.Equal(got, got2) { - t.Error("dupe generation") + t.Errorf("dupe generation: %s", stderrStash.Bytes()) + stderrStash.Reset() } }) t.Run("post over post", func(t *testing.T) { - resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) < 100 { + if len(got) < 20 { t.Error(string(got)) } - resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) + resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode()) got2, err := io.ReadAll(resp2.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp2.StatusCode, got2) - if len(got2) < 100 { + if len(got2) < 20 { t.Error(string(got)) } if bytes.Equal(got, got2) { - t.Error("dupe generation") + t.Errorf("dupe generation: %s", stderrStash.Bytes()) + stderrStash.Reset() } }) t.Run("put over zero", func(t *testing.T) { - resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) + resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode()) got, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } t.Logf("(%d) %s", resp.StatusCode, got) - if len(got) > 100 { + if len(got) > 20 { t.Error(string(got)) } }) } +var stderrStash = bytes.NewBuffer(nil) + func goTestMain(t *testing.T) func() { - log.SetOutput(io.Discard) + //log.SetOutput(stderrStash) ctx, can := context.WithCancel(context.Background()) ctx, cleanup := contextWithCleanup(ctx) @@ -122,14 +128,16 @@ func httpDo(t *testing.T, method, path, body string) *http.Response { req.Header.Set("Cookie", "root="+cookie.Serialize()) for { req.Body = io.NopCloser(strings.NewReader(body)) - if resp, err := http.DefaultClient.Do(req.Clone(context.Background())); err == nil { + resp, err := http.DefaultClient.Do(req.Clone(context.Background())) + if err == nil { defer resp.Body.Close() b, _ := io.ReadAll(resp.Body) resp.Body = io.NopCloser(bytes.NewReader(b)) return resp } time.Sleep(time.Millisecond * 25) - t.Logf("retrying %s", req.URL.String()) + t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes()) + stderrStash.Reset() } t.Fatalf("failed to ever %s %s", method, path) panic(nil)