accept -n for loops to avg

This commit is contained in:
bel
2026-03-22 20:32:57 -06:00
parent 7645785859
commit 24e0df9e93

64
main.go
View File

@@ -4,8 +4,10 @@ import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"html/template"
"math"
"os"
"os/signal"
"strings"
@@ -17,13 +19,18 @@ import (
func main() {
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
defer can()
if err := run(ctx, os.Args[1]); err != nil && ctx.Err() == nil {
if err := run(ctx); err != nil && ctx.Err() == nil {
panic(err)
}
}
func run(ctx context.Context, city string) error {
city = strings.Title(city)
func run(ctx context.Context) error {
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
n := fs.Int("n", 1, "loops")
if err := fs.Parse(os.Args[1:]); err != nil {
return err
}
city := strings.Title(fs.Args()[0])
client, err := genai.NewClient(ctx, nil)
if err != nil {
@@ -94,30 +101,43 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
}
//panic(w.String())
result, err := client.Models.GenerateContent(
ctx,
"gemini-2.5-flash",
genai.Text(w.String()),
nil,
)
if err != nil {
return err
}
text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}"
results := make([][]float64, len(descriptors))
for i := 0; i < *n; i++ {
result, err := client.Models.GenerateContent(
ctx,
"gemini-2.5-flash",
genai.Text(w.String()),
nil,
)
if err != nil {
return err
}
text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}"
var m map[string]any
if err := json.Unmarshal([]byte(text), &m); err != nil {
return err
}
for _, d := range descriptors {
if _, ok := m[d]; !ok {
return fmt.Errorf("desc %q missing", d)
var m map[string]any
if err := json.Unmarshal([]byte(text), &m); err != nil {
return err
}
for j, d := range descriptors {
if v, ok := m[d]; ok {
k := fmt.Sprint(v)
if k == "1" {
results[j] = append(results[j], 1.0)
} else {
results[j] = append(results[j], 0.0)
}
}
}
}
//log.Println(text)
for _, d := range descriptors {
fmt.Println(m[d])
for i := range descriptors {
sum := 0.0
for j := range results[i] {
sum += results[i][j]
}
avg := sum / float64(len(results[i]))
fmt.Println(math.Round(avg))
}
return ctx.Err()