diff --git a/db.go b/db.go index 16a34d6..3aaa7c6 100644 --- a/db.go +++ b/db.go @@ -1,9 +1,6 @@ package main import ( - "errors" - "fmt" - "strconv" "time" "github.com/google/uuid" @@ -22,7 +19,6 @@ type ( user struct { History map[string][]History } - duration time.Duration ) func (db db) HistoryOf(user string) map[string][]History { @@ -88,57 +84,3 @@ func (db db) lastTS(user, q string) time.Time { } return time.Unix(0, max) } - -func (d duration) MarshalYAML() (interface{}, error) { - return d.String(), nil -} - -func (d duration) String() string { - result := "" - if weeks := time.Duration(d) / (time.Hour * 24 * 7); weeks > 0 { - result += fmt.Sprintf("%dw", weeks) - d -= duration(weeks * time.Hour * 24 * 7) - } - if days := time.Duration(d) / (time.Hour * 24); days > 0 { - result += fmt.Sprintf("%dd", days) - d -= duration(days * time.Hour * 24) - } - return result + time.Duration(d).String() -} - -func (d *duration) UnmarshalYAML(unmarshal func(interface{}) error) error { - var s string - if err := unmarshal(&s); err != nil { - return err - } - count := "" - var ttl time.Duration - for i := range s { - if s[i] < '0' || s[i] > '9' { - n, err := strconv.Atoi(count) - if err != nil { - return err - } - count += s[i : i+1] - switch s[i] { - case 'w': - count = fmt.Sprintf("%dh", n*24*7) - case 'd': - count = fmt.Sprintf("%dh", n*24) - } - d, err := time.ParseDuration(count) - if err != nil { - return err - } - ttl += d - count = "" - } else { - count += s[i : i+1] - } - } - if count != "" { - return errors.New(count) - } - *d = duration(ttl) - return nil -} diff --git a/duration.go b/duration.go new file mode 100644 index 0000000..eb11283 --- /dev/null +++ b/duration.go @@ -0,0 +1,64 @@ +package main + +import ( + "errors" + "fmt" + "strconv" + "time" +) + +type duration time.Duration + +func (d duration) MarshalYAML() (interface{}, error) { + return d.String(), nil +} + +func (d duration) String() string { + result := "" + if weeks := time.Duration(d) / (time.Hour * 24 * 7); weeks > 0 { + result += fmt.Sprintf("%dw", weeks) + d -= duration(weeks * time.Hour * 24 * 7) + } + if days := time.Duration(d) / (time.Hour * 24); days > 0 { + result += fmt.Sprintf("%dd", days) + d -= duration(days * time.Hour * 24) + } + return result + time.Duration(d).String() +} + +func (d *duration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + count := "" + var ttl time.Duration + for i := range s { + if s[i] < '0' || s[i] > '9' { + n, err := strconv.Atoi(count) + if err != nil { + return err + } + count += s[i : i+1] + switch s[i] { + case 'w': + count = fmt.Sprintf("%dh", n*24*7) + case 'd': + count = fmt.Sprintf("%dh", n*24) + } + d, err := time.ParseDuration(count) + if err != nil { + return err + } + ttl += d + count = "" + } else { + count += s[i : i+1] + } + } + if count != "" { + return errors.New(count) + } + *d = duration(ttl) + return nil +} diff --git a/duration_test.go b/duration_test.go new file mode 100644 index 0000000..0010062 --- /dev/null +++ b/duration_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" + "time" + + "gopkg.in/yaml.v2" +) + +func TestDurationYAML(t *testing.T) { + cases := map[string]time.Duration{ + "": 0, + "1s": time.Second, + "1m1s": time.Minute + time.Second, + "1h1s": time.Hour + time.Second, + "1d": 24 * time.Hour, + "1w": 7 * 24 * time.Hour, + "1w2d3h4m5s": 7*24*time.Hour + 2*24*time.Hour + 3*time.Hour + 4*time.Minute + 5*time.Second, + } + + for name, wantd := range cases { + want := wantd + t.Run(name, func(t *testing.T) { + var d1, d2 duration + if err := yaml.Unmarshal([]byte(name), &d1); err != nil { + t.Error(err) + } else if time.Duration(d1) != want { + t.Error(d1) + } else if b, err := yaml.Marshal(d1); err != nil { + t.Error(err) + } else if err := yaml.Unmarshal(b, &d2); err != nil { + t.Error(err) + } else if d1 != d2 { + t.Error(d1, d2) + } + }) + } +}