ohno
parent
f37098223e
commit
4f08abbb61
|
|
@ -16,6 +16,7 @@ import (
|
|||
"os/exec"
|
||||
"os/signal"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -27,11 +28,12 @@ var (
|
|||
Port int
|
||||
Debug bool
|
||||
ChatBot struct {
|
||||
SessionD string
|
||||
semaphore sync.Mutex
|
||||
WD string
|
||||
Command string
|
||||
N int
|
||||
SessionD string
|
||||
PromptDelimiter string
|
||||
semaphore sync.Mutex
|
||||
WD string
|
||||
Command string
|
||||
N int
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -88,6 +90,7 @@ func config(ctx context.Context) {
|
|||
|
||||
flag.IntVar(&Config.Port, "p", 37070, "port to listen on")
|
||||
flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions")
|
||||
flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU")
|
||||
flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot")
|
||||
flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix")
|
||||
flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen")
|
||||
|
|
@ -117,9 +120,9 @@ func listenAndServe(ctx context.Context) {
|
|||
func handle(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, _ := ParseCookie(r)
|
||||
if err := _handle(w, r); err != nil {
|
||||
log.Printf("%s: %s %s: %v", cookie.Name, r.Method, r.URL.Path, err)
|
||||
log.Printf("%s: %s %s: %v", cookie.MyName(), r.Method, r.URL.Path, err)
|
||||
} else {
|
||||
log.Printf("%s: %s %s", cookie.Name, r.Method, r.URL.Path)
|
||||
log.Printf("%s: %s %s", cookie.MyName(), r.Method, r.URL.Path)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -208,8 +211,15 @@ func parseCookieFromCookie(r *http.Request) (Cookie, error) {
|
|||
return result, result.Verify()
|
||||
}
|
||||
|
||||
func (cookie Cookie) MyName() string {
|
||||
return string(bytes.Join(
|
||||
regexp.MustCompile(`[a-zA-Z]`).FindAll([]byte(cookie.Name), -1),
|
||||
[]byte(""),
|
||||
))
|
||||
}
|
||||
|
||||
func (cookie Cookie) Verify() error {
|
||||
if cookie.Name == "" {
|
||||
if cookie.MyName() == "" {
|
||||
return fmt.Errorf("incomplete cookie")
|
||||
}
|
||||
return nil
|
||||
|
|
@ -265,7 +275,7 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
|
||||
cookie, _ := ParseCookie(r)
|
||||
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name)
|
||||
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
|
||||
if _, err := os.Stat(path.Join(sessionD)); err == nil {
|
||||
if err := os.RemoveAll(path.Join(sessionD)); err != nil {
|
||||
return err
|
||||
|
|
@ -293,56 +303,81 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
cookie, _ := ParseCookie(r)
|
||||
sessionD := path.Join(Config.ChatBot.SessionD, cookie.Name)
|
||||
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
|
||||
|
||||
promptF := path.Join(sessionD, "prompt.txt")
|
||||
inputF := path.Join(sessionD, "input.txt")
|
||||
cacheF := path.Join(sessionD, "cache.bin")
|
||||
reversePrompt := cookie.Name + "///"
|
||||
reversePrompt := cookie.MyName()
|
||||
if len(reversePrompt) > 8 {
|
||||
reversePrompt = reversePrompt[:8]
|
||||
}
|
||||
reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter
|
||||
forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter
|
||||
if err := copyFile(inputF, promptF); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := appendFile(inputF, reversePrompt+message); err != nil {
|
||||
if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := os.Stat(cacheF); os.IsNotExist(err) {
|
||||
if err := func() error {
|
||||
commands := strings.Fields(Config.ChatBot.Command)
|
||||
commands = append(commands,
|
||||
"--batch-size", "8",
|
||||
"--prompt-cache", cacheF,
|
||||
"-f", inputF,
|
||||
"--n_predict", "1",
|
||||
)
|
||||
command := exec.CommandContext(
|
||||
r.Context(),
|
||||
commands[0],
|
||||
commands[1:]...,
|
||||
)
|
||||
command.Dir = Config.ChatBot.WD
|
||||
|
||||
Config.ChatBot.semaphore.Lock()
|
||||
defer Config.ChatBot.semaphore.Unlock()
|
||||
|
||||
if b, err := command.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b)
|
||||
}
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Rename(inputF, promptF); err != nil {
|
||||
return err
|
||||
}
|
||||
w.Write(bytes.TrimSuffix(justNew, []byte(reversePrompt)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error {
|
||||
if _, err := os.Stat(cacheF); !os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
commands := strings.Fields(Config.ChatBot.Command)
|
||||
commands = append(commands,
|
||||
"--batch-size", "8",
|
||||
"--prompt-cache", cacheF,
|
||||
"-f", inputF,
|
||||
"--n_predict", "1",
|
||||
)
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
commands[0],
|
||||
commands[1:]...,
|
||||
)
|
||||
command.Dir = Config.ChatBot.WD
|
||||
|
||||
Config.ChatBot.semaphore.Lock()
|
||||
defer Config.ChatBot.semaphore.Unlock()
|
||||
|
||||
if b, err := command.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("error generating cache with '%s': %w: %s", command.String(), err, b)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) {
|
||||
commands := strings.Fields(Config.ChatBot.Command)
|
||||
commands = append(commands,
|
||||
"-f", inputF,
|
||||
"--prompt-cache-all",
|
||||
"--prompt-cache", cacheF,
|
||||
"-n", strconv.Itoa(Config.ChatBot.N),
|
||||
"--reverse-prompt", reversePrompt,
|
||||
)
|
||||
if len(reversePrompt) > 0 {
|
||||
commands = append(commands,
|
||||
"--reverse-prompt", reversePrompt,
|
||||
)
|
||||
}
|
||||
command := exec.CommandContext(
|
||||
r.Context(),
|
||||
ctx,
|
||||
commands[0],
|
||||
commands[1:]...,
|
||||
)
|
||||
|
|
@ -351,33 +386,52 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
stdout, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
buff := bytes.NewBuffer(nil)
|
||||
go func() {
|
||||
stdout.Read(make([]byte, 1))
|
||||
stdout.Read(make([]byte, 1)) //1 BOS byte
|
||||
io.Copy(buff, stdout)
|
||||
}()
|
||||
|
||||
Config.ChatBot.semaphore.Lock()
|
||||
defer Config.ChatBot.semaphore.Unlock()
|
||||
if err := func() error {
|
||||
Config.ChatBot.semaphore.Lock()
|
||||
defer Config.ChatBot.semaphore.Unlock()
|
||||
|
||||
return command.Run()
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
|
||||
if err := command.Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := bytes.TrimSuffix(buff.Bytes(), []byte(reversePrompt))
|
||||
oldAndNew := buff.Bytes()
|
||||
log.Printf("generated: [%s]", oldAndNew)
|
||||
priorContent, err := os.ReadFile(inputF)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
shortResult := result[len(priorContent):]
|
||||
if err := os.WriteFile(promptF, result, os.ModePerm); err != nil {
|
||||
return err
|
||||
justNew := oldAndNew[len(priorContent):]
|
||||
idx := bytes.LastIndex(
|
||||
append(priorContent, justNew...),
|
||||
[]byte(reversePrompt),
|
||||
)
|
||||
log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx)
|
||||
if idx+len(reversePrompt) > len(priorContent) {
|
||||
justNew = justNew[:idx+len(reversePrompt)-len(priorContent)]
|
||||
}
|
||||
w.Write(shortResult)
|
||||
if err := _appendFile(inputF, string(justNew)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Printf("newly generated: [%s]", justNew)
|
||||
|
||||
return nil
|
||||
if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) {
|
||||
more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
justNew = append(justNew, more...)
|
||||
}
|
||||
return justNew, nil
|
||||
}
|
||||
|
||||
func copyFile(toF, fromF string) error {
|
||||
|
|
@ -398,16 +452,18 @@ func copyFile(toF, fromF string) error {
|
|||
}
|
||||
|
||||
func appendFile(toF, msg string) error {
|
||||
return _appendFile(toF, "\n"+msg+"\n")
|
||||
}
|
||||
|
||||
func _appendFile(toF, msg string) error {
|
||||
f, err := os.OpenFile(toF, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
f.Write([]byte("\n"))
|
||||
if _, err := f.WriteString(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Write([]byte("\n"))
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
|
@ -19,78 +18,85 @@ import (
|
|||
func TestAPIV0ChatBot(t *testing.T) {
|
||||
defer goTestMain(t)()
|
||||
|
||||
body := url.Values{}
|
||||
body.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
|
||||
body.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
|
||||
body := func() url.Values {
|
||||
result := url.Values{}
|
||||
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
|
||||
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
|
||||
return result
|
||||
}
|
||||
|
||||
t.Run("put over post", func(t *testing.T) {
|
||||
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
|
||||
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("(%d) %s", resp.StatusCode, got)
|
||||
if len(got) < 100 {
|
||||
if len(got) < 20 {
|
||||
t.Error(string(got))
|
||||
}
|
||||
|
||||
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode())
|
||||
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
|
||||
got2, err := io.ReadAll(resp2.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("(%d) %s", resp2.StatusCode, got2)
|
||||
if len(got2) < 100 {
|
||||
if len(got2) < 20 {
|
||||
t.Error(string(got2))
|
||||
}
|
||||
|
||||
if bytes.Equal(got, got2) {
|
||||
t.Error("dupe generation")
|
||||
t.Errorf("dupe generation: %s", stderrStash.Bytes())
|
||||
stderrStash.Reset()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("post over post", func(t *testing.T) {
|
||||
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
|
||||
resp := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("(%d) %s", resp.StatusCode, got)
|
||||
if len(got) < 100 {
|
||||
if len(got) < 20 {
|
||||
t.Error(string(got))
|
||||
}
|
||||
|
||||
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body.Encode())
|
||||
resp2 := httpDo(t, http.MethodPost, "/api/v0/chatbot", body().Encode())
|
||||
got2, err := io.ReadAll(resp2.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("(%d) %s", resp2.StatusCode, got2)
|
||||
if len(got2) < 100 {
|
||||
if len(got2) < 20 {
|
||||
t.Error(string(got))
|
||||
}
|
||||
|
||||
if bytes.Equal(got, got2) {
|
||||
t.Error("dupe generation")
|
||||
t.Errorf("dupe generation: %s", stderrStash.Bytes())
|
||||
stderrStash.Reset()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("put over zero", func(t *testing.T) {
|
||||
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body.Encode())
|
||||
resp := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("(%d) %s", resp.StatusCode, got)
|
||||
if len(got) > 100 {
|
||||
if len(got) > 20 {
|
||||
t.Error(string(got))
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
var stderrStash = bytes.NewBuffer(nil)
|
||||
|
||||
func goTestMain(t *testing.T) func() {
|
||||
log.SetOutput(io.Discard)
|
||||
//log.SetOutput(stderrStash)
|
||||
|
||||
ctx, can := context.WithCancel(context.Background())
|
||||
ctx, cleanup := contextWithCleanup(ctx)
|
||||
|
|
@ -122,14 +128,16 @@ func httpDo(t *testing.T, method, path, body string) *http.Response {
|
|||
req.Header.Set("Cookie", "root="+cookie.Serialize())
|
||||
for {
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
if resp, err := http.DefaultClient.Do(req.Clone(context.Background())); err == nil {
|
||||
resp, err := http.DefaultClient.Do(req.Clone(context.Background()))
|
||||
if err == nil {
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
resp.Body = io.NopCloser(bytes.NewReader(b))
|
||||
return resp
|
||||
}
|
||||
time.Sleep(time.Millisecond * 25)
|
||||
t.Logf("retrying %s", req.URL.String())
|
||||
t.Logf("retrying %s: %v (%s)", req.URL.String(), err, stderrStash.Bytes())
|
||||
stderrStash.Reset()
|
||||
}
|
||||
t.Fatalf("failed to ever %s %s", method, path)
|
||||
panic(nil)
|
||||
|
|
|
|||
Loading…
Reference in New Issue