accept -n for loops to avg

This commit is contained in:
bel
2026-03-22 20:32:57 -06:00
parent 7645785859
commit 24e0df9e93

36
main.go
View File

@@ -4,8 +4,10 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"flag"
"fmt" "fmt"
"html/template" "html/template"
"math"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
@@ -17,13 +19,18 @@ import (
func main() { func main() {
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT) ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
defer can() defer can()
if err := run(ctx, os.Args[1]); err != nil && ctx.Err() == nil { if err := run(ctx); err != nil && ctx.Err() == nil {
panic(err) panic(err)
} }
} }
func run(ctx context.Context, city string) error { func run(ctx context.Context) error {
city = strings.Title(city) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
n := fs.Int("n", 1, "loops")
if err := fs.Parse(os.Args[1:]); err != nil {
return err
}
city := strings.Title(fs.Args()[0])
client, err := genai.NewClient(ctx, nil) client, err := genai.NewClient(ctx, nil)
if err != nil { if err != nil {
@@ -94,6 +101,8 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
} }
//panic(w.String()) //panic(w.String())
results := make([][]float64, len(descriptors))
for i := 0; i < *n; i++ {
result, err := client.Models.GenerateContent( result, err := client.Models.GenerateContent(
ctx, ctx,
"gemini-2.5-flash", "gemini-2.5-flash",
@@ -109,15 +118,26 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
if err := json.Unmarshal([]byte(text), &m); err != nil { if err := json.Unmarshal([]byte(text), &m); err != nil {
return err return err
} }
for _, d := range descriptors { for j, d := range descriptors {
if _, ok := m[d]; !ok { if v, ok := m[d]; ok {
return fmt.Errorf("desc %q missing", d) k := fmt.Sprint(v)
if k == "1" {
results[j] = append(results[j], 1.0)
} else {
results[j] = append(results[j], 0.0)
}
}
} }
} }
//log.Println(text) //log.Println(text)
for _, d := range descriptors { for i := range descriptors {
fmt.Println(m[d]) sum := 0.0
for j := range results[i] {
sum += results[i][j]
}
avg := sum / float64(len(results[i]))
fmt.Println(math.Round(avg))
} }
return ctx.Err() return ctx.Err()