rproxy3/server/server.go

114 lines
2.5 KiB
Go

package server
import (
"context"
"encoding/base64"
"errors"
"local/rproxy3/config"
"local/rproxy3/storage"
"local/rproxy3/storage/packable"
"log"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/time/rate"
)
const nsRouting = "routing"
type listenerScheme int
const (
schemeHTTP listenerScheme = iota
schemeHTTPS listenerScheme = iota
)
func (ls listenerScheme) String() string {
switch ls {
case schemeHTTP:
return "http"
case schemeHTTPS:
return "https"
}
return ""
}
type Server struct {
db storage.DB
addr string
username string
password string
limiter *rate.Limiter
}
func (s *Server) Route(src, dst string) error {
log.Printf("Adding route %q -> %q...\n", src, dst)
u, err := url.Parse(dst)
if err != nil {
return err
}
return s.db.Set(nsRouting, src, packable.NewURL(u))
}
func (s *Server) Run() error {
scheme := schemeHTTP
if _, _, ok := config.GetSSL(); ok {
scheme = schemeHTTPS
}
log.Printf("Listening for %v on %v...\n", scheme, s.addr)
switch scheme {
case schemeHTTP:
return http.ListenAndServe(s.addr, s)
case schemeHTTPS:
c, k, _ := config.GetSSL()
return http.ListenAndServeTLS(s.addr, c, k, s)
}
return errors.New("did not load server")
}
func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
rusr, rpwd, ok := config.GetAuth()
if ok {
//usr, pwd := getProxyAuth(r)
usr, pwd, ok := r.BasicAuth()
if !ok || rusr != usr || rpwd != pwd {
w.WriteHeader(http.StatusUnauthorized)
log.Printf("denying proxy basic auth")
return
}
}
foo(w, r)
}
}
func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout()))
defer can()
if err := s.limiter.Wait(ctx); err != nil {
w.WriteHeader(http.StatusTooManyRequests)
return
}
s.doAuth(foo)(w, r)
}
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.Pre(s.Proxy)(w, r)
}
func getProxyAuth(r *http.Request) (string, string) {
proxyAuthHeader := r.Header.Get("Proxy-Authorization")
proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ")
proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64)
proxyAuth := string(proxyAuthBytes)
if !strings.Contains(proxyAuth, ":") {
return "", ""
}
proxyAuthSplit := strings.Split(proxyAuth, ":")
return proxyAuthSplit[0], proxyAuthSplit[1]
}