From 24e0df9e93b2bfbc4d7ec54f55a4666c46a89d8f Mon Sep 17 00:00:00 2001 From: bel Date: Sun, 22 Mar 2026 20:32:57 -0600 Subject: [PATCH] accept -n for loops to avg --- main.go | 64 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/main.go b/main.go index be19446..e6e7899 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "encoding/json" + "flag" "fmt" "html/template" + "math" "os" "os/signal" "strings" @@ -17,13 +19,18 @@ import ( func main() { ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT) defer can() - if err := run(ctx, os.Args[1]); err != nil && ctx.Err() == nil { + if err := run(ctx); err != nil && ctx.Err() == nil { panic(err) } } -func run(ctx context.Context, city string) error { - city = strings.Title(city) +func run(ctx context.Context) error { + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + n := fs.Int("n", 1, "loops") + if err := fs.Parse(os.Args[1:]); err != nil { + return err + } + city := strings.Title(fs.Args()[0]) client, err := genai.NewClient(ctx, nil) if err != nil { @@ -94,30 +101,43 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for } //panic(w.String()) - result, err := client.Models.GenerateContent( - ctx, - "gemini-2.5-flash", - genai.Text(w.String()), - nil, - ) - if err != nil { - return err - } - text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}" + results := make([][]float64, len(descriptors)) + for i := 0; i < *n; i++ { + result, err := client.Models.GenerateContent( + ctx, + "gemini-2.5-flash", + genai.Text(w.String()), + nil, + ) + if err != nil { + return err + } + text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}" - var m map[string]any - if err := json.Unmarshal([]byte(text), &m); err != nil { - return err - } - for _, d := range descriptors { - if _, ok := m[d]; !ok { - return fmt.Errorf("desc %q missing", d) + var m map[string]any + if err := json.Unmarshal([]byte(text), &m); err != nil { + return err + } + for j, d := range descriptors { + if v, ok := m[d]; ok { + k := fmt.Sprint(v) + if k == "1" { + results[j] = append(results[j], 1.0) + } else { + results[j] = append(results[j], 0.0) + } + } } } //log.Println(text) - for _, d := range descriptors { - fmt.Println(m[d]) + for i := range descriptors { + sum := 0.0 + for j := range results[i] { + sum += results[i][j] + } + avg := sum / float64(len(results[i])) + fmt.Println(math.Round(avg)) } return ctx.Err()