live-studio-audience/cmd/server/main.go

310 lines
6.7 KiB
Go

package main
import (
"context"
"embed"
"encoding/base64"
"encoding/json"
"flag"
"fmt"
"io/fs"
"log"
"net/http"
"os"
"os/signal"
"path"
"regexp"
"strings"
"syscall"
"golang.org/x/time/rate"
)
type Config struct {
Addr string
RPS int
fsDB string
}
type Handler struct {
cfg Config
limiter *rate.Limiter
db DB
}
type DB interface {
GetQuestions() ([]Question, error)
GetQuestion(string) (Question, error)
InsertAnswer(string, string, Answer) error
GetAnswers(string) ([]Answer, error)
}
type fsDB string
type Question struct {
ID string
Live bool
Closed bool
Text string
Options []string
}
type Answer struct {
Text string
}
type Session struct {
User struct {
ID string
Name string
}
}
func main() {
ctx, can := signal.NotifyContext(context.Background(), syscall.SIGINT)
defer can()
if err := run(ctx); err != nil {
panic(err)
}
}
func run(ctx context.Context) error {
cfg, err := newConfig()
if err != nil {
return err
}
return runHTTP(ctx, cfg)
}
func newConfig() (Config, error) {
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.Addr, "addr", ":8080", "address to listen on")
fs.IntVar(&cfg.RPS, "rps", 100, "requests per second to serve")
fs.StringVar(&cfg.fsDB, "fs-db", "/tmp/live-audience.d", "api dir to serve")
err := fs.Parse(os.Args[1:])
return cfg, err
}
func (cfg Config) NewHandler() Handler {
return Handler{
cfg: cfg,
limiter: rate.NewLimiter(rate.Limit(cfg.RPS), 10),
db: fsDB(cfg.fsDB),
}
}
func (s Session) Empty() bool {
return s == (Session{})
}
func runHTTP(ctx context.Context, cfg Config) error {
server := &http.Server{
Addr: cfg.Addr,
Handler: cfg.NewHandler(),
}
go func() {
<-ctx.Done()
server.Close()
}()
log.Println("listening on", cfg.Addr)
if err := server.ListenAndServe(); err != nil && ctx.Err() == nil {
return err
}
return nil
}
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := h.serveHTTP(w, r); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) serveHTTP(w http.ResponseWriter, r *http.Request) error {
h.limiter.Wait(r.Context())
session, err := h.auth(r)
if err != nil {
return err
}
if session.Empty() {
w.Header().Set("WWW-Authenticate", "Basic realm=xyz")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`IDENTIFY YOURSELF!`))
return nil
}
return h.handle(session, w, r)
}
func (h Handler) auth(r *http.Request) (Session, error) {
user, pass, ok := r.BasicAuth()
if !ok {
return Session{}, nil
}
session := Session{}
session.User.ID = base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass)))
session.User.Name = user
return session, nil
}
//go:embed internal/public
var _public embed.FS
var public = func() http.FileSystem {
d, err := fs.Sub(_public, "internal/public")
if err != nil {
panic(err)
}
return http.FS(d)
}()
func (h Handler) handle(session Session, w http.ResponseWriter, r *http.Request) error {
if !strings.HasPrefix(r.URL.Path, "/api/") {
w.Header().Set("Cache-Control", "private, max-age=60")
http.FileServer(public).ServeHTTP(w, r)
return nil
}
handlers := map[string]func(Session, http.ResponseWriter, *http.Request) error{
`^/api/v1/questions$`: h.handleAPIV1Questions,
`^/api/v1/questions/[^/]*$`: h.handleAPIV1Question,
`^/api/v1/questions/[^/]*/answers$`: h.handleAPIV1QuestionsAnswers,
}
for k, v := range handlers {
if regexp.MustCompile(k).MatchString(r.URL.Path) {
return v(session, w, r)
}
}
http.NotFound(w, r)
return nil
}
func (h Handler) handleAPIV1Question(session Session, w http.ResponseWriter, r *http.Request) error {
qid := path.Base(r.URL.Path)
q, err := h.db.GetQuestion(qid)
if err != nil {
return err
}
return json.NewEncoder(w).Encode(q)
}
func (h Handler) handleAPIV1Questions(session Session, w http.ResponseWriter, r *http.Request) error {
qs, err := h.db.GetQuestions()
if err != nil {
return err
}
return json.NewEncoder(w).Encode(qs)
}
func (h Handler) handleAPIV1QuestionsAnswers(session Session, w http.ResponseWriter, r *http.Request) error {
switch r.Method {
case http.MethodGet:
return h.handleAPIV1QuestionsAnswersGet(session, w, r)
case http.MethodPost:
return h.handleAPIV1QuestionsAnswersPost(session, w, r)
}
http.NotFound(w, r)
return nil
}
func (h Handler) handleAPIV1QuestionsAnswersGet(session Session, w http.ResponseWriter, r *http.Request) error {
qid := path.Base(path.Dir(r.URL.Path))
as, err := h.db.GetAnswers(qid)
if err != nil {
return err
}
return json.NewEncoder(w).Encode(as)
}
func (h Handler) handleAPIV1QuestionsAnswersPost(session Session, w http.ResponseWriter, r *http.Request) error {
qid := path.Base(path.Dir(r.URL.Path))
uid := session.User.ID
var a Answer
if err := json.NewDecoder(r.Body).Decode(&a); err != nil {
return fmt.Errorf("failed to read answer: %w", err)
}
return h.db.InsertAnswer(qid, uid, a)
}
func (db fsDB) GetQuestion(qid string) (Question, error) {
p := db.path(qid)
b, err := os.ReadFile(p)
if err != nil {
return Question{}, err
}
var q Question
if err := json.Unmarshal(b, &q); err != nil {
return Question{}, fmt.Errorf("failed to parse %s as question: %w", b, err)
}
q.ID = qid
return q, nil
}
func (db fsDB) InsertAnswer(qid, uid string, a Answer) error {
p := path.Join(db.path(qid)+".d", uid)
os.MkdirAll(path.Dir(p), os.ModePerm)
b, err := json.Marshal(a)
if err != nil {
return err
}
if _, err := os.Stat(p); !os.IsNotExist(err) {
return nil
}
return os.WriteFile(p, b, os.ModePerm)
}
func (db fsDB) GetQuestions() ([]Question, error) {
p := db.path("")
entries, err := os.ReadDir(p)
if err != nil {
return nil, err
}
results := []Question{}
for _, entry := range entries {
if strings.HasPrefix(path.Base(entry.Name()), ".") || entry.IsDir() {
continue
}
qid := path.Base(entry.Name())
q, err := db.GetQuestion(qid)
if err != nil {
return nil, err
}
results = append(results, q)
}
return results, nil
}
func (db fsDB) GetAnswers(qid string) ([]Answer, error) {
p := db.path(qid) + ".d"
entries, err := os.ReadDir(p)
if os.IsNotExist(err) {
return nil, nil
}
if err != nil {
return nil, err
}
results := []Answer{}
for _, entry := range entries {
if strings.HasPrefix(path.Base(entry.Name()), ".") {
continue
}
b, err := os.ReadFile(path.Join(p, entry.Name()))
if err != nil {
return nil, err
}
var a Answer
if err := json.Unmarshal(b, &a); err != nil {
return nil, fmt.Errorf("failed to parse %s as answer: %w", path.Join(p, entry.Name()), err)
}
results = append(results, a)
}
return results, nil
}
func (db fsDB) path(q string) string {
return path.Join(string(db), q)
}