package server import ( "context" "fmt" "net/http" "net/http/httptest" "strings" "testing" "gitea.bel.blue/local/rproxy3/config" "gitea.bel.blue/local/rproxy3/storage" "golang.org/x/time/rate" ) func TestServerStart(t *testing.T) { return // depends on etc hosts server := mockServer() p := config.Proxy{ To: "http://hello.localhost" + server.addr, } if err := server.Route("world", p); err != nil { t.Fatalf("cannot add route: %v", err) } go func() { if err := server.Run(); err != nil { t.Fatalf("err running server: %v", err) } }() r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil) if _, err := (&http.Client{}).Do(r); err != nil { t.Errorf("failed to get: %v", err) } } func mockServer() *Server { portServer := httptest.NewServer(nil) port := strings.Split(portServer.URL, ":")[2] portServer.Close() s := &Server{ db: storage.NewMap(), addr: ":" + port, limiter: rate.NewLimiter(rate.Limit(50), 50), } if err := s.Routes(); err != nil { panic(fmt.Sprintf("cannot initiate server routes; %v", err)) } return s } func TestServerRoute(t *testing.T) { server := mockServer() p := config.Proxy{ To: "http://hello.localhost" + server.addr, } if err := server.Route("world", p); err != nil { t.Fatalf("cannot add route: %v", err) } w := httptest.NewRecorder() r, _ := http.NewRequest("GET", "http://world.localhost"+server.addr, nil) r = r.WithContext(context.Background()) server.ServeHTTP(w, r) if w.Code != 502 { t.Fatalf("cannot proxy from 'world' to 'hello', status %v", w.Code) } } func TestCORS(t *testing.T) { t.Run(http.MethodOptions, func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodOptions, "/", nil) w2, did := _doCORS(w, r) w2.WriteHeader(300) if !did { t.Error("didnt do on options") } if w.Header().Get("Access-Control-Allow-Origin") != "*" { t.Error("didnt set origina") } if w.Header().Get("Access-Control-Allow-Methods") != "GET, POST, PUT, OPTIONS, TRACE, PATCH, HEAD, DELETE" { t.Error("didnt set allow methods") } }) t.Run(http.MethodGet, func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) w2, did := _doCORS(w, r) w2.Header().Set("a", "b") w2.Header().Set("Access-Control-Allow-Origin", "NO") w2.WriteHeader(300) if did { t.Error("did cors on options") } if w.Header().Get("Access-Control-Allow-Origin") != "*" { t.Error("didnt set origina") } else if len(w.Header()["Access-Control-Allow-Origin"]) != 1 { t.Error(w.Header()) } if w.Header().Get("Access-Control-Allow-Methods") != "" { t.Error("did set allow methods") } }) } func TestAssertFrom(t *testing.T) { cases := map[string]struct { from string remote string err bool }{ "empty": {}, "ipv6 localhost": { from: "::1/128", remote: "::1:12345", }, "ipv4 localhost": { from: "127.0.0.1/32", remote: "127.0.0.1:12345", }, } for name, d := range cases { c := d t.Run(name, func(t *testing.T) { err := assertFrom(c.from, c.remote) got := err != nil if got != c.err { t.Errorf("expected err=%v but got %v", c.err, err) } }) } }