diff --git a/config/config.go b/config/config.go index 9b357a7..1c29dba 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "io/ioutil" "local/args" "os" + "path" "strings" "time" ) @@ -21,6 +22,7 @@ type Config struct { RPS int SysRPS int Delay time.Duration + StaticRoot string } func New() Config { @@ -35,6 +37,7 @@ func New() Config { as.Append(args.INT, "p", "port to listen on", 18114) as.Append(args.STRING, "fileprefix", "path prefix for file service", "/__files__") as.Append(args.STRING, "api-prefix", "path prefix for api", "api") + as.Append(args.STRING, "static-root", "path to the root of a static file server", "./public") as.Append(args.STRING, "fileroot", "path to file hosting root", "/tmp/") as.Append(args.STRING, "database", "database name to use", "db") as.Append(args.STRING, "driver", "database driver args to use, like [local/storage.Type,arg1,arg2...] or [/path/to/boltdb]", "map") @@ -63,5 +66,6 @@ func New() Config { RPS: as.GetInt("rps"), SysRPS: as.GetInt("sys-rps"), APIPrefix: strings.TrimPrefix(as.GetString("api-prefix"), "/"), + StaticRoot: path.Join(as.GetString("static-root")), } } diff --git a/server/rest.go b/server/rest.go index bce926a..58ea42b 100644 --- a/server/rest.go +++ b/server/rest.go @@ -62,6 +62,13 @@ func NewREST(g storage.RateLimitedGraph) (*REST, error) { } } + bar := rest.static + bar = rest.defend(bar) + bar = rest.delay(bar) + if err := rest.router.Add(params, bar); err != nil { + return nil, err + } + return rest, nil } diff --git a/server/static.go b/server/static.go new file mode 100644 index 0000000..21b0a57 --- /dev/null +++ b/server/static.go @@ -0,0 +1,13 @@ +package server + +import ( + "local/dndex/config" + "local/simpleserve/simpleserve" + "net/http" +) + +func (rest *REST) static(w http.ResponseWriter, r *http.Request) { + simpleserve.SetContentTypeIfMedia(w, r) + server := http.FileServer(http.Dir(config.New().StaticRoot)) + server.ServeHTTP(w, r) +} diff --git a/server/static_test.go b/server/static_test.go new file mode 100644 index 0000000..c72541f --- /dev/null +++ b/server/static_test.go @@ -0,0 +1,46 @@ +package server + +import ( + "io/ioutil" + "local/dndex/config" + "net/http" + "net/http/httptest" + "os" + "path" + "testing" +) + +func TestRESTStatic(t *testing.T) { + os.Args = []string{"a"} + d, err := ioutil.TempDir(os.TempDir(), "static*") + if err != nil { + t.Fatal(err) + } + os.Setenv("STATIC_ROOT", d) + if err := ioutil.WriteFile(path.Join(d, "index.html"), []byte("Hello, world"), os.ModePerm); err != nil { + t.Fatal(err) + } + rest, _, clean := testREST(t) + defer clean() + + t.Run("assert nonstatic OK", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, path.Join("/", config.New().APIPrefix, "version"), nil) + w := httptest.NewRecorder() + rest.router.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatal(w.Code) + } + }) + + t.Run("assert static OK", func(t *testing.T) { + r := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + rest.router.ServeHTTP(w, r) + if w.Code != http.StatusOK { + t.Fatal(w.Code) + } + if s := string(w.Body.Bytes()); s != "Hello, world" { + t.Fatal(s) + } + }) +}