diff --git a/db.go b/db.go index f2c94c4..f407a39 100644 --- a/db.go +++ b/db.go @@ -1,6 +1,10 @@ package main import ( + "errors" + "fmt" + "log" + "strconv" "time" "github.com/google/uuid" @@ -8,17 +12,18 @@ import ( type ( db struct { - Knowledge Knowledge - Users map[string]User - Cadence []Duration + Knowledge knowledge + Users map[string]user + Cadence []duration } - Knowledge struct { + knowledge struct { Questions map[string]Question Answers map[string]Answer } - User struct { + user struct { History map[string][]History } + duration time.Duration ) func (db db) HistoryOf(user string) map[string][]History { @@ -84,3 +89,45 @@ func (db db) lastTS(user, q string) time.Time { } return time.Unix(0, max) } + +func (d duration) MarshalYAML() (interface{}, error) { + log.Println("marshalling duration", time.Duration(d)) + return time.Duration(d).String(), nil +} + +func (d *duration) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err + } + count := "" + var ttl time.Duration + for i := range s { + if s[i] < '0' || s[i] > '9' { + n, err := strconv.Atoi(count) + if err != nil { + return err + } + count += s[i : i+1] + switch s[i] { + case 'w': + count = fmt.Sprintf("%dh", n*24*7) + case 'd': + count = fmt.Sprintf("%dh", n*24) + } + d, err := time.ParseDuration(count) + if err != nil { + return err + } + ttl += d + count = "" + } else { + count += s[i : i+1] + } + } + if count != "" { + return errors.New(count) + } + *d = duration(ttl) + return nil +} diff --git a/main.go b/main.go index cbb5ca2..33633fb 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,9 @@ package main import ( "bufio" - "errors" "fmt" "log" "os" - "strconv" "time" "gopkg.in/yaml.v2" @@ -36,7 +34,6 @@ type ( TS int64 Pass bool } - Duration time.Duration ) func main() { @@ -91,43 +88,3 @@ func NewDB() (DB, error) { } return db, nil } - -func (d Duration) MarshalYAML() (interface{}, error) { - return time.Duration(d).String(), nil -} - -func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { - var s string - if err := unmarshal(&s); err != nil { - return err - } - count := "" - var ttl time.Duration - for i := range s { - if s[i] < '0' || s[i] > '9' { - n, err := strconv.Atoi(count) - if err != nil { - return err - } - count += s[i : i+1] - switch s[i] { - case 'w': - count = fmt.Sprintf("%dh", n*24*7) - case 'd': - count = fmt.Sprintf("%dh", n*24) - } - d, err := time.ParseDuration(count) - if err != nil { - return err - } - ttl += d - count = "" - } else { - count += s[i : i+1] - } - } - if count != "" { - return errors.New(count) - } - return nil -}