accept -n for loops to avg
This commit is contained in:
36
main.go
36
main.go
@@ -4,8 +4,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,13 +19,18 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
|
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
|
||||||
defer can()
|
defer can()
|
||||||
if err := run(ctx, os.Args[1]); err != nil && ctx.Err() == nil {
|
if err := run(ctx); err != nil && ctx.Err() == nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func run(ctx context.Context, city string) error {
|
func run(ctx context.Context) error {
|
||||||
city = strings.Title(city)
|
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
|
||||||
|
n := fs.Int("n", 1, "loops")
|
||||||
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
city := strings.Title(fs.Args()[0])
|
||||||
|
|
||||||
client, err := genai.NewClient(ctx, nil)
|
client, err := genai.NewClient(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,6 +101,8 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
|
|||||||
}
|
}
|
||||||
//panic(w.String())
|
//panic(w.String())
|
||||||
|
|
||||||
|
results := make([][]float64, len(descriptors))
|
||||||
|
for i := 0; i < *n; i++ {
|
||||||
result, err := client.Models.GenerateContent(
|
result, err := client.Models.GenerateContent(
|
||||||
ctx,
|
ctx,
|
||||||
"gemini-2.5-flash",
|
"gemini-2.5-flash",
|
||||||
@@ -109,15 +118,26 @@ For the city {{.City}}, answer the following descriptors '1' for true or '0' for
|
|||||||
if err := json.Unmarshal([]byte(text), &m); err != nil {
|
if err := json.Unmarshal([]byte(text), &m); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, d := range descriptors {
|
for j, d := range descriptors {
|
||||||
if _, ok := m[d]; !ok {
|
if v, ok := m[d]; ok {
|
||||||
return fmt.Errorf("desc %q missing", d)
|
k := fmt.Sprint(v)
|
||||||
|
if k == "1" {
|
||||||
|
results[j] = append(results[j], 1.0)
|
||||||
|
} else {
|
||||||
|
results[j] = append(results[j], 0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//log.Println(text)
|
//log.Println(text)
|
||||||
for _, d := range descriptors {
|
for i := range descriptors {
|
||||||
fmt.Println(m[d])
|
sum := 0.0
|
||||||
|
for j := range results[i] {
|
||||||
|
sum += results[i][j]
|
||||||
|
}
|
||||||
|
avg := sum / float64(len(results[i]))
|
||||||
|
fmt.Println(math.Round(avg))
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
|
|||||||
Reference in New Issue
Block a user