package main import ( "encoding/json" "flag" "fmt" "io" "io/ioutil" "log" "net" "net/http" "net/textproto" "os" "os/exec" "os/signal" "strings" "sync/atomic" "syscall" "time" ) type mpd struct { conns int32 port string httpd string mpd string stateChange chan string stateFilePath string } func main() { listenPort := flag.String("port", "7000", "port to listen on") httpdPort := flag.String("httpd", "8000", "port to forward to") mpdPort := flag.String("mpd", "6600", "port of mpd") flag.Parse() mpd := &mpd{ conns: int32(0), port: ":" + strings.TrimPrefix(*listenPort, ":"), httpd: ":" + strings.TrimPrefix(*httpdPort, ":"), mpd: ":" + strings.TrimPrefix(*mpdPort, ":"), stateChange: make(chan string, 6), stateFilePath: "/mnt/goproxystate", } go mpd.proxy() go mpd.periodic() catchSignal() mpd.close() } func (mpd *mpd) proxy() { go mpd.queueStateChange() mpd.setState("pause") mpd.loadState() log.Printf("xferring %v -> %v", mpd.port, mpd.httpd) listen, err := net.Listen("tcp", ":"+strings.TrimPrefix(mpd.port, ":")) if err != nil { panic(err) } for { conn, err := listen.Accept() if err != nil { log.Printf("accept err: %v", err) } else { remote, err := net.Dial("tcp", mpd.httpd) if err != nil { log.Printf("dial err: %v", err) } else { log.Printf("xferring r:%v <-> c:%v", remote.RemoteAddr(), conn.RemoteAddr()) go mpd.xfer(conn, remote) go mpd.xfer(remote, conn) } } } } func (mpd *mpd) periodic() { for { time.Sleep(time.Hour) if err := mpd.saveState(); err != nil { log.Printf("MPD failed to save state: %v", err) } } } func (mpd *mpd) queueStateChange() { minSilence := time.Second * 1 for { last := <-mpd.stateChange complete := false for !complete { select { case last = <-mpd.stateChange: case <-time.After(minSilence): complete = true } } mpd.setState(last) } } func (mpd *mpd) queueState(state string) { log.Printf("queue state %q", state) mpd.stateChange <- state } func (mpd *mpd) setState(state string) { log.Printf("Set state: %v", state) defer log.Printf("/Set state: %v", state) was, err := mpd.getStatus() if err != nil { log.Printf("cannot get state: %v", err) return } log.Printf("WAS: %v", was) if was["state"] == state { log.Printf("state is already %q: %q", state, was["state"]) return } if _, err := mpd.mpdCommand(state); err != nil { log.Printf("MPD %q returned err on cmd: %v", state, err) return } if state != "pause" { return } if _, err := mpd.mpdCommand("seekcur -15"); err != nil { log.Printf("MPD seek -15s returned err: %v", err) return } else if err := mpd.saveState(); err != nil { log.Printf("MPD failed to save state: %v", err) return } } func (mpd *mpd) loadState() error { log.Printf("LOAD STATE") var m map[string]string if b, err := ioutil.ReadFile(mpd.stateFilePath); err != nil { log.Printf("1 %v", err) return err } else if err := json.Unmarshal(b, &m); err != nil { log.Printf("2 %v", err) return err } else if _, err := mpd.mpdCommand("play " + m["song"]); err != nil { log.Printf("3 %v", err) return err } else if _, err := mpd.mpdCommand("pause"); err != nil { log.Printf("4 %v", err) return err } else if _, err := mpd.mpdCommand("seekcur " + m["elapsed"]); err != nil { log.Printf("5 %v", err) return err } return nil } func (mpd *mpd) saveState() error { if is, err := mpd.getStatus(); err != nil { return err } else if b, err := json.Marshal(is); err != nil { return err } else if err := ioutil.WriteFile(mpd.stateFilePath, b, os.ModePerm); err != nil { return err } shPath, err := exec.LookPath("sh") if err != nil { return err } cmd := exec.Command(shPath, "/opt/mpd.save.sh") if err := cmd.Run(); err != nil { return err } return nil } func (mpd *mpd) mpdCommand(cmd string) ([]string, error) { mpc, err := textproto.Dial("tcp", mpd.mpd) if err != nil { return nil, err } defer mpc.Close() if ok, err := mpc.ReadLine(); err != nil || !strings.HasPrefix(ok, "OK MPD") { return nil, err } id, err := mpc.Cmd(cmd) if err != nil { return nil, err } mpc.StartResponse(id) defer mpc.EndResponse(id) b := []string{} line, err := mpc.ReadLine() for err == nil && line != "OK" { b = append(b, line) line, err = mpc.ReadLine() } return b, err } func (mpd *mpd) getStatus() (map[string]string, error) { lines, err := mpd.mpdCommand("status") if err != nil { return nil, err } m := make(map[string]string) for _, line := range lines { splits := strings.Split(line, ": ") if len(splits) != 2 { continue } m[splits[0]] = splits[1] } return m, nil } func (mpd *mpd) close() { log.Printf("CLOSE") } func (mpd *mpd) inc() { if atomic.AddInt32(&mpd.conns, 1) == 1 { mpd.queueState("play") } mpd.blockUntilPlaying() } func (mpd *mpd) blockUntilPlaying() { for { time.Sleep(time.Second) resp, err := http.Get(fmt.Sprintf("http://localhost%s", mpd.httpd)) if err != nil { continue } defer resp.Body.Close() b := make([]byte, 1024) if _, err := resp.Body.Read(b); err != nil { continue } return } } func (mpd *mpd) dec() { if atomic.AddInt32(&mpd.conns, -1) == 0 { mpd.queueState("pause") } } func (mpd *mpd) xfer(dst io.WriteCloser, src io.ReadCloser) { mpd.inc() defer mpd.dec() defer dst.Close() defer src.Close() io.Copy(dst, src) log.Printf("xfer %v -> %v ended", dst, src) } func catchSignal() { log.Printf("listening for signal") sigc := make(chan os.Signal) signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, ) <-sigc }