package main import ( "context" "embed" "encoding/base64" "encoding/json" "flag" "fmt" "io/fs" "log" "net/http" "os" "os/signal" "path" "regexp" "strings" "syscall" "gitea.inhome.blapointe.com/local/gziphttp" "golang.org/x/time/rate" ) type Config struct { Addr string RPS int fsDB string } type Handler struct { cfg Config limiter *rate.Limiter db DB } type DB interface { GetQuestions() ([]Question, error) GetQuestion(string) (Question, error) InsertAnswer(string, string, Answer) error GetAnswers(string) ([]Answer, error) } type fsDB string type Question struct { ID string Live bool Closed bool Text string Options []string } type Answer struct { Text string } type Session struct { User struct { ID string Name string } } func main() { ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT) defer can() if err := run(ctx); err != nil { panic(err) } } func run(ctx context.Context) error { cfg, err := newConfig() if err != nil { return err } return runHTTP(ctx, cfg) } func newConfig() (Config, error) { cfg := Config{} fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) fs.StringVar(&cfg.Addr, "addr", ":8080", "address to listen on") fs.IntVar(&cfg.RPS, "rps", 100, "requests per second to serve") fs.StringVar(&cfg.fsDB, "fs-db", "/tmp/live-audience.d", "api dir to serve") err := fs.Parse(os.Args[1:]) return cfg, err } func (cfg Config) NewHandler() Handler { return Handler{ cfg: cfg, limiter: rate.NewLimiter(rate.Limit(cfg.RPS), 10), db: fsDB(cfg.fsDB), } } func (s Session) Empty() bool { return s == (Session{}) } func runHTTP(ctx context.Context, cfg Config) error { server := &http.Server{ Addr: cfg.Addr, Handler: cfg.NewHandler(), } go func() { <-ctx.Done() server.Close() }() log.Println("listening on", cfg.Addr) if err := server.ListenAndServe(); err != nil && ctx.Err() == nil { return err } return nil } func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := h.serveHTTP(w, r); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } func (h Handler) serveHTTP(w http.ResponseWriter, r *http.Request) error { h.limiter.Wait(r.Context()) session, err := h.auth(r) if err != nil { return err } if session.Empty() { w.Header().Set("WWW-Authenticate", "Basic realm=xyz") w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`IDENTIFY YOURSELF!`)) return nil } return h.handle(session, w, r) } func (h Handler) auth(r *http.Request) (Session, error) { user, pass, ok := r.BasicAuth() if !ok { return Session{}, nil } session := Session{} session.User.ID = base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass))) session.User.Name = user return session, nil } //go:embed internal/public var _public embed.FS var public = func() http.FileSystem { d, err := fs.Sub(_public, "internal/public") if err != nil { panic(err) } return http.FS(d) }() func (h Handler) handle(session Session, w http.ResponseWriter, r *http.Request) error { if !strings.HasPrefix(r.URL.Path, "/api/") { if gziphttp.Can(r) { w = gziphttp.New(w) } w.Header().Set("Cache-Control", "private, max-age=60") http.FileServer(public).ServeHTTP(w, r) return nil } handlers := map[string]func(Session, http.ResponseWriter, *http.Request) error{ `^/api/v1/questions$`: h.handleAPIV1Questions, `^/api/v1/questions/[^/]*$`: h.handleAPIV1Question, `^/api/v1/questions/[^/]*/answers$`: h.handleAPIV1QuestionsAnswers, } for k, v := range handlers { if regexp.MustCompile(k).MatchString(r.URL.Path) { return v(session, w, r) } } http.NotFound(w, r) return nil } func (h Handler) handleAPIV1Question(session Session, w http.ResponseWriter, r *http.Request) error { qid := path.Base(r.URL.Path) q, err := h.db.GetQuestion(qid) if err != nil { return err } return json.NewEncoder(w).Encode(q) } func (h Handler) handleAPIV1Questions(session Session, w http.ResponseWriter, r *http.Request) error { qs, err := h.db.GetQuestions() if err != nil { return err } return json.NewEncoder(w).Encode(qs) } func (h Handler) handleAPIV1QuestionsAnswers(session Session, w http.ResponseWriter, r *http.Request) error { switch r.Method { case http.MethodGet: return h.handleAPIV1QuestionsAnswersGet(session, w, r) case http.MethodPost: return h.handleAPIV1QuestionsAnswersPost(session, w, r) } http.NotFound(w, r) return nil } func (h Handler) handleAPIV1QuestionsAnswersGet(session Session, w http.ResponseWriter, r *http.Request) error { qid := path.Base(path.Dir(r.URL.Path)) as, err := h.db.GetAnswers(qid) if err != nil { return err } return json.NewEncoder(w).Encode(as) } func (h Handler) handleAPIV1QuestionsAnswersPost(session Session, w http.ResponseWriter, r *http.Request) error { qid := path.Base(path.Dir(r.URL.Path)) uid := session.User.ID var a Answer if err := json.NewDecoder(r.Body).Decode(&a); err != nil { return fmt.Errorf("failed to read answer: %w", err) } return h.db.InsertAnswer(qid, uid, a) } func (db fsDB) GetQuestion(qid string) (Question, error) { p := db.path(qid) b, err := os.ReadFile(p) if err != nil { return Question{}, err } var q Question if err := json.Unmarshal(b, &q); err != nil { return Question{}, fmt.Errorf("failed to parse %s as question: %w", b, err) } q.ID = qid return q, nil } func (db fsDB) InsertAnswer(qid, uid string, a Answer) error { p := path.Join(db.path(qid)+".d", uid) os.MkdirAll(path.Dir(p), os.ModePerm) b, err := json.Marshal(a) if err != nil { return err } if _, err := os.Stat(p); !os.IsNotExist(err) { return nil } return os.WriteFile(p, b, os.ModePerm) } func (db fsDB) GetQuestions() ([]Question, error) { p := db.path("") entries, err := os.ReadDir(p) if err != nil { return nil, err } results := []Question{} for _, entry := range entries { if strings.HasPrefix(path.Base(entry.Name()), ".") { continue } qid := path.Base(entry.Name()) q, err := db.GetQuestion(qid) if err != nil { return nil, err } results = append(results, q) } return results, nil } func (db fsDB) GetAnswers(qid string) ([]Answer, error) { p := db.path(qid) + ".d" entries, err := os.ReadDir(p) if os.IsNotExist(err) { return nil, nil } if err != nil { return nil, err } results := []Answer{} for _, entry := range entries { if strings.HasPrefix(path.Base(entry.Name()), ".") { continue } b, err := os.ReadFile(path.Join(p, entry.Name())) if err != nil { return nil, err } var a Answer if err := json.Unmarshal(b, &a); err != nil { return nil, fmt.Errorf("failed to parse %s as answer: %w", path.Join(p, entry.Name()), err) } results = append(results, a) } return results, nil } func (db fsDB) path(q string) string { return path.Join(string(db), q) }