Compare commits

...

10 Commits

Author SHA1 Message Date
bel
429bdfff78 todo 2023-06-17 16:02:50 -06:00
bel
a29101296b integration tests k 2023-06-17 15:29:56 -06:00
bel
242bf0a746 i gotta pass prompt... 2023-06-17 15:18:21 -06:00
bel
68b9010ed5 not sure why stderr to discard makes things funnier.... 2023-06-17 15:05:35 -06:00
bel
30bd4311a5 wip 2023-06-17 15:02:54 -06:00
bel
8abea5c8c1 ui disables while loading 2023-06-17 14:48:58 -06:00
bel
0a3a7a0616 almost 2023-06-17 14:36:41 -06:00
bel
1a7f783fe0 ok submitting a prompt 2023-06-17 14:31:47 -06:00
bel
287b72639d break recurse generation if current output contains a delimiter using first OR last when combined with prev if not contains 2023-06-17 14:01:13 -06:00
bel
68a87dc3db passing 2023-06-17 13:58:33 -06:00
4 changed files with 203 additions and 34 deletions

View File

@@ -28,12 +28,11 @@ var (
Port int
Debug bool
ChatBot struct {
SessionD string
PromptDelimiter string
semaphore sync.Mutex
WD string
Command string
N int
SessionD string
semaphore sync.Mutex
WD string
Command string
N int
}
}
@@ -90,7 +89,6 @@ func config(ctx context.Context) {
flag.IntVar(&Config.Port, "p", 37070, "port to listen on")
flag.StringVar(&Config.ChatBot.SessionD, "chatbot-session-d", d, "dir to store chat bot sessions")
flag.StringVar(&Config.ChatBot.PromptDelimiter, "chatbot-rp", "> ", "prompt delimiter prefixed by NAME/YOU")
flag.StringVar(&Config.ChatBot.WD, "chatbot-working-d", "./llama.cpp", "working directory for chatbot")
flag.StringVar(&Config.ChatBot.Command, "chatbot-cmd", "./main -m ./models/ggml-vic7b-uncensored-q5_1.bin --repeat_penalty 1.0", "chatbot cmd prefix")
flag.IntVar(&Config.ChatBot.N, "chatbot-n", 256, "chatbot items to gen")
@@ -127,6 +125,10 @@ func handle(w http.ResponseWriter, r *http.Request) {
}
func _handle(w http.ResponseWriter, r *http.Request) error {
//ctx, can := context.WithTimeout(r.Context(), time.Minute)
//defer can()
//r = r.WithContext(ctx)
first := strings.Split(strings.TrimLeft(r.URL.Path, "/"), "/")[0]
switch first {
case "login":
@@ -171,7 +173,12 @@ func handleUI(w http.ResponseWriter, r *http.Request) error {
w.Write(htmlIndex)
return nil
case "/chatbot":
w.Write(htmlChatBot)
if !Config.Debug {
w.Write(htmlChatBot)
} else {
b, _ := os.ReadFile("./template.d/chatbot.html")
w.Write(b)
}
return nil
default:
return handleNotFound(w, r)
@@ -258,6 +265,7 @@ func handleNotFound(w http.ResponseWriter, r *http.Request) error {
}
func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error {
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
err := r.ParseForm()
if err != nil {
return err
@@ -268,11 +276,41 @@ func handleAPIChatBot(w http.ResponseWriter, r *http.Request) error {
return handleAPIChatBotPost(w, r)
case http.MethodPut:
return handleAPIChatBotPut(w, r)
case http.MethodGet:
return handleAPIChatBotGet(w, r)
default:
return handleNotFound(w, r)
}
}
func handleAPIChatBotGet(w http.ResponseWriter, r *http.Request) error {
cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
messages, err := os.ReadFile(path.Join(sessionD, "prompt.txt"))
if err != nil {
return err
}
reversePrompt, err := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt"))
if err != nil {
return err
}
prompt, err := os.ReadFile(path.Join(sessionD, "initial-prompt.txt"))
if err != nil {
return err
}
return json.NewEncoder(w).Encode(struct {
Messages string
Prompt string
ReversePrompt string
}{
Messages: string(bytes.TrimPrefix(messages, prompt)),
Prompt: string(prompt),
ReversePrompt: string(reversePrompt),
})
}
func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
cookie, _ := ParseCookie(r)
sessionD := path.Join(Config.ChatBot.SessionD, cookie.MyName())
@@ -289,6 +327,13 @@ func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
if len(prompt) == 0 {
return errors.New("no prompt")
}
reverseprompt := r.PostForm.Get("ReversePrompt")
if err := os.WriteFile(path.Join(sessionD, "reverse-prompt.txt"), []byte(reverseprompt), os.ModePerm); err != nil {
return err
}
if err := os.WriteFile(path.Join(sessionD, "initial-prompt.txt"), []byte(prompt), os.ModePerm); err != nil {
return err
}
if err := os.WriteFile(path.Join(sessionD, "prompt.txt"), []byte(prompt), os.ModePerm); err != nil {
return err
}
@@ -298,7 +343,7 @@ func handleAPIChatBotPost(w http.ResponseWriter, r *http.Request) error {
func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
message := r.PostForm.Get("Message")
if len(message) == 0 {
if len(message) == 0 && r.URL.Query.Get("more") != "true" {
return errors.New("empty Message")
}
@@ -308,23 +353,21 @@ func handleAPIChatBotPut(w http.ResponseWriter, r *http.Request) error {
promptF := path.Join(sessionD, "prompt.txt")
inputF := path.Join(sessionD, "input.txt")
cacheF := path.Join(sessionD, "cache.bin")
reversePrompt := cookie.MyName()
if len(reversePrompt) > 8 {
reversePrompt = reversePrompt[:8]
}
reversePrompt = reversePrompt + Config.ChatBot.PromptDelimiter
forwardPrompt := "YOU" + Config.ChatBot.PromptDelimiter
reversePrompt := func() string {
b, _ := os.ReadFile(path.Join(sessionD, "reverse-prompt.txt"))
return string(b)
}()
if err := copyFile(inputF, promptF); err != nil {
return err
}
if err := chatBotGenerateInitCacheF(r.Context(), cacheF, inputF); err != nil {
return err
}
if err := appendFile(inputF, reversePrompt+message+"\n"+forwardPrompt); err != nil {
if err := appendFile(inputF, reversePrompt+message+"\n"); err != nil {
return err
}
justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt)
justNew, err := chatBotGenerateAndFillInputF(r.Context(), cacheF, inputF, reversePrompt, 0)
if err != nil {
return err
}
@@ -363,7 +406,12 @@ func chatBotGenerateInitCacheF(ctx context.Context, cacheF, inputF string) error
return nil
}
func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string) ([]byte, error) {
func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePrompt string, depth int) ([]byte, error) {
if depth >= 3 {
justNew := []byte("...\n" + reversePrompt)
return justNew, _appendFile(inputF, string(justNew))
}
commands := strings.Fields(Config.ChatBot.Command)
commands = append(commands,
"-f", inputF,
@@ -382,7 +430,12 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
commands[1:]...,
)
command.Dir = Config.ChatBot.WD
command.Stderr = log.Writer()
logf, err := os.Create(path.Join(Config.ChatBot.SessionD, "chatbot.stderr"))
if err != nil {
return nil, err
}
defer logf.Close()
command.Stderr = logf
stdout, err := command.StdoutPipe()
if err != nil {
@@ -390,7 +443,7 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
}
buff := bytes.NewBuffer(nil)
go func() {
stdout.Read(make([]byte, 1)) //1 BOS byte
//stdout.Read(make([]byte, 1)) //1 BOS byte
io.Copy(buff, stdout)
}()
@@ -405,27 +458,31 @@ func chatBotGenerateAndFillInputF(ctx context.Context, cacheF, inputF, reversePr
}
oldAndNew := buff.Bytes()
log.Printf("generated: [%s]", oldAndNew)
priorContent, err := os.ReadFile(inputF)
if err != nil {
return nil, err
}
justNew := oldAndNew[len(priorContent):]
idx := bytes.LastIndex(
append(priorContent, justNew...),
[]byte(reversePrompt),
)
log.Printf("found lastindex(%d priorContent + %d newContent, %s) = %v", len(priorContent), len(justNew), reversePrompt, idx)
if idx+len(reversePrompt) > len(priorContent) {
justNew = justNew[:idx+len(reversePrompt)-len(priorContent)]
trimmedReversePrompt := strings.TrimSpace(reversePrompt)
if len(trimmedReversePrompt) > 0 {
if idx := bytes.Index(justNew, []byte(trimmedReversePrompt)); idx > -1 {
justNew = justNew[:idx+len(trimmedReversePrompt)]
} else if idx := bytes.LastIndex(
append(priorContent, justNew...),
[]byte(trimmedReversePrompt),
); idx+len(trimmedReversePrompt) > len(priorContent) {
justNew = justNew[:idx+len(trimmedReversePrompt)-len(priorContent)]
}
}
panic("TODO TRIM THE GARBO AT THE END")
if err := _appendFile(inputF, string(justNew)); err != nil {
return nil, err
}
log.Printf("newly generated: [%s]", justNew)
if !bytes.HasSuffix(append(priorContent, justNew...), []byte(reversePrompt)) {
more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt)
log.Printf("%s generated %q", trimmedReversePrompt, justNew)
if !bytes.HasSuffix(append(priorContent, justNew...), []byte(trimmedReversePrompt)) {
more, err := chatBotGenerateAndFillInputF(ctx, cacheF, inputF, reversePrompt, depth+1)
if err != nil {
return nil, err
}

View File

@@ -5,6 +5,7 @@ package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -20,7 +21,8 @@ func TestAPIV0ChatBot(t *testing.T) {
body := func() url.Values {
result := url.Values{}
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'.")
result.Set(`Prompt`, "Text transcript of a never ending dialogue between a middle manager and his direct reports. The middle manager works in a middle sized corporation and must tell employees what he thinks of employees' work. Middle manager always replies to bad news with an overly optimistic observation prefixed with 'Perhaps, but have you considered'. Middle manager always replies with non-zero verbal responses.\n\nMIDDLE_MANAGER: What a lovely day.")
result.Set(`ReversePrompt`, "AN_EMPLOYEE: ")
result.Set(`Message`, `I lost keys to the company car in my couch, boss.`)
return result
}
@@ -36,7 +38,9 @@ func TestAPIV0ChatBot(t *testing.T) {
t.Error(string(got))
}
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body().Encode())
body2 := body()
body2.Set("Message", "I don't know, boss....")
resp2 := httpDo(t, http.MethodPut, "/api/v0/chatbot", body2.Encode())
got2, err := io.ReadAll(resp2.Body)
if err != nil {
t.Fatal(err)
@@ -50,6 +54,28 @@ func TestAPIV0ChatBot(t *testing.T) {
t.Errorf("dupe generation: %s", stderrStash.Bytes())
stderrStash.Reset()
}
resp3 := httpDo(t, http.MethodGet, "/api/v0/chatbot", "")
var result struct {
Messages string
ReversePrompt string
Prompt string
}
if err := json.NewDecoder(resp3.Body).Decode(&result); err != nil {
t.Fatal(err)
}
if !bytes.Contains([]byte(result.Messages), got) {
t.Errorf("forgot got: %s does not contain %s", result.Messages, got)
}
if !bytes.Contains([]byte(result.Messages), got2) {
t.Errorf("forgot got2: %s does not contain %s", result.Messages, got2)
}
if result.Prompt != body().Get("Prompt") {
t.Error(result.Prompt)
}
if result.ReversePrompt != body().Get("ReversePrompt") {
t.Error(result.ReversePrompt)
}
})
t.Run("post over post", func(t *testing.T) {

View File

@@ -1,8 +1,94 @@
<html>
<header>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/water.css@2/out/dark.css">
<script>
function loadStream() {
http("GET", "/api/v0/chatbot", (body, status) => {
var data = JSON.parse(body)
document.getElementById("stream-log").innerHTML = data["Messages"]
document.getElementById("stream-prompt").innerHTML = data["Prompt"]
document.getElementById("stream-reverse-prompt").innerHTML = data["ReversePrompt"]
}, null)
}
function appendStreamLog(message) {
if (!message)
return
document.getElementById("stream-log").innerHTML += "\n" + message
var textarea = document.getElementById("stream-log")
textarea.scrollTop = textarea.scrollHeight
}
function startStream(newPrompt) {
body = new URLSearchParams(new FormData(newPrompt)).toString()
http("POST", "/api/v0/chatbot", (body, status) => {
if (status != 200) {
log(body)
return
}
}, body)
}
function pushStream(newMessage) {
body = new URLSearchParams(new FormData(newMessage)).toString()
for(var e of newMessage.elements)
if(!e.attributes.readonly)
e.disabled = true
appendStreamLog(document.getElementsByName("Message")[0].value)
http("PUT", "/api/v0/chatbot", (body, status) => {
for(var e of newMessage.elements)
e.disabled = false
if (status != 200) {
log(body)
return
}
appendStreamLog(body)
document.getElementsByName("Message")[0].value = ""
}, body)
}
function log() {
console.log(arguments)
document.getElementById("debug-log").innerHTML += "\n" + JSON.stringify(arguments)
}
function http(method, remote, callback, body) {
var xmlhttp = new XMLHttpRequest();
xmlhttp.onreadystatechange = function() {
if (xmlhttp.readyState == XMLHttpRequest.DONE) {
callback(xmlhttp.responseText, xmlhttp.status)
}
};
xmlhttp.open(method, remote, true);
if (typeof body == "undefined") {
body = null
}
xmlhttp.send(body);
}
</script>
</header>
<body>
<body onload="loadStream()">
TODO ?more=true
<details>
<summary>Set up a new session</summary>
<form id="prompt" onsubmit="startStream(this); return false;">
<textarea id="stream-prompt" name="Prompt"></textarea>
<input type="text" id="stream-reverse-prompt" name="ReversePrompt"/>
<button type="submit">Start with prompt</button>
</form>
</details>
<details open=true>
<summary>Use your session</summary>
<form id="stream" onsubmit="pushStream(this); return false;">
<textarea id="stream-log" readonly=true></textarea>
<div style="display: flex; flex-direction: row;">
<input style="flex-grow: 1;" type="text" name="Message"/>
<button type="submit">Send</button>
</div>
</form>
</details>
<pre id="debug-log">
</pre>
</body>
<footer>
</footer>

BIN
vicuna-tools.d/vicuna-tools.d Executable file

Binary file not shown.