accept -n for loops to avg
This commit is contained in:
64
main.go
64
main.go
@@ -4,8 +4,10 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"math"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
@@ -17,13 +19,18 @@ import (
|
||||
func main() {
|
||||
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
|
||||
defer can()
|
||||
if err := run(ctx, os.Args[1]); err != nil && ctx.Err() == nil {
|
||||
if err := run(ctx); err != nil && ctx.Err() == nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func run(ctx context.Context, city string) error {
|
||||
city = strings.Title(city)
|
||||
func run(ctx context.Context) error {
|
||||
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
|
||||
n := fs.Int("n", 1, "loops")
|
||||
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
city := strings.Title(fs.Args()[0])
|
||||
|
||||
client, err := genai.NewClient(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -94,30 +101,43 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
|
||||
}
|
||||
//panic(w.String())
|
||||
|
||||
result, err := client.Models.GenerateContent(
|
||||
ctx,
|
||||
"gemini-2.5-flash",
|
||||
genai.Text(w.String()),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}"
|
||||
results := make([][]float64, len(descriptors))
|
||||
for i := 0; i < *n; i++ {
|
||||
result, err := client.Models.GenerateContent(
|
||||
ctx,
|
||||
"gemini-2.5-flash",
|
||||
genai.Text(w.String()),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
text := "{" + strings.Split(strings.Split(result.Text(), "{")[1], "}")[0] + "}"
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(text), &m); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, d := range descriptors {
|
||||
if _, ok := m[d]; !ok {
|
||||
return fmt.Errorf("desc %q missing", d)
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(text), &m); err != nil {
|
||||
return err
|
||||
}
|
||||
for j, d := range descriptors {
|
||||
if v, ok := m[d]; ok {
|
||||
k := fmt.Sprint(v)
|
||||
if k == "1" {
|
||||
results[j] = append(results[j], 1.0)
|
||||
} else {
|
||||
results[j] = append(results[j], 0.0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//log.Println(text)
|
||||
for _, d := range descriptors {
|
||||
fmt.Println(m[d])
|
||||
for i := range descriptors {
|
||||
sum := 0.0
|
||||
for j := range results[i] {
|
||||
sum += results[i][j]
|
||||
}
|
||||
avg := sum / float64(len(results[i]))
|
||||
fmt.Println(math.Round(avg))
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
|
||||
Reference in New Issue
Block a user