diff --git a/db.go b/db.go index b869d58..cff7d1c 100755 --- a/db.go +++ b/db.go @@ -42,6 +42,8 @@ func New(key Type, params ...string) (db DB, err error) { db, err = rclone.NewRClone(params[0], params[1]) case FILES: db, err = NewFiles(params[0]) + case YAML: + db, err = NewYaml(params[0]) case BOLT: db, err = NewBolt(params[0]) case MINIO: diff --git a/db_test.go b/db_test.go index a30ee57..67122fe 100755 --- a/db_test.go +++ b/db_test.go @@ -14,6 +14,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/google/uuid" ) @@ -198,19 +199,20 @@ type = local for _, db := range cases { t.Run(fmt.Sprintf("%T", db), func(t *testing.T) { log.Printf("Trying %T", db) - t.Logf(" %T: set", db) + t.Logf(" %T: list @[ns1, ns2] against empty", db) if keys, err := db.List([]string{"ns1", "ns2"}); err != nil || len(keys) > 0 { t.Errorf("%T) cannot List() empty: (%T) %+v: %v", db, err, err, keys) } - if keys, err := db.List([]string{path.Join("ns1", "ns2")}); err != nil || len(keys) > 0 { - t.Errorf("%T) cannot List() empty w/ /: (%T) %+v: %v", db, err, err, keys) - } + t.Logf(" %T: set %s @[ns1, ns2]", db, validKey) if err := db.Set(validKey, validValue, "ns1", "ns2"); err != nil { t.Errorf("%T) cannot set: %v", db, err) } - t.Logf(" %T: get", db) + t.Logf(" %T: db: %+v", db, db) + t.Logf(" %T: get %s @[ns1, ns2]", db, validKey) if v, err := db.Get(validKey, "ns1", "ns2"); err != nil { t.Errorf("%T) cannot get: %v", db, err) + log.Printf("%T) cannot get: %v (%+v)", db, err, db) + time.Sleep(time.Second * 10) } else if !bytes.Equal(v, validValue) { t.Errorf("%T) wrong get: %q vs %q", db, v, validValue) } diff --git a/type.go b/type.go index b848f8e..ef4975c 100755 --- a/type.go +++ b/type.go @@ -21,6 +21,7 @@ const ( MINIO = Type(iota) RCLONE = Type(iota) MAPSTREAM = Type(iota) + YAML = Type(iota) ) func (t Type) String() string { @@ -37,6 +38,8 @@ func (t Type) String() string { return "rclone" case COCKROACH: return "cockroach" + case YAML: + return "yaml" case FILES: return "files" case BOLT: diff --git a/yaml.go b/yaml.go index b4c5193..2d4ccdc 100755 --- a/yaml.go +++ b/yaml.go @@ -2,9 +2,12 @@ package storage import ( "bytes" + "encoding/base64" "errors" + "fmt" "io" "io/ioutil" + "local/storage/resolve" "os" "path" "path/filepath" @@ -43,7 +46,19 @@ func (y *Yaml) Namespaces() ([][]string, error) { } func (y *Yaml) List(ns []string, limits ...string) ([]string, error) { - return nil, errors.New("not impl") + namespace := resolve.Namespace(ns) + m, err := y.getMap(namespace) + if err != nil { + return nil, err + } + limits = resolve.Limits(limits) + ks := make([]string, 0, len(m)) + for k := range m { + if k >= limits[0] && k <= limits[1] { + ks = append(ks, k) + } + } + return ks, nil } func (y *Yaml) Get(key string, ns ...string) ([]byte, error) { @@ -55,7 +70,21 @@ func (y *Yaml) Get(key string, ns ...string) ([]byte, error) { } func (y *Yaml) GetStream(key string, ns ...string) (io.Reader, error) { - return nil, errors.New("not impl") + namespace := resolve.Namespace(ns) + m, err := y.getMap(namespace) + if err != nil { + return nil, err + } + v, ok := m[key] + if !ok { + return nil, ErrNotFound + } + s, ok := v.(string) + if !ok { + return nil, ErrNotFound + } + b, err := base64.StdEncoding.DecodeString(s) + return bytes.NewReader(b), err } func (y *Yaml) Set(key string, value []byte, ns ...string) error { @@ -67,30 +96,98 @@ func (y *Yaml) Set(key string, value []byte, ns ...string) error { } func (y *Yaml) Del(key string, ns ...string) error { - return errors.New("not impl") + return y.SetStream(key, nil, ns...) } func (y *Yaml) SetStream(key string, r io.Reader, ns ...string) error { - return errors.New("not impl") + namespace := resolve.Namespace(ns) + var v interface{} = nil + if r != nil { + b, err := ioutil.ReadAll(r) + if err != nil { + return err + } + v = base64.StdEncoding.EncodeToString(b) + } + m, err := y.getMap() + if err != nil { + return err + } + if err := setInMap(m, []string{namespace}, key, v); err != nil { + return err + } + return y.setMap(m) } func (y *Yaml) Close() error { return nil } -func (y *Yaml) getMap() (map[string]interface{}, error) { +func (y *Yaml) getMap(keys ...string) (map[string]interface{}, error) { b, err := y.get() if err != nil { return nil, err } - var m map[string]interface{} - err = yaml.Unmarshal(b, &m) + var mBad map[interface{}]interface{} + if err := yaml.Unmarshal(b, &mBad); err != nil { + return nil, err + } + m, err := mbadToM(mBad) + if err != nil { + return nil, err + } + if m == nil { + m = map[string]interface{}{} + } + for _, k := range keys { + subv, ok := m[k] + if !ok { + subv = map[string]interface{}{} + m[k] = subv + } + subm, ok := subv.(map[string]interface{}) + if !ok { + return nil, ErrNotFound + } + m = subm + } return m, err } +func (y *Yaml) setMap(m map[string]interface{}) error { + b, err := yaml.Marshal(m) + if err != nil { + return err + } + return y.set(b) +} + +func setInMap(m map[string]interface{}, keys []string, key string, v interface{}) error { + if len(keys) == 0 { + m[key] = v + if v == nil { + delete(m, key) + } + return nil + } + subv, ok := m[keys[0]] + if !ok { + subv = map[string]interface{}{} + } + subm, ok := subv.(map[string]interface{}) + if !ok { + return errors.New("clobber") + } + if err := setInMap(subm, keys[1:], key, v); err != nil { + return err + } + m[keys[0]] = subm + return nil +} + func (y *Yaml) get() ([]byte, error) { b, err := ioutil.ReadFile(y.path) - if err == os.ErrNotExist { + if os.IsNotExist(err) { return []byte{}, nil } return b, err @@ -127,3 +224,22 @@ func _keysDFS(m map[string]interface{}) ([][]string, bool, error) { } return keys, hasNonMaps, nil } + +func mbadToM(mBad map[interface{}]interface{}) (map[string]interface{}, error) { + m := map[string]interface{}{} + for k, v := range mBad { + s, ok := k.(string) + if !ok { + s = fmt.Sprint(k) + } + m[s] = v + if m2, ok := v.(map[interface{}]interface{}); ok { + v2, err := mbadToM(m2) + if err != nil { + return nil, err + } + m[s] = v2 + } + } + return m, nil +} diff --git a/yaml_test.go b/yaml_test.go index 3236a61..cac79d1 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -43,6 +43,17 @@ func TestKeysDFS(t *testing.T) { if err != nil { t.Fatal(err) } + for j := range c.want { + found := false + for i := range got { + if fmt.Sprint(got[i]) == fmt.Sprint(c.want[j]) { + found = true + } + } + if !found { + t.Errorf("want %+v among %+v", c.want[j], got) + } + } if fmt.Sprintf("%+v", got) != fmt.Sprintf("%+v", c.want) { t.Fatalf("want: %+v\ngot: %+v", c.want, got) }