package driver import ( "context" "encoding/json" "errors" "local/dndex/storage/entity" "local/storage" "log" "go.mongodb.org/mongo-driver/bson" ) type Storage struct { db storage.DB } func NewStorage(args ...string) *Storage { if len(args) == 0 { args = []string{"map"} } typed := storage.TypeFromString(args[0]) log.Println("Creating storage of type", typed) args = args[1:] db, err := storage.New(typed, args...) if err != nil { panic(err) } return &Storage{ db: db, } } func (s *Storage) Count(ctx context.Context, ns string, filter interface{}) (int, error) { ch, err := s.Find(ctx, ns, filter) n := 0 for _ = range ch { n++ } return n, err } func (s *Storage) Find(ctx context.Context, ns string, filter interface{}) (chan bson.Raw, error) { ch := make(chan bson.Raw) go func() { defer close(ch) if err := s.forEach(ctx, ns, filter, func(_ string, v []byte) error { ch <- v return nil }); err != nil { log.Println(err) } }() return ch, nil } func (s *Storage) Update(ctx context.Context, ns string, filter, operator interface{}) error { return s.forEach(ctx, ns, filter, func(id string, v []byte) error { n := bson.M{} if err := bson.Unmarshal(v, &n); err != nil { return err } n, err := apply(n, operator) if err != nil { return err } v, err = json.Marshal(n) if err != nil { return err } return s.db.Set(id, v, ns) }) } func (s *Storage) Insert(ctx context.Context, ns string, doc interface{}) error { b, err := json.Marshal(doc) if err != nil { return err } m := bson.M{} if err := json.Unmarshal(b, &m); err != nil { return err } idi, ok := m[entity.ID] if !ok { return errors.New("primary key required to insert: did not find " + entity.ID) } id, ok := idi.(string) if !ok { return errors.New("primary key must be a string") } if _, err := s.db.Get(id, ns); err == nil { return errors.New("collision") } return s.db.Set(id, b, ns) } func (s *Storage) Delete(ctx context.Context, ns string, filter interface{}) error { return s.forEach(ctx, ns, filter, func(id string, v []byte) error { return s.db.Set(id, nil, ns) }) } func (s *Storage) forEach(ctx context.Context, ns string, filter interface{}, foo func(string, []byte) error) error { b, err := bson.Marshal(filter) if err != nil { return err } m := bson.M{} if err := bson.Unmarshal(b, &m); err != nil { return err } ids, err := s.db.List([]string{ns}) if err != nil { return err } for _, id := range ids { v, err := s.db.Get(id, ns) if err != nil { return err } else { n := bson.M{} if err := json.Unmarshal(v, &n); err != nil { return err } if matches(n, m) { b, err := bson.Marshal(n) if err != nil { return err } if err := foo(id, b); err != nil { return err } } } } return nil }