From ba6d6483e0e0f6c1f0b2e2097c2a435e7d04f680 Mon Sep 17 00:00:00 2001 From: bel Date: Sun, 29 Oct 2023 10:09:14 -0600 Subject: [PATCH] i think files.TempLastNLines are symlink friendly... --- src/ledger/file.go | 41 +++++++++++++++++++++++++++++++++-------- src/ledger/file_test.go | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/ledger/file.go b/src/ledger/file.go index 264d047..925e17b 100644 --- a/src/ledger/file.go +++ b/src/ledger/file.go @@ -36,26 +36,51 @@ func (files Files) TempGetLastNLines(n int) ([]string, error) { func (files Files) TempSetLastNLines(n int, lines []string) error { p := files.paths()[0] - w, err := ioutil.TempFile(os.TempDir(), path.Base(p)) + + newFile, err := func() (string, error) { + w, err := ioutil.TempFile(os.TempDir(), path.Base(p)) + if err != nil { + return "", err + } + defer w.Close() + + r, err := os.Open(p) + if err != nil { + return "", err + } + defer r.Close() + + if _, err := peekLastNLines(w, bufio.NewReader(r), n); err != nil { + return "", err + } + for i := range lines { + if _, err := fmt.Fprintln(w, lines[i]); err != nil { + return "", err + } + } + if err := w.Close(); err != nil { + return "", err + } + return w.Name(), nil + }() if err != nil { return err } - defer w.Close() - r, err := os.Open(p) + r, err := os.Open(newFile) if err != nil { return err } defer r.Close() - if _, err := peekLastNLines(w, bufio.NewReader(r), n); err != nil { + w, err := os.Create(p) + if err != nil { return err } - for i := range lines { - fmt.Fprintln(w, lines[i]) - } + defer w.Close() - return os.Rename(w.Name(), p) + _, err = io.Copy(w, r) + return err } func peekLastNLines(w io.Writer, r *bufio.Reader, n int) ([]string, error) { diff --git a/src/ledger/file_test.go b/src/ledger/file_test.go index 1628a88..8150691 100644 --- a/src/ledger/file_test.go +++ b/src/ledger/file_test.go @@ -1,6 +1,7 @@ package ledger import ( + "bytes" "encoding/base64" "fmt" "os" @@ -401,22 +402,39 @@ func TestFilesTempSetLastNLines(t *testing.T) { c := d t.Run(name, func(t *testing.T) { p := path.Join(t.TempDir(), base64.URLEncoding.EncodeToString([]byte(t.Name()))) - os.WriteFile(p, []byte(c.given), os.ModePerm) - files := Files([]string{p}) - if err := files.TempSetLastNLines(c.n, c.input); err != nil { - s := err.Error() - if _, err := os.Stat(s); err == nil { - got, _ := os.ReadFile(s) - if string(got) != c.want { - t.Errorf("want\n%s, got\n%s", c.want, got) - } - } + realp := p + ".real" + + os.WriteFile(realp, []byte(c.given), os.ModePerm) + if err := os.Symlink(realp, p); err != nil { t.Fatal(err) } + if stat, err := os.Lstat(p); err != nil { + t.Error(err) + } else if stat.Mode().IsRegular() { + t.Error("p is already a regular file") + } + + files := Files([]string{p}) + if err := files.TempSetLastNLines(c.n, c.input); err != nil { + t.Fatal(err) + } + got, _ := os.ReadFile(p) if string(got) != c.want { t.Errorf("want\n%s, got\n%s", c.want, got) } + + realb, _ := os.ReadFile(realp) + b, _ := os.ReadFile(realp) + if !bytes.Equal(b, realb) { + t.Errorf("%s no longer links to %s", p, realp) + } + + if stat, err := os.Lstat(p); err != nil { + t.Error(err) + } else if stat.Mode().IsRegular() { + t.Error("p is now a regular file") + } }) } }