diff --git a/src/config.go b/src/config.go index ce626dc..6a2ad99 100755 --- a/src/config.go +++ b/src/config.go @@ -2,6 +2,7 @@ package src import ( "bufio" + "bytes" "context" "encoding/base64" "encoding/json" @@ -17,6 +18,7 @@ import ( ) type Config struct { + ctx context.Context Listen string `json:"LISTEN"` Forwards string `json:"FORWARDS"` Hello string `json:"HELLO_B64"` @@ -25,14 +27,28 @@ type Config struct { func NewConfig(ctx context.Context) (Config, error) { config := Config{ + ctx: ctx, Listen: ":10000", Forwards: "", Hello: base64.StdEncoding.EncodeToString([]byte("*1\r\n$4\r\nping\r\n")), } - b, _ := json.Marshal(config) + + if err := config.loadEnv(); err != nil { + return config, err + } + + if err := config.setForwards(); err != nil { + return config, err + } + + return config, nil +} + +func (c *Config) loadEnv() error { + b, _ := json.Marshal(*c) var m map[string]any if err := json.Unmarshal(b, &m); err != nil { - return config, err + return err } for k := range m { if v := os.Getenv(k); v != "" { @@ -40,52 +56,62 @@ func NewConfig(ctx context.Context) (Config, error) { } } b2, _ := json.Marshal(m) - if err := json.Unmarshal(b2, &config); err != nil { - return config, err - } + return json.Unmarshal(b2, c) +} - hello, err := base64.StdEncoding.DecodeString(config.Hello) +func (c *Config) setForwards() error { + hello, err := base64.StdEncoding.DecodeString(c.Hello) if err != nil { - return config, err + return err } - forwards := strings.Split(config.Forwards, ",") - forwards = slices.DeleteFunc(forwards, func(s string) bool { return s == "" }) - if len(forwards) == 0 { - return config, fmt.Errorf("at least one $FORWARD required") + addresses := strings.Split(c.Forwards, ",") + addresses = slices.DeleteFunc(addresses, func(s string) bool { return s == "" }) + if len(addresses) == 0 { + return fmt.Errorf("at least one $FORWARD required") } - config.forwards = make([]*sync.Pool, len(forwards)) - for i := range forwards { - forward := forwards[i] - config.forwards[i] = &sync.Pool{ + + c.forwards = make([]*sync.Pool, len(addresses)) + for i := range addresses { + address := addresses[i] + c.forwards[i] = &sync.Pool{ New: func() any { - if ctx.Err() != nil { - return nil - } - log.Printf("dialing %q with %q", forward, hello) - v, err := (&net.Dialer{}).DialContext(ctx, "tcp", forward) + conn, err := c.dial(address, hello) if err != nil { - log.Printf("! failed dial %q: %v", forward, err) - return nil + log.Printf("! failed to dial %q: %v", address, err) } - if _, err := v.Write([]byte(hello)); err != nil { - log.Printf("! failed write hello %q: %v", forward, err) - v.Close() - return nil - } - if raw, err := readMessage(bufio.NewReader(v)); err != nil { - log.Printf("! failed read hello %q: %v", forward, err) - v.Close() - return nil - } else { - log.Printf("dial reply: %q", raw) - } - return v + return conn }, } } + return nil +} - return config, nil +func (c *Config) dial(address string, hello []byte) (net.Conn, error) { + if err := c.ctx.Err(); err != nil { + return nil, err + } + + log.Printf("dialing %q with %q", address, hello) + v, err := (&net.Dialer{}).DialContext(c.ctx, "tcp", address) + if err != nil { + return nil, err + } + + if _, err := v.Write([]byte(hello)); err != nil { + v.Close() + return nil, err + } + + if raw, err := readMessage(bufio.NewReader(v)); err != nil { + v.Close() + return nil, err + } else if bytes.HasPrefix(raw, []byte("-")) { + v.Close() + return nil, fmt.Errorf("failed hello: %s", raw) + } + + return v, nil } func (c Config) Close() {