master
bel 2023-06-17 13:52:12 -06:00
parent f37098223e
commit 4f08abbb61
2 changed files with 138 additions and 74 deletions

View File

@ -16,6 +16,7 @@ import (
"os/exec" "os/exec"
"os/signal" "os/signal"
"path" "path"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -28,6 +29,7 @@ var (
Debug bool Debug bool
ChatBot struct { ChatBot struct {
SessionD string SessionD string
PromptDelimiter string
semaphore sync.Mutex semaphore sync.Mutex
WD string WD string
Command string Command string
@ -88,6 +90,7 @@ func config(ctx context.Context) {
flag.IntVar(&Config.Port, "p", 37070, "port to listen on") flag.IntVar(&Config.Port, "p", 37070, "port to listen on")
flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions") flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions")
flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU")
flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot") flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot")
flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix") flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix")
flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen") flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen")
@ -117,9 +120,9 @@ func listenAndServe(ctx context.Context) {
func handle(w http.ResponseWriter, r *http.Request) { func handle(w http.ResponseWriter, r *http.Request) {
cookie, _ := ParseCookie(r) cookie, _ := ParseCookie(r)
if err := _handle(w, r); err != nil { if err := _handle(w, r); err != nil {
log.Printf("%s: %s %s: %v", cookie.Name, r.Method, r.URL.Path, err) log.Printf("%s: %s %s: %v", cookie.MyName(), r.Method, r.URL.Path, err)
} else { } else {
log.Printf("%s: %s %s", cookie.Name, r.Method, r.URL.Path) log.Printf("%s: %s %s", cookie.MyName(), r.Method, r.URL.Path)
} }
} }
@ -208,8 +211,15 @@ func parseCookieFromCookie(r *http.Request) (Cookie, error) {
return result, result.Verify() return result, result.Verify()
} }
func (cookie Cookie) MyName() string {
return string(bytes.Join(
regexp.MustCompile(`[a-zA-Z]`).FindAll([]byte(cookie.Name), -1),
[]byte(""),
))
}
func (cookie Cookie) Verify() error { func (cookie Cookie) Verify() error {
if cookie.Name == "" { if cookie.MyName() == "" {
return fmt.Errorf("incomplete cookie") return fmt.Errorf("incomplete cookie")
} }
return nil return nil
@ -265,7 +275,7 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error {
func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error { func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
cookie, _ := ParseCookie(r) cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
if _, err := os.Stat(path.Join(sessionD)); err == nil { if _, err := os.Stat(path.Join(sessionD)); err == nil {
if err := os.RemoveAll(path.Join(sessionD)); err != nil { if err := os.RemoveAll(path.Join(sessionD)); err != nil {
return err return err
@ -293,20 +303,43 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
} }
cookie, _ := ParseCookie(r) cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name) sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
promptF := path.Join(sessionD, "prompt.txt") promptF := path.Join(sessionD, "prompt.txt")
inputF := path.Join(sessionD, "input.txt") inputF := path.Join(sessionD, "input.txt")
cacheF := path.Join(sessionD, "cache.bin") cacheF := path.Join(sessionD, "cache.bin")
reversePrompt := cookie.Name + "///" reversePrompt := cookie.MyName()
if len(reversePrompt) > 8 {
reversePrompt = reversePrompt[:8]
}
reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter
forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter
if err := copyFile(inputF, promptF); err != nil { if err := copyFile(inputF, promptF); err != nil {
return err return err
} }
if err := appendFile(inputF, reversePrompt+message); err != nil { if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil {
return err return err
} }
if _, err := os.Stat(cacheF); os.IsNotExist(err) { if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil {
if err := func() error { return err
}
justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt)
if err != nil {
return err
}
if err := os.Rename(inputF, promptF); err != nil {
return err
}
w.Write(bytes.TrimSuffix(justNew, []byte(reversePrompt)))
return nil
}
func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error {
if _, err := os.Stat(cacheF); !os.IsNotExist(err) {
return nil
}
commands := strings.Fields(Config.ChatBot.Command) commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands, commands = append(commands,
"--batch-size", "8", "--batch-size", "8",
@ -315,7 +348,7 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
"--n_predict", "1", "--n_predict", "1",
) )
command := exec.CommandContext( command := exec.CommandContext(
r.Context(), ctx,
commands[0], commands[0],
commands[1:]..., commands[1:]...,
) )
@ -328,21 +361,23 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b) return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b)
} }
return nil return nil
}(); err != nil {
return err
}
} }
func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) {
commands := strings.Fields(Config.ChatBot.Command) commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands, commands = append(commands,
"-f", inputF, "-f", inputF,
"--prompt-cache-all", "--prompt-cache-all",
"--prompt-cache", cacheF, "--prompt-cache", cacheF,
"-n", strconv.Itoa(Config.ChatBot.N), "-n", strconv.Itoa(Config.ChatBot.N),
)
if len(reversePrompt) > 0 {
commands = append(commands,
"--reverse-prompt", reversePrompt, "--reverse-prompt", reversePrompt,
) )
}
command := exec.CommandContext( command := exec.CommandContext(
r.Context(), ctx,
commands[0], commands[0],
commands[1:]..., commands[1:]...,
) )
@ -351,33 +386,52 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
stdout, err := command.StdoutPipe() stdout, err := command.StdoutPipe()
if err != nil { if err != nil {
return err return nil, err
} }
buff := bytes.NewBuffer(nil) buff := bytes.NewBuffer(nil)
go func() { go func() {
stdout.Read(make([]byte, 1)) stdout.Read(make([]byte, 1)) //1 BOS byte
io.Copy(buff, stdout) io.Copy(buff, stdout)
}() }()
if err := func() error {
Config.ChatBot.semaphore.Lock() Config.ChatBot.semaphore.Lock()
defer Config.ChatBot.semaphore.Unlock() defer Config.ChatBot.semaphore.Unlock()
if err := command.Run(); err != nil { return command.Run()
return err }(); err != nil {
return nil, err
} }
result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt)) oldAndNew := buff.Bytes()
log.Printf("generated: [%s]", oldAndNew)
priorContent, err := os.ReadFile(inputF) priorContent, err := os.ReadFile(inputF)
if err != nil { if err != nil {
return err return nil, err
} }
shortResult := result[len(priorContent):] justNew := oldAndNew[len(priorContent):]
if err := os.WriteFile(promptF, result, os.ModePerm); err != nil { idx := bytes.LastIndex(
return err append(priorContent, justNew...),
[]byte(reversePrompt),
)
log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx)
if idx+len(reversePrompt) > len(priorContent) {
justNew = justNew[:idx+len(reversePrompt)-len(priorContent)]
} }
w.Write(shortResult) if err := _appendFile(inputF, string(justNew)); err != nil {
return nil, err
}
log.Printf("newly generated: [%s]", justNew)
return nil if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) {
more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt)
if err != nil {
return nil, err
}
justNew = append(justNew, more...)
}
return justNew, nil
} }
func copyFile(toF, fromF string) error { func copyFile(toF, fromF string) error {
@ -398,16 +452,18 @@ func copyFile(toF, fromF string) error {
} }
func appendFile(toF, msg string) error { func appendFile(toF, msg string) error {
return _appendFile(toF, "\n"+msg+"\n")
}
func _appendFile(toF, msg string) error {
f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600) f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer f.Close()
f.Write([]byte("\n"))
if _, err := f.WriteString(msg); err != nil { if _, err := f.WriteString(msg); err != nil {
return err return err
} }
f.Write([]byte("\n"))
return nil return nil
} }

View File

@ -7,7 +7,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -19,78 +18,85 @@ import (
func TestAPIV0ChatBot(t *testing.T) { func TestAPIV0ChatBot(t *testing.T) {
defer goTestMain(t)() defer goTestMain(t)()
body := url.Values{} body := func() url.Values {
body.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.") result := url.Values{}
body.Set(`Message`, `I lost keys to the company car in my couch, boss.`) result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
return result
}
t.Run("put over post", func(t *testing.T) { t.Run("put over post", func(t *testing.T) {
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body) got, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("(%d) %s", resp.StatusCode, got) t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) < 100 { if len(got) < 20 {
t.Error(string(got)) t.Error(string(got))
} }
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
got2, err := io.ReadAll(resp2.Body) got2, err := io.ReadAll(resp2.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("(%d) %s", resp2.StatusCode, got2) t.Logf("(%d) %s", resp2.StatusCode, got2)
if len(got2) < 100 { if len(got2) < 20 {
t.Error(string(got2)) t.Error(string(got2))
} }
if bytes.Equal(got, got2) { if bytes.Equal(got, got2) {
t.Error("dupe generation") t.Errorf("dupe generation: %s", stderrStash.Bytes())
stderrStash.Reset()
} }
}) })
t.Run("post over post", func(t *testing.T) { t.Run("post over post", func(t *testing.T) {
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body) got, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("(%d) %s", resp.StatusCode, got) t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) < 100 { if len(got) < 20 {
t.Error(string(got)) t.Error(string(got))
} }
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode()) resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got2, err := io.ReadAll(resp2.Body) got2, err := io.ReadAll(resp2.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("(%d) %s", resp2.StatusCode, got2) t.Logf("(%d) %s", resp2.StatusCode, got2)
if len(got2) < 100 { if len(got2) < 20 {
t.Error(string(got)) t.Error(string(got))
} }
if bytes.Equal(got, got2) { if bytes.Equal(got, got2) {
t.Error("dupe generation") t.Errorf("dupe generation: %s", stderrStash.Bytes())
stderrStash.Reset()
} }
}) })
t.Run("put over zero", func(t *testing.T) { t.Run("put over zero", func(t *testing.T) {
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode()) resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body) got, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Logf("(%d) %s", resp.StatusCode, got) t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) > 100 { if len(got) > 20 {
t.Error(string(got)) t.Error(string(got))
} }
}) })
} }
var stderrStash = bytes.NewBuffer(nil)
func goTestMain(t *testing.T) func() { func goTestMain(t *testing.T) func() {
log.SetOutput(io.Discard) //log.SetOutput(stderrStash)
ctx, can := context.WithCancel(context.Background()) ctx, can := context.WithCancel(context.Background())
ctx, cleanup := contextWithCleanup(ctx) ctx, cleanup := contextWithCleanup(ctx)
@ -122,14 +128,16 @@ func httpDo(t *testing.T, method, path, body string) *http.Response {
req.Header.Set("Cookie", "root="+cookie.Serialize()) req.Header.Set("Cookie", "root="+cookie.Serialize())
for { for {
req.Body = io.NopCloser(strings.NewReader(body)) req.Body = io.NopCloser(strings.NewReader(body))
if resp, err := http.DefaultClient.Do(req.Clone(context.Background())); err == nil { resp, err := http.DefaultClient.Do(req.Clone(context.Background()))
if err == nil {
defer resp.Body.Close() defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body) b, _ := io.ReadAll(resp.Body)
resp.Body = io.NopCloser(bytes.NewReader(b)) resp.Body = io.NopCloser(bytes.NewReader(b))
return resp return resp
} }
time.Sleep(time.Millisecond * 25) time.Sleep(time.Millisecond * 25)
t.Logf("retrying %s", req.URL.String()) t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes())
stderrStash.Reset()
} }
t.Fatalf("failed to ever %s %s", method, path) t.Fatalf("failed to ever %s %s", method, path)
panic(nil) panic(nil)