jwtcurl/main.go

259 lines
6.6 KiB
Go
Executable File

package main
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"
"gitlab-app.eng.qops.net/data-store/jwt"
)
type LagReader struct {
lag time.Duration
src io.Reader
}
func main() {
err := Main()
if err != nil {
log.Println(err)
os.Exit(1)
}
}
var bodyRepeat, n int
var path, host, method, body, headers, brandID, userID, issuer, audience, basicAuth, claims, kid string
var ca, cert, key, secret string
var needJWT, verbose, jsonPP, quiet, responseHeaders bool
var timeout, lag time.Duration
func Main() error {
flag.StringVar(&method, "method", "get", "method for request")
flag.StringVar(&kid, "kid", "prod/rems-api", "kid for request")
flag.StringVar(&path, "path", "fieldsetdefinitions/v1/index/surveys/SV_031sm3MMOPSa8Tz/fieldsets?assumeHasPermission=true", "path for request")
flag.StringVar(&host, "host", "data-platform.service.b1-prv.consul:8080", "host and port for request")
flag.StringVar(&body, "body", "", "body for request")
flag.IntVar(&bodyRepeat, "bodyrepeat", 1, "repeat body for request")
flag.IntVar(&n, "n", 1, "how many times to execute")
flag.StringVar(&brandID, "brand", "testencresponse", "brandID for request JWT")
flag.StringVar(&userID, "user", "breel", "userid for request JWT")
flag.StringVar(&basicAuth, "auth", "", "comma separated user,password for basic auth")
flag.StringVar(&headers, "headers", "", "headers as k=v,k=v for request")
flag.StringVar(&issuer, "issuer", "dataprocessing,responseengine,fieldset-definitions,qualtrics,objectstore,svs,monolith,ex,blixt,null,responseengine", "issuer for jwt")
flag.StringVar(&audience, "audience", "qualtrics", "aud for jwt")
flag.BoolVar(&needJWT, "jwt", true, "need jwt boolean")
flag.BoolVar(&jsonPP, "jpp", true, "try json pretty print")
flag.BoolVar(&verbose, "v", false, "is verbose")
flag.BoolVar(&responseHeaders, "i", false, "print response headers")
flag.BoolVar(&quiet, "q", false, "is quiet")
flag.DurationVar(&timeout, "t", time.Second*10, "request timeout")
flag.DurationVar(&lag, "lag", time.Second*0, "writing request lag after connecting")
flag.StringVar(&ca, "ca", "", "ca for server")
flag.StringVar(&cert, "cert", "", "cert for client")
flag.StringVar(&key, "key", "", "key for client")
flag.StringVar(&secret, "secret", "", "secret for jwt")
flag.StringVar(&claims, "claims", "", "extra claims as k=v,k=v")
flag.Parse()
if quiet {
f, _ := ioutil.TempFile(os.TempDir(), "*")
os.Stderr = f
}
if !strings.HasPrefix(host, "http") {
host = "http://" + host
}
for i := 0; i < n; i++ {
if err := do(); err != nil {
return err
}
}
return nil
}
func do() error {
c := makeClient(timeout, ca, cert, key)
var reqBody io.Reader
if bodyRepeat >= 1 {
reqBody = strings.NewReader(strings.Repeat(body, bodyRepeat))
} else {
reqBody = os.Stdin
}
req, err := http.NewRequest(
strings.ToUpper(method),
host+"/"+strings.Trim(path, "/"),
reqBody,
)
if err != nil {
return err
}
if lag != 0 {
req.Body = io.NopCloser(NewLagReader(lag, req.Body))
} else {
b, _ := ioutil.ReadAll(req.Body)
req.Body = io.NopCloser(bytes.NewReader(b))
req.ContentLength = int64(len(b))
}
req.Header.Set("Content-Type", "application/json")
if len(headers) > 0 {
for _, pair := range strings.Split(headers, ",") {
kv := strings.Split(pair, "=")
k := kv[0]
v := strings.Join(kv[1:], "=")
vd, err := url.QueryUnescape(v)
if err != nil {
panic(err)
}
req.Header.Add(k, vd)
}
}
if req.Header.Get("brandId") == "" {
req.Header.Set("brandId", brandID)
}
if needJWT {
setJWT(verbose, req, brandID, userID, issuer, secret, claims, kid, audience)
}
if basicAuth != "" {
splits := strings.Split(basicAuth, ",")
if len(splits) == 1 {
splits = strings.Split(basicAuth, ":")
}
req.SetBasicAuth(splits[0], splits[1])
}
if verbose {
fmt.Fprintf(os.Stderr, "%+v\n", req)
for k, v := range req.Header {
fmt.Fprintf(os.Stderr, "\t[%s] = (%d) %+v\n", k, len(v), v)
}
}
start := time.Now()
resp, err := c.Do(req)
elapsed := time.Since(start)
if err != nil {
return fmt.Errorf("DO failed: %v", err)
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("READ BODY failed: %v", err)
}
defer resp.Body.Close()
fmt.Fprintf(os.Stderr, "(%d / %v) ", resp.StatusCode, elapsed)
if responseHeaders || verbose {
f := os.Stdout
if verbose {
f = os.Stderr
}
fmt.Fprintf(f, "\n")
for k := range resp.Header {
fmt.Fprintf(f, "%s: %s\n", k, resp.Header.Get(k))
}
}
if jsonPP {
var v interface{}
if err := json.Unmarshal(b, &v); err == nil {
if c, err := json.MarshalIndent(v, "", " "); err == nil {
b = c
}
}
}
fmt.Printf("%s\n", bytes.TrimSpace(b))
if resp.StatusCode >= http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
return fmt.Errorf("Status %v", resp.StatusCode)
}
return nil
}
func makeClient(timeout time.Duration, ca, cert, key string) *http.Client {
transport := &http.Transport{
TLSClientConfig: &tls.Config{},
}
if ca == "" {
transport.TLSClientConfig.InsecureSkipVerify = true
} else {
caBytes, err := ioutil.ReadFile(ca)
if err != nil {
panic(err)
}
rootCAs := x509.NewCertPool()
rootCAs.AppendCertsFromPEM(caBytes)
transport.TLSClientConfig.RootCAs = rootCAs
}
if cert != "" && key != "" {
clientCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
panic(err)
}
transport.TLSClientConfig.Certificates = []tls.Certificate{clientCert}
transport.TLSClientConfig.BuildNameToCertificate()
}
return &http.Client{
Timeout: timeout,
Transport: transport,
}
}
func setJWT(verbose bool, r *http.Request, brandID, userID string, issuer, secret, claims, kid, audience string) {
signer := &jwt.Signer{
Timeout: time.Minute * 5,
Key: []byte(secret),
DefaultHeaders: jwt.Headers{
KeyID: kid,
},
DefaultClaims: jwt.Claims{
Audience: audience,
Issuer: issuer,
UserID: userID,
BrandID: brandID,
Custom: map[string]interface{}{
"IsolationPartitionID": brandID,
"userType": "UT_SERVERADMIN",
},
},
IncludeBodyHash: true,
}
for _, claim := range strings.Split(claims, ",") {
c := strings.Split(claim, "=")
if len(c) < 2 {
continue
}
signer.DefaultClaims.Custom[c[0]] = c[1]
}
if err := signer.Sign(r, jwt.Claims{}); err != nil {
panic(err)
}
if verbose {
log.Printf("%+v", *signer)
}
}
func NewLagReader(lag time.Duration, src io.Reader) *LagReader {
return &LagReader{
lag: lag,
src: src,
}
}
func (lr *LagReader) Read(p []byte) (n int, err error) {
if lr.lag > 0 {
<-time.After(lr.lag)
lr.lag = 0
}
return lr.src.Read(p)
}