newline battles continue

This commit is contained in:
bel
2020-01-19 20:41:30 +00:00
parent 98adb53caf
commit 991c27d044
1457 changed files with 525871 additions and 6 deletions

View File

@@ -0,0 +1,94 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// AbortTransaction represents the abortTransaction() command
type AbortTransaction struct {
Session *session.Client
err error
result result.TransactionResult
}
// Encode will encode this command into a wiremessage for the given server description.
func (at *AbortTransaction) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := at.encode(desc)
return cmd.Encode(desc)
}
func (at *AbortTransaction) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"abortTransaction", bsonx.Int32(1)}}
if at.Session.RecoveryToken != nil {
tokenDoc, _ := bsonx.ReadDoc(at.Session.RecoveryToken)
cmd = append(cmd, bsonx.Elem{"recoveryToken", bsonx.Document(tokenDoc)})
}
return &Write{
DB: "admin",
Command: cmd,
Session: at.Session,
WriteConcern: at.Session.CurrentWc,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (at *AbortTransaction) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *AbortTransaction {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
at.err = err
return at
}
return at.decode(desc, rdr)
}
func (at *AbortTransaction) decode(desc description.SelectedServer, rdr bson.Raw) *AbortTransaction {
at.err = bson.Unmarshal(rdr, &at.result)
if at.err == nil && at.result.WriteConcernError != nil {
at.err = Error{
Name: at.result.WriteConcernError.Name,
Code: int32(at.result.WriteConcernError.Code),
Message: at.result.WriteConcernError.ErrMsg,
}
}
return at
}
// Result returns the result of a decoded wire message and server description.
func (at *AbortTransaction) Result() (result.TransactionResult, error) {
if at.err != nil {
return result.TransactionResult{}, at.err
}
return at.result, nil
}
// Err returns the error set on this command
func (at *AbortTransaction) Err() error {
return at.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (at *AbortTransaction) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.TransactionResult, error) {
cmd := at.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.TransactionResult{}, err
}
return at.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,165 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Aggregate represents the aggregate command.
//
// The aggregate command performs an aggregation.
type Aggregate struct {
NS Namespace
Pipeline bsonx.Arr
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
WriteConcern *writeconcern.WriteConcern
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (a *Aggregate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := a.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) {
if err := a.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"aggregate", bsonx.String(a.NS.Collection)},
{"pipeline", bsonx.Array(a.Pipeline)},
}
cursor := bsonx.Doc{}
hasOutStage := a.HasDollarOut()
for _, opt := range a.Opts {
switch opt.Key {
case "batchSize":
if opt.Value.Int32() == 0 && hasOutStage {
continue
}
cursor = append(cursor, opt)
default:
command = append(command, opt)
}
}
command = append(command, bsonx.Elem{"cursor", bsonx.Document(cursor)})
// add write concern because it won't be added by the Read command's Encode()
if desc.WireVersion.Max >= 5 && hasOutStage && a.WriteConcern != nil {
t, data, err := a.WriteConcern.MarshalBSONValue()
if err != nil {
return nil, err
}
var xval bsonx.Val
err = xval.UnmarshalBSONValue(t, data)
if err != nil {
return nil, err
}
command = append(command, bsonx.Elem{Key: "writeConcern", Value: xval})
}
return &Read{
DB: a.NS.DB,
Command: command,
ReadPref: a.ReadPref,
ReadConcern: a.ReadConcern,
Clock: a.Clock,
Session: a.Session,
}, nil
}
// HasDollarOut returns true if the Pipeline field contains a $out stage.
func (a *Aggregate) HasDollarOut() bool {
if a.Pipeline == nil {
return false
}
if len(a.Pipeline) == 0 {
return false
}
val := a.Pipeline[len(a.Pipeline)-1]
doc, ok := val.DocumentOK()
if !ok || len(doc) != 1 {
return false
}
return doc[0].Key == "$out"
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
a.err = err
return a
}
return a.decode(desc, rdr)
}
func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate {
a.result = rdr
if val, err := rdr.LookupErr("writeConcernError"); err == nil {
var wce result.WriteConcernError
_ = val.Unmarshal(&wce)
a.err = wce
}
return a
}
// Result returns the result of a decoded wire message and server description.
func (a *Aggregate) Result() (bson.Raw, error) {
if a.err != nil {
return nil, a.err
}
return a.result, nil
}
// Err returns the error set on this command.
func (a *Aggregate) Err() error { return a.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := a.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return a.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,95 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// BuildInfo represents the buildInfo command.
//
// The buildInfo command is used for getting the build information for a
// MongoDB server.
type BuildInfo struct {
err error
res result.BuildInfo
}
// Encode will encode this command into a wire message for the given server description.
func (bi *BuildInfo) Encode() (wiremessage.WireMessage, error) {
// This can probably just be a global variable that we reuse.
cmd := bsonx.Doc{{"buildInfo", bsonx.Int32(1)}}
rdr, err := cmd.MarshalBSON()
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: "admin.$cmd",
Flags: wiremessage.SlaveOK,
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (bi *BuildInfo) Decode(wm wiremessage.WireMessage) *BuildInfo {
reply, ok := wm.(wiremessage.Reply)
if !ok {
bi.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return bi
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
bi.err = err
return bi
}
err = bson.Unmarshal(rdr, &bi.res)
if err != nil {
bi.err = err
return bi
}
return bi
}
// Result returns the result of a decoded wire message and server description.
func (bi *BuildInfo) Result() (result.BuildInfo, error) {
if bi.err != nil {
return result.BuildInfo{}, bi.err
}
return bi.res, nil
}
// Err returns the error set on this command.
func (bi *BuildInfo) Err() error { return bi.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (bi *BuildInfo) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.BuildInfo, error) {
wm, err := bi.Encode()
if err != nil {
return result.BuildInfo{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return result.BuildInfo{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return result.BuildInfo{}, err
}
return bi.Decode(wm).Result()
}

View File

@@ -0,0 +1,708 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command // import "go.mongodb.org/mongo-driver/x/network/command"
import (
"errors"
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// WriteBatch represents a single batch for a write operation.
type WriteBatch struct {
*Write
numDocs int
}
// DecodeError attempts to decode the wiremessage as an error
func DecodeError(wm wiremessage.WireMessage) error {
var rdr bson.Raw
switch msg := wm.(type) {
case wiremessage.Msg:
for _, section := range msg.Sections {
switch converted := section.(type) {
case wiremessage.SectionBody:
rdr = converted.Document
}
}
case wiremessage.Reply:
if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
return nil
}
rdr = msg.Documents[0]
}
err := rdr.Validate()
if err != nil {
return nil
}
extractedError := extractError(rdr)
// If parsed successfully return the error
if _, ok := extractedError.(Error); ok {
return err
}
return nil
}
// helper method to extract an error from a reader if there is one; first returned item is the
// error if it exists, the second holds parsing errors
func extractError(rdr bson.Raw) error {
var errmsg, codeName string
var code int32
var labels []string
elems, err := rdr.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bson.TypeInt32:
if elem.Value().Int32() == 1 {
return nil
}
case bson.TypeInt64:
if elem.Value().Int64() == 1 {
return nil
}
case bson.TypeDouble:
if elem.Value().Double() == 1 {
return nil
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
elems, err := arr.Elements()
if err != nil {
continue
}
for _, elem := range elems {
if str, ok := elem.Value().StringValueOK(); ok {
labels = append(labels, str)
}
}
}
}
}
if errmsg == "" {
errmsg = "command failed"
}
return Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
}
}
func responseClusterTime(response bson.Raw) bson.Raw {
clusterTime, err := response.LookupErr("$clusterTime")
if err != nil {
// $clusterTime not included by the server
return nil
}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
doc = append(doc, clusterTime.Value...)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc
}
func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
clusterTime := responseClusterTime(response)
if clusterTime == nil {
return nil
}
if sess != nil {
err := sess.AdvanceClusterTime(clusterTime)
if err != nil {
return err
}
}
if clock != nil {
clock.AdvanceClusterTime(clusterTime)
}
return nil
}
func updateOperationTime(sess *session.Client, response bson.Raw) error {
if sess == nil {
return nil
}
opTimeElem, err := response.LookupErr("operationTime")
if err != nil {
// operationTime not included by the server
return nil
}
t, i := opTimeElem.Timestamp()
return sess.AdvanceOperationTime(&primitive.Timestamp{
T: t,
I: i,
})
}
func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
if cmd == nil {
return bson.Raw{5, 0, 0, 0, 0}, nil
}
return cmd.MarshalBSON()
}
// adds session related fields to a BSON doc representing a command
func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
return cmd, nil
}
if client.Terminated {
return cmd, session.ErrSessionEnded
}
if _, err := cmd.LookupElementErr("lsid"); err != nil {
cmd = cmd.Delete("lsid")
}
cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})
if client.TransactionRunning() ||
client.RetryingCommit {
cmd = addTransaction(cmd, client)
}
client.ApplyCommand(desc.Server) // advance the state machine based on a command executing
return cmd, nil
}
// if in a transaction, add the transaction fields
func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
if client.TransactionStarting() {
// When starting transaction, always transition to the next state, even on error
cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
}
return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
}
func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
return cmd
}
var clusterTime bson.Raw
if clock != nil {
clusterTime = clock.GetClusterTime()
}
if sess != nil {
if clusterTime == nil {
clusterTime = sess.ClusterTime
} else {
clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
}
}
if clusterTime == nil {
return cmd
}
d, err := bsonx.ReadDoc(clusterTime)
if err != nil {
return cmd // broken clusterTime
}
cmd = cmd.Delete("$clusterTime")
return append(cmd, d...)
}
// add a read concern to a BSON doc representing a command
func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
// Starting transaction's read concern overrides all others
if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
rc = sess.CurrentRc
}
// start transaction must append afterclustertime IF causally consistent and operation time exists
if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
rc = readconcern.New()
}
if rc == nil {
return cmd, nil
}
t, data, err := rc.MarshalBSONValue()
if err != nil {
return cmd, err
}
var rcDoc bsonx.Doc
err = rcDoc.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
}
cmd = cmd.Delete("readConcern")
if len(rcDoc) != 0 {
cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
}
return cmd, nil
}
// add a write concern to a BSON doc representing a command
func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
if wc == nil {
return cmd, nil
}
t, data, err := wc.MarshalBSONValue()
if err != nil {
if err == writeconcern.ErrEmptyWriteConcern {
return cmd, nil
}
return cmd, err
}
var xval bsonx.Val
err = xval.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
// delete if doc already has write concern
cmd = cmd.Delete("writeConcern")
return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
}
// Get the error labels from a command response
func getErrorLabels(rdr *bson.Raw) ([]string, error) {
var labels []string
labelsElem, err := rdr.LookupErr("errorLabels")
if err != bsoncore.ErrElementNotFound {
return nil, err
}
if labelsElem.Type == bsontype.Array {
labelsIt, err := labelsElem.Array().Elements()
if err != nil {
return nil, err
}
for _, elem := range labelsIt {
labels = append(labels, elem.Value().StringValue())
}
}
return labels, nil
}
// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
// as a Section 1 payload in OP_MSG
func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
var array bsonx.Arr
var id string
keys := []string{"documents", "updates", "deletes"}
for _, key := range keys {
val, err := cmd.LookupErr(key)
if err != nil {
continue
}
array = val.Array()
cmd = cmd.Delete(key)
id = key
break
}
return cmd, array, id
}
// Add the $db and $readPreference keys to the command
// If the command has no read preference, pass nil for rpDoc
func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
if rpDoc != nil {
cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
}
return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
}
func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
docSequence := wiremessage.SectionDocumentSequence{
PayloadType: wiremessage.DocumentSequence,
Identifier: identifier,
Documents: make([]bson.Raw, 0, len(arr)),
}
for _, val := range arr {
d, _ := val.Document().MarshalBSON()
docSequence.Documents = append(docSequence.Documents, d)
}
docSequence.Size = int32(docSequence.PayloadLen())
return docSequence, nil
}
func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
batches := [][]bsonx.Doc{}
if targetBatchSize > reservedCommandBufferBytes {
targetBatchSize -= reservedCommandBufferBytes
}
if maxCount <= 0 {
maxCount = 1
}
startAt := 0
splitInserts:
for {
size := 0
batch := []bsonx.Doc{}
assembleBatch:
for idx := startAt; idx < len(docs); idx++ {
raw, _ := docs[idx].MarshalBSON()
if len(raw) > targetBatchSize {
return nil, ErrDocumentTooLarge
}
if size+len(raw) > targetBatchSize {
break assembleBatch
}
size += len(raw)
batch = append(batch, docs[idx])
startAt++
if len(batch) == maxCount {
break assembleBatch
}
}
batches = append(batches, batch)
if startAt == len(docs) {
break splitInserts
}
}
return batches, nil
}
func encodeBatch(
docs []bsonx.Doc,
opts []bsonx.Elem,
cmdKind WriteCommandKind,
collName string,
) (bsonx.Doc, error) {
var cmdName string
var docString string
switch cmdKind {
case InsertCommand:
cmdName = "insert"
docString = "documents"
case UpdateCommand:
cmdName = "update"
docString = "updates"
case DeleteCommand:
cmdName = "delete"
docString = "deletes"
}
cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}
vals := make(bsonx.Arr, 0, len(docs))
for _, doc := range docs {
vals = append(vals, bsonx.Document(doc))
}
cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
cmd = append(cmd, opts...)
return cmd, nil
}
// converts batches of Write Commands to wire messages
func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
wms := make([]wiremessage.WireMessage, len(batches))
for _, cmd := range batches {
wm, err := cmd.Encode(desc)
if err != nil {
return nil, err
}
wms = append(wms, wm)
}
return wms, nil
}
// Roundtrips the write batches, returning the result structs (as interface),
// the write batches that weren't round tripped and any errors
func roundTripBatches(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
batches []*WriteBatch,
continueOnError bool,
sess *session.Client,
cmdKind WriteCommandKind,
) (interface{}, []*WriteBatch, error) {
var res interface{}
var upsertIndex int64 // the operation index for the upserted IDs map
// hold onto txnNumber, reset it when loop exits to ensure reuse of same
// transaction number if retry is needed
var txnNumber int64
if sess != nil && sess.RetryWrite {
txnNumber = sess.TxnNumber
}
for j, cmd := range batches {
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber + int64(j)
}
return res, batches, err
}
// TODO can probably DRY up this code
switch cmdKind {
case InsertCommand:
if res == nil {
res = result.Insert{}
}
conv, _ := res.(result.Insert)
insertCmd := &Insert{}
r, err := insertCmd.decode(desc, rdr).Result()
if err != nil {
return res, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
case UpdateCommand:
if res == nil {
res = result.Update{}
}
conv, _ := res.(result.Update)
updateCmd := &Update{}
r, err := updateCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.MatchedCount += r.MatchedCount
conv.ModifiedCount += r.ModifiedCount
for _, upsert := range r.Upserted {
conv.Upserted = append(conv.Upserted, result.Upsert{
Index: upsert.Index + upsertIndex,
ID: upsert.ID,
})
}
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
upsertIndex += int64(cmd.numDocs)
case DeleteCommand:
if res == nil {
res = result.Delete{}
}
conv, _ := res.(result.Delete)
deleteCmd := &Delete{}
r, err := deleteCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
}
// Increment txnNumber for each batch
if sess != nil && sess.RetryWrite {
sess.IncrementTxnNumber()
batches = batches[1:] // if batch encoded successfully, remove it from the slice
}
}
if sess != nil && sess.RetryWrite {
// if retryable write succeeded, transaction number will be incremented one extra time,
// so we decrement it here
sess.TxnNumber--
}
return res, batches, nil
}
// get the firstBatch, cursor ID, and namespace from a bson.Raw
func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
cur, err := result.LookupErr("cursor")
if err != nil {
return nil, Namespace{}, 0, err
}
if cur.Type != bson.TypeEmbeddedDocument {
return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
}
elems, err := cur.Document().Elements()
if err != nil {
return nil, Namespace{}, 0, err
}
var ok bool
var arr bson.Raw
var namespace Namespace
var cursorID int64
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok = elem.Value().ArrayOK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
}
if err != nil {
return nil, Namespace{}, 0, err
}
case "ns":
if elem.Value().Type != bson.TypeString {
return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
}
namespace = ParseNamespace(elem.Value().StringValue())
err = namespace.Validate()
if err != nil {
return nil, Namespace{}, 0, err
}
case "id":
cursorID, ok = elem.Value().Int64OK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
}
}
vals, err := arr.Values()
if err != nil {
return nil, Namespace{}, 0, err
}
return vals, namespace, cursorID, nil
}
func getBatchSize(opts []bsonx.Elem) int32 {
for _, opt := range opts {
if opt.Key == "batchSize" {
return opt.Value.Int32()
}
}
return 0
}
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
// WriteCommandKind is the type of command represented by a Write
type WriteCommandKind int8
// These constants represent the valid types of write commands.
const (
InsertCommand WriteCommandKind = iota
UpdateCommand
DeleteCommand
)

View File

@@ -0,0 +1,94 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CommitTransaction represents the commitTransaction() command
type CommitTransaction struct {
Session *session.Client
err error
result result.TransactionResult
}
// Encode will encode this command into a wiremessage for the given server description.
func (ct *CommitTransaction) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := ct.encode(desc)
return cmd.Encode(desc)
}
func (ct *CommitTransaction) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"commitTransaction", bsonx.Int32(1)}}
if ct.Session.RecoveryToken != nil {
tokenDoc, _ := bsonx.ReadDoc(ct.Session.RecoveryToken)
cmd = append(cmd, bsonx.Elem{"recoveryToken", bsonx.Document(tokenDoc)})
}
return &Write{
DB: "admin",
Command: cmd,
Session: ct.Session,
WriteConcern: ct.Session.CurrentWc,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (ct *CommitTransaction) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *CommitTransaction {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ct.err = err
return ct
}
return ct.decode(desc, rdr)
}
func (ct *CommitTransaction) decode(desc description.SelectedServer, rdr bson.Raw) *CommitTransaction {
ct.err = bson.Unmarshal(rdr, &ct.result)
if ct.err == nil && ct.result.WriteConcernError != nil {
ct.err = Error{
Name: ct.result.WriteConcernError.Name,
Code: int32(ct.result.WriteConcernError.Code),
Message: ct.result.WriteConcernError.ErrMsg,
}
}
return ct
}
// Result returns the result of a decoded wire message and server description.
func (ct *CommitTransaction) Result() (result.TransactionResult, error) {
if ct.err != nil {
return result.TransactionResult{}, ct.err
}
return ct.result, nil
}
// Err returns the error set on this command
func (ct *CommitTransaction) Err() error {
return ct.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (ct *CommitTransaction) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.TransactionResult, error) {
cmd := ct.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.TransactionResult{}, err
}
return ct.decode(desc, rdr).Result()
}

128
vendor/go.mongodb.org/mongo-driver/x/network/command/count.go generated vendored Executable file
View File

@@ -0,0 +1,128 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Count represents the count command.
//
// The count command counts how many documents in a collection match the given query.
type Count struct {
NS Namespace
Query bsonx.Doc
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result int64
err error
}
// Encode will encode this command into a wire message for the given server description.
func (c *Count) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := c.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (c *Count) encode(desc description.SelectedServer) (*Read, error) {
if err := c.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"count", bsonx.String(c.NS.Collection)}, {"query", bsonx.Document(c.Query)}}
command = append(command, c.Opts...)
return &Read{
Clock: c.Clock,
DB: c.NS.DB,
ReadPref: c.ReadPref,
Command: command,
ReadConcern: c.ReadConcern,
Session: c.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (c *Count) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Count {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
c.err = err
return c
}
return c.decode(desc, rdr)
}
func (c *Count) decode(desc description.SelectedServer, rdr bson.Raw) *Count {
val, err := rdr.LookupErr("n")
switch {
case err == bsoncore.ErrElementNotFound:
c.err = errors.New("invalid response from server, no 'n' field")
return c
case err != nil:
c.err = err
return c
}
switch val.Type {
case bson.TypeInt32:
c.result = int64(val.Int32())
case bson.TypeInt64:
c.result = val.Int64()
case bson.TypeDouble:
c.result = int64(val.Double())
default:
c.err = errors.New("invalid response from server, value field is not a number")
}
return c
}
// Result returns the result of a decoded wire message and server description.
func (c *Count) Result() (int64, error) {
if c.err != nil {
return 0, c.err
}
return c.result, nil
}
// Err returns the error set on this command.
func (c *Count) Err() error { return c.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (c *Count) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
cmd, err := c.encode(desc)
if err != nil {
return 0, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return 0, err
}
return c.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,134 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CountDocuments represents the CountDocuments command.
//
// The countDocuments command counts how many documents in a collection match the given query.
type CountDocuments struct {
NS Namespace
Pipeline bsonx.Arr
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result int64
err error
}
// Encode will encode this command into a wire message for the given server description.
func (c *CountDocuments) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
if err := c.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"aggregate", bsonx.String(c.NS.Collection)}, {"pipeline", bsonx.Array(c.Pipeline)}}
command = append(command, bsonx.Elem{"cursor", bsonx.Document(bsonx.Doc{})})
command = append(command, c.Opts...)
return (&Read{DB: c.NS.DB, ReadPref: c.ReadPref, Command: command, Session: c.Session}).Encode(desc)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, wm wiremessage.WireMessage) *CountDocuments {
rdr, err := (&Read{Session: c.Session}).Decode(desc, wm).Result()
if err != nil {
c.err = err
return c
}
cursor, err := rdr.LookupErr("cursor")
if err != nil || cursor.Type != bsontype.EmbeddedDocument {
c.err = errors.New("Invalid response from server, no 'cursor' field")
return c
}
batch, err := cursor.Document().LookupErr("firstBatch")
if err != nil || batch.Type != bsontype.Array {
c.err = errors.New("Invalid response from server, no 'firstBatch' field")
return c
}
elem, err := batch.Array().IndexErr(0)
if err != nil || elem.Value().Type != bsontype.EmbeddedDocument {
c.result = 0
return c
}
val, err := elem.Value().Document().LookupErr("n")
if err != nil {
c.err = errors.New("Invalid response from server, no 'n' field")
return c
}
switch val.Type {
case bsontype.Int32:
c.result = int64(val.Int32())
case bsontype.Int64:
c.result = val.Int64()
default:
c.err = errors.New("Invalid response from server, value field is not a number")
}
return c
}
// Result returns the result of a decoded wire message and server description.
func (c *CountDocuments) Result() (int64, error) {
if c.err != nil {
return 0, c.err
}
return c.result, nil
}
// Err returns the error set on this command.
func (c *CountDocuments) Err() error { return c.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
wm, err := c.Encode(desc)
if err != nil {
return 0, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return 0, err
}
// Connection errors are transient
c.Session.ClearPinnedServer()
return 0, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return 0, err
}
// Connection errors are transient
c.Session.ClearPinnedServer()
return 0, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
return c.Decode(ctx, desc, wm).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CreateIndexes represents the createIndexes command.
//
// The createIndexes command creates indexes for a namespace.
type CreateIndexes struct {
NS Namespace
Indexes bsonx.Arr
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.CreateIndexes
err error
}
// Encode will encode this command into a wire message for the given server description.
func (ci *CreateIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := ci.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (ci *CreateIndexes) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{
{"createIndexes", bsonx.String(ci.NS.Collection)},
{"indexes", bsonx.Array(ci.Indexes)},
}
cmd = append(cmd, ci.Opts...)
write := &Write{
Clock: ci.Clock,
DB: ci.NS.DB,
Command: cmd,
Session: ci.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = ci.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (ci *CreateIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *CreateIndexes {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ci.err = err
return ci
}
return ci.decode(desc, rdr)
}
func (ci *CreateIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *CreateIndexes {
ci.err = bson.Unmarshal(rdr, &ci.result)
return ci
}
// Result returns the result of a decoded wire message and server description.
func (ci *CreateIndexes) Result() (result.CreateIndexes, error) {
if ci.err != nil {
return result.CreateIndexes{}, ci.err
}
return ci.result, nil
}
// Err returns the error set on this command.
func (ci *CreateIndexes) Err() error { return ci.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (ci *CreateIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.CreateIndexes, error) {
cmd, err := ci.encode(desc)
if err != nil {
return result.CreateIndexes{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.CreateIndexes{}, err
}
return ci.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,154 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Delete represents the delete command.
//
// The delete command executes a delete with a given set of delete documents
// and options.
type Delete struct {
ContinueOnError bool
NS Namespace
Deletes []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
batches []*WriteBatch
result result.Delete
err error
}
// Encode will encode this command into a wire message for the given server description.
func (d *Delete) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := d.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(d.batches, desc)
}
func (d *Delete) encode(desc description.SelectedServer) error {
batches, err := splitBatches(d.Deletes, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := d.encodeBatch(docs, desc)
if err != nil {
return err
}
d.batches = append(d.batches, cmd)
}
return nil
}
func (d *Delete) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
copyDocs := make([]bsonx.Doc, 0, len(docs))
for _, doc := range docs {
copyDocs = append(copyDocs, doc.Copy())
}
var options []bsonx.Elem
for _, opt := range d.Opts {
if opt.Key == "collation" {
for idx := range copyDocs {
copyDocs[idx] = append(copyDocs[idx], opt)
}
} else {
options = append(options, opt)
}
}
command, err := encodeBatch(copyDocs, options, DeleteCommand, d.NS.Collection)
if err != nil {
return nil, err
}
return &WriteBatch{
&Write{
Clock: d.Clock,
DB: d.NS.DB,
Command: command,
WriteConcern: d.WriteConcern,
Session: d.Session,
},
len(docs),
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (d *Delete) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Delete {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
d.err = err
return d
}
return d.decode(desc, rdr)
}
func (d *Delete) decode(desc description.SelectedServer, rdr bson.Raw) *Delete {
d.err = bson.Unmarshal(rdr, &d.result)
return d
}
// Result returns the result of a decoded wire message and server description.
func (d *Delete) Result() (result.Delete, error) {
if d.err != nil {
return result.Delete{}, d.err
}
return d.result, nil
}
// Err returns the error set on this command.
func (d *Delete) Err() error { return d.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (d *Delete) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Delete, error) {
if d.batches == nil {
if err := d.encode(desc); err != nil {
return result.Delete{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
d.batches,
d.ContinueOnError,
d.Session,
DeleteCommand,
)
if batches != nil {
d.batches = batches
}
if err != nil {
return result.Delete{}, err
}
return r.(result.Delete), nil
}

View File

@@ -0,0 +1,115 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Distinct represents the disctinct command.
//
// The distinct command returns the distinct values for a specified field
// across a single collection.
type Distinct struct {
NS Namespace
Field string
Query bsonx.Doc
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result result.Distinct
err error
}
// Encode will encode this command into a wire message for the given server description.
func (d *Distinct) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := d.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
// Encode will encode this command into a wire message for the given server description.
func (d *Distinct) encode(desc description.SelectedServer) (*Read, error) {
if err := d.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"distinct", bsonx.String(d.NS.Collection)}, {"key", bsonx.String(d.Field)}}
if d.Query != nil {
command = append(command, bsonx.Elem{"query", bsonx.Document(d.Query)})
}
command = append(command, d.Opts...)
return &Read{
Clock: d.Clock,
DB: d.NS.DB,
ReadPref: d.ReadPref,
Command: command,
ReadConcern: d.ReadConcern,
Session: d.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (d *Distinct) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Distinct {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
d.err = err
return d
}
return d.decode(desc, rdr)
}
func (d *Distinct) decode(desc description.SelectedServer, rdr bson.Raw) *Distinct {
d.err = bson.Unmarshal(rdr, &d.result)
return d
}
// Result returns the result of a decoded wire message and server description.
func (d *Distinct) Result() (result.Distinct, error) {
if d.err != nil {
return result.Distinct{}, d.err
}
return d.result, nil
}
// Err returns the error set on this command.
func (d *Distinct) Err() error { return d.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (d *Distinct) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Distinct, error) {
cmd, err := d.encode(desc)
if err != nil {
return result.Distinct{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.Distinct{}, err
}
return d.decode(desc, rdr).Result()
}

16
vendor/go.mongodb.org/mongo-driver/x/network/command/doc.go generated vendored Executable file
View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package command contains abstractions for operations that can be performed against a MongoDB
// deployment. The types in this package are meant to provide a general set of commands that a
// user can run against a MongoDB database without knowing the version of the database.
//
// Each type consists of two levels of interaction. The lowest level are the Encode and Decode
// methods. These are meant to be symmetric eventually, but currently only support the driver
// side of commands. The higher level is the RoundTrip method. This only makes sense from the
// driver side of commands and this method handles the encoding of the request and decoding of
// the response using the given wiremessage.ReadWriter.
package command

View File

@@ -0,0 +1,101 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropCollection represents the drop command.
//
// The dropCollections command drops collection for a database.
type DropCollection struct {
DB string
Collection string
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (dc *DropCollection) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := dc.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (dc *DropCollection) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{{"drop", bsonx.String(dc.Collection)}}
write := &Write{
Clock: dc.Clock,
DB: dc.DB,
Command: cmd,
Session: dc.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = dc.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (dc *DropCollection) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropCollection {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
dc.err = err
return dc
}
return dc.decode(desc, rdr)
}
func (dc *DropCollection) decode(desc description.SelectedServer, rdr bson.Raw) *DropCollection {
dc.result = rdr
return dc
}
// Result returns the result of a decoded wire message and server description.
func (dc *DropCollection) Result() (bson.Raw, error) {
if dc.err != nil {
return nil, dc.err
}
return dc.result, nil
}
// Err returns the error set on this command.
func (dc *DropCollection) Err() error { return dc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (dc *DropCollection) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := dc.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return dc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,100 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropDatabase represents the DropDatabase command.
//
// The DropDatabases command drops database.
type DropDatabase struct {
DB string
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (dd *DropDatabase) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := dd.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (dd *DropDatabase) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{{"dropDatabase", bsonx.Int32(1)}}
write := &Write{
Clock: dd.Clock,
DB: dd.DB,
Command: cmd,
Session: dd.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = dd.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (dd *DropDatabase) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropDatabase {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
dd.err = err
return dd
}
return dd.decode(desc, rdr)
}
func (dd *DropDatabase) decode(desc description.SelectedServer, rdr bson.Raw) *DropDatabase {
dd.result = rdr
return dd
}
// Result returns the result of a decoded wire message and server description.
func (dd *DropDatabase) Result() (bson.Raw, error) {
if dd.err != nil {
return nil, dd.err
}
return dd.result, nil
}
// Err returns the error set on this command.
func (dd *DropDatabase) Err() error { return dd.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (dd *DropDatabase) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := dd.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return dd.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropIndexes represents the dropIndexes command.
//
// The dropIndexes command drops indexes for a namespace.
type DropIndexes struct {
NS Namespace
Index string
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (di *DropIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := di.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (di *DropIndexes) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{
{"dropIndexes", bsonx.String(di.NS.Collection)},
{"index", bsonx.String(di.Index)},
}
cmd = append(cmd, di.Opts...)
write := &Write{
Clock: di.Clock,
DB: di.NS.DB,
Command: cmd,
Session: di.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = di.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (di *DropIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropIndexes {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
di.err = err
return di
}
return di.decode(desc, rdr)
}
func (di *DropIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *DropIndexes {
di.result = rdr
return di
}
// Result returns the result of a decoded wire message and server description.
func (di *DropIndexes) Result() (bson.Raw, error) {
if di.err != nil {
return nil, di.err
}
return di.result, nil
}
// Err returns the error set on this command.
func (di *DropIndexes) Err() error { return di.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (di *DropIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := di.encode(desc)
if err != nil {
return nil, err
}
di.result, err = cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return di.Result()
}

View File

@@ -0,0 +1,138 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// must be sent to admin db
// { endSessions: [ {id: uuid}, ... ], $clusterTime: ... }
// only send $clusterTime when gossiping the cluster time
// send 10k sessions at a time
// EndSessions represents an endSessions command.
type EndSessions struct {
Clock *session.ClusterClock
SessionIDs []bsonx.Doc
results []result.EndSessions
errors []error
}
// BatchSize is the max number of sessions to be included in 1 endSessions command.
const BatchSize = 10000
func (es *EndSessions) split() [][]bsonx.Doc {
batches := [][]bsonx.Doc{}
docIndex := 0
totalNumDocs := len(es.SessionIDs)
createBatches:
for {
batch := []bsonx.Doc{}
for i := 0; i < BatchSize; i++ {
if docIndex == totalNumDocs {
break createBatches
}
batch = append(batch, es.SessionIDs[docIndex])
docIndex++
}
batches = append(batches, batch)
}
return batches
}
func (es *EndSessions) encodeBatch(batch []bsonx.Doc, desc description.SelectedServer) *Write {
vals := make(bsonx.Arr, 0, len(batch))
for _, doc := range batch {
vals = append(vals, bsonx.Document(doc))
}
cmd := bsonx.Doc{{"endSessions", bsonx.Array(vals)}}
return &Write{
Clock: es.Clock,
DB: "admin",
Command: cmd,
}
}
// Encode will encode this command into a series of wire messages for the given server description.
func (es *EndSessions) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
cmds := es.encode(desc)
wms := make([]wiremessage.WireMessage, len(cmds))
for _, cmd := range cmds {
wm, err := cmd.Encode(desc)
if err != nil {
return nil, err
}
wms = append(wms, wm)
}
return wms, nil
}
func (es *EndSessions) encode(desc description.SelectedServer) []*Write {
out := []*Write{}
batches := es.split()
for _, batch := range batches {
out = append(out, es.encodeBatch(batch, desc))
}
return out
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (es *EndSessions) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *EndSessions {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
es.errors = append(es.errors, err)
return es
}
return es.decode(desc, rdr)
}
func (es *EndSessions) decode(desc description.SelectedServer, rdr bson.Raw) *EndSessions {
var res result.EndSessions
es.errors = append(es.errors, bson.Unmarshal(rdr, &res))
es.results = append(es.results, res)
return es
}
// Result returns the results of the decoded wire messages.
func (es *EndSessions) Result() ([]result.EndSessions, []error) {
return es.results, es.errors
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (es *EndSessions) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) ([]result.EndSessions, []error) {
cmds := es.encode(desc)
for _, cmd := range cmds {
rdr, _ := cmd.RoundTrip(ctx, desc, rw) // ignore any errors returned by the command
es.decode(desc, rdr)
}
return es.Result()
}

View File

@@ -0,0 +1,141 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/network/result"
)
var (
// ErrUnknownCommandFailure occurs when a command fails for an unknown reason.
ErrUnknownCommandFailure = errors.New("unknown command failure")
// ErrNoCommandResponse occurs when the server sent no response document to a command.
ErrNoCommandResponse = errors.New("no command response document")
// ErrMultiDocCommandResponse occurs when the server sent multiple documents in response to a command.
ErrMultiDocCommandResponse = errors.New("command returned multiple documents")
// ErrNoDocCommandResponse occurs when the server indicated a response existed, but none was found.
ErrNoDocCommandResponse = errors.New("command returned no documents")
// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a
// server is passed to an insert command.
ErrDocumentTooLarge = errors.New("an inserted document is too large")
// ErrNonPrimaryRP occurs when a nonprimary read preference is used with a transaction.
ErrNonPrimaryRP = errors.New("read preference in a transaction must be primary")
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
// TransientTransactionError is an error label for transient errors with transactions.
TransientTransactionError = "TransientTransactionError"
// NetworkError is an error label for network errors.
NetworkError = "NetworkError"
// ReplyDocumentMismatch is an error label for OP_QUERY field mismatch errors.
ReplyDocumentMismatch = "malformed OP_REPLY: NumberReturned does not match number of documents returned"
)
var retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001}
// QueryFailureError is an error representing a command failure as a document.
type QueryFailureError struct {
Message string
Response bson.Raw
}
// Error implements the error interface.
func (e QueryFailureError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Response)
}
// ResponseError is an error parsing the response to a command.
type ResponseError struct {
Message string
Wrapped error
}
// NewCommandResponseError creates a CommandResponseError.
func NewCommandResponseError(msg string, err error) ResponseError {
return ResponseError{Message: msg, Wrapped: err}
}
// Error implements the error interface.
func (e ResponseError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
}
return fmt.Sprintf("%s", e.Message)
}
// Error is a command execution error from the database.
type Error struct {
Code int32
Message string
Labels []string
Name string
}
// Error implements the error interface.
func (e Error) Error() string {
if e.Name != "" {
return fmt.Sprintf("(%v) %v", e.Name, e.Message)
}
return e.Message
}
// HasErrorLabel returns true if the error contains the specified label.
func (e Error) HasErrorLabel(label string) bool {
if e.Labels != nil {
for _, l := range e.Labels {
if l == label {
return true
}
}
}
return false
}
// Retryable returns true if the error is retryable
func (e Error) Retryable() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
if strings.Contains(e.Message, "not master") || strings.Contains(e.Message, "node is recovering") {
return true
}
return false
}
// IsWriteConcernErrorRetryable returns true if the write concern error is retryable.
func IsWriteConcernErrorRetryable(wce *result.WriteConcernError) bool {
for _, code := range retryableCodes {
if int32(wce.Code) == code {
return true
}
}
if strings.Contains(wce.ErrMsg, "not master") || strings.Contains(wce.ErrMsg, "node is recovering") {
return true
}
return false
}
// IsNotFound indicates if the error is from a namespace not being found.
func IsNotFound(err error) bool {
e, ok := err.(Error)
// need message check because legacy servers don't include the error code
return ok && (e.Code == 26 || e.Message == "ns not found")
}

113
vendor/go.mongodb.org/mongo-driver/x/network/command/find.go generated vendored Executable file
View File

@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Find represents the find command.
//
// The find command finds documents within a collection that match a filter.
type Find struct {
NS Namespace
Filter bsonx.Doc
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *Find) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *Find) encode(desc description.SelectedServer) (*Read, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"find", bsonx.String(f.NS.Collection)}}
if f.Filter != nil {
command = append(command, bsonx.Elem{"filter", bsonx.Document(f.Filter)})
}
command = append(command, f.Opts...)
return &Read{
Clock: f.Clock,
DB: f.NS.DB,
ReadPref: f.ReadPref,
Command: command,
ReadConcern: f.ReadConcern,
Session: f.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *Find) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Find {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *Find) decode(desc description.SelectedServer, rdr bson.Raw) *Find {
f.result = rdr
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *Find) Result() (bson.Raw, error) {
if f.err != nil {
return nil, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *Find) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *Find) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,54 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/result"
)
// unmarshalFindAndModifyResult turns the provided bson.Reader into a findAndModify result.
func unmarshalFindAndModifyResult(rdr bson.Raw) (result.FindAndModify, error) {
var res result.FindAndModify
val, err := rdr.LookupErr("value")
switch {
case err == bsoncore.ErrElementNotFound:
return result.FindAndModify{}, errors.New("invalid response from server, no value field")
case err != nil:
return result.FindAndModify{}, err
}
switch val.Type {
case bson.TypeNull:
case bson.TypeEmbeddedDocument:
res.Value = val.Document()
default:
return result.FindAndModify{}, errors.New("invalid response from server, 'value' field is not a document")
}
if val, err := rdr.LookupErr("lastErrorObject", "updatedExisting"); err == nil {
b, ok := val.BooleanOK()
if ok {
res.LastErrorObject.UpdatedExisting = b
}
}
if val, err := rdr.LookupErr("lastErrorObject", "upserted"); err == nil {
oid, ok := val.ObjectIDOK()
if ok {
res.LastErrorObject.Upserted = oid
}
}
if val, err := rdr.LookupErr("writeConcernError"); err == nil {
_ = val.Unmarshal(&res.WriteConcernError)
}
return res, nil
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndDelete represents the findOneAndDelete operation.
//
// The findOneAndDelete command deletes a single document that matches a query and returns it.
type FindOneAndDelete struct {
NS Namespace
Query bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndDelete) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndDelete) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"remove", bsonx.Boolean(true)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndDelete) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndDelete {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndDelete) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndDelete {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndDelete) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndDelete) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndDelete) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndReplace represents the findOneAndReplace operation.
//
// The findOneAndReplace command modifies and returns a single document.
type FindOneAndReplace struct {
NS Namespace
Query bsonx.Doc
Replacement bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndReplace) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndReplace) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"update", bsonx.Document(f.Replacement)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndReplace) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndReplace {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndReplace) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndReplace {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndReplace) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndReplace) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndReplace) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndUpdate represents the findOneAndUpdate operation.
//
// The findOneAndUpdate command modifies and returns a single document.
type FindOneAndUpdate struct {
NS Namespace
Query bsonx.Doc
Update bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndUpdate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndUpdate) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"update", bsonx.Document(f.Update)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndUpdate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndUpdate {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndUpdate) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndUpdate {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndUpdate) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndUpdate) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndUpdate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,108 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// GetMore represents the getMore command.
//
// The getMore command retrieves additional documents from a cursor.
type GetMore struct {
ID int64
NS Namespace
Opts []bsonx.Elem
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (gm *GetMore) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := gm.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (gm *GetMore) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{
{"getMore", bsonx.Int64(gm.ID)},
{"collection", bsonx.String(gm.NS.Collection)},
}
for _, opt := range gm.Opts {
switch opt.Key {
case "maxAwaitTimeMS":
cmd = append(cmd, bsonx.Elem{"maxTimeMs", opt.Value})
default:
cmd = append(cmd, opt)
}
}
return &Read{
Clock: gm.Clock,
DB: gm.NS.DB,
Command: cmd,
Session: gm.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (gm *GetMore) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *GetMore {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
gm.err = err
return gm
}
return gm.decode(desc, rdr)
}
func (gm *GetMore) decode(desc description.SelectedServer, rdr bson.Raw) *GetMore {
gm.result = rdr
return gm
}
// Result returns the result of a decoded wire message and server description.
func (gm *GetMore) Result() (bson.Raw, error) {
if gm.err != nil {
return nil, gm.err
}
return gm.result, nil
}
// Err returns the error set on this command.
func (gm *GetMore) Err() error { return gm.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (gm *GetMore) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := gm.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return gm.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// GetLastError represents the getLastError command.
//
// The getLastError command is used for getting the last
// error from the last command on a connection.
//
// Since GetLastError only makes sense in the context of
// a single connection, there is no Dispatch method.
type GetLastError struct {
Clock *session.ClusterClock
Session *session.Client
err error
res result.GetLastError
}
// Encode will encode this command into a wire message for the given server description.
func (gle *GetLastError) Encode() (wiremessage.WireMessage, error) {
encoded, err := gle.encode()
if err != nil {
return nil, err
}
return encoded.Encode(description.SelectedServer{})
}
func (gle *GetLastError) encode() (*Read, error) {
// This can probably just be a global variable that we reuse.
cmd := bsonx.Doc{{"getLastError", bsonx.Int32(1)}}
return &Read{
Clock: gle.Clock,
DB: "admin",
ReadPref: readpref.Secondary(),
Session: gle.Session,
Command: cmd,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (gle *GetLastError) Decode(wm wiremessage.WireMessage) *GetLastError {
reply, ok := wm.(wiremessage.Reply)
if !ok {
gle.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return gle
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
gle.err = err
return gle
}
return gle.decode(rdr)
}
func (gle *GetLastError) decode(rdr bson.Raw) *GetLastError {
err := bson.Unmarshal(rdr, &gle.res)
if err != nil {
gle.err = err
return gle
}
return gle
}
// Result returns the result of a decoded wire message and server description.
func (gle *GetLastError) Result() (result.GetLastError, error) {
if gle.err != nil {
return result.GetLastError{}, gle.err
}
return gle.res, nil
}
// Err returns the error set on this command.
func (gle *GetLastError) Err() error { return gle.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (gle *GetLastError) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.GetLastError, error) {
cmd, err := gle.encode()
if err != nil {
return result.GetLastError{}, err
}
rdr, err := cmd.RoundTrip(ctx, description.SelectedServer{}, rw)
if err != nil {
return result.GetLastError{}, err
}
return gle.decode(rdr).Result()
}

View File

@@ -0,0 +1,117 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"runtime"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Handshake represents a generic MongoDB Handshake. It calls isMaster and
// buildInfo.
//
// The isMaster and buildInfo commands are used to build a server description.
type Handshake struct {
Client bsonx.Doc
Compressors []string
SaslSupportedMechs string
ismstr result.IsMaster
err error
}
// Encode will encode the handshake commands into a wire message containing isMaster
func (h *Handshake) Encode() (wiremessage.WireMessage, error) {
var wm wiremessage.WireMessage
ismstr, err := (&IsMaster{
Client: h.Client,
Compressors: h.Compressors,
SaslSupportedMechs: h.SaslSupportedMechs,
}).Encode()
if err != nil {
return wm, err
}
wm = ismstr
return wm, nil
}
// Decode will decode the wire messages.
// Errors during decoding are deferred until either the Result or Err methods
// are called.
func (h *Handshake) Decode(wm wiremessage.WireMessage) *Handshake {
h.ismstr, h.err = (&IsMaster{}).Decode(wm).Result()
if h.err != nil {
return h
}
return h
}
// Result returns the result of decoded wire messages.
func (h *Handshake) Result(addr address.Address) (description.Server, error) {
if h.err != nil {
return description.Server{}, h.err
}
return description.NewServer(addr, h.ismstr), nil
}
// Err returns the error set on this Handshake.
func (h *Handshake) Err() error { return h.err }
// Handshake implements the connection.Handshaker interface. It is identical
// to the RoundTrip methods on other types in this package. It will execute
// the isMaster command.
func (h *Handshake) Handshake(ctx context.Context, addr address.Address, rw wiremessage.ReadWriter) (description.Server, error) {
wm, err := h.Encode()
if err != nil {
return description.Server{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return description.Server{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return description.Server{}, err
}
return h.Decode(wm).Result(addr)
}
// ClientDoc creates a client information document for use in an isMaster
// command.
func ClientDoc(app string) bsonx.Doc {
doc := bsonx.Doc{
{"driver",
bsonx.Document(bsonx.Doc{
{"name", bsonx.String("mongo-go-driver")},
{"version", bsonx.String(version.Driver)},
}),
},
{"os",
bsonx.Document(bsonx.Doc{
{"type", bsonx.String(runtime.GOOS)},
{"architecture", bsonx.String(runtime.GOARCH)},
}),
},
{"platform", bsonx.String(runtime.Version())},
}
if app != "" {
doc = append(doc, bsonx.Elem{"application", bsonx.Document(bsonx.Doc{{"name", bsonx.String(app)}})})
}
return doc
}

View File

@@ -0,0 +1,158 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// this is the amount of reserved buffer space in a message that the
// driver reserves for command overhead.
const reservedCommandBufferBytes = 16 * 10 * 10 * 10
// Insert represents the insert command.
//
// The insert command inserts a set of documents into the database.
//
// Since the Insert command does not return any value other than ok or
// an error, this type has no Err method.
type Insert struct {
ContinueOnError bool
Clock *session.ClusterClock
NS Namespace
Docs []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Session *session.Client
batches []*WriteBatch
result result.Insert
err error
}
// Encode will encode this command into a wire message for the given server description.
func (i *Insert) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := i.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(i.batches, desc)
}
func (i *Insert) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
command, err := encodeBatch(docs, i.Opts, InsertCommand, i.NS.Collection)
if err != nil {
return nil, err
}
for _, opt := range i.Opts {
if opt.Key == "ordered" && !opt.Value.Boolean() {
i.ContinueOnError = true
break
}
}
return &WriteBatch{
&Write{
Clock: i.Clock,
DB: i.NS.DB,
Command: command,
WriteConcern: i.WriteConcern,
Session: i.Session,
},
len(docs),
}, nil
}
func (i *Insert) encode(desc description.SelectedServer) error {
batches, err := splitBatches(i.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := i.encodeBatch(docs, desc)
if err != nil {
return err
}
i.batches = append(i.batches, cmd)
}
return nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (i *Insert) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Insert {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
i.err = err
return i
}
return i.decode(desc, rdr)
}
func (i *Insert) decode(desc description.SelectedServer, rdr bson.Raw) *Insert {
i.err = bson.Unmarshal(rdr, &i.result)
return i
}
// Result returns the result of a decoded wire message and server description.
func (i *Insert) Result() (result.Insert, error) {
if i.err != nil {
return result.Insert{}, i.err
}
return i.result, nil
}
// Err returns the error set on this command.
func (i *Insert) Err() error { return i.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
//func (i *Insert) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Insert, error) {
func (i *Insert) RoundTrip(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
) (result.Insert, error) {
if i.batches == nil {
err := i.encode(desc)
if err != nil {
return result.Insert{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
i.batches,
i.ContinueOnError,
i.Session,
InsertCommand,
)
// if there are leftover batches, save them for retry
if batches != nil {
i.batches = batches
}
if err != nil {
return result.Insert{}, err
}
res := r.(result.Insert)
return res, nil
}

View File

@@ -0,0 +1,121 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// IsMaster represents the isMaster command.
//
// The isMaster command is used for setting up a connection to MongoDB and
// for monitoring a MongoDB server.
//
// Since IsMaster can only be run on a connection, there is no Dispatch method.
type IsMaster struct {
Client bsonx.Doc
Compressors []string
SaslSupportedMechs string
err error
res result.IsMaster
}
// Encode will encode this command into a wire message for the given server description.
func (im *IsMaster) Encode() (wiremessage.WireMessage, error) {
cmd := bsonx.Doc{{"isMaster", bsonx.Int32(1)}}
if im.Client != nil {
cmd = append(cmd, bsonx.Elem{"client", bsonx.Document(im.Client)})
}
if im.SaslSupportedMechs != "" {
cmd = append(cmd, bsonx.Elem{"saslSupportedMechs", bsonx.String(im.SaslSupportedMechs)})
}
// always send compressors even if empty slice
array := bsonx.Arr{}
for _, compressor := range im.Compressors {
array = append(array, bsonx.String(compressor))
}
cmd = append(cmd, bsonx.Elem{"compression", bsonx.Array(array)})
rdr, err := cmd.MarshalBSON()
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: "admin.$cmd",
Flags: wiremessage.SlaveOK,
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (im *IsMaster) Decode(wm wiremessage.WireMessage) *IsMaster {
reply, ok := wm.(wiremessage.Reply)
if !ok {
im.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return im
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
im.err = err
return im
}
err = bson.Unmarshal(rdr, &im.res)
if err != nil {
im.err = err
return im
}
// Reconstructs the $clusterTime doc after decode
if im.res.ClusterTime != nil {
im.res.ClusterTime = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", im.res.ClusterTime))
}
return im
}
// Result returns the result of a decoded wire message and server description.
func (im *IsMaster) Result() (result.IsMaster, error) {
if im.err != nil {
return result.IsMaster{}, im.err
}
return im.res, nil
}
// Err returns the error set on this command.
func (im *IsMaster) Err() error { return im.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (im *IsMaster) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.IsMaster, error) {
wm, err := im.Encode()
if err != nil {
return result.IsMaster{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return result.IsMaster{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return result.IsMaster{}, err
}
return im.Decode(wm).Result()
}

View File

@@ -0,0 +1,103 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// KillCursors represents the killCursors command.
//
// The killCursors command kills a set of cursors.
type KillCursors struct {
Clock *session.ClusterClock
NS Namespace
IDs []int64
result result.KillCursors
err error
}
// Encode will encode this command into a wire message for the given server description.
func (kc *KillCursors) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := kc.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (kc *KillCursors) encode(desc description.SelectedServer) (*Read, error) {
idVals := make([]bsonx.Val, 0, len(kc.IDs))
for _, id := range kc.IDs {
idVals = append(idVals, bsonx.Int64(id))
}
cmd := bsonx.Doc{
{"killCursors", bsonx.String(kc.NS.Collection)},
{"cursors", bsonx.Array(idVals)},
}
return &Read{
Clock: kc.Clock,
DB: kc.NS.DB,
Command: cmd,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (kc *KillCursors) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *KillCursors {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
kc.err = err
return kc
}
return kc.decode(desc, rdr)
}
func (kc *KillCursors) decode(desc description.SelectedServer, rdr bson.Raw) *KillCursors {
err := bson.Unmarshal(rdr, &kc.result)
if err != nil {
kc.err = err
return kc
}
return kc
}
// Result returns the result of a decoded wire message and server description.
func (kc *KillCursors) Result() (result.KillCursors, error) {
if kc.err != nil {
return result.KillCursors{}, kc.err
}
return kc.result, nil
}
// Err returns the error set on this command.
func (kc *KillCursors) Err() error { return kc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (kc *KillCursors) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.KillCursors, error) {
cmd, err := kc.encode(desc)
if err != nil {
return result.KillCursors{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.KillCursors{}, err
}
return kc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,102 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ListCollections represents the listCollections command.
//
// The listCollections command lists the collections in a database.
type ListCollections struct {
Clock *session.ClusterClock
DB string
Filter bsonx.Doc
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (lc *ListCollections) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := lc.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (lc *ListCollections) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listCollections", bsonx.Int32(1)}}
if lc.Filter != nil {
cmd = append(cmd, bsonx.Elem{"filter", bsonx.Document(lc.Filter)})
}
cmd = append(cmd, lc.Opts...)
return &Read{
Clock: lc.Clock,
DB: lc.DB,
Command: cmd,
ReadPref: lc.ReadPref,
Session: lc.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decolcng
// are deferred until either the Result or Err methods are called.
func (lc *ListCollections) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListCollections {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
lc.err = err
return lc
}
return lc.decode(desc, rdr)
}
func (lc *ListCollections) decode(desc description.SelectedServer, rdr bson.Raw) *ListCollections {
lc.result = rdr
return lc
}
// Result returns the result of a decoded wire message and server description.
func (lc *ListCollections) Result() (bson.Raw, error) {
if lc.err != nil {
return nil, lc.err
}
return lc.result, nil
}
// Err returns the error set on this command.
func (lc *ListCollections) Err() error { return lc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (lc *ListCollections) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := lc.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return lc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ListDatabases represents the listDatabases command.
//
// The listDatabases command lists the databases in a MongoDB deployment.
type ListDatabases struct {
Clock *session.ClusterClock
Filter bsonx.Doc
Opts []bsonx.Elem
Session *session.Client
result result.ListDatabases
err error
}
// Encode will encode this command into a wire message for the given server description.
func (ld *ListDatabases) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := ld.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (ld *ListDatabases) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listDatabases", bsonx.Int32(1)}}
if ld.Filter != nil {
cmd = append(cmd, bsonx.Elem{"filter", bsonx.Document(ld.Filter)})
}
cmd = append(cmd, ld.Opts...)
return &Read{
Clock: ld.Clock,
DB: "admin",
Command: cmd,
Session: ld.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (ld *ListDatabases) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListDatabases {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
ld.err = err
return ld
}
return ld.decode(desc, rdr)
}
func (ld *ListDatabases) decode(desc description.SelectedServer, rdr bson.Raw) *ListDatabases {
ld.err = bson.Unmarshal(rdr, &ld.result)
return ld
}
// Result returns the result of a decoded wire message and server description.
func (ld *ListDatabases) Result() (result.ListDatabases, error) {
if ld.err != nil {
return result.ListDatabases{}, ld.err
}
return ld.result, nil
}
// Err returns the error set on this command.
func (ld *ListDatabases) Err() error { return ld.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (ld *ListDatabases) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.ListDatabases, error) {
cmd, err := ld.encode(desc)
if err != nil {
return result.ListDatabases{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.ListDatabases{}, err
}
return ld.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ErrEmptyCursor is a signaling error when a cursor for list indexes is empty.
var ErrEmptyCursor = errors.New("empty cursor")
// ListIndexes represents the listIndexes command.
//
// The listIndexes command lists the indexes for a namespace.
type ListIndexes struct {
Clock *session.ClusterClock
NS Namespace
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (li *ListIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := li.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (li *ListIndexes) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listIndexes", bsonx.String(li.NS.Collection)}}
cmd = append(cmd, li.Opts...)
return &Read{
Clock: li.Clock,
DB: li.NS.DB,
Command: cmd,
Session: li.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoling
// are deferred until either the Result or Err methods are called.
func (li *ListIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListIndexes {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
if IsNotFound(err) {
li.err = ErrEmptyCursor
return li
}
li.err = err
return li
}
return li.decode(desc, rdr)
}
func (li *ListIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *ListIndexes {
li.result = rdr
return li
}
// Result returns the result of a decoded wire message and server description.
func (li *ListIndexes) Result() (bson.Raw, error) {
if li.err != nil {
return nil, li.err
}
return li.result, nil
}
// Err returns the error set on this command.
func (li *ListIndexes) Err() error { return li.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (li *ListIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := li.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
if IsNotFound(err) {
return nil, ErrEmptyCursor
}
return nil, err
}
return li.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,79 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"strings"
)
// Namespace encapsulates a database and collection name, which together
// uniquely identifies a collection within a MongoDB cluster.
type Namespace struct {
DB string
Collection string
}
// NewNamespace returns a new Namespace for the
// given database and collection.
func NewNamespace(db, collection string) Namespace { return Namespace{DB: db, Collection: collection} }
// ParseNamespace parses a namespace string into a Namespace.
//
// The namespace string must contain at least one ".", the first of which is the separator
// between the database and collection names. If not, the default (invalid) Namespace is returned.
func ParseNamespace(name string) Namespace {
index := strings.Index(name, ".")
if index == -1 {
return Namespace{}
}
return Namespace{
DB: name[:index],
Collection: name[index+1:],
}
}
// FullName returns the full namespace string, which is the result of joining the database
// name and the collection name with a "." character.
func (ns *Namespace) FullName() string {
return strings.Join([]string{ns.DB, ns.Collection}, ".")
}
// Validate validates the namespace.
func (ns *Namespace) Validate() error {
if err := ns.validateDB(); err != nil {
return err
}
return ns.validateCollection()
}
// validateDB ensures the database name is not an empty string, contain a ".",
// or contain a " ".
func (ns *Namespace) validateDB() error {
if ns.DB == "" {
return errors.New("database name cannot be empty")
}
if strings.Contains(ns.DB, " ") {
return errors.New("database name cannot contain ' '")
}
if strings.Contains(ns.DB, ".") {
return errors.New("database name cannot contain '.'")
}
return nil
}
// validateCollection ensures the collection name is not an empty string.
func (ns *Namespace) validateCollection() error {
if ns.Collection == "" {
return errors.New("collection name cannot be empty")
}
return nil
}

View File

@@ -0,0 +1,53 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
func decodeCommandOpMsg(msg wiremessage.Msg) (bson.Raw, error) {
var mainDoc bsonx.Doc
for _, section := range msg.Sections {
switch converted := section.(type) {
case wiremessage.SectionBody:
err := mainDoc.UnmarshalBSON(converted.Document)
if err != nil {
return nil, err
}
case wiremessage.SectionDocumentSequence:
arr := bsonx.Arr{}
for _, doc := range converted.Documents {
newDoc := bsonx.Doc{}
err := newDoc.UnmarshalBSON(doc)
if err != nil {
return nil, err
}
arr = append(arr, bsonx.Document(newDoc))
}
mainDoc = append(mainDoc, bsonx.Elem{converted.Identifier, bsonx.Array(arr)})
}
}
byteArray, err := mainDoc.MarshalBSON()
if err != nil {
return nil, err
}
rdr := bson.Raw(byteArray)
err = rdr.Validate()
if err != nil {
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
}
return rdr, extractError(rdr)
}

View File

@@ -0,0 +1,43 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// decodeCommandOpReply handles decoding the OP_REPLY response to an OP_QUERY
// command.
func decodeCommandOpReply(reply wiremessage.Reply) (bson.Raw, error) {
if reply.NumberReturned == 0 {
return nil, ErrNoDocCommandResponse
}
if reply.NumberReturned > 1 {
return nil, ErrMultiDocCommandResponse
}
if len(reply.Documents) != 1 {
return nil, NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of documents returned", nil)
}
rdr := reply.Documents[0]
err := rdr.Validate()
if err != nil {
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
}
if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
return nil, QueryFailureError{
Message: "command failure",
Response: reply.Documents[0],
}
}
err = extractError(rdr)
if err != nil {
return nil, err
}
return rdr, nil
}

294
vendor/go.mongodb.org/mongo-driver/x/network/command/read.go generated vendored Executable file
View File

@@ -0,0 +1,294 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Read represents a generic database read command.
type Read struct {
DB string
Command bsonx.Doc
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
func (r *Read) createReadPref(serverKind description.ServerKind, topologyKind description.TopologyKind, isOpQuery bool) bsonx.Doc {
doc := bsonx.Doc{}
rp := r.ReadPref
if rp == nil {
if topologyKind == description.Single && serverKind != description.Mongos {
return append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
}
return nil
}
switch rp.Mode() {
case readpref.PrimaryMode:
if serverKind == description.Mongos {
return nil
}
if topologyKind == description.Single {
return append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
}
doc = append(doc, bsonx.Elem{"mode", bsonx.String("primary")})
case readpref.PrimaryPreferredMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
case readpref.SecondaryPreferredMode:
_, ok := r.ReadPref.MaxStaleness()
if serverKind == description.Mongos && isOpQuery && !ok && len(r.ReadPref.TagSets()) == 0 {
return nil
}
doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondaryPreferred")})
case readpref.SecondaryMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondary")})
case readpref.NearestMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("nearest")})
}
sets := make([]bsonx.Val, 0, len(r.ReadPref.TagSets()))
for _, ts := range r.ReadPref.TagSets() {
if len(ts) == 0 {
continue
}
set := bsonx.Doc{}
for _, t := range ts {
set = append(set, bsonx.Elem{t.Name, bsonx.String(t.Value)})
}
sets = append(sets, bsonx.Document(set))
}
if len(sets) > 0 {
doc = append(doc, bsonx.Elem{"tags", bsonx.Array(sets)})
}
if d, ok := r.ReadPref.MaxStaleness(); ok {
doc = append(doc, bsonx.Elem{"maxStalenessSeconds", bsonx.Int32(int32(d.Seconds()))})
}
return doc
}
// addReadPref will add a read preference to the query document.
//
// NOTE: This method must always return either a valid bson.Reader or an error.
func (r *Read) addReadPref(rp *readpref.ReadPref, serverKind description.ServerKind, topologyKind description.TopologyKind, query bson.Raw) (bson.Raw, error) {
doc := r.createReadPref(serverKind, topologyKind, true)
if doc == nil {
return query, nil
}
qdoc := bsonx.Doc{}
err := bson.Unmarshal(query, &qdoc)
if err != nil {
return query, err
}
return bsonx.Doc{
{"$query", bsonx.Document(qdoc)},
{"$readPreference", bsonx.Document(doc)},
}.MarshalBSON()
}
// Encode r as OP_MSG
func (r *Read) encodeOpMsg(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
msg := wiremessage.Msg{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
Sections: make([]wiremessage.Section, 0),
}
readPrefDoc := r.createReadPref(desc.Server.Kind, desc.Kind, false)
fullDocRdr, err := opmsgAddGlobals(cmd, r.DB, readPrefDoc)
if err != nil {
return nil, err
}
// type 0 doc
msg.Sections = append(msg.Sections, wiremessage.SectionBody{
PayloadType: wiremessage.SingleDocument,
Document: fullDocRdr,
})
// no flags to add
return msg, nil
}
func (r *Read) slaveOK(desc description.SelectedServer) wiremessage.QueryFlag {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
return wiremessage.SlaveOK
}
if r.ReadPref == nil {
// assume primary
return 0
}
if r.ReadPref.Mode() != readpref.PrimaryMode {
return wiremessage.SlaveOK
}
return 0
}
// Encode c as OP_QUERY
func (r *Read) encodeOpQuery(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
rdr, err := marshalCommand(cmd)
if err != nil {
return nil, err
}
if desc.Server.Kind == description.Mongos {
rdr, err = r.addReadPref(r.ReadPref, desc.Server.Kind, desc.Kind, rdr)
if err != nil {
return nil, err
}
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: r.DB + ".$cmd",
Flags: r.slaveOK(desc),
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
func (r *Read) decodeOpMsg(wm wiremessage.WireMessage) {
msg, ok := wm.(wiremessage.Msg)
if !ok {
r.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
r.result, r.err = decodeCommandOpMsg(msg)
}
func (r *Read) decodeOpReply(wm wiremessage.WireMessage) {
reply, ok := wm.(wiremessage.Reply)
if !ok {
r.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
r.result, r.err = decodeCommandOpReply(reply)
}
// Encode will encode this command into a wire message for the given server description.
func (r *Read) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := r.Command.Copy()
cmd, err := addReadConcern(cmd, desc, r.ReadConcern, r.Session)
if err != nil {
return nil, err
}
cmd, err = addSessionFields(cmd, desc, r.Session)
if err != nil {
return nil, err
}
cmd = addClusterTime(cmd, desc, r.Session, r.Clock)
if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
return r.encodeOpQuery(desc, cmd)
}
return r.encodeOpMsg(desc, cmd)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (r *Read) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Read {
switch wm.(type) {
case wiremessage.Reply:
r.decodeOpReply(wm)
default:
r.decodeOpMsg(wm)
}
if r.err != nil {
// decode functions set error if an invalid response document was returned or if the OK flag in the response was 0
// if the OK flag was 0, a type Error is returned. otherwise, a special type is returned
cerr, ok := r.err.(Error)
if !ok {
return r // for missing/invalid response docs, don't update cluster times
}
if cerr.HasErrorLabel(TransientTransactionError) {
r.Session.ClearPinnedServer()
}
}
_ = updateClusterTimes(r.Session, r.Clock, r.result)
_ = updateOperationTime(r.Session, r.result)
r.Session.UpdateRecoveryToken(r.result)
return r
}
// Result returns the result of a decoded wire message and server description.
func (r *Read) Result() (bson.Raw, error) {
if r.err != nil {
return nil, r.err
}
return r.result, nil
}
// Err returns the error set on this command.
func (r *Read) Err() error {
return r.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (r *Read) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
wm, err := r.Encode(desc)
if err != nil {
return nil, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
r.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
r.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if r.Session != nil {
err = r.Session.UpdateUseTime()
if err != nil {
return nil, err
}
}
return r.Decode(desc, wm).Result()
}

View File

@@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// StartSession represents a startSession command
type StartSession struct {
Clock *session.ClusterClock
result result.StartSession
err error
}
// Encode will encode this command into a wiremessage for the given server description.
func (ss *StartSession) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := ss.encode(desc)
return cmd.Encode(desc)
}
func (ss *StartSession) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"startSession", bsonx.Int32(1)}}
return &Write{
Clock: ss.Clock,
DB: "admin",
Command: cmd,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (ss *StartSession) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *StartSession {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ss.err = err
return ss
}
return ss.decode(desc, rdr)
}
func (ss *StartSession) decode(desc description.SelectedServer, rdr bson.Raw) *StartSession {
ss.err = bson.Unmarshal(rdr, &ss.result)
return ss
}
// Result returns the result of a decoded wire message and server description.
func (ss *StartSession) Result() (result.StartSession, error) {
if ss.err != nil {
return result.StartSession{}, ss.err
}
return ss.result, nil
}
// Err returns the error set on this command
func (ss *StartSession) Err() error {
return ss.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (ss *StartSession) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.StartSession, error) {
cmd := ss.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.StartSession{}, err
}
return ss.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,161 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Update represents the update command.
//
// The update command updates a set of documents with the database.
type Update struct {
ContinueOnError bool
Clock *session.ClusterClock
NS Namespace
Docs []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Session *session.Client
batches []*WriteBatch
result result.Update
err error
}
// Encode will encode this command into a wire message for the given server description.
func (u *Update) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := u.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(u.batches, desc)
}
func (u *Update) encode(desc description.SelectedServer) error {
batches, err := splitBatches(u.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := u.encodeBatch(docs, desc)
if err != nil {
return err
}
u.batches = append(u.batches, cmd)
}
return nil
}
func (u *Update) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
copyDocs := make([]bsonx.Doc, 0, len(docs)) // copy of all the documents
for _, doc := range docs {
newDoc := doc.Copy()
copyDocs = append(copyDocs, newDoc)
}
var options []bsonx.Elem
for _, opt := range u.Opts {
switch opt.Key {
case "upsert", "collation", "arrayFilters":
// options that are encoded on each individual document
for idx := range copyDocs {
copyDocs[idx] = append(copyDocs[idx], opt)
}
default:
options = append(options, opt)
}
}
command, err := encodeBatch(copyDocs, options, UpdateCommand, u.NS.Collection)
if err != nil {
return nil, err
}
return &WriteBatch{
&Write{
Clock: u.Clock,
DB: u.NS.DB,
Command: command,
WriteConcern: u.WriteConcern,
Session: u.Session,
},
len(docs),
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Update {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
u.err = err
return u
}
return u.decode(desc, rdr)
}
func (u *Update) decode(desc description.SelectedServer, rdr bson.Raw) *Update {
u.err = bson.Unmarshal(rdr, &u.result)
return u
}
// Result returns the result of a decoded wire message and server description.
func (u *Update) Result() (result.Update, error) {
if u.err != nil {
return result.Update{}, u.err
}
return u.result, nil
}
// Err returns the error set on this command.
func (u *Update) Err() error { return u.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (u *Update) RoundTrip(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
) (result.Update, error) {
if u.batches == nil {
err := u.encode(desc)
if err != nil {
return result.Update{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
u.batches,
u.ContinueOnError,
u.Session,
UpdateCommand,
)
// if there are leftover batches, save them for retry
if batches != nil {
u.batches = batches
}
if err != nil {
return result.Update{}, err
}
return r.(result.Update), nil
}

252
vendor/go.mongodb.org/mongo-driver/x/network/command/write.go generated vendored Executable file
View File

@@ -0,0 +1,252 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Write represents a generic write database command.
// This can be used to send arbitrary write commands to the database.
type Write struct {
DB string
Command bsonx.Doc
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode c as OP_MSG
func (w *Write) encodeOpMsg(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
var arr bsonx.Arr
var identifier string
cmd, arr, identifier = opmsgRemoveArray(cmd)
msg := wiremessage.Msg{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
Sections: make([]wiremessage.Section, 0),
}
fullDocRdr, err := opmsgAddGlobals(cmd, w.DB, nil)
if err != nil {
return nil, err
}
// type 0 doc
msg.Sections = append(msg.Sections, wiremessage.SectionBody{
PayloadType: wiremessage.SingleDocument,
Document: fullDocRdr,
})
// type 1 doc
if identifier != "" {
docSequence, err := opmsgCreateDocSequence(arr, identifier)
if err != nil {
return nil, err
}
msg.Sections = append(msg.Sections, docSequence)
}
// flags
if !writeconcern.AckWrite(w.WriteConcern) {
msg.FlagBits |= wiremessage.MoreToCome
}
return msg, nil
}
// Encode w as OP_QUERY
func (w *Write) encodeOpQuery(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
rdr, err := marshalCommand(cmd)
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: w.DB + ".$cmd",
Flags: w.slaveOK(desc),
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
func (w *Write) slaveOK(desc description.SelectedServer) wiremessage.QueryFlag {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
return wiremessage.SlaveOK
}
return 0
}
func (w *Write) decodeOpReply(wm wiremessage.WireMessage) {
reply, ok := wm.(wiremessage.Reply)
if !ok {
w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
w.result, w.err = decodeCommandOpReply(reply)
}
func (w *Write) decodeOpMsg(wm wiremessage.WireMessage) {
msg, ok := wm.(wiremessage.Msg)
if !ok {
w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
w.result, w.err = decodeCommandOpMsg(msg)
}
// Encode will encode this command into a wire message for the given server description.
func (w *Write) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := w.Command.Copy()
var err error
if w.Session != nil && w.Session.TransactionStarting() {
// Starting transactions have a read concern, even in writes.
cmd, err = addReadConcern(cmd, desc, nil, w.Session)
if err != nil {
return nil, err
}
}
cmd, err = addWriteConcern(cmd, w.WriteConcern)
if err != nil {
return nil, err
}
if !writeconcern.AckWrite(w.WriteConcern) {
// unack write with explicit session --> raise an error
// unack write with implicit session --> do not send session ID (implicit session shouldn't have been created
// in the first place)
if w.Session != nil && w.Session.SessionType == session.Explicit {
return nil, errors.New("explicit sessions cannot be used with unacknowledged writes")
}
} else {
// only encode session ID for acknowledged writes
cmd, err = addSessionFields(cmd, desc, w.Session)
if err != nil {
return nil, err
}
}
if w.Session != nil && w.Session.RetryWrite && cmd.IndexOf("txnNumber") == -1 {
cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(w.Session.TxnNumber)})
}
cmd = addClusterTime(cmd, desc, w.Session, w.Clock)
if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
return w.encodeOpQuery(desc, cmd)
}
return w.encodeOpMsg(desc, cmd)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (w *Write) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Write {
switch wm.(type) {
case wiremessage.Reply:
w.decodeOpReply(wm)
default:
w.decodeOpMsg(wm)
}
if w.err != nil {
cerr, ok := w.err.(Error)
if !ok {
return w
}
if cerr.HasErrorLabel(TransientTransactionError) {
w.Session.ClearPinnedServer()
}
}
_ = updateClusterTimes(w.Session, w.Clock, w.result)
w.Session.UpdateRecoveryToken(w.result)
if writeconcern.AckWrite(w.WriteConcern) {
// don't update session operation time for unacknowledged write
_ = updateOperationTime(w.Session, w.result)
}
return w
}
// Result returns the result of a decoded wire message and server description.
func (w *Write) Result() (bson.Raw, error) {
if w.err != nil {
return nil, w.err
}
return w.result, nil
}
// Err returns the error set on this command.
func (w *Write) Err() error {
return w.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriteCloser.
func (w *Write) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
wm, err := w.Encode(desc)
if err != nil {
return nil, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
w.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if msg, ok := wm.(wiremessage.Msg); ok {
// don't expect response if using OP_MSG for an unacknowledged write
if msg.FlagBits&wiremessage.MoreToCome > 0 {
return nil, ErrUnacknowledgedWrite
}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
w.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if w.Session != nil {
err = w.Session.UpdateUseTime()
if err != nil {
return nil, err
}
}
return w.Decode(desc, wm).Result()
}