diff --git a/cmd/cli.go b/cmd/cli.go index b82a1d4..6c64f4d 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -8,6 +8,8 @@ import ( "io/ioutil" "local/pt-todo-server/pttodo" "os" + "path" + "syscall" "gopkg.in/yaml.v2" ) @@ -20,13 +22,88 @@ func main() { func _main() error { filepath := flag.String("f", "-", "path to yaml file") + e := flag.Bool("e", false, "edit file") flag.Parse() + if *e { + if err := edit(*filepath); err != nil { + return err + } + } + return dump(os.Stdout, *filepath) +} +func edit(filepath string) error { + var tempFile string + cp := func() error { + f, err := ioutil.TempFile(os.TempDir(), path.Base(filepath)) + if err != nil { + return err + } + g, err := os.Open(filepath) + if err != nil { + return err + } + if _, err := io.Copy(f, g); err != nil { + return err + } + g.Close() + f.Close() + tempFile = f.Name() + return nil + } + vi := func() error { + vibin := "/usr/bin/vi" + cpid, err := syscall.ForkExec( + vibin, + []string{vibin, tempFile}, + &syscall.ProcAttr{ + Dir: "", + Env: os.Environ(), + Files: []uintptr{os.Stdin.Fd(), os.Stdout.Fd(), os.Stderr.Fd()}, + Sys: nil, + }, + ) + if err != nil { + return err + } + proc, err := os.FindProcess(cpid) + if err != nil { + return err + } + state, err := proc.Wait() + if err != nil { + return err + } + if exitCode := state.ExitCode(); exitCode != 0 { + return fmt.Errorf("bad exit code on vim: %d, state: %+v", exitCode, state) + } + return nil + } + verify := func() error { + return dump(io.Discard, tempFile) + } + save := func() error { + return os.Rename(tempFile, filepath) + } + + for _, foo := range []func() error{cp, vi, verify, save} { + if err := foo(); err != nil { + if tempFile != "" { + os.Remove(tempFile) + } + return err + } + } + + return nil +} + +func dump(writer io.Writer, filepath string) error { var reader io.Reader - if *filepath == "-" { + if filepath == "-" { reader = os.Stdin } else { - b, err := ioutil.ReadFile(*filepath) + b, err := ioutil.ReadFile(filepath) if err != nil { return err } @@ -48,6 +125,6 @@ func _main() error { if err != nil { return err } - fmt.Printf("%s\n", b2) + fmt.Fprintf(writer, "%s\n", b2) return nil }