From 6d1ebe926531c2910638d22ea449ea3541b382c6 Mon Sep 17 00:00:00 2001 From: Bel LaPointe Date: Sun, 30 Sep 2018 18:59:33 -0600 Subject: [PATCH] go test pls --- main.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++ main_test.go | 33 ++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 main_test.go diff --git a/main.go b/main.go index cdd9fae..37ab55b 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,18 @@ package main import ( + "encoding/binary" + "encoding/csv" "flag" + "local1/logger" "log" + "net" "net/http" "os" "path" "path/filepath" + "sort" + "strconv" "strings" "google.golang.org/appengine" @@ -22,7 +28,14 @@ func main() { directory := flag.String("d", path.Join(exePath, "public"), "the directory of static file to host") flag.Parse() + getIPs() + http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + remoteIP := strings.Split(r.RemoteAddr, ":")[0] + if notUSA(remoteIP) { + logger.Log(remoteIP, "NOT USA") + return + } if r.URL.Scheme == "http" || strings.HasPrefix(r.Host, "http:") { r.URL.Scheme = "https" http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect) @@ -35,3 +48,87 @@ func main() { appengine.Main() } + +var globalIPs []uint64 + +func getIPs() []uint64 { + if globalIPs != nil { + return globalIPs + } + + globalIPs = make([]uint64, 0) + + f, err := os.Open("private/ipv4.csv") + if err != nil { + panic(err) + } + ipv4r := csv.NewReader(f) + ipv4all, err := ipv4r.ReadAll() + if err != nil { + panic(err) + } + logger.Log("IPV4s") + for i := range ipv4all { + if ipv4all[i][2] == "US" || ipv4all[i][2] == "-" { + start, err := strconv.ParseUint(ipv4all[i][0], 10, 64) + if err != nil { + panic(err) + } + stop, err := strconv.ParseUint(ipv4all[i][1], 10, 64) + if err != nil { + panic(err) + } + globalIPs = append(globalIPs, uint64(start), uint64(stop)) + } + } + + g, err := os.Open("private/ipv6.csv") + if err != nil { + panic(err) + } + + ipv6r := csv.NewReader(g) + ipv6all, err := ipv6r.ReadAll() + if err != nil { + panic(err) + } + logger.Log("IPV6s") + for i := range ipv6all { + if ipv6all[i][2] == "US" { + start, err := strconv.ParseUint(ipv6all[i][0], 10, 64) + if err != nil { + continue + } + stop, err := strconv.ParseUint(ipv6all[i][1], 10, 64) + if err != nil { + continue + } + globalIPs = append(globalIPs, uint64(start), uint64(stop)) + } + } + sort.Slice(globalIPs, func(i, j int) bool { + return globalIPs[i] > globalIPs[j] + }) + return globalIPs +} + +func notUSA(ip string) bool { + dec := toDec(ip) + ips := getIPs() + n := sort.Search(len(ips), func(i int) bool { + return ips[i] > dec + }) + logger.Log(ip, dec, ips[0], n, len(ips)) + return n%2 == 1 +} + +func toDec(ips string) uint64 { + ip := net.ParseIP(ips) + if ip == nil { + return uint64(0) + } + if len(ip) == 16 { + return uint64(binary.BigEndian.Uint32(ip[12:16])) + } + return uint64(binary.BigEndian.Uint32(ip)) +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..a2db503 --- /dev/null +++ b/main_test.go @@ -0,0 +1,33 @@ +package main + +import "testing" + +func Test_notUSA(t *testing.T) { + cases := []struct { + ip string + ok bool + }{ + { + ip: "8.8.8.8", + ok: false, + }, + { + ip: "192.168.0.86", + ok: false, + }, + { + ip: "127.0.0.1", + ok: false, + }, + { + ip: "223.144.0.0", + ok: true, + }, + } + + for _, c := range cases { + if notUSA(c.ip) != c.ok { + t.Errorf("WRONG VALIDATION for %v, expected %v", c.ip, c.ok) + } + } +}