114 lines
2.5 KiB
Go
114 lines
2.5 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"local/rproxy3/config"
|
|
"local/rproxy3/storage"
|
|
"local/rproxy3/storage/packable"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
const nsRouting = "routing"
|
|
|
|
type listenerScheme int
|
|
|
|
const (
|
|
schemeHTTP listenerScheme = iota
|
|
schemeHTTPS listenerScheme = iota
|
|
)
|
|
|
|
func (ls listenerScheme) String() string {
|
|
switch ls {
|
|
case schemeHTTP:
|
|
return "http"
|
|
case schemeHTTPS:
|
|
return "https"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
type Server struct {
|
|
db storage.DB
|
|
addr string
|
|
username string
|
|
password string
|
|
limiter *rate.Limiter
|
|
}
|
|
|
|
func (s *Server) Route(src, dst string) error {
|
|
log.Printf("Adding route %q -> %q...\n", src, dst)
|
|
u, err := url.Parse(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.db.Set(nsRouting, src, packable.NewURL(u))
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
scheme := schemeHTTP
|
|
if _, _, ok := config.GetSSL(); ok {
|
|
scheme = schemeHTTPS
|
|
}
|
|
log.Printf("Listening for %v on %v...\n", scheme, s.addr)
|
|
switch scheme {
|
|
case schemeHTTP:
|
|
return http.ListenAndServe(s.addr, s)
|
|
case schemeHTTPS:
|
|
c, k, _ := config.GetSSL()
|
|
return http.ListenAndServeTLS(s.addr, c, k, s)
|
|
}
|
|
return errors.New("did not load server")
|
|
}
|
|
|
|
func (s *Server) doAuth(foo http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
rusr, rpwd, ok := config.GetAuth()
|
|
if ok {
|
|
//usr, pwd := getProxyAuth(r)
|
|
usr, pwd, ok := r.BasicAuth()
|
|
if !ok || rusr != usr || rpwd != pwd {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
log.Printf("denying proxy basic auth")
|
|
return
|
|
}
|
|
}
|
|
foo(w, r)
|
|
}
|
|
}
|
|
|
|
func (s *Server) Pre(foo http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, can := context.WithTimeout(r.Context(), time.Second*time.Duration(config.GetTimeout()))
|
|
defer can()
|
|
if err := s.limiter.Wait(ctx); err != nil {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
s.doAuth(foo)(w, r)
|
|
}
|
|
}
|
|
|
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
s.Pre(s.Proxy)(w, r)
|
|
}
|
|
|
|
func getProxyAuth(r *http.Request) (string, string) {
|
|
proxyAuthHeader := r.Header.Get("Proxy-Authorization")
|
|
proxyAuthB64 := strings.TrimPrefix(proxyAuthHeader, "Basic ")
|
|
proxyAuthBytes, _ := base64.StdEncoding.DecodeString(proxyAuthB64)
|
|
proxyAuth := string(proxyAuthBytes)
|
|
if !strings.Contains(proxyAuth, ":") {
|
|
return "", ""
|
|
}
|
|
proxyAuthSplit := strings.Split(proxyAuth, ":")
|
|
return proxyAuthSplit[0], proxyAuthSplit[1]
|
|
}
|