master
bel 2023-06-17 13:52:12 -06:00
parent f37098223e
commit 4f08abbb61
2 changed files with 138 additions and 74 deletions

View File

@ -16,6 +16,7 @@ import (
"os/exec"
"os/signal"
"path"
"regexp"
"strconv"
"strings"
"sync"
@ -27,11 +28,12 @@ var (
Port int
Debug bool
ChatBot struct {
SessionD string
semaphore sync.Mutex
WD string
Command string
N int
SessionD string
PromptDelimiter string
semaphore sync.Mutex
WD string
Command string
N int
}
}
@ -88,6 +90,7 @@ func config(ctx context.Context) {
flag.IntVar(&Config.Port, "p", 37070, "port to listen on")
flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions")
flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU")
flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot")
flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix")
flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen")
@ -117,9 +120,9 @@ func listenAndServe(ctx context.Context) {
func handle(w http.ResponseWriter, r *http.Request) {
cookie, _ := ParseCookie(r)
if err := _handle(w, r); err != nil {
log.Printf("%s: %s %s: %v", cookie.Name, r.Method, r.URL.Path, err)
log.Printf("%s: %s %s: %v", cookie.MyName(), r.Method, r.URL.Path, err)
} else {
log.Printf("%s: %s %s", cookie.Name, r.Method, r.URL.Path)
log.Printf("%s: %s %s", cookie.MyName(), r.Method, r.URL.Path)
}
}
@ -208,8 +211,15 @@ func parseCookieFromCookie(r *http.Request) (Cookie, error) {
return result, result.Verify()
}
func (cookie Cookie) MyName() string {
return string(bytes.Join(
regexp.MustCompile(`[a-zA-Z]`).FindAll([]byte(cookie.Name), -1),
[]byte(""),
))
}
func (cookie Cookie) Verify() error {
if cookie.Name == "" {
if cookie.MyName() == "" {
return fmt.Errorf("incomplete cookie")
}
return nil
@ -265,7 +275,7 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error {
func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
if _, err := os.Stat(path.Join(sessionD)); err == nil {
if err := os.RemoveAll(path.Join(sessionD)); err != nil {
return err
@ -293,56 +303,81 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
}
cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
promptF := path.Join(sessionD, "prompt.txt")
inputF := path.Join(sessionD, "input.txt")
cacheF := path.Join(sessionD, "cache.bin")
reversePrompt := cookie.Name + "///"
reversePrompt := cookie.MyName()
if len(reversePrompt) > 8 {
reversePrompt = reversePrompt[:8]
}
reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter
forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter
if err := copyFile(inputF, promptF); err != nil {
return err
}
if err := appendFile(inputF, reversePrompt+message); err != nil {
if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil {
return err
}
if _, err := os.Stat(cacheF); os.IsNotExist(err) {
if err := func() error {
commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands,
"--batch-size", "8",
"--prompt-cache", cacheF,
"-f", inputF,
"--n_predict", "1",
)
command := exec.CommandContext(
r.Context(),
commands[0],
commands[1:]...,
)
command.Dir = Config.ChatBot.WD
Config.ChatBot.semaphore.Lock()
defer Config.ChatBot.semaphore.Unlock()
if b, err := command.CombinedOutput(); err != nil {
return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b)
}
return nil
}(); err != nil {
return err
}
if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil {
return err
}
justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt)
if err != nil {
return err
}
if err := os.Rename(inputF, promptF); err != nil {
return err
}
w.Write(bytes.TrimSuffix(justNew, []byte(reversePrompt)))
return nil
}
func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error {
if _, err := os.Stat(cacheF); !os.IsNotExist(err) {
return nil
}
commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands,
"--batch-size", "8",
"--prompt-cache", cacheF,
"-f", inputF,
"--n_predict", "1",
)
command := exec.CommandContext(
ctx,
commands[0],
commands[1:]...,
)
command.Dir = Config.ChatBot.WD
Config.ChatBot.semaphore.Lock()
defer Config.ChatBot.semaphore.Unlock()
if b, err := command.CombinedOutput(); err != nil {
return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b)
}
return nil
}
func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) {
commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands,
"-f", inputF,
"--prompt-cache-all",
"--prompt-cache", cacheF,
"-n", strconv.Itoa(Config.ChatBot.N),
"--reverse-prompt", reversePrompt,
)
if len(reversePrompt) > 0 {
commands = append(commands,
"--reverse-prompt", reversePrompt,
)
}
command := exec.CommandContext(
r.Context(),
ctx,
commands[0],
commands[1:]...,
)
@ -351,33 +386,52 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
stdout, err := command.StdoutPipe()
if err != nil {
return err
return nil, err
}
buff := bytes.NewBuffer(nil)
go func() {
stdout.Read(make([]byte, 1))
stdout.Read(make([]byte, 1)) //1 BOS byte
io.Copy(buff, stdout)
}()
Config.ChatBot.semaphore.Lock()
defer Config.ChatBot.semaphore.Unlock()
if err := func() error {
Config.ChatBot.semaphore.Lock()
defer Config.ChatBot.semaphore.Unlock()
return command.Run()
}(); err != nil {
return nil, err
if err := command.Run(); err != nil {
return err
}
result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt))
oldAndNew := buff.Bytes()
log.Printf("generated: [%s]", oldAndNew)
priorContent, err := os.ReadFile(inputF)
if err != nil {
return err
return nil, err
}
shortResult := result[len(priorContent):]
if err := os.WriteFile(promptF, result, os.ModePerm); err != nil {
return err
justNew := oldAndNew[len(priorContent):]
idx := bytes.LastIndex(
append(priorContent, justNew...),
[]byte(reversePrompt),
)
log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx)
if idx+len(reversePrompt) > len(priorContent) {
justNew = justNew[:idx+len(reversePrompt)-len(priorContent)]
}
w.Write(shortResult)
if err := _appendFile(inputF, string(justNew)); err != nil {
return nil, err
}
log.Printf("newly generated: [%s]", justNew)
return nil
if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) {
more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt)
if err != nil {
return nil, err
}
justNew = append(justNew, more...)
}
return justNew, nil
}
func copyFile(toF, fromF string) error {
@ -398,16 +452,18 @@ func copyFile(toF, fromF string) error {
}
func appendFile(toF, msg string) error {
return _appendFile(toF, "\n"+msg+"\n")
}
func _appendFile(toF, msg string) error {
f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return err
}
defer f.Close()
f.Write([]byte("\n"))
if _, err := f.WriteString(msg); err != nil {
return err
}
f.Write([]byte("\n"))
return nil
}

View File

@ -7,7 +7,6 @@ import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
@ -19,78 +18,85 @@ import (
func TestAPIV0ChatBot(t *testing.T) {
defer goTestMain(t)()
body := url.Values{}
body.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
body.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
body := func() url.Values {
result := url.Values{}
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
return result
}
t.Run("put over post", func(t *testing.T) {
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) < 100 {
if len(got) < 20 {
t.Error(string(got))
}
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode())
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
got2, err := io.ReadAll(resp2.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("(%d) %s", resp2.StatusCode, got2)
if len(got2) < 100 {
if len(got2) < 20 {
t.Error(string(got2))
}
if bytes.Equal(got, got2) {
t.Error("dupe generation")
t.Errorf("dupe generation: %s", stderrStash.Bytes())
stderrStash.Reset()
}
})
t.Run("post over post", func(t *testing.T) {
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) < 100 {
if len(got) < 20 {
t.Error(string(got))
}
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
got2, err := io.ReadAll(resp2.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("(%d) %s", resp2.StatusCode, got2)
if len(got2) < 100 {
if len(got2) < 20 {
t.Error(string(got))
}
if bytes.Equal(got, got2) {
t.Error("dupe generation")
t.Errorf("dupe generation: %s", stderrStash.Bytes())
stderrStash.Reset()
}
})
t.Run("put over zero", func(t *testing.T) {
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode())
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
t.Logf("(%d) %s", resp.StatusCode, got)
if len(got) > 100 {
if len(got) > 20 {
t.Error(string(got))
}
})
}
var stderrStash = bytes.NewBuffer(nil)
func goTestMain(t *testing.T) func() {
log.SetOutput(io.Discard)
//log.SetOutput(stderrStash)
ctx, can := context.WithCancel(context.Background())
ctx, cleanup := contextWithCleanup(ctx)
@ -122,14 +128,16 @@ func httpDo(t *testing.T, method, path, body string) *http.Response {
req.Header.Set("Cookie", "root="+cookie.Serialize())
for {
req.Body = io.NopCloser(strings.NewReader(body))
if resp, err := http.DefaultClient.Do(req.Clone(context.Background())); err == nil {
resp, err := http.DefaultClient.Do(req.Clone(context.Background()))
if err == nil {
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
resp.Body = io.NopCloser(bytes.NewReader(b))
return resp
}
time.Sleep(time.Millisecond * 25)
t.Logf("retrying %s", req.URL.String())
t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes())
stderrStash.Reset()
}
t.Fatalf("failed to ever %s %s", method, path)
panic(nil)