go test pls

master
Bel LaPointe 2018-09-30 18:59:33 -06:00
parent e1c7e2bcfd
commit 6d1ebe9265
2 changed files with 130 additions and 0 deletions

97
main.go
View File

@ -1,12 +1,18 @@
package main
import (
"encoding/binary"
"encoding/csv"
"flag"
"local1/logger"
"log"
"net"
"net/http"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
"google.golang.org/appengine"
@ -22,7 +28,14 @@ func main() {
directory := flag.String("d", path.Join(exePath, "public"), "the directory of static file to host")
flag.Parse()
getIPs()
http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
remoteIP := strings.Split(r.RemoteAddr, ":")[0]
if notUSA(remoteIP) {
logger.Log(remoteIP, "NOT USA")
return
}
if r.URL.Scheme == "http" || strings.HasPrefix(r.Host, "http:") {
r.URL.Scheme = "https"
http.Redirect(w, r, r.URL.String(), http.StatusTemporaryRedirect)
@ -35,3 +48,87 @@ func main() {
appengine.Main()
}
var globalIPs []uint64
func getIPs() []uint64 {
if globalIPs != nil {
return globalIPs
}
globalIPs = make([]uint64, 0)
f, err := os.Open("private/ipv4.csv")
if err != nil {
panic(err)
}
ipv4r := csv.NewReader(f)
ipv4all, err := ipv4r.ReadAll()
if err != nil {
panic(err)
}
logger.Log("IPV4s")
for i := range ipv4all {
if ipv4all[i][2] == "US" || ipv4all[i][2] == "-" {
start, err := strconv.ParseUint(ipv4all[i][0], 10, 64)
if err != nil {
panic(err)
}
stop, err := strconv.ParseUint(ipv4all[i][1], 10, 64)
if err != nil {
panic(err)
}
globalIPs = append(globalIPs, uint64(start), uint64(stop))
}
}
g, err := os.Open("private/ipv6.csv")
if err != nil {
panic(err)
}
ipv6r := csv.NewReader(g)
ipv6all, err := ipv6r.ReadAll()
if err != nil {
panic(err)
}
logger.Log("IPV6s")
for i := range ipv6all {
if ipv6all[i][2] == "US" {
start, err := strconv.ParseUint(ipv6all[i][0], 10, 64)
if err != nil {
continue
}
stop, err := strconv.ParseUint(ipv6all[i][1], 10, 64)
if err != nil {
continue
}
globalIPs = append(globalIPs, uint64(start), uint64(stop))
}
}
sort.Slice(globalIPs, func(i, j int) bool {
return globalIPs[i] > globalIPs[j]
})
return globalIPs
}
func notUSA(ip string) bool {
dec := toDec(ip)
ips := getIPs()
n := sort.Search(len(ips), func(i int) bool {
return ips[i] > dec
})
logger.Log(ip, dec, ips[0], n, len(ips))
return n%2 == 1
}
func toDec(ips string) uint64 {
ip := net.ParseIP(ips)
if ip == nil {
return uint64(0)
}
if len(ip) == 16 {
return uint64(binary.BigEndian.Uint32(ip[12:16]))
}
return uint64(binary.BigEndian.Uint32(ip))
}

33
main_test.go Normal file
View File

@ -0,0 +1,33 @@
package main
import "testing"
func Test_notUSA(t *testing.T) {
cases := []struct {
ip string
ok bool
}{
{
ip: "8.8.8.8",
ok: false,
},
{
ip: "192.168.0.86",
ok: false,
},
{
ip: "127.0.0.1",
ok: false,
},
{
ip: "223.144.0.0",
ok: true,
},
}
for _, c := range cases {
if notUSA(c.ip) != c.ok {
t.Errorf("WRONG VALIDATION for %v, expected %v", c.ip, c.ok)
}
}
}