newline battles continue

This commit is contained in:
bel
2020-01-19 20:41:30 +00:00
parent 98adb53caf
commit 573696774e
1456 changed files with 501133 additions and 6 deletions

View File

@@ -0,0 +1,94 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// AbortTransaction represents the abortTransaction() command
type AbortTransaction struct {
Session *session.Client
err error
result result.TransactionResult
}
// Encode will encode this command into a wiremessage for the given server description.
func (at *AbortTransaction) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := at.encode(desc)
return cmd.Encode(desc)
}
func (at *AbortTransaction) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"abortTransaction", bsonx.Int32(1)}}
if at.Session.RecoveryToken != nil {
tokenDoc, _ := bsonx.ReadDoc(at.Session.RecoveryToken)
cmd = append(cmd, bsonx.Elem{"recoveryToken", bsonx.Document(tokenDoc)})
}
return &Write{
DB: "admin",
Command: cmd,
Session: at.Session,
WriteConcern: at.Session.CurrentWc,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (at *AbortTransaction) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *AbortTransaction {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
at.err = err
return at
}
return at.decode(desc, rdr)
}
func (at *AbortTransaction) decode(desc description.SelectedServer, rdr bson.Raw) *AbortTransaction {
at.err = bson.Unmarshal(rdr, &at.result)
if at.err == nil && at.result.WriteConcernError != nil {
at.err = Error{
Name: at.result.WriteConcernError.Name,
Code: int32(at.result.WriteConcernError.Code),
Message: at.result.WriteConcernError.ErrMsg,
}
}
return at
}
// Result returns the result of a decoded wire message and server description.
func (at *AbortTransaction) Result() (result.TransactionResult, error) {
if at.err != nil {
return result.TransactionResult{}, at.err
}
return at.result, nil
}
// Err returns the error set on this command
func (at *AbortTransaction) Err() error {
return at.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (at *AbortTransaction) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.TransactionResult, error) {
cmd := at.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.TransactionResult{}, err
}
return at.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,165 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Aggregate represents the aggregate command.
//
// The aggregate command performs an aggregation.
type Aggregate struct {
NS Namespace
Pipeline bsonx.Arr
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
WriteConcern *writeconcern.WriteConcern
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (a *Aggregate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := a.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (a *Aggregate) encode(desc description.SelectedServer) (*Read, error) {
if err := a.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"aggregate", bsonx.String(a.NS.Collection)},
{"pipeline", bsonx.Array(a.Pipeline)},
}
cursor := bsonx.Doc{}
hasOutStage := a.HasDollarOut()
for _, opt := range a.Opts {
switch opt.Key {
case "batchSize":
if opt.Value.Int32() == 0 && hasOutStage {
continue
}
cursor = append(cursor, opt)
default:
command = append(command, opt)
}
}
command = append(command, bsonx.Elem{"cursor", bsonx.Document(cursor)})
// add write concern because it won't be added by the Read command's Encode()
if desc.WireVersion.Max >= 5 && hasOutStage && a.WriteConcern != nil {
t, data, err := a.WriteConcern.MarshalBSONValue()
if err != nil {
return nil, err
}
var xval bsonx.Val
err = xval.UnmarshalBSONValue(t, data)
if err != nil {
return nil, err
}
command = append(command, bsonx.Elem{Key: "writeConcern", Value: xval})
}
return &Read{
DB: a.NS.DB,
Command: command,
ReadPref: a.ReadPref,
ReadConcern: a.ReadConcern,
Clock: a.Clock,
Session: a.Session,
}, nil
}
// HasDollarOut returns true if the Pipeline field contains a $out stage.
func (a *Aggregate) HasDollarOut() bool {
if a.Pipeline == nil {
return false
}
if len(a.Pipeline) == 0 {
return false
}
val := a.Pipeline[len(a.Pipeline)-1]
doc, ok := val.DocumentOK()
if !ok || len(doc) != 1 {
return false
}
return doc[0].Key == "$out"
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (a *Aggregate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Aggregate {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
a.err = err
return a
}
return a.decode(desc, rdr)
}
func (a *Aggregate) decode(desc description.SelectedServer, rdr bson.Raw) *Aggregate {
a.result = rdr
if val, err := rdr.LookupErr("writeConcernError"); err == nil {
var wce result.WriteConcernError
_ = val.Unmarshal(&wce)
a.err = wce
}
return a
}
// Result returns the result of a decoded wire message and server description.
func (a *Aggregate) Result() (bson.Raw, error) {
if a.err != nil {
return nil, a.err
}
return a.result, nil
}
// Err returns the error set on this command.
func (a *Aggregate) Err() error { return a.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (a *Aggregate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := a.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return a.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,95 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// BuildInfo represents the buildInfo command.
//
// The buildInfo command is used for getting the build information for a
// MongoDB server.
type BuildInfo struct {
err error
res result.BuildInfo
}
// Encode will encode this command into a wire message for the given server description.
func (bi *BuildInfo) Encode() (wiremessage.WireMessage, error) {
// This can probably just be a global variable that we reuse.
cmd := bsonx.Doc{{"buildInfo", bsonx.Int32(1)}}
rdr, err := cmd.MarshalBSON()
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: "admin.$cmd",
Flags: wiremessage.SlaveOK,
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (bi *BuildInfo) Decode(wm wiremessage.WireMessage) *BuildInfo {
reply, ok := wm.(wiremessage.Reply)
if !ok {
bi.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return bi
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
bi.err = err
return bi
}
err = bson.Unmarshal(rdr, &bi.res)
if err != nil {
bi.err = err
return bi
}
return bi
}
// Result returns the result of a decoded wire message and server description.
func (bi *BuildInfo) Result() (result.BuildInfo, error) {
if bi.err != nil {
return result.BuildInfo{}, bi.err
}
return bi.res, nil
}
// Err returns the error set on this command.
func (bi *BuildInfo) Err() error { return bi.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (bi *BuildInfo) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.BuildInfo, error) {
wm, err := bi.Encode()
if err != nil {
return result.BuildInfo{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return result.BuildInfo{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return result.BuildInfo{}, err
}
return bi.Decode(wm).Result()
}

View File

@@ -0,0 +1,708 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command // import "go.mongodb.org/mongo-driver/x/network/command"
import (
"errors"
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// WriteBatch represents a single batch for a write operation.
type WriteBatch struct {
*Write
numDocs int
}
// DecodeError attempts to decode the wiremessage as an error
func DecodeError(wm wiremessage.WireMessage) error {
var rdr bson.Raw
switch msg := wm.(type) {
case wiremessage.Msg:
for _, section := range msg.Sections {
switch converted := section.(type) {
case wiremessage.SectionBody:
rdr = converted.Document
}
}
case wiremessage.Reply:
if msg.ResponseFlags&wiremessage.QueryFailure != wiremessage.QueryFailure {
return nil
}
rdr = msg.Documents[0]
}
err := rdr.Validate()
if err != nil {
return nil
}
extractedError := extractError(rdr)
// If parsed successfully return the error
if _, ok := extractedError.(Error); ok {
return err
}
return nil
}
// helper method to extract an error from a reader if there is one; first returned item is the
// error if it exists, the second holds parsing errors
func extractError(rdr bson.Raw) error {
var errmsg, codeName string
var code int32
var labels []string
elems, err := rdr.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bson.TypeInt32:
if elem.Value().Int32() == 1 {
return nil
}
case bson.TypeInt64:
if elem.Value().Int64() == 1 {
return nil
}
case bson.TypeDouble:
if elem.Value().Double() == 1 {
return nil
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
elems, err := arr.Elements()
if err != nil {
continue
}
for _, elem := range elems {
if str, ok := elem.Value().StringValueOK(); ok {
labels = append(labels, str)
}
}
}
}
}
if errmsg == "" {
errmsg = "command failed"
}
return Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
}
}
func responseClusterTime(response bson.Raw) bson.Raw {
clusterTime, err := response.LookupErr("$clusterTime")
if err != nil {
// $clusterTime not included by the server
return nil
}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendHeader(doc, clusterTime.Type, "$clusterTime")
doc = append(doc, clusterTime.Value...)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
return doc
}
func updateClusterTimes(sess *session.Client, clock *session.ClusterClock, response bson.Raw) error {
clusterTime := responseClusterTime(response)
if clusterTime == nil {
return nil
}
if sess != nil {
err := sess.AdvanceClusterTime(clusterTime)
if err != nil {
return err
}
}
if clock != nil {
clock.AdvanceClusterTime(clusterTime)
}
return nil
}
func updateOperationTime(sess *session.Client, response bson.Raw) error {
if sess == nil {
return nil
}
opTimeElem, err := response.LookupErr("operationTime")
if err != nil {
// operationTime not included by the server
return nil
}
t, i := opTimeElem.Timestamp()
return sess.AdvanceOperationTime(&primitive.Timestamp{
T: t,
I: i,
})
}
func marshalCommand(cmd bsonx.Doc) (bson.Raw, error) {
if cmd == nil {
return bson.Raw{5, 0, 0, 0, 0}, nil
}
return cmd.MarshalBSON()
}
// adds session related fields to a BSON doc representing a command
func addSessionFields(cmd bsonx.Doc, desc description.SelectedServer, client *session.Client) (bsonx.Doc, error) {
if client == nil || !description.SessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutes == 0 {
return cmd, nil
}
if client.Terminated {
return cmd, session.ErrSessionEnded
}
if _, err := cmd.LookupElementErr("lsid"); err != nil {
cmd = cmd.Delete("lsid")
}
cmd = append(cmd, bsonx.Elem{"lsid", bsonx.Document(client.SessionID)})
if client.TransactionRunning() ||
client.RetryingCommit {
cmd = addTransaction(cmd, client)
}
client.ApplyCommand(desc.Server) // advance the state machine based on a command executing
return cmd, nil
}
// if in a transaction, add the transaction fields
func addTransaction(cmd bsonx.Doc, client *session.Client) bsonx.Doc {
cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(client.TxnNumber)})
if client.TransactionStarting() {
// When starting transaction, always transition to the next state, even on error
cmd = append(cmd, bsonx.Elem{"startTransaction", bsonx.Boolean(true)})
}
return append(cmd, bsonx.Elem{"autocommit", bsonx.Boolean(false)})
}
func addClusterTime(cmd bsonx.Doc, desc description.SelectedServer, sess *session.Client, clock *session.ClusterClock) bsonx.Doc {
if (clock == nil && sess == nil) || !description.SessionsSupported(desc.WireVersion) {
return cmd
}
var clusterTime bson.Raw
if clock != nil {
clusterTime = clock.GetClusterTime()
}
if sess != nil {
if clusterTime == nil {
clusterTime = sess.ClusterTime
} else {
clusterTime = session.MaxClusterTime(clusterTime, sess.ClusterTime)
}
}
if clusterTime == nil {
return cmd
}
d, err := bsonx.ReadDoc(clusterTime)
if err != nil {
return cmd // broken clusterTime
}
cmd = cmd.Delete("$clusterTime")
return append(cmd, d...)
}
// add a read concern to a BSON doc representing a command
func addReadConcern(cmd bsonx.Doc, desc description.SelectedServer, rc *readconcern.ReadConcern, sess *session.Client) (bsonx.Doc, error) {
// Starting transaction's read concern overrides all others
if sess != nil && sess.TransactionStarting() && sess.CurrentRc != nil {
rc = sess.CurrentRc
}
// start transaction must append afterclustertime IF causally consistent and operation time exists
if rc == nil && sess != nil && sess.TransactionStarting() && sess.Consistent && sess.OperationTime != nil {
rc = readconcern.New()
}
if rc == nil {
return cmd, nil
}
t, data, err := rc.MarshalBSONValue()
if err != nil {
return cmd, err
}
var rcDoc bsonx.Doc
err = rcDoc.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
if description.SessionsSupported(desc.WireVersion) && sess != nil && sess.Consistent && sess.OperationTime != nil {
rcDoc = append(rcDoc, bsonx.Elem{"afterClusterTime", bsonx.Timestamp(sess.OperationTime.T, sess.OperationTime.I)})
}
cmd = cmd.Delete("readConcern")
if len(rcDoc) != 0 {
cmd = append(cmd, bsonx.Elem{"readConcern", bsonx.Document(rcDoc)})
}
return cmd, nil
}
// add a write concern to a BSON doc representing a command
func addWriteConcern(cmd bsonx.Doc, wc *writeconcern.WriteConcern) (bsonx.Doc, error) {
if wc == nil {
return cmd, nil
}
t, data, err := wc.MarshalBSONValue()
if err != nil {
if err == writeconcern.ErrEmptyWriteConcern {
return cmd, nil
}
return cmd, err
}
var xval bsonx.Val
err = xval.UnmarshalBSONValue(t, data)
if err != nil {
return cmd, err
}
// delete if doc already has write concern
cmd = cmd.Delete("writeConcern")
return append(cmd, bsonx.Elem{Key: "writeConcern", Value: xval}), nil
}
// Get the error labels from a command response
func getErrorLabels(rdr *bson.Raw) ([]string, error) {
var labels []string
labelsElem, err := rdr.LookupErr("errorLabels")
if err != bsoncore.ErrElementNotFound {
return nil, err
}
if labelsElem.Type == bsontype.Array {
labelsIt, err := labelsElem.Array().Elements()
if err != nil {
return nil, err
}
for _, elem := range labelsIt {
labels = append(labels, elem.Value().StringValue())
}
}
return labels, nil
}
// Remove command arguments for insert, update, and delete commands from the BSON document so they can be encoded
// as a Section 1 payload in OP_MSG
func opmsgRemoveArray(cmd bsonx.Doc) (bsonx.Doc, bsonx.Arr, string) {
var array bsonx.Arr
var id string
keys := []string{"documents", "updates", "deletes"}
for _, key := range keys {
val, err := cmd.LookupErr(key)
if err != nil {
continue
}
array = val.Array()
cmd = cmd.Delete(key)
id = key
break
}
return cmd, array, id
}
// Add the $db and $readPreference keys to the command
// If the command has no read preference, pass nil for rpDoc
func opmsgAddGlobals(cmd bsonx.Doc, dbName string, rpDoc bsonx.Doc) (bson.Raw, error) {
cmd = append(cmd, bsonx.Elem{"$db", bsonx.String(dbName)})
if rpDoc != nil {
cmd = append(cmd, bsonx.Elem{"$readPreference", bsonx.Document(rpDoc)})
}
return cmd.MarshalBSON() // bsonx.Doc.MarshalBSON never returns an error.
}
func opmsgCreateDocSequence(arr bsonx.Arr, identifier string) (wiremessage.SectionDocumentSequence, error) {
docSequence := wiremessage.SectionDocumentSequence{
PayloadType: wiremessage.DocumentSequence,
Identifier: identifier,
Documents: make([]bson.Raw, 0, len(arr)),
}
for _, val := range arr {
d, _ := val.Document().MarshalBSON()
docSequence.Documents = append(docSequence.Documents, d)
}
docSequence.Size = int32(docSequence.PayloadLen())
return docSequence, nil
}
func splitBatches(docs []bsonx.Doc, maxCount, targetBatchSize int) ([][]bsonx.Doc, error) {
batches := [][]bsonx.Doc{}
if targetBatchSize > reservedCommandBufferBytes {
targetBatchSize -= reservedCommandBufferBytes
}
if maxCount <= 0 {
maxCount = 1
}
startAt := 0
splitInserts:
for {
size := 0
batch := []bsonx.Doc{}
assembleBatch:
for idx := startAt; idx < len(docs); idx++ {
raw, _ := docs[idx].MarshalBSON()
if len(raw) > targetBatchSize {
return nil, ErrDocumentTooLarge
}
if size+len(raw) > targetBatchSize {
break assembleBatch
}
size += len(raw)
batch = append(batch, docs[idx])
startAt++
if len(batch) == maxCount {
break assembleBatch
}
}
batches = append(batches, batch)
if startAt == len(docs) {
break splitInserts
}
}
return batches, nil
}
func encodeBatch(
docs []bsonx.Doc,
opts []bsonx.Elem,
cmdKind WriteCommandKind,
collName string,
) (bsonx.Doc, error) {
var cmdName string
var docString string
switch cmdKind {
case InsertCommand:
cmdName = "insert"
docString = "documents"
case UpdateCommand:
cmdName = "update"
docString = "updates"
case DeleteCommand:
cmdName = "delete"
docString = "deletes"
}
cmd := bsonx.Doc{{cmdName, bsonx.String(collName)}}
vals := make(bsonx.Arr, 0, len(docs))
for _, doc := range docs {
vals = append(vals, bsonx.Document(doc))
}
cmd = append(cmd, bsonx.Elem{docString, bsonx.Array(vals)})
cmd = append(cmd, opts...)
return cmd, nil
}
// converts batches of Write Commands to wire messages
func batchesToWireMessage(batches []*WriteBatch, desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
wms := make([]wiremessage.WireMessage, len(batches))
for _, cmd := range batches {
wm, err := cmd.Encode(desc)
if err != nil {
return nil, err
}
wms = append(wms, wm)
}
return wms, nil
}
// Roundtrips the write batches, returning the result structs (as interface),
// the write batches that weren't round tripped and any errors
func roundTripBatches(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
batches []*WriteBatch,
continueOnError bool,
sess *session.Client,
cmdKind WriteCommandKind,
) (interface{}, []*WriteBatch, error) {
var res interface{}
var upsertIndex int64 // the operation index for the upserted IDs map
// hold onto txnNumber, reset it when loop exits to ensure reuse of same
// transaction number if retry is needed
var txnNumber int64
if sess != nil && sess.RetryWrite {
txnNumber = sess.TxnNumber
}
for j, cmd := range batches {
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber + int64(j)
}
return res, batches, err
}
// TODO can probably DRY up this code
switch cmdKind {
case InsertCommand:
if res == nil {
res = result.Insert{}
}
conv, _ := res.(result.Insert)
insertCmd := &Insert{}
r, err := insertCmd.decode(desc, rdr).Result()
if err != nil {
return res, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
case UpdateCommand:
if res == nil {
res = result.Update{}
}
conv, _ := res.(result.Update)
updateCmd := &Update{}
r, err := updateCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.MatchedCount += r.MatchedCount
conv.ModifiedCount += r.ModifiedCount
for _, upsert := range r.Upserted {
conv.Upserted = append(conv.Upserted, result.Upsert{
Index: upsert.Index + upsertIndex,
ID: upsert.ID,
})
}
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
upsertIndex += int64(cmd.numDocs)
case DeleteCommand:
if res == nil {
res = result.Delete{}
}
conv, _ := res.(result.Delete)
deleteCmd := &Delete{}
r, err := deleteCmd.decode(desc, rdr).Result()
if err != nil {
return conv, batches, err
}
conv.WriteErrors = append(conv.WriteErrors, r.WriteErrors...)
if r.WriteConcernError != nil {
conv.WriteConcernError = r.WriteConcernError
if sess != nil && sess.RetryWrite {
sess.TxnNumber = txnNumber
return conv, batches, nil // report writeconcernerror for retry
}
}
conv.N += r.N
if !continueOnError && len(conv.WriteErrors) > 0 {
return conv, batches, nil
}
res = conv
}
// Increment txnNumber for each batch
if sess != nil && sess.RetryWrite {
sess.IncrementTxnNumber()
batches = batches[1:] // if batch encoded successfully, remove it from the slice
}
}
if sess != nil && sess.RetryWrite {
// if retryable write succeeded, transaction number will be incremented one extra time,
// so we decrement it here
sess.TxnNumber--
}
return res, batches, nil
}
// get the firstBatch, cursor ID, and namespace from a bson.Raw
func getCursorValues(result bson.Raw) ([]bson.RawValue, Namespace, int64, error) {
cur, err := result.LookupErr("cursor")
if err != nil {
return nil, Namespace{}, 0, err
}
if cur.Type != bson.TypeEmbeddedDocument {
return nil, Namespace{}, 0, fmt.Errorf("cursor should be an embedded document but it is a BSON %s", cur.Type)
}
elems, err := cur.Document().Elements()
if err != nil {
return nil, Namespace{}, 0, err
}
var ok bool
var arr bson.Raw
var namespace Namespace
var cursorID int64
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok = elem.Value().ArrayOK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("firstBatch should be an array but it is a BSON %s", elem.Value().Type)
}
if err != nil {
return nil, Namespace{}, 0, err
}
case "ns":
if elem.Value().Type != bson.TypeString {
return nil, Namespace{}, 0, fmt.Errorf("namespace should be a string but it is a BSON %s", elem.Value().Type)
}
namespace = ParseNamespace(elem.Value().StringValue())
err = namespace.Validate()
if err != nil {
return nil, Namespace{}, 0, err
}
case "id":
cursorID, ok = elem.Value().Int64OK()
if !ok {
return nil, Namespace{}, 0, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
}
}
vals, err := arr.Values()
if err != nil {
return nil, Namespace{}, 0, err
}
return vals, namespace, cursorID, nil
}
func getBatchSize(opts []bsonx.Elem) int32 {
for _, opt := range opts {
if opt.Key == "batchSize" {
return opt.Value.Int32()
}
}
return 0
}
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
// WriteCommandKind is the type of command represented by a Write
type WriteCommandKind int8
// These constants represent the valid types of write commands.
const (
InsertCommand WriteCommandKind = iota
UpdateCommand
DeleteCommand
)

View File

@@ -0,0 +1,94 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CommitTransaction represents the commitTransaction() command
type CommitTransaction struct {
Session *session.Client
err error
result result.TransactionResult
}
// Encode will encode this command into a wiremessage for the given server description.
func (ct *CommitTransaction) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := ct.encode(desc)
return cmd.Encode(desc)
}
func (ct *CommitTransaction) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"commitTransaction", bsonx.Int32(1)}}
if ct.Session.RecoveryToken != nil {
tokenDoc, _ := bsonx.ReadDoc(ct.Session.RecoveryToken)
cmd = append(cmd, bsonx.Elem{"recoveryToken", bsonx.Document(tokenDoc)})
}
return &Write{
DB: "admin",
Command: cmd,
Session: ct.Session,
WriteConcern: ct.Session.CurrentWc,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (ct *CommitTransaction) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *CommitTransaction {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ct.err = err
return ct
}
return ct.decode(desc, rdr)
}
func (ct *CommitTransaction) decode(desc description.SelectedServer, rdr bson.Raw) *CommitTransaction {
ct.err = bson.Unmarshal(rdr, &ct.result)
if ct.err == nil && ct.result.WriteConcernError != nil {
ct.err = Error{
Name: ct.result.WriteConcernError.Name,
Code: int32(ct.result.WriteConcernError.Code),
Message: ct.result.WriteConcernError.ErrMsg,
}
}
return ct
}
// Result returns the result of a decoded wire message and server description.
func (ct *CommitTransaction) Result() (result.TransactionResult, error) {
if ct.err != nil {
return result.TransactionResult{}, ct.err
}
return ct.result, nil
}
// Err returns the error set on this command
func (ct *CommitTransaction) Err() error {
return ct.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (ct *CommitTransaction) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.TransactionResult, error) {
cmd := ct.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.TransactionResult{}, err
}
return ct.decode(desc, rdr).Result()
}

128
vendor/go.mongodb.org/mongo-driver/x/network/command/count.go generated vendored Executable file
View File

@@ -0,0 +1,128 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Count represents the count command.
//
// The count command counts how many documents in a collection match the given query.
type Count struct {
NS Namespace
Query bsonx.Doc
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result int64
err error
}
// Encode will encode this command into a wire message for the given server description.
func (c *Count) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := c.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (c *Count) encode(desc description.SelectedServer) (*Read, error) {
if err := c.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"count", bsonx.String(c.NS.Collection)}, {"query", bsonx.Document(c.Query)}}
command = append(command, c.Opts...)
return &Read{
Clock: c.Clock,
DB: c.NS.DB,
ReadPref: c.ReadPref,
Command: command,
ReadConcern: c.ReadConcern,
Session: c.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (c *Count) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Count {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
c.err = err
return c
}
return c.decode(desc, rdr)
}
func (c *Count) decode(desc description.SelectedServer, rdr bson.Raw) *Count {
val, err := rdr.LookupErr("n")
switch {
case err == bsoncore.ErrElementNotFound:
c.err = errors.New("invalid response from server, no 'n' field")
return c
case err != nil:
c.err = err
return c
}
switch val.Type {
case bson.TypeInt32:
c.result = int64(val.Int32())
case bson.TypeInt64:
c.result = val.Int64()
case bson.TypeDouble:
c.result = int64(val.Double())
default:
c.err = errors.New("invalid response from server, value field is not a number")
}
return c
}
// Result returns the result of a decoded wire message and server description.
func (c *Count) Result() (int64, error) {
if c.err != nil {
return 0, c.err
}
return c.result, nil
}
// Err returns the error set on this command.
func (c *Count) Err() error { return c.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (c *Count) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
cmd, err := c.encode(desc)
if err != nil {
return 0, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return 0, err
}
return c.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,134 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CountDocuments represents the CountDocuments command.
//
// The countDocuments command counts how many documents in a collection match the given query.
type CountDocuments struct {
NS Namespace
Pipeline bsonx.Arr
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result int64
err error
}
// Encode will encode this command into a wire message for the given server description.
func (c *CountDocuments) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
if err := c.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"aggregate", bsonx.String(c.NS.Collection)}, {"pipeline", bsonx.Array(c.Pipeline)}}
command = append(command, bsonx.Elem{"cursor", bsonx.Document(bsonx.Doc{})})
command = append(command, c.Opts...)
return (&Read{DB: c.NS.DB, ReadPref: c.ReadPref, Command: command, Session: c.Session}).Encode(desc)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (c *CountDocuments) Decode(ctx context.Context, desc description.SelectedServer, wm wiremessage.WireMessage) *CountDocuments {
rdr, err := (&Read{Session: c.Session}).Decode(desc, wm).Result()
if err != nil {
c.err = err
return c
}
cursor, err := rdr.LookupErr("cursor")
if err != nil || cursor.Type != bsontype.EmbeddedDocument {
c.err = errors.New("Invalid response from server, no 'cursor' field")
return c
}
batch, err := cursor.Document().LookupErr("firstBatch")
if err != nil || batch.Type != bsontype.Array {
c.err = errors.New("Invalid response from server, no 'firstBatch' field")
return c
}
elem, err := batch.Array().IndexErr(0)
if err != nil || elem.Value().Type != bsontype.EmbeddedDocument {
c.result = 0
return c
}
val, err := elem.Value().Document().LookupErr("n")
if err != nil {
c.err = errors.New("Invalid response from server, no 'n' field")
return c
}
switch val.Type {
case bsontype.Int32:
c.result = int64(val.Int32())
case bsontype.Int64:
c.result = val.Int64()
default:
c.err = errors.New("Invalid response from server, value field is not a number")
}
return c
}
// Result returns the result of a decoded wire message and server description.
func (c *CountDocuments) Result() (int64, error) {
if c.err != nil {
return 0, c.err
}
return c.result, nil
}
// Err returns the error set on this command.
func (c *CountDocuments) Err() error { return c.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (c *CountDocuments) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (int64, error) {
wm, err := c.Encode(desc)
if err != nil {
return 0, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return 0, err
}
// Connection errors are transient
c.Session.ClearPinnedServer()
return 0, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return 0, err
}
// Connection errors are transient
c.Session.ClearPinnedServer()
return 0, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
return c.Decode(ctx, desc, wm).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// CreateIndexes represents the createIndexes command.
//
// The createIndexes command creates indexes for a namespace.
type CreateIndexes struct {
NS Namespace
Indexes bsonx.Arr
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.CreateIndexes
err error
}
// Encode will encode this command into a wire message for the given server description.
func (ci *CreateIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := ci.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (ci *CreateIndexes) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{
{"createIndexes", bsonx.String(ci.NS.Collection)},
{"indexes", bsonx.Array(ci.Indexes)},
}
cmd = append(cmd, ci.Opts...)
write := &Write{
Clock: ci.Clock,
DB: ci.NS.DB,
Command: cmd,
Session: ci.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = ci.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (ci *CreateIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *CreateIndexes {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ci.err = err
return ci
}
return ci.decode(desc, rdr)
}
func (ci *CreateIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *CreateIndexes {
ci.err = bson.Unmarshal(rdr, &ci.result)
return ci
}
// Result returns the result of a decoded wire message and server description.
func (ci *CreateIndexes) Result() (result.CreateIndexes, error) {
if ci.err != nil {
return result.CreateIndexes{}, ci.err
}
return ci.result, nil
}
// Err returns the error set on this command.
func (ci *CreateIndexes) Err() error { return ci.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (ci *CreateIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.CreateIndexes, error) {
cmd, err := ci.encode(desc)
if err != nil {
return result.CreateIndexes{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.CreateIndexes{}, err
}
return ci.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,154 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Delete represents the delete command.
//
// The delete command executes a delete with a given set of delete documents
// and options.
type Delete struct {
ContinueOnError bool
NS Namespace
Deletes []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
batches []*WriteBatch
result result.Delete
err error
}
// Encode will encode this command into a wire message for the given server description.
func (d *Delete) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := d.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(d.batches, desc)
}
func (d *Delete) encode(desc description.SelectedServer) error {
batches, err := splitBatches(d.Deletes, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := d.encodeBatch(docs, desc)
if err != nil {
return err
}
d.batches = append(d.batches, cmd)
}
return nil
}
func (d *Delete) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
copyDocs := make([]bsonx.Doc, 0, len(docs))
for _, doc := range docs {
copyDocs = append(copyDocs, doc.Copy())
}
var options []bsonx.Elem
for _, opt := range d.Opts {
if opt.Key == "collation" {
for idx := range copyDocs {
copyDocs[idx] = append(copyDocs[idx], opt)
}
} else {
options = append(options, opt)
}
}
command, err := encodeBatch(copyDocs, options, DeleteCommand, d.NS.Collection)
if err != nil {
return nil, err
}
return &WriteBatch{
&Write{
Clock: d.Clock,
DB: d.NS.DB,
Command: command,
WriteConcern: d.WriteConcern,
Session: d.Session,
},
len(docs),
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (d *Delete) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Delete {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
d.err = err
return d
}
return d.decode(desc, rdr)
}
func (d *Delete) decode(desc description.SelectedServer, rdr bson.Raw) *Delete {
d.err = bson.Unmarshal(rdr, &d.result)
return d
}
// Result returns the result of a decoded wire message and server description.
func (d *Delete) Result() (result.Delete, error) {
if d.err != nil {
return result.Delete{}, d.err
}
return d.result, nil
}
// Err returns the error set on this command.
func (d *Delete) Err() error { return d.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (d *Delete) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Delete, error) {
if d.batches == nil {
if err := d.encode(desc); err != nil {
return result.Delete{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
d.batches,
d.ContinueOnError,
d.Session,
DeleteCommand,
)
if batches != nil {
d.batches = batches
}
if err != nil {
return result.Delete{}, err
}
return r.(result.Delete), nil
}

View File

@@ -0,0 +1,115 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Distinct represents the disctinct command.
//
// The distinct command returns the distinct values for a specified field
// across a single collection.
type Distinct struct {
NS Namespace
Field string
Query bsonx.Doc
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result result.Distinct
err error
}
// Encode will encode this command into a wire message for the given server description.
func (d *Distinct) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := d.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
// Encode will encode this command into a wire message for the given server description.
func (d *Distinct) encode(desc description.SelectedServer) (*Read, error) {
if err := d.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"distinct", bsonx.String(d.NS.Collection)}, {"key", bsonx.String(d.Field)}}
if d.Query != nil {
command = append(command, bsonx.Elem{"query", bsonx.Document(d.Query)})
}
command = append(command, d.Opts...)
return &Read{
Clock: d.Clock,
DB: d.NS.DB,
ReadPref: d.ReadPref,
Command: command,
ReadConcern: d.ReadConcern,
Session: d.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (d *Distinct) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Distinct {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
d.err = err
return d
}
return d.decode(desc, rdr)
}
func (d *Distinct) decode(desc description.SelectedServer, rdr bson.Raw) *Distinct {
d.err = bson.Unmarshal(rdr, &d.result)
return d
}
// Result returns the result of a decoded wire message and server description.
func (d *Distinct) Result() (result.Distinct, error) {
if d.err != nil {
return result.Distinct{}, d.err
}
return d.result, nil
}
// Err returns the error set on this command.
func (d *Distinct) Err() error { return d.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (d *Distinct) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Distinct, error) {
cmd, err := d.encode(desc)
if err != nil {
return result.Distinct{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.Distinct{}, err
}
return d.decode(desc, rdr).Result()
}

16
vendor/go.mongodb.org/mongo-driver/x/network/command/doc.go generated vendored Executable file
View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package command contains abstractions for operations that can be performed against a MongoDB
// deployment. The types in this package are meant to provide a general set of commands that a
// user can run against a MongoDB database without knowing the version of the database.
//
// Each type consists of two levels of interaction. The lowest level are the Encode and Decode
// methods. These are meant to be symmetric eventually, but currently only support the driver
// side of commands. The higher level is the RoundTrip method. This only makes sense from the
// driver side of commands and this method handles the encoding of the request and decoding of
// the response using the given wiremessage.ReadWriter.
package command

View File

@@ -0,0 +1,101 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropCollection represents the drop command.
//
// The dropCollections command drops collection for a database.
type DropCollection struct {
DB string
Collection string
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (dc *DropCollection) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := dc.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (dc *DropCollection) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{{"drop", bsonx.String(dc.Collection)}}
write := &Write{
Clock: dc.Clock,
DB: dc.DB,
Command: cmd,
Session: dc.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = dc.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (dc *DropCollection) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropCollection {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
dc.err = err
return dc
}
return dc.decode(desc, rdr)
}
func (dc *DropCollection) decode(desc description.SelectedServer, rdr bson.Raw) *DropCollection {
dc.result = rdr
return dc
}
// Result returns the result of a decoded wire message and server description.
func (dc *DropCollection) Result() (bson.Raw, error) {
if dc.err != nil {
return nil, dc.err
}
return dc.result, nil
}
// Err returns the error set on this command.
func (dc *DropCollection) Err() error { return dc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (dc *DropCollection) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := dc.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return dc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,100 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropDatabase represents the DropDatabase command.
//
// The DropDatabases command drops database.
type DropDatabase struct {
DB string
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (dd *DropDatabase) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := dd.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (dd *DropDatabase) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{{"dropDatabase", bsonx.Int32(1)}}
write := &Write{
Clock: dd.Clock,
DB: dd.DB,
Command: cmd,
Session: dd.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = dd.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (dd *DropDatabase) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropDatabase {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
dd.err = err
return dd
}
return dd.decode(desc, rdr)
}
func (dd *DropDatabase) decode(desc description.SelectedServer, rdr bson.Raw) *DropDatabase {
dd.result = rdr
return dd
}
// Result returns the result of a decoded wire message and server description.
func (dd *DropDatabase) Result() (bson.Raw, error) {
if dd.err != nil {
return nil, dd.err
}
return dd.result, nil
}
// Err returns the error set on this command.
func (dd *DropDatabase) Err() error { return dd.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (dd *DropDatabase) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := dd.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return dd.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// DropIndexes represents the dropIndexes command.
//
// The dropIndexes command drops indexes for a namespace.
type DropIndexes struct {
NS Namespace
Index string
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (di *DropIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := di.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (di *DropIndexes) encode(desc description.SelectedServer) (*Write, error) {
cmd := bsonx.Doc{
{"dropIndexes", bsonx.String(di.NS.Collection)},
{"index", bsonx.String(di.Index)},
}
cmd = append(cmd, di.Opts...)
write := &Write{
Clock: di.Clock,
DB: di.NS.DB,
Command: cmd,
Session: di.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 5 {
write.WriteConcern = di.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (di *DropIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *DropIndexes {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
di.err = err
return di
}
return di.decode(desc, rdr)
}
func (di *DropIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *DropIndexes {
di.result = rdr
return di
}
// Result returns the result of a decoded wire message and server description.
func (di *DropIndexes) Result() (bson.Raw, error) {
if di.err != nil {
return nil, di.err
}
return di.result, nil
}
// Err returns the error set on this command.
func (di *DropIndexes) Err() error { return di.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (di *DropIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := di.encode(desc)
if err != nil {
return nil, err
}
di.result, err = cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return di.Result()
}

View File

@@ -0,0 +1,138 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// must be sent to admin db
// { endSessions: [ {id: uuid}, ... ], $clusterTime: ... }
// only send $clusterTime when gossiping the cluster time
// send 10k sessions at a time
// EndSessions represents an endSessions command.
type EndSessions struct {
Clock *session.ClusterClock
SessionIDs []bsonx.Doc
results []result.EndSessions
errors []error
}
// BatchSize is the max number of sessions to be included in 1 endSessions command.
const BatchSize = 10000
func (es *EndSessions) split() [][]bsonx.Doc {
batches := [][]bsonx.Doc{}
docIndex := 0
totalNumDocs := len(es.SessionIDs)
createBatches:
for {
batch := []bsonx.Doc{}
for i := 0; i < BatchSize; i++ {
if docIndex == totalNumDocs {
break createBatches
}
batch = append(batch, es.SessionIDs[docIndex])
docIndex++
}
batches = append(batches, batch)
}
return batches
}
func (es *EndSessions) encodeBatch(batch []bsonx.Doc, desc description.SelectedServer) *Write {
vals := make(bsonx.Arr, 0, len(batch))
for _, doc := range batch {
vals = append(vals, bsonx.Document(doc))
}
cmd := bsonx.Doc{{"endSessions", bsonx.Array(vals)}}
return &Write{
Clock: es.Clock,
DB: "admin",
Command: cmd,
}
}
// Encode will encode this command into a series of wire messages for the given server description.
func (es *EndSessions) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
cmds := es.encode(desc)
wms := make([]wiremessage.WireMessage, len(cmds))
for _, cmd := range cmds {
wm, err := cmd.Encode(desc)
if err != nil {
return nil, err
}
wms = append(wms, wm)
}
return wms, nil
}
func (es *EndSessions) encode(desc description.SelectedServer) []*Write {
out := []*Write{}
batches := es.split()
for _, batch := range batches {
out = append(out, es.encodeBatch(batch, desc))
}
return out
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (es *EndSessions) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *EndSessions {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
es.errors = append(es.errors, err)
return es
}
return es.decode(desc, rdr)
}
func (es *EndSessions) decode(desc description.SelectedServer, rdr bson.Raw) *EndSessions {
var res result.EndSessions
es.errors = append(es.errors, bson.Unmarshal(rdr, &res))
es.results = append(es.results, res)
return es
}
// Result returns the results of the decoded wire messages.
func (es *EndSessions) Result() ([]result.EndSessions, []error) {
return es.results, es.errors
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (es *EndSessions) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) ([]result.EndSessions, []error) {
cmds := es.encode(desc)
for _, cmd := range cmds {
rdr, _ := cmd.RoundTrip(ctx, desc, rw) // ignore any errors returned by the command
es.decode(desc, rdr)
}
return es.Result()
}

View File

@@ -0,0 +1,141 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/network/result"
)
var (
// ErrUnknownCommandFailure occurs when a command fails for an unknown reason.
ErrUnknownCommandFailure = errors.New("unknown command failure")
// ErrNoCommandResponse occurs when the server sent no response document to a command.
ErrNoCommandResponse = errors.New("no command response document")
// ErrMultiDocCommandResponse occurs when the server sent multiple documents in response to a command.
ErrMultiDocCommandResponse = errors.New("command returned multiple documents")
// ErrNoDocCommandResponse occurs when the server indicated a response existed, but none was found.
ErrNoDocCommandResponse = errors.New("command returned no documents")
// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a
// server is passed to an insert command.
ErrDocumentTooLarge = errors.New("an inserted document is too large")
// ErrNonPrimaryRP occurs when a nonprimary read preference is used with a transaction.
ErrNonPrimaryRP = errors.New("read preference in a transaction must be primary")
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
// TransientTransactionError is an error label for transient errors with transactions.
TransientTransactionError = "TransientTransactionError"
// NetworkError is an error label for network errors.
NetworkError = "NetworkError"
// ReplyDocumentMismatch is an error label for OP_QUERY field mismatch errors.
ReplyDocumentMismatch = "malformed OP_REPLY: NumberReturned does not match number of documents returned"
)
var retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001}
// QueryFailureError is an error representing a command failure as a document.
type QueryFailureError struct {
Message string
Response bson.Raw
}
// Error implements the error interface.
func (e QueryFailureError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Response)
}
// ResponseError is an error parsing the response to a command.
type ResponseError struct {
Message string
Wrapped error
}
// NewCommandResponseError creates a CommandResponseError.
func NewCommandResponseError(msg string, err error) ResponseError {
return ResponseError{Message: msg, Wrapped: err}
}
// Error implements the error interface.
func (e ResponseError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
}
return fmt.Sprintf("%s", e.Message)
}
// Error is a command execution error from the database.
type Error struct {
Code int32
Message string
Labels []string
Name string
}
// Error implements the error interface.
func (e Error) Error() string {
if e.Name != "" {
return fmt.Sprintf("(%v) %v", e.Name, e.Message)
}
return e.Message
}
// HasErrorLabel returns true if the error contains the specified label.
func (e Error) HasErrorLabel(label string) bool {
if e.Labels != nil {
for _, l := range e.Labels {
if l == label {
return true
}
}
}
return false
}
// Retryable returns true if the error is retryable
func (e Error) Retryable() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
if strings.Contains(e.Message, "not master") || strings.Contains(e.Message, "node is recovering") {
return true
}
return false
}
// IsWriteConcernErrorRetryable returns true if the write concern error is retryable.
func IsWriteConcernErrorRetryable(wce *result.WriteConcernError) bool {
for _, code := range retryableCodes {
if int32(wce.Code) == code {
return true
}
}
if strings.Contains(wce.ErrMsg, "not master") || strings.Contains(wce.ErrMsg, "node is recovering") {
return true
}
return false
}
// IsNotFound indicates if the error is from a namespace not being found.
func IsNotFound(err error) bool {
e, ok := err.(Error)
// need message check because legacy servers don't include the error code
return ok && (e.Code == 26 || e.Message == "ns not found")
}

113
vendor/go.mongodb.org/mongo-driver/x/network/command/find.go generated vendored Executable file
View File

@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Find represents the find command.
//
// The find command finds documents within a collection that match a filter.
type Find struct {
NS Namespace
Filter bsonx.Doc
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *Find) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *Find) encode(desc description.SelectedServer) (*Read, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{{"find", bsonx.String(f.NS.Collection)}}
if f.Filter != nil {
command = append(command, bsonx.Elem{"filter", bsonx.Document(f.Filter)})
}
command = append(command, f.Opts...)
return &Read{
Clock: f.Clock,
DB: f.NS.DB,
ReadPref: f.ReadPref,
Command: command,
ReadConcern: f.ReadConcern,
Session: f.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *Find) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Find {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *Find) decode(desc description.SelectedServer, rdr bson.Raw) *Find {
f.result = rdr
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *Find) Result() (bson.Raw, error) {
if f.err != nil {
return nil, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *Find) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *Find) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,54 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/result"
)
// unmarshalFindAndModifyResult turns the provided bson.Reader into a findAndModify result.
func unmarshalFindAndModifyResult(rdr bson.Raw) (result.FindAndModify, error) {
var res result.FindAndModify
val, err := rdr.LookupErr("value")
switch {
case err == bsoncore.ErrElementNotFound:
return result.FindAndModify{}, errors.New("invalid response from server, no value field")
case err != nil:
return result.FindAndModify{}, err
}
switch val.Type {
case bson.TypeNull:
case bson.TypeEmbeddedDocument:
res.Value = val.Document()
default:
return result.FindAndModify{}, errors.New("invalid response from server, 'value' field is not a document")
}
if val, err := rdr.LookupErr("lastErrorObject", "updatedExisting"); err == nil {
b, ok := val.BooleanOK()
if ok {
res.LastErrorObject.UpdatedExisting = b
}
}
if val, err := rdr.LookupErr("lastErrorObject", "upserted"); err == nil {
oid, ok := val.ObjectIDOK()
if ok {
res.LastErrorObject.Upserted = oid
}
}
if val, err := rdr.LookupErr("writeConcernError"); err == nil {
_ = val.Unmarshal(&res.WriteConcernError)
}
return res, nil
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndDelete represents the findOneAndDelete operation.
//
// The findOneAndDelete command deletes a single document that matches a query and returns it.
type FindOneAndDelete struct {
NS Namespace
Query bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndDelete) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndDelete) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"remove", bsonx.Boolean(true)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndDelete) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndDelete {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndDelete) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndDelete {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndDelete) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndDelete) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndDelete) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndReplace represents the findOneAndReplace operation.
//
// The findOneAndReplace command modifies and returns a single document.
type FindOneAndReplace struct {
NS Namespace
Query bsonx.Doc
Replacement bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndReplace) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndReplace) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"update", bsonx.Document(f.Replacement)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndReplace) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndReplace {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndReplace) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndReplace {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndReplace) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndReplace) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndReplace) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// FindOneAndUpdate represents the findOneAndUpdate operation.
//
// The findOneAndUpdate command modifies and returns a single document.
type FindOneAndUpdate struct {
NS Namespace
Query bsonx.Doc
Update bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result result.FindAndModify
err error
}
// Encode will encode this command into a wire message for the given server description.
func (f *FindOneAndUpdate) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := f.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (f *FindOneAndUpdate) encode(desc description.SelectedServer) (*Write, error) {
if err := f.NS.Validate(); err != nil {
return nil, err
}
command := bsonx.Doc{
{"findAndModify", bsonx.String(f.NS.Collection)},
{"query", bsonx.Document(f.Query)},
{"update", bsonx.Document(f.Update)},
}
command = append(command, f.Opts...)
write := &Write{
Clock: f.Clock,
DB: f.NS.DB,
Command: command,
Session: f.Session,
}
if desc.WireVersion != nil && desc.WireVersion.Max >= 4 {
write.WriteConcern = f.WriteConcern
}
return write, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (f *FindOneAndUpdate) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *FindOneAndUpdate {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
f.err = err
return f
}
return f.decode(desc, rdr)
}
func (f *FindOneAndUpdate) decode(desc description.SelectedServer, rdr bson.Raw) *FindOneAndUpdate {
f.result, f.err = unmarshalFindAndModifyResult(rdr)
return f
}
// Result returns the result of a decoded wire message and server description.
func (f *FindOneAndUpdate) Result() (result.FindAndModify, error) {
if f.err != nil {
return result.FindAndModify{}, f.err
}
return f.result, nil
}
// Err returns the error set on this command.
func (f *FindOneAndUpdate) Err() error { return f.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (f *FindOneAndUpdate) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.FindAndModify, error) {
cmd, err := f.encode(desc)
if err != nil {
return result.FindAndModify{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.FindAndModify{}, err
}
return f.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,108 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// GetMore represents the getMore command.
//
// The getMore command retrieves additional documents from a cursor.
type GetMore struct {
ID int64
NS Namespace
Opts []bsonx.Elem
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (gm *GetMore) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd, err := gm.encode(desc)
if err != nil {
return nil, err
}
return cmd.Encode(desc)
}
func (gm *GetMore) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{
{"getMore", bsonx.Int64(gm.ID)},
{"collection", bsonx.String(gm.NS.Collection)},
}
for _, opt := range gm.Opts {
switch opt.Key {
case "maxAwaitTimeMS":
cmd = append(cmd, bsonx.Elem{"maxTimeMs", opt.Value})
default:
cmd = append(cmd, opt)
}
}
return &Read{
Clock: gm.Clock,
DB: gm.NS.DB,
Command: cmd,
Session: gm.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (gm *GetMore) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *GetMore {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
gm.err = err
return gm
}
return gm.decode(desc, rdr)
}
func (gm *GetMore) decode(desc description.SelectedServer, rdr bson.Raw) *GetMore {
gm.result = rdr
return gm
}
// Result returns the result of a decoded wire message and server description.
func (gm *GetMore) Result() (bson.Raw, error) {
if gm.err != nil {
return nil, gm.err
}
return gm.result, nil
}
// Err returns the error set on this command.
func (gm *GetMore) Err() error { return gm.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (gm *GetMore) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := gm.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return gm.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// GetLastError represents the getLastError command.
//
// The getLastError command is used for getting the last
// error from the last command on a connection.
//
// Since GetLastError only makes sense in the context of
// a single connection, there is no Dispatch method.
type GetLastError struct {
Clock *session.ClusterClock
Session *session.Client
err error
res result.GetLastError
}
// Encode will encode this command into a wire message for the given server description.
func (gle *GetLastError) Encode() (wiremessage.WireMessage, error) {
encoded, err := gle.encode()
if err != nil {
return nil, err
}
return encoded.Encode(description.SelectedServer{})
}
func (gle *GetLastError) encode() (*Read, error) {
// This can probably just be a global variable that we reuse.
cmd := bsonx.Doc{{"getLastError", bsonx.Int32(1)}}
return &Read{
Clock: gle.Clock,
DB: "admin",
ReadPref: readpref.Secondary(),
Session: gle.Session,
Command: cmd,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (gle *GetLastError) Decode(wm wiremessage.WireMessage) *GetLastError {
reply, ok := wm.(wiremessage.Reply)
if !ok {
gle.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return gle
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
gle.err = err
return gle
}
return gle.decode(rdr)
}
func (gle *GetLastError) decode(rdr bson.Raw) *GetLastError {
err := bson.Unmarshal(rdr, &gle.res)
if err != nil {
gle.err = err
return gle
}
return gle
}
// Result returns the result of a decoded wire message and server description.
func (gle *GetLastError) Result() (result.GetLastError, error) {
if gle.err != nil {
return result.GetLastError{}, gle.err
}
return gle.res, nil
}
// Err returns the error set on this command.
func (gle *GetLastError) Err() error { return gle.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (gle *GetLastError) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.GetLastError, error) {
cmd, err := gle.encode()
if err != nil {
return result.GetLastError{}, err
}
rdr, err := cmd.RoundTrip(ctx, description.SelectedServer{}, rw)
if err != nil {
return result.GetLastError{}, err
}
return gle.decode(rdr).Result()
}

View File

@@ -0,0 +1,117 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"runtime"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Handshake represents a generic MongoDB Handshake. It calls isMaster and
// buildInfo.
//
// The isMaster and buildInfo commands are used to build a server description.
type Handshake struct {
Client bsonx.Doc
Compressors []string
SaslSupportedMechs string
ismstr result.IsMaster
err error
}
// Encode will encode the handshake commands into a wire message containing isMaster
func (h *Handshake) Encode() (wiremessage.WireMessage, error) {
var wm wiremessage.WireMessage
ismstr, err := (&IsMaster{
Client: h.Client,
Compressors: h.Compressors,
SaslSupportedMechs: h.SaslSupportedMechs,
}).Encode()
if err != nil {
return wm, err
}
wm = ismstr
return wm, nil
}
// Decode will decode the wire messages.
// Errors during decoding are deferred until either the Result or Err methods
// are called.
func (h *Handshake) Decode(wm wiremessage.WireMessage) *Handshake {
h.ismstr, h.err = (&IsMaster{}).Decode(wm).Result()
if h.err != nil {
return h
}
return h
}
// Result returns the result of decoded wire messages.
func (h *Handshake) Result(addr address.Address) (description.Server, error) {
if h.err != nil {
return description.Server{}, h.err
}
return description.NewServer(addr, h.ismstr), nil
}
// Err returns the error set on this Handshake.
func (h *Handshake) Err() error { return h.err }
// Handshake implements the connection.Handshaker interface. It is identical
// to the RoundTrip methods on other types in this package. It will execute
// the isMaster command.
func (h *Handshake) Handshake(ctx context.Context, addr address.Address, rw wiremessage.ReadWriter) (description.Server, error) {
wm, err := h.Encode()
if err != nil {
return description.Server{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return description.Server{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return description.Server{}, err
}
return h.Decode(wm).Result(addr)
}
// ClientDoc creates a client information document for use in an isMaster
// command.
func ClientDoc(app string) bsonx.Doc {
doc := bsonx.Doc{
{"driver",
bsonx.Document(bsonx.Doc{
{"name", bsonx.String("mongo-go-driver")},
{"version", bsonx.String(version.Driver)},
}),
},
{"os",
bsonx.Document(bsonx.Doc{
{"type", bsonx.String(runtime.GOOS)},
{"architecture", bsonx.String(runtime.GOARCH)},
}),
},
{"platform", bsonx.String(runtime.Version())},
}
if app != "" {
doc = append(doc, bsonx.Elem{"application", bsonx.Document(bsonx.Doc{{"name", bsonx.String(app)}})})
}
return doc
}

View File

@@ -0,0 +1,158 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// this is the amount of reserved buffer space in a message that the
// driver reserves for command overhead.
const reservedCommandBufferBytes = 16 * 10 * 10 * 10
// Insert represents the insert command.
//
// The insert command inserts a set of documents into the database.
//
// Since the Insert command does not return any value other than ok or
// an error, this type has no Err method.
type Insert struct {
ContinueOnError bool
Clock *session.ClusterClock
NS Namespace
Docs []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Session *session.Client
batches []*WriteBatch
result result.Insert
err error
}
// Encode will encode this command into a wire message for the given server description.
func (i *Insert) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := i.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(i.batches, desc)
}
func (i *Insert) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
command, err := encodeBatch(docs, i.Opts, InsertCommand, i.NS.Collection)
if err != nil {
return nil, err
}
for _, opt := range i.Opts {
if opt.Key == "ordered" && !opt.Value.Boolean() {
i.ContinueOnError = true
break
}
}
return &WriteBatch{
&Write{
Clock: i.Clock,
DB: i.NS.DB,
Command: command,
WriteConcern: i.WriteConcern,
Session: i.Session,
},
len(docs),
}, nil
}
func (i *Insert) encode(desc description.SelectedServer) error {
batches, err := splitBatches(i.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := i.encodeBatch(docs, desc)
if err != nil {
return err
}
i.batches = append(i.batches, cmd)
}
return nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (i *Insert) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Insert {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
i.err = err
return i
}
return i.decode(desc, rdr)
}
func (i *Insert) decode(desc description.SelectedServer, rdr bson.Raw) *Insert {
i.err = bson.Unmarshal(rdr, &i.result)
return i
}
// Result returns the result of a decoded wire message and server description.
func (i *Insert) Result() (result.Insert, error) {
if i.err != nil {
return result.Insert{}, i.err
}
return i.result, nil
}
// Err returns the error set on this command.
func (i *Insert) Err() error { return i.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
//func (i *Insert) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.Insert, error) {
func (i *Insert) RoundTrip(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
) (result.Insert, error) {
if i.batches == nil {
err := i.encode(desc)
if err != nil {
return result.Insert{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
i.batches,
i.ContinueOnError,
i.Session,
InsertCommand,
)
// if there are leftover batches, save them for retry
if batches != nil {
i.batches = batches
}
if err != nil {
return result.Insert{}, err
}
res := r.(result.Insert)
return res, nil
}

View File

@@ -0,0 +1,121 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// IsMaster represents the isMaster command.
//
// The isMaster command is used for setting up a connection to MongoDB and
// for monitoring a MongoDB server.
//
// Since IsMaster can only be run on a connection, there is no Dispatch method.
type IsMaster struct {
Client bsonx.Doc
Compressors []string
SaslSupportedMechs string
err error
res result.IsMaster
}
// Encode will encode this command into a wire message for the given server description.
func (im *IsMaster) Encode() (wiremessage.WireMessage, error) {
cmd := bsonx.Doc{{"isMaster", bsonx.Int32(1)}}
if im.Client != nil {
cmd = append(cmd, bsonx.Elem{"client", bsonx.Document(im.Client)})
}
if im.SaslSupportedMechs != "" {
cmd = append(cmd, bsonx.Elem{"saslSupportedMechs", bsonx.String(im.SaslSupportedMechs)})
}
// always send compressors even if empty slice
array := bsonx.Arr{}
for _, compressor := range im.Compressors {
array = append(array, bsonx.String(compressor))
}
cmd = append(cmd, bsonx.Elem{"compression", bsonx.Array(array)})
rdr, err := cmd.MarshalBSON()
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: "admin.$cmd",
Flags: wiremessage.SlaveOK,
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (im *IsMaster) Decode(wm wiremessage.WireMessage) *IsMaster {
reply, ok := wm.(wiremessage.Reply)
if !ok {
im.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return im
}
rdr, err := decodeCommandOpReply(reply)
if err != nil {
im.err = err
return im
}
err = bson.Unmarshal(rdr, &im.res)
if err != nil {
im.err = err
return im
}
// Reconstructs the $clusterTime doc after decode
if im.res.ClusterTime != nil {
im.res.ClusterTime = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", im.res.ClusterTime))
}
return im
}
// Result returns the result of a decoded wire message and server description.
func (im *IsMaster) Result() (result.IsMaster, error) {
if im.err != nil {
return result.IsMaster{}, im.err
}
return im.res, nil
}
// Err returns the error set on this command.
func (im *IsMaster) Err() error { return im.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (im *IsMaster) RoundTrip(ctx context.Context, rw wiremessage.ReadWriter) (result.IsMaster, error) {
wm, err := im.Encode()
if err != nil {
return result.IsMaster{}, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
return result.IsMaster{}, err
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
return result.IsMaster{}, err
}
return im.Decode(wm).Result()
}

View File

@@ -0,0 +1,103 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// KillCursors represents the killCursors command.
//
// The killCursors command kills a set of cursors.
type KillCursors struct {
Clock *session.ClusterClock
NS Namespace
IDs []int64
result result.KillCursors
err error
}
// Encode will encode this command into a wire message for the given server description.
func (kc *KillCursors) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := kc.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (kc *KillCursors) encode(desc description.SelectedServer) (*Read, error) {
idVals := make([]bsonx.Val, 0, len(kc.IDs))
for _, id := range kc.IDs {
idVals = append(idVals, bsonx.Int64(id))
}
cmd := bsonx.Doc{
{"killCursors", bsonx.String(kc.NS.Collection)},
{"cursors", bsonx.Array(idVals)},
}
return &Read{
Clock: kc.Clock,
DB: kc.NS.DB,
Command: cmd,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (kc *KillCursors) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *KillCursors {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
kc.err = err
return kc
}
return kc.decode(desc, rdr)
}
func (kc *KillCursors) decode(desc description.SelectedServer, rdr bson.Raw) *KillCursors {
err := bson.Unmarshal(rdr, &kc.result)
if err != nil {
kc.err = err
return kc
}
return kc
}
// Result returns the result of a decoded wire message and server description.
func (kc *KillCursors) Result() (result.KillCursors, error) {
if kc.err != nil {
return result.KillCursors{}, kc.err
}
return kc.result, nil
}
// Err returns the error set on this command.
func (kc *KillCursors) Err() error { return kc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (kc *KillCursors) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.KillCursors, error) {
cmd, err := kc.encode(desc)
if err != nil {
return result.KillCursors{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.KillCursors{}, err
}
return kc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,102 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ListCollections represents the listCollections command.
//
// The listCollections command lists the collections in a database.
type ListCollections struct {
Clock *session.ClusterClock
DB string
Filter bsonx.Doc
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
ReadPref *readpref.ReadPref
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (lc *ListCollections) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := lc.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (lc *ListCollections) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listCollections", bsonx.Int32(1)}}
if lc.Filter != nil {
cmd = append(cmd, bsonx.Elem{"filter", bsonx.Document(lc.Filter)})
}
cmd = append(cmd, lc.Opts...)
return &Read{
Clock: lc.Clock,
DB: lc.DB,
Command: cmd,
ReadPref: lc.ReadPref,
Session: lc.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decolcng
// are deferred until either the Result or Err methods are called.
func (lc *ListCollections) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListCollections {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
lc.err = err
return lc
}
return lc.decode(desc, rdr)
}
func (lc *ListCollections) decode(desc description.SelectedServer, rdr bson.Raw) *ListCollections {
lc.result = rdr
return lc
}
// Result returns the result of a decoded wire message and server description.
func (lc *ListCollections) Result() (bson.Raw, error) {
if lc.err != nil {
return nil, lc.err
}
return lc.result, nil
}
// Err returns the error set on this command.
func (lc *ListCollections) Err() error { return lc.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (lc *ListCollections) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := lc.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return nil, err
}
return lc.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ListDatabases represents the listDatabases command.
//
// The listDatabases command lists the databases in a MongoDB deployment.
type ListDatabases struct {
Clock *session.ClusterClock
Filter bsonx.Doc
Opts []bsonx.Elem
Session *session.Client
result result.ListDatabases
err error
}
// Encode will encode this command into a wire message for the given server description.
func (ld *ListDatabases) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := ld.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (ld *ListDatabases) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listDatabases", bsonx.Int32(1)}}
if ld.Filter != nil {
cmd = append(cmd, bsonx.Elem{"filter", bsonx.Document(ld.Filter)})
}
cmd = append(cmd, ld.Opts...)
return &Read{
Clock: ld.Clock,
DB: "admin",
Command: cmd,
Session: ld.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (ld *ListDatabases) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListDatabases {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
ld.err = err
return ld
}
return ld.decode(desc, rdr)
}
func (ld *ListDatabases) decode(desc description.SelectedServer, rdr bson.Raw) *ListDatabases {
ld.err = bson.Unmarshal(rdr, &ld.result)
return ld
}
// Result returns the result of a decoded wire message and server description.
func (ld *ListDatabases) Result() (result.ListDatabases, error) {
if ld.err != nil {
return result.ListDatabases{}, ld.err
}
return ld.result, nil
}
// Err returns the error set on this command.
func (ld *ListDatabases) Err() error { return ld.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (ld *ListDatabases) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.ListDatabases, error) {
cmd, err := ld.encode(desc)
if err != nil {
return result.ListDatabases{}, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.ListDatabases{}, err
}
return ld.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,106 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// ErrEmptyCursor is a signaling error when a cursor for list indexes is empty.
var ErrEmptyCursor = errors.New("empty cursor")
// ListIndexes represents the listIndexes command.
//
// The listIndexes command lists the indexes for a namespace.
type ListIndexes struct {
Clock *session.ClusterClock
NS Namespace
CursorOpts []bsonx.Elem
Opts []bsonx.Elem
Session *session.Client
result bson.Raw
err error
}
// Encode will encode this command into a wire message for the given server description.
func (li *ListIndexes) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
encoded, err := li.encode(desc)
if err != nil {
return nil, err
}
return encoded.Encode(desc)
}
func (li *ListIndexes) encode(desc description.SelectedServer) (*Read, error) {
cmd := bsonx.Doc{{"listIndexes", bsonx.String(li.NS.Collection)}}
cmd = append(cmd, li.Opts...)
return &Read{
Clock: li.Clock,
DB: li.NS.DB,
Command: cmd,
Session: li.Session,
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoling
// are deferred until either the Result or Err methods are called.
func (li *ListIndexes) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *ListIndexes {
rdr, err := (&Read{}).Decode(desc, wm).Result()
if err != nil {
if IsNotFound(err) {
li.err = ErrEmptyCursor
return li
}
li.err = err
return li
}
return li.decode(desc, rdr)
}
func (li *ListIndexes) decode(desc description.SelectedServer, rdr bson.Raw) *ListIndexes {
li.result = rdr
return li
}
// Result returns the result of a decoded wire message and server description.
func (li *ListIndexes) Result() (bson.Raw, error) {
if li.err != nil {
return nil, li.err
}
return li.result, nil
}
// Err returns the error set on this command.
func (li *ListIndexes) Err() error { return li.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (li *ListIndexes) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
cmd, err := li.encode(desc)
if err != nil {
return nil, err
}
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
if IsNotFound(err) {
return nil, ErrEmptyCursor
}
return nil, err
}
return li.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,79 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"errors"
"strings"
)
// Namespace encapsulates a database and collection name, which together
// uniquely identifies a collection within a MongoDB cluster.
type Namespace struct {
DB string
Collection string
}
// NewNamespace returns a new Namespace for the
// given database and collection.
func NewNamespace(db, collection string) Namespace { return Namespace{DB: db, Collection: collection} }
// ParseNamespace parses a namespace string into a Namespace.
//
// The namespace string must contain at least one ".", the first of which is the separator
// between the database and collection names. If not, the default (invalid) Namespace is returned.
func ParseNamespace(name string) Namespace {
index := strings.Index(name, ".")
if index == -1 {
return Namespace{}
}
return Namespace{
DB: name[:index],
Collection: name[index+1:],
}
}
// FullName returns the full namespace string, which is the result of joining the database
// name and the collection name with a "." character.
func (ns *Namespace) FullName() string {
return strings.Join([]string{ns.DB, ns.Collection}, ".")
}
// Validate validates the namespace.
func (ns *Namespace) Validate() error {
if err := ns.validateDB(); err != nil {
return err
}
return ns.validateCollection()
}
// validateDB ensures the database name is not an empty string, contain a ".",
// or contain a " ".
func (ns *Namespace) validateDB() error {
if ns.DB == "" {
return errors.New("database name cannot be empty")
}
if strings.Contains(ns.DB, " ") {
return errors.New("database name cannot contain ' '")
}
if strings.Contains(ns.DB, ".") {
return errors.New("database name cannot contain '.'")
}
return nil
}
// validateCollection ensures the collection name is not an empty string.
func (ns *Namespace) validateCollection() error {
if ns.Collection == "" {
return errors.New("collection name cannot be empty")
}
return nil
}

View File

@@ -0,0 +1,53 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
func decodeCommandOpMsg(msg wiremessage.Msg) (bson.Raw, error) {
var mainDoc bsonx.Doc
for _, section := range msg.Sections {
switch converted := section.(type) {
case wiremessage.SectionBody:
err := mainDoc.UnmarshalBSON(converted.Document)
if err != nil {
return nil, err
}
case wiremessage.SectionDocumentSequence:
arr := bsonx.Arr{}
for _, doc := range converted.Documents {
newDoc := bsonx.Doc{}
err := newDoc.UnmarshalBSON(doc)
if err != nil {
return nil, err
}
arr = append(arr, bsonx.Document(newDoc))
}
mainDoc = append(mainDoc, bsonx.Elem{converted.Identifier, bsonx.Array(arr)})
}
}
byteArray, err := mainDoc.MarshalBSON()
if err != nil {
return nil, err
}
rdr := bson.Raw(byteArray)
err = rdr.Validate()
if err != nil {
return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
}
return rdr, extractError(rdr)
}

View File

@@ -0,0 +1,43 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// decodeCommandOpReply handles decoding the OP_REPLY response to an OP_QUERY
// command.
func decodeCommandOpReply(reply wiremessage.Reply) (bson.Raw, error) {
if reply.NumberReturned == 0 {
return nil, ErrNoDocCommandResponse
}
if reply.NumberReturned > 1 {
return nil, ErrMultiDocCommandResponse
}
if len(reply.Documents) != 1 {
return nil, NewCommandResponseError("malformed OP_REPLY: NumberReturned does not match number of documents returned", nil)
}
rdr := reply.Documents[0]
err := rdr.Validate()
if err != nil {
return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
}
if reply.ResponseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
return nil, QueryFailureError{
Message: "command failure",
Response: reply.Documents[0],
}
}
err = extractError(rdr)
if err != nil {
return nil, err
}
return rdr, nil
}

294
vendor/go.mongodb.org/mongo-driver/x/network/command/read.go generated vendored Executable file
View File

@@ -0,0 +1,294 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Read represents a generic database read command.
type Read struct {
DB string
Command bsonx.Doc
ReadPref *readpref.ReadPref
ReadConcern *readconcern.ReadConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
func (r *Read) createReadPref(serverKind description.ServerKind, topologyKind description.TopologyKind, isOpQuery bool) bsonx.Doc {
doc := bsonx.Doc{}
rp := r.ReadPref
if rp == nil {
if topologyKind == description.Single && serverKind != description.Mongos {
return append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
}
return nil
}
switch rp.Mode() {
case readpref.PrimaryMode:
if serverKind == description.Mongos {
return nil
}
if topologyKind == description.Single {
return append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
}
doc = append(doc, bsonx.Elem{"mode", bsonx.String("primary")})
case readpref.PrimaryPreferredMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("primaryPreferred")})
case readpref.SecondaryPreferredMode:
_, ok := r.ReadPref.MaxStaleness()
if serverKind == description.Mongos && isOpQuery && !ok && len(r.ReadPref.TagSets()) == 0 {
return nil
}
doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondaryPreferred")})
case readpref.SecondaryMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("secondary")})
case readpref.NearestMode:
doc = append(doc, bsonx.Elem{"mode", bsonx.String("nearest")})
}
sets := make([]bsonx.Val, 0, len(r.ReadPref.TagSets()))
for _, ts := range r.ReadPref.TagSets() {
if len(ts) == 0 {
continue
}
set := bsonx.Doc{}
for _, t := range ts {
set = append(set, bsonx.Elem{t.Name, bsonx.String(t.Value)})
}
sets = append(sets, bsonx.Document(set))
}
if len(sets) > 0 {
doc = append(doc, bsonx.Elem{"tags", bsonx.Array(sets)})
}
if d, ok := r.ReadPref.MaxStaleness(); ok {
doc = append(doc, bsonx.Elem{"maxStalenessSeconds", bsonx.Int32(int32(d.Seconds()))})
}
return doc
}
// addReadPref will add a read preference to the query document.
//
// NOTE: This method must always return either a valid bson.Reader or an error.
func (r *Read) addReadPref(rp *readpref.ReadPref, serverKind description.ServerKind, topologyKind description.TopologyKind, query bson.Raw) (bson.Raw, error) {
doc := r.createReadPref(serverKind, topologyKind, true)
if doc == nil {
return query, nil
}
qdoc := bsonx.Doc{}
err := bson.Unmarshal(query, &qdoc)
if err != nil {
return query, err
}
return bsonx.Doc{
{"$query", bsonx.Document(qdoc)},
{"$readPreference", bsonx.Document(doc)},
}.MarshalBSON()
}
// Encode r as OP_MSG
func (r *Read) encodeOpMsg(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
msg := wiremessage.Msg{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
Sections: make([]wiremessage.Section, 0),
}
readPrefDoc := r.createReadPref(desc.Server.Kind, desc.Kind, false)
fullDocRdr, err := opmsgAddGlobals(cmd, r.DB, readPrefDoc)
if err != nil {
return nil, err
}
// type 0 doc
msg.Sections = append(msg.Sections, wiremessage.SectionBody{
PayloadType: wiremessage.SingleDocument,
Document: fullDocRdr,
})
// no flags to add
return msg, nil
}
func (r *Read) slaveOK(desc description.SelectedServer) wiremessage.QueryFlag {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
return wiremessage.SlaveOK
}
if r.ReadPref == nil {
// assume primary
return 0
}
if r.ReadPref.Mode() != readpref.PrimaryMode {
return wiremessage.SlaveOK
}
return 0
}
// Encode c as OP_QUERY
func (r *Read) encodeOpQuery(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
rdr, err := marshalCommand(cmd)
if err != nil {
return nil, err
}
if desc.Server.Kind == description.Mongos {
rdr, err = r.addReadPref(r.ReadPref, desc.Server.Kind, desc.Kind, rdr)
if err != nil {
return nil, err
}
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: r.DB + ".$cmd",
Flags: r.slaveOK(desc),
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
func (r *Read) decodeOpMsg(wm wiremessage.WireMessage) {
msg, ok := wm.(wiremessage.Msg)
if !ok {
r.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
r.result, r.err = decodeCommandOpMsg(msg)
}
func (r *Read) decodeOpReply(wm wiremessage.WireMessage) {
reply, ok := wm.(wiremessage.Reply)
if !ok {
r.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
r.result, r.err = decodeCommandOpReply(reply)
}
// Encode will encode this command into a wire message for the given server description.
func (r *Read) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := r.Command.Copy()
cmd, err := addReadConcern(cmd, desc, r.ReadConcern, r.Session)
if err != nil {
return nil, err
}
cmd, err = addSessionFields(cmd, desc, r.Session)
if err != nil {
return nil, err
}
cmd = addClusterTime(cmd, desc, r.Session, r.Clock)
if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
return r.encodeOpQuery(desc, cmd)
}
return r.encodeOpMsg(desc, cmd)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (r *Read) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Read {
switch wm.(type) {
case wiremessage.Reply:
r.decodeOpReply(wm)
default:
r.decodeOpMsg(wm)
}
if r.err != nil {
// decode functions set error if an invalid response document was returned or if the OK flag in the response was 0
// if the OK flag was 0, a type Error is returned. otherwise, a special type is returned
cerr, ok := r.err.(Error)
if !ok {
return r // for missing/invalid response docs, don't update cluster times
}
if cerr.HasErrorLabel(TransientTransactionError) {
r.Session.ClearPinnedServer()
}
}
_ = updateClusterTimes(r.Session, r.Clock, r.result)
_ = updateOperationTime(r.Session, r.result)
r.Session.UpdateRecoveryToken(r.result)
return r
}
// Result returns the result of a decoded wire message and server description.
func (r *Read) Result() (bson.Raw, error) {
if r.err != nil {
return nil, r.err
}
return r.result, nil
}
// Err returns the error set on this command.
func (r *Read) Err() error {
return r.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (r *Read) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
wm, err := r.Encode(desc)
if err != nil {
return nil, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
r.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
r.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if r.Session != nil {
err = r.Session.UpdateUseTime()
if err != nil {
return nil, err
}
}
return r.Decode(desc, wm).Result()
}

View File

@@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// StartSession represents a startSession command
type StartSession struct {
Clock *session.ClusterClock
result result.StartSession
err error
}
// Encode will encode this command into a wiremessage for the given server description.
func (ss *StartSession) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := ss.encode(desc)
return cmd.Encode(desc)
}
func (ss *StartSession) encode(desc description.SelectedServer) *Write {
cmd := bsonx.Doc{{"startSession", bsonx.Int32(1)}}
return &Write{
Clock: ss.Clock,
DB: "admin",
Command: cmd,
}
}
// Decode will decode the wire message using the provided server description. Errors during decoding are deferred until
// either the Result or Err methods are called.
func (ss *StartSession) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *StartSession {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
ss.err = err
return ss
}
return ss.decode(desc, rdr)
}
func (ss *StartSession) decode(desc description.SelectedServer, rdr bson.Raw) *StartSession {
ss.err = bson.Unmarshal(rdr, &ss.result)
return ss
}
// Result returns the result of a decoded wire message and server description.
func (ss *StartSession) Result() (result.StartSession, error) {
if ss.err != nil {
return result.StartSession{}, ss.err
}
return ss.result, nil
}
// Err returns the error set on this command
func (ss *StartSession) Err() error {
return ss.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter
func (ss *StartSession) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (result.StartSession, error) {
cmd := ss.encode(desc)
rdr, err := cmd.RoundTrip(ctx, desc, rw)
if err != nil {
return result.StartSession{}, err
}
return ss.decode(desc, rdr).Result()
}

View File

@@ -0,0 +1,161 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/result"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Update represents the update command.
//
// The update command updates a set of documents with the database.
type Update struct {
ContinueOnError bool
Clock *session.ClusterClock
NS Namespace
Docs []bsonx.Doc
Opts []bsonx.Elem
WriteConcern *writeconcern.WriteConcern
Session *session.Client
batches []*WriteBatch
result result.Update
err error
}
// Encode will encode this command into a wire message for the given server description.
func (u *Update) Encode(desc description.SelectedServer) ([]wiremessage.WireMessage, error) {
err := u.encode(desc)
if err != nil {
return nil, err
}
return batchesToWireMessage(u.batches, desc)
}
func (u *Update) encode(desc description.SelectedServer) error {
batches, err := splitBatches(u.Docs, int(desc.MaxBatchCount), int(desc.MaxDocumentSize))
if err != nil {
return err
}
for _, docs := range batches {
cmd, err := u.encodeBatch(docs, desc)
if err != nil {
return err
}
u.batches = append(u.batches, cmd)
}
return nil
}
func (u *Update) encodeBatch(docs []bsonx.Doc, desc description.SelectedServer) (*WriteBatch, error) {
copyDocs := make([]bsonx.Doc, 0, len(docs)) // copy of all the documents
for _, doc := range docs {
newDoc := doc.Copy()
copyDocs = append(copyDocs, newDoc)
}
var options []bsonx.Elem
for _, opt := range u.Opts {
switch opt.Key {
case "upsert", "collation", "arrayFilters":
// options that are encoded on each individual document
for idx := range copyDocs {
copyDocs[idx] = append(copyDocs[idx], opt)
}
default:
options = append(options, opt)
}
}
command, err := encodeBatch(copyDocs, options, UpdateCommand, u.NS.Collection)
if err != nil {
return nil, err
}
return &WriteBatch{
&Write{
Clock: u.Clock,
DB: u.NS.DB,
Command: command,
WriteConcern: u.WriteConcern,
Session: u.Session,
},
len(docs),
}, nil
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (u *Update) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Update {
rdr, err := (&Write{}).Decode(desc, wm).Result()
if err != nil {
u.err = err
return u
}
return u.decode(desc, rdr)
}
func (u *Update) decode(desc description.SelectedServer, rdr bson.Raw) *Update {
u.err = bson.Unmarshal(rdr, &u.result)
return u
}
// Result returns the result of a decoded wire message and server description.
func (u *Update) Result() (result.Update, error) {
if u.err != nil {
return result.Update{}, u.err
}
return u.result, nil
}
// Err returns the error set on this command.
func (u *Update) Err() error { return u.err }
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriter.
func (u *Update) RoundTrip(
ctx context.Context,
desc description.SelectedServer,
rw wiremessage.ReadWriter,
) (result.Update, error) {
if u.batches == nil {
err := u.encode(desc)
if err != nil {
return result.Update{}, err
}
}
r, batches, err := roundTripBatches(
ctx, desc, rw,
u.batches,
u.ContinueOnError,
u.Session,
UpdateCommand,
)
// if there are leftover batches, save them for retry
if batches != nil {
u.batches = batches
}
if err != nil {
return result.Update{}, err
}
return r.(result.Update), nil
}

252
vendor/go.mongodb.org/mongo-driver/x/network/command/write.go generated vendored Executable file
View File

@@ -0,0 +1,252 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package command
import (
"context"
"fmt"
"errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Write represents a generic write database command.
// This can be used to send arbitrary write commands to the database.
type Write struct {
DB string
Command bsonx.Doc
WriteConcern *writeconcern.WriteConcern
Clock *session.ClusterClock
Session *session.Client
result bson.Raw
err error
}
// Encode c as OP_MSG
func (w *Write) encodeOpMsg(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
var arr bsonx.Arr
var identifier string
cmd, arr, identifier = opmsgRemoveArray(cmd)
msg := wiremessage.Msg{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
Sections: make([]wiremessage.Section, 0),
}
fullDocRdr, err := opmsgAddGlobals(cmd, w.DB, nil)
if err != nil {
return nil, err
}
// type 0 doc
msg.Sections = append(msg.Sections, wiremessage.SectionBody{
PayloadType: wiremessage.SingleDocument,
Document: fullDocRdr,
})
// type 1 doc
if identifier != "" {
docSequence, err := opmsgCreateDocSequence(arr, identifier)
if err != nil {
return nil, err
}
msg.Sections = append(msg.Sections, docSequence)
}
// flags
if !writeconcern.AckWrite(w.WriteConcern) {
msg.FlagBits |= wiremessage.MoreToCome
}
return msg, nil
}
// Encode w as OP_QUERY
func (w *Write) encodeOpQuery(desc description.SelectedServer, cmd bsonx.Doc) (wiremessage.WireMessage, error) {
rdr, err := marshalCommand(cmd)
if err != nil {
return nil, err
}
query := wiremessage.Query{
MsgHeader: wiremessage.Header{RequestID: wiremessage.NextRequestID()},
FullCollectionName: w.DB + ".$cmd",
Flags: w.slaveOK(desc),
NumberToReturn: -1,
Query: rdr,
}
return query, nil
}
func (w *Write) slaveOK(desc description.SelectedServer) wiremessage.QueryFlag {
if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
return wiremessage.SlaveOK
}
return 0
}
func (w *Write) decodeOpReply(wm wiremessage.WireMessage) {
reply, ok := wm.(wiremessage.Reply)
if !ok {
w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
w.result, w.err = decodeCommandOpReply(reply)
}
func (w *Write) decodeOpMsg(wm wiremessage.WireMessage) {
msg, ok := wm.(wiremessage.Msg)
if !ok {
w.err = fmt.Errorf("unsupported response wiremessage type %T", wm)
return
}
w.result, w.err = decodeCommandOpMsg(msg)
}
// Encode will encode this command into a wire message for the given server description.
func (w *Write) Encode(desc description.SelectedServer) (wiremessage.WireMessage, error) {
cmd := w.Command.Copy()
var err error
if w.Session != nil && w.Session.TransactionStarting() {
// Starting transactions have a read concern, even in writes.
cmd, err = addReadConcern(cmd, desc, nil, w.Session)
if err != nil {
return nil, err
}
}
cmd, err = addWriteConcern(cmd, w.WriteConcern)
if err != nil {
return nil, err
}
if !writeconcern.AckWrite(w.WriteConcern) {
// unack write with explicit session --> raise an error
// unack write with implicit session --> do not send session ID (implicit session shouldn't have been created
// in the first place)
if w.Session != nil && w.Session.SessionType == session.Explicit {
return nil, errors.New("explicit sessions cannot be used with unacknowledged writes")
}
} else {
// only encode session ID for acknowledged writes
cmd, err = addSessionFields(cmd, desc, w.Session)
if err != nil {
return nil, err
}
}
if w.Session != nil && w.Session.RetryWrite && cmd.IndexOf("txnNumber") == -1 {
cmd = append(cmd, bsonx.Elem{"txnNumber", bsonx.Int64(w.Session.TxnNumber)})
}
cmd = addClusterTime(cmd, desc, w.Session, w.Clock)
if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
return w.encodeOpQuery(desc, cmd)
}
return w.encodeOpMsg(desc, cmd)
}
// Decode will decode the wire message using the provided server description. Errors during decoding
// are deferred until either the Result or Err methods are called.
func (w *Write) Decode(desc description.SelectedServer, wm wiremessage.WireMessage) *Write {
switch wm.(type) {
case wiremessage.Reply:
w.decodeOpReply(wm)
default:
w.decodeOpMsg(wm)
}
if w.err != nil {
cerr, ok := w.err.(Error)
if !ok {
return w
}
if cerr.HasErrorLabel(TransientTransactionError) {
w.Session.ClearPinnedServer()
}
}
_ = updateClusterTimes(w.Session, w.Clock, w.result)
w.Session.UpdateRecoveryToken(w.result)
if writeconcern.AckWrite(w.WriteConcern) {
// don't update session operation time for unacknowledged write
_ = updateOperationTime(w.Session, w.result)
}
return w
}
// Result returns the result of a decoded wire message and server description.
func (w *Write) Result() (bson.Raw, error) {
if w.err != nil {
return nil, w.err
}
return w.result, nil
}
// Err returns the error set on this command.
func (w *Write) Err() error {
return w.err
}
// RoundTrip handles the execution of this command using the provided wiremessage.ReadWriteCloser.
func (w *Write) RoundTrip(ctx context.Context, desc description.SelectedServer, rw wiremessage.ReadWriter) (bson.Raw, error) {
wm, err := w.Encode(desc)
if err != nil {
return nil, err
}
err = rw.WriteWireMessage(ctx, wm)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
w.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if msg, ok := wm.(wiremessage.Msg); ok {
// don't expect response if using OP_MSG for an unacknowledged write
if msg.FlagBits&wiremessage.MoreToCome > 0 {
return nil, ErrUnacknowledgedWrite
}
}
wm, err = rw.ReadWireMessage(ctx)
if err != nil {
if _, ok := err.(Error); ok {
return nil, err
}
// Connection errors are transient
w.Session.ClearPinnedServer()
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
if w.Session != nil {
err = w.Session.UpdateUseTime()
if err != nil {
return nil, err
}
}
return w.Decode(desc, wm).Result()
}

View File

@@ -0,0 +1,170 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package compressor // import "go.mongodb.org/mongo-driver/x/network/compressor"
import (
"bytes"
"compress/zlib"
"io"
"github.com/golang/snappy"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Compressor is the interface implemented by types that can compress and decompress wire messages. This is used
// when sending and receiving messages to and from the server.
type Compressor interface {
CompressBytes(src, dest []byte) ([]byte, error)
UncompressBytes(src, dest []byte) ([]byte, error)
CompressorID() wiremessage.CompressorID
Name() string
}
type writer struct {
buf []byte
}
// Write appends bytes to the writer
func (w *writer) Write(p []byte) (n int, err error) {
index := len(w.buf)
if len(p) > cap(w.buf)-index {
buf := make([]byte, 2*cap(w.buf)+len(p))
copy(buf, w.buf)
w.buf = buf
}
w.buf = w.buf[:index+len(p)]
copy(w.buf[index:], p)
return len(p), nil
}
// SnappyCompressor uses the snappy method to compress data
type SnappyCompressor struct {
}
// ZlibCompressor uses the zlib method to compress data
type ZlibCompressor struct {
level int
zlibWriter *zlib.Writer
}
// CompressBytes uses snappy to compress a slice of bytes.
func (s *SnappyCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
dest = dest[:0]
dest = snappy.Encode(dest, src)
return dest, nil
}
// UncompressBytes uses snappy to uncompress a slice of bytes.
func (s *SnappyCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
var err error
dest, err = snappy.Decode(dest, src)
if err != nil {
return dest, err
}
return dest, nil
}
// CompressorID returns the ID for the snappy compressor.
func (s *SnappyCompressor) CompressorID() wiremessage.CompressorID {
return wiremessage.CompressorSnappy
}
// Name returns the string name for the snappy compressor.
func (s *SnappyCompressor) Name() string {
return "snappy"
}
// CompressBytes uses zlib to compress a slice of bytes.
func (z *ZlibCompressor) CompressBytes(src, dest []byte) ([]byte, error) {
output := &writer{
buf: dest[:0],
}
z.zlibWriter.Reset(output)
_, err := z.zlibWriter.Write(src)
if err != nil {
_ = z.zlibWriter.Close()
return output.buf, err
}
err = z.zlibWriter.Close()
if err != nil {
return output.buf, err
}
return output.buf, nil
}
// UncompressBytes uses zlib to uncompress a slice of bytes. It assumes dest is empty and is the exact size that it
// needs to be.
func (z *ZlibCompressor) UncompressBytes(src, dest []byte) ([]byte, error) {
reader := bytes.NewReader(src)
zlibReader, err := zlib.NewReader(reader)
if err != nil {
return dest, err
}
defer func() {
_ = zlibReader.Close()
}()
_, err = io.ReadFull(zlibReader, dest)
if err != nil {
return dest, err
}
return dest, nil
}
// CompressorID returns the ID for the zlib compressor.
func (z *ZlibCompressor) CompressorID() wiremessage.CompressorID {
return wiremessage.CompressorZLib
}
// Name returns the name for the zlib compressor.
func (z *ZlibCompressor) Name() string {
return "zlib"
}
// CreateSnappy creates a snappy compressor
func CreateSnappy() Compressor {
return &SnappyCompressor{}
}
// CreateZlib creates a zlib compressor
func CreateZlib(level *int) (Compressor, error) {
var l int
if level == nil {
l = wiremessage.DefaultZlibLevel
} else {
l = *level
}
if l < zlib.NoCompression {
l = wiremessage.DefaultZlibLevel
}
if l > zlib.BestCompression {
l = zlib.BestCompression
}
var compressBuf bytes.Buffer
zlibWriter, err := zlib.NewWriterLevel(&compressBuf, l)
if err != nil {
return &ZlibCompressor{}, err
}
return &ZlibCompressor{
level: l,
zlibWriter: zlibWriter,
}, nil
}

View File

@@ -0,0 +1,30 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package connection contains the types for building and pooling connections that can speak the
// MongoDB Wire Protocol. Since this low level library is meant to be used in the context of either
// a driver or a server there are some extra identifiers on a connection so one can keep track of
// what a connection is. This package purposefully hides the underlying network and abstracts the
// writing to and reading from a connection to wireops.Op's. This package also provides types for
// listening for and accepting Connections, as well as some types for handling connections and
// proxying connections to another server.
package connection // import "go.mongodb.org/mongo-driver/x/network/connection"
import (
"context"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Connection is used to read and write wire protocol messages to a network.
type Connection interface {
WriteWireMessage(context.Context, wiremessage.WireMessage) error
ReadWireMessage(context.Context) (wiremessage.WireMessage, error)
Close() error
Expired() bool
Alive() bool
ID() string
}

183
vendor/go.mongodb.org/mongo-driver/x/network/result/result.go generated vendored Executable file
View File

@@ -0,0 +1,183 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package result contains the results from various operations.
package result // import "go.mongodb.org/mongo-driver/x/network/result"
import (
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx"
)
// Upsert contains the information for a single upsert.
type Upsert struct {
Index int64 `bson:"index"`
ID interface{} `bson:"_id"`
}
// Insert is a result from an Insert command.
type Insert struct {
N int
WriteErrors []WriteError `bson:"writeErrors"`
WriteConcernError *WriteConcernError `bson:"writeConcernError"`
}
// StartSession is a result from a StartSession command.
type StartSession struct {
ID bsonx.Doc `bson:"id"`
}
// EndSessions is a result from an EndSessions command.
type EndSessions struct{}
// Delete is a result from a Delete command.
type Delete struct {
N int
WriteErrors []WriteError `bson:"writeErrors"`
WriteConcernError *WriteConcernError `bson:"writeConcernError"`
}
// Update is a result of an Update command.
type Update struct {
MatchedCount int64 `bson:"n"`
ModifiedCount int64 `bson:"nModified"`
Upserted []Upsert `bson:"upserted"`
WriteErrors []WriteError `bson:"writeErrors"`
WriteConcernError *WriteConcernError `bson:"writeConcernError"`
}
// Distinct is a result from a Distinct command.
type Distinct struct {
Values []interface{}
}
// FindAndModify is a result from a findAndModify command.
type FindAndModify struct {
Value bson.Raw
LastErrorObject struct {
UpdatedExisting bool
Upserted interface{}
}
WriteConcernError *WriteConcernError `bson:"writeConcernError"`
}
// WriteError is an error from a write operation that is not a write concern
// error.
type WriteError struct {
Index int
Code int
ErrMsg string
}
// WriteConcernError is an error related to a write concern.
type WriteConcernError struct {
Name string `bson:"codeName"`
Code int
ErrMsg string
ErrInfo bson.Raw
}
func (wce WriteConcernError) Error() string {
if wce.Name != "" {
return fmt.Sprintf("(%v) %v", wce.Name, wce.ErrMsg)
}
return wce.ErrMsg
}
// ListDatabases is the result from a listDatabases command.
type ListDatabases struct {
Databases []struct {
Name string
SizeOnDisk int64 `bson:"sizeOnDisk"`
Empty bool
}
TotalSize int64 `bson:"totalSize"`
}
// IsMaster is a result of an IsMaster command.
type IsMaster struct {
Arbiters []string `bson:"arbiters,omitempty"`
ArbiterOnly bool `bson:"arbiterOnly,omitempty"`
ClusterTime bson.Raw `bson:"$clusterTime,omitempty"`
Compression []string `bson:"compression,omitempty"`
ElectionID primitive.ObjectID `bson:"electionId,omitempty"`
Hidden bool `bson:"hidden,omitempty"`
Hosts []string `bson:"hosts,omitempty"`
IsMaster bool `bson:"ismaster,omitempty"`
IsReplicaSet bool `bson:"isreplicaset,omitempty"`
LastWriteTimestamp time.Time `bson:"lastWriteDate,omitempty"`
LogicalSessionTimeoutMinutes uint32 `bson:"logicalSessionTimeoutMinutes,omitempty"`
MaxBSONObjectSize uint32 `bson:"maxBsonObjectSize,omitempty"`
MaxMessageSizeBytes uint32 `bson:"maxMessageSizeBytes,omitempty"`
MaxWriteBatchSize uint32 `bson:"maxWriteBatchSize,omitempty"`
Me string `bson:"me,omitempty"`
MaxWireVersion int32 `bson:"maxWireVersion,omitempty"`
MinWireVersion int32 `bson:"minWireVersion,omitempty"`
Msg string `bson:"msg,omitempty"`
OK int32 `bson:"ok"`
Passives []string `bson:"passives,omitempty"`
ReadOnly bool `bson:"readOnly,omitempty"`
SaslSupportedMechs []string `bson:"saslSupportedMechs,omitempty"`
Secondary bool `bson:"secondary,omitempty"`
SetName string `bson:"setName,omitempty"`
SetVersion uint32 `bson:"setVersion,omitempty"`
Tags map[string]string `bson:"tags,omitempty"`
}
// BuildInfo is a result of a BuildInfo command.
type BuildInfo struct {
OK bool `bson:"ok"`
GitVersion string `bson:"gitVersion,omitempty"`
Version string `bson:"version,omitempty"`
VersionArray []uint8 `bson:"versionArray,omitempty"`
}
// IsZero returns true if the BuildInfo is the zero value.
func (bi BuildInfo) IsZero() bool {
if !bi.OK && bi.GitVersion == "" && bi.Version == "" && bi.VersionArray == nil {
return true
}
return false
}
// GetLastError is a result of a GetLastError command.
type GetLastError struct {
ConnectionID uint32 `bson:"connectionId"`
}
// KillCursors is a result of a KillCursors command.
type KillCursors struct {
CursorsKilled []int64 `bson:"cursorsKilled"`
CursorsNotFound []int64 `bson:"cursorsNotFound"`
CursorsAlive []int64 `bson:"cursorsAlive"`
}
// CreateIndexes is a result of a CreateIndexes command.
type CreateIndexes struct {
CreatedCollectionAutomatically bool `bson:"createdCollectionAutomatically"`
IndexesBefore int `bson:"numIndexesBefore"`
IndexesAfter int `bson:"numIndexesAfter"`
}
// TransactionResult holds the result of committing or aborting a transaction.
type TransactionResult struct {
WriteConcernError *WriteConcernError `bson:"writeConcernError"`
}
// BulkWrite holds the result of a bulk write operation.
type BulkWrite struct {
InsertedCount int64
MatchedCount int64
ModifiedCount int64
DeletedCount int64
UpsertedCount int64
UpsertedIDs map[int64]interface{}
}

View File

@@ -0,0 +1,20 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
func appendInt32(b []byte, i int32) []byte {
return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24))
}
func appendCString(b []byte, str string) []byte {
b = append(b, str...)
return append(b, 0x00)
}
func appendInt64(b []byte, i int64) []byte {
return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56))
}

View File

@@ -0,0 +1,49 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import "go.mongodb.org/mongo-driver/bson"
// Command represents the OP_COMMAND message of the MongoDB wire protocol.
type Command struct {
MsgHeader Header
Database string
CommandName string
Metadata string
CommandArgs string
InputDocs []bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (c Command) MarshalWireMessage() ([]byte, error) {
panic("not implemented")
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (c Command) ValidateWireMessage() error {
panic("not implemented")
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (c Command) AppendWireMessage([]byte) ([]byte, error) {
panic("not implemented")
}
// String implements the fmt.Stringer interface.
func (c Command) String() string {
panic("not implemented")
}
// Len implements the WireMessage interface.
func (c Command) Len() int {
panic("not implemented")
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (c *Command) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}

View File

@@ -0,0 +1,47 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import "go.mongodb.org/mongo-driver/bson"
// CommandReply represents the OP_COMMANDREPLY message of the MongoDB wire protocol.
type CommandReply struct {
MsgHeader Header
Metadata bson.Raw
CommandReply bson.Raw
OutputDocs []bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (cr CommandReply) MarshalWireMessage() ([]byte, error) {
panic("not implemented")
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (cr CommandReply) ValidateWireMessage() error {
panic("not implemented")
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (cr CommandReply) AppendWireMessage([]byte) ([]byte, error) {
panic("not implemented")
}
// String implements the fmt.Stringer interface.
func (cr CommandReply) String() string {
panic("not implemented")
}
// Len implements the WireMessage interface.
func (cr CommandReply) Len() int {
panic("not implemented")
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (cr *CommandReply) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Compressed represents the OP_COMPRESSED message of the MongoDB wire protocol.
type Compressed struct {
MsgHeader Header
OriginalOpCode OpCode
UncompressedSize int32
CompressorID CompressorID
CompressedMessage []byte
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (c Compressed) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, c.Len())
return c.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (c Compressed) ValidateWireMessage() error {
if int(c.MsgHeader.MessageLength) != c.Len() {
return errors.New("incorrect header: message length is not correct")
}
if c.MsgHeader.OpCode != OpCompressed {
return errors.New("incorrect header: opcode is not OpCompressed")
}
if c.OriginalOpCode != c.MsgHeader.OpCode {
return errors.New("incorrect header: original opcode does not match opcode in message header")
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
//
// AppendWireMessage will set the MessageLength property of MsgHeader if it is 0. It will also set the OpCode to
// OpCompressed if the OpCode is 0. If either of these properties are non-zero and not correct, this method will return
// both the []byte with the wire message appended to it and an invalid header error.
func (c Compressed) AppendWireMessage(b []byte) ([]byte, error) {
err := c.MsgHeader.SetDefaults(c.Len(), OpCompressed)
b = c.MsgHeader.AppendHeader(b)
b = appendInt32(b, int32(c.OriginalOpCode))
b = appendInt32(b, c.UncompressedSize)
b = append(b, byte(c.CompressorID))
b = append(b, c.CompressedMessage...)
return b, err
}
// String implements the fmt.Stringer interface.
func (c Compressed) String() string {
return fmt.Sprintf(
`OP_COMPRESSED{MsgHeader: %s, Uncompressed Size: %d, CompressorId: %d, Compressed message: %s}`,
c.MsgHeader, c.UncompressedSize, c.CompressorID, c.CompressedMessage,
)
}
// Len implements the WireMessage interface.
func (c Compressed) Len() int {
// Header + OpCode + UncompressedSize + CompressorId + CompressedMessage
return 16 + 4 + 4 + 1 + len(c.CompressedMessage)
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (c *Compressed) UnmarshalWireMessage(b []byte) error {
var err error
c.MsgHeader, err = ReadHeader(b, 0)
if err != nil {
return err
}
if len(b) < int(c.MsgHeader.MessageLength) {
return Error{Type: ErrOpCompressed, Message: "[]byte too small"}
}
c.OriginalOpCode = OpCode(readInt32(b, 16)) // skip first 16 for header
c.UncompressedSize = readInt32(b, 20)
c.CompressorID = CompressorID(b[24])
// messageLength - Header - OpCode - UncompressedSize - CompressorId
msgLen := c.MsgHeader.MessageLength - 16 - 4 - 4 - 1
c.CompressedMessage = b[25 : 25+msgLen]
return nil
}
// CompressorID is the ID for each type of Compressor.
type CompressorID = wiremessage.CompressorID
// These constants represent the individual compressor IDs for an OP_COMPRESSED.
const (
CompressorNoOp CompressorID = iota
CompressorSnappy
CompressorZLib
)
// DefaultZlibLevel is the default level for zlib compression
const DefaultZlibLevel = 6

View File

@@ -0,0 +1,55 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import "go.mongodb.org/mongo-driver/bson"
// Delete represents the OP_DELETE message of the MongoDB wire protocol.
type Delete struct {
MsgHeader Header
FullCollectionName string
Flags DeleteFlag
Selector bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (d Delete) MarshalWireMessage() ([]byte, error) {
panic("not implemented")
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (d Delete) ValidateWireMessage() error {
panic("not implemented")
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (d Delete) AppendWireMessage([]byte) ([]byte, error) {
panic("not implemented")
}
// String implements the fmt.Stringer interface.
func (d Delete) String() string {
panic("not implemented")
}
// Len implements the WireMessage interface.
func (d Delete) Len() int {
panic("not implemented")
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (d *Delete) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}
// DeleteFlag represents the flags on an OP_DELETE message.
type DeleteFlag int32
// These constants represent the individual flags on an OP_DELETE message.
const (
SingleRemove DeleteFlag = 1 << iota
)

View File

@@ -0,0 +1,103 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"go.mongodb.org/mongo-driver/x/bsonx"
"strings"
)
// GetMore represents the OP_GET_MORE message of the MongoDB wire protocol.
type GetMore struct {
MsgHeader Header
Zero int32
FullCollectionName string
NumberToReturn int32
CursorID int64
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (gm GetMore) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, gm.Len())
return gm.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (gm GetMore) ValidateWireMessage() error {
if int(gm.MsgHeader.MessageLength) != gm.Len() {
return errors.New("incorrect header: message length is not correct")
}
if gm.MsgHeader.OpCode != OpGetMore {
return errors.New("incorrect header: op code is not OpGetMore")
}
if strings.Index(gm.FullCollectionName, ".") == -1 {
return errors.New("incorrect header: collection name does not contain a dot")
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
//
// AppendWireMessage will set the MessageLength property of the MsgHeader
// if it is zero. It will also set the OpCode to OpGetMore if the OpCode is
// zero. If either of these properties are non-zero and not correct, this
// method will return both the []byte with the wire message appended to it
// and an invalid header error.
func (gm GetMore) AppendWireMessage(b []byte) ([]byte, error) {
var err error
err = gm.MsgHeader.SetDefaults(gm.Len(), OpGetMore)
b = gm.MsgHeader.AppendHeader(b)
b = appendInt32(b, gm.Zero)
b = appendCString(b, gm.FullCollectionName)
b = appendInt32(b, gm.NumberToReturn)
b = appendInt64(b, gm.CursorID)
return b, err
}
// String implements the fmt.Stringer interface.
func (gm GetMore) String() string {
return fmt.Sprintf(
`OP_GET_MORE{MsgHeader: %s, Zero: %d, FullCollectionName: %s, NumberToReturn: %d, CursorID: %d}`,
gm.MsgHeader, gm.Zero, gm.FullCollectionName, gm.NumberToReturn, gm.CursorID,
)
}
// Len implements the WireMessage interface.
func (gm GetMore) Len() int {
// Header + Zero + CollectionName + Null Terminator + Return + CursorID
return 16 + 4 + len(gm.FullCollectionName) + 1 + 4 + 8
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (gm *GetMore) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}
// CommandDocument creates a BSON document representing this command.
func (gm GetMore) CommandDocument() bsonx.Doc {
parts := strings.Split(gm.FullCollectionName, ".")
collName := parts[len(parts)-1]
doc := bsonx.Doc{
{"getMore", bsonx.Int64(gm.CursorID)},
{"collection", bsonx.String(collName)},
}
if gm.NumberToReturn != 0 {
doc = doc.Append("batchSize", bsonx.Int32(gm.NumberToReturn))
}
return doc
}
// DatabaseName returns the name of the database for this command.
func (gm GetMore) DatabaseName() string {
return strings.Split(gm.FullCollectionName, ".")[0]
}

View File

@@ -0,0 +1,87 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"fmt"
)
// ErrInvalidHeader is returned when methods are called on a malformed Header.
var ErrInvalidHeader error = Error{Type: ErrHeader, Message: "invalid header"}
// ErrHeaderTooSmall is returned when the size of the header is too small to be valid.
var ErrHeaderTooSmall error = Error{Type: ErrHeader, Message: "the header is too small to be valid"}
// ErrHeaderTooFewBytes is returned when a call to ReadHeader does not contain enough
// bytes to be a valid header.
var ErrHeaderTooFewBytes error = Error{Type: ErrHeader, Message: "invalid header because []byte too small"}
// ErrHeaderInvalidLength is returned when the MessageLength of a header is
// set but is not set to the correct size.
var ErrHeaderInvalidLength error = Error{Type: ErrHeader, Message: "invalid header because MessageLength is imporperly set"}
// ErrHeaderIncorrectOpCode is returned when the OpCode on a header is set but
// is not set to the correct OpCode.
var ErrHeaderIncorrectOpCode error = Error{Type: ErrHeader, Message: "invalid header because OpCode is improperly set"}
// Header represents the header of a MongoDB wire protocol message.
type Header struct {
MessageLength int32
RequestID int32
ResponseTo int32
OpCode OpCode
}
// ReadHeader reads a header from the given slice of bytes starting at offset
// pos.
func ReadHeader(b []byte, pos int32) (Header, error) {
if len(b) < 16 {
return Header{}, ErrHeaderTooFewBytes
}
return Header{
MessageLength: readInt32(b, 0),
RequestID: readInt32(b, 4),
ResponseTo: readInt32(b, 8),
OpCode: OpCode(readInt32(b, 12)),
}, nil
}
func (h Header) String() string {
return fmt.Sprintf(
`Header{MessageLength: %d, RequestID: %d, ResponseTo: %d, OpCode: %v}`,
h.MessageLength, h.RequestID, h.ResponseTo, h.OpCode,
)
}
// AppendHeader will append this header to the given slice of bytes.
func (h Header) AppendHeader(b []byte) []byte {
b = appendInt32(b, h.MessageLength)
b = appendInt32(b, h.RequestID)
b = appendInt32(b, h.ResponseTo)
b = appendInt32(b, int32(h.OpCode))
return b
}
// SetDefaults sets the length and opcode of this header.
func (h *Header) SetDefaults(length int, opcode OpCode) error {
switch h.MessageLength {
case int32(length):
case 0:
h.MessageLength = int32(length)
default:
return ErrHeaderInvalidLength
}
switch h.OpCode {
case opcode:
case OpCode(0):
h.OpCode = opcode
default:
return ErrHeaderIncorrectOpCode
}
return nil
}

View File

@@ -0,0 +1,55 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import "go.mongodb.org/mongo-driver/bson"
// Insert represents the OP_INSERT message of the MongoDB wire protocol.
type Insert struct {
MsgHeader Header
Flags InsertFlag
FullCollectionName string
Documents []bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (i Insert) MarshalWireMessage() ([]byte, error) {
panic("not implemented")
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (i Insert) ValidateWireMessage() error {
panic("not implemented")
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (i Insert) AppendWireMessage([]byte) ([]byte, error) {
panic("not implemented")
}
// String implements the fmt.Stringer interface.
func (i Insert) String() string {
panic("not implemented")
}
// Len implements the WireMessage interface.
func (i Insert) Len() int {
panic("not implemented")
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (i *Insert) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}
// InsertFlag represents the flags on an OP_INSERT message.
type InsertFlag int32
// These constants represent the individual flags on an OP_INSERT message.
const (
ContinueOnError InsertFlag = 1 << iota
)

View File

@@ -0,0 +1,92 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"go.mongodb.org/mongo-driver/x/bsonx"
)
// KillCursors represents the OP_KILL_CURSORS message of the MongoDB wire protocol.
type KillCursors struct {
MsgHeader Header
Zero int32
NumberOfCursorIDs int32
CursorIDs []int64
DatabaseName string
CollectionName string
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (kc KillCursors) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, kc.Len())
return kc.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (kc KillCursors) ValidateWireMessage() error {
if int(kc.MsgHeader.MessageLength) != kc.Len() {
return errors.New("incorrect header: message length is not correct")
}
if kc.MsgHeader.OpCode != OpKillCursors {
return errors.New("incorrect header: op code is not OpGetMore")
}
if kc.NumberOfCursorIDs != int32(len(kc.CursorIDs)) {
return errors.New("incorrect number of cursor IDs")
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (kc KillCursors) AppendWireMessage(b []byte) ([]byte, error) {
var err error
err = kc.MsgHeader.SetDefaults(kc.Len(), OpKillCursors)
b = kc.MsgHeader.AppendHeader(b)
b = appendInt32(b, kc.Zero)
b = appendInt32(b, kc.NumberOfCursorIDs)
for _, id := range kc.CursorIDs {
b = appendInt64(b, id)
}
return b, err
}
// String implements the fmt.Stringer interface.
func (kc KillCursors) String() string {
return fmt.Sprintf(
`OP_KILL_CURSORS{MsgHeader: %s, Zero: %d, Number of Cursor IDS: %d, Cursor IDs: %v}`,
kc.MsgHeader, kc.Zero, kc.NumberOfCursorIDs, kc.CursorIDs,
)
}
// Len implements the WireMessage interface.
func (kc KillCursors) Len() int {
// Header + Zero + Number IDs + 8 * Number IDs
return 16 + 4 + 4 + int(kc.NumberOfCursorIDs*8)
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (kc *KillCursors) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}
// CommandDocument creates a BSON document representing this command.
func (kc KillCursors) CommandDocument() bsonx.Doc {
cursors := make([]bsonx.Val, len(kc.CursorIDs))
for i, id := range kc.CursorIDs {
cursors[i] = bsonx.Int64(id)
}
return bsonx.Doc{
{"killCursors", bsonx.String(kc.CollectionName)},
{"cursors", bsonx.Array(cursors)},
}
}

View File

@@ -0,0 +1,303 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Msg represents the OP_MSG message of the MongoDB wire protocol.
type Msg struct {
MsgHeader Header
FlagBits MsgFlag
Sections []Section
Checksum uint32
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (m Msg) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, m.Len())
return m.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (m Msg) ValidateWireMessage() error {
if int(m.MsgHeader.MessageLength) != m.Len() {
return errors.New("incorrect header: message length is not correct")
}
if m.MsgHeader.OpCode != OpMsg {
return errors.New("incorrect header: opcode is not OpMsg")
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
//
// AppendWireMesssage will set the MessageLength property of the MsgHeader if it is zero. It will also set the Opcode
// to OP_MSG if it is zero. If either of these properties are non-zero and not correct, this method will return both the
// []byte with the wire message appended to it and an invalid header error.
func (m Msg) AppendWireMessage(b []byte) ([]byte, error) {
var err error
err = m.MsgHeader.SetDefaults(m.Len(), OpMsg)
b = m.MsgHeader.AppendHeader(b)
b = appendInt32(b, int32(m.FlagBits))
for _, section := range m.Sections {
newB := make([]byte, 0)
newB = section.AppendSection(newB)
b = section.AppendSection(b)
}
return b, err
}
// String implements the fmt.Stringer interface.
func (m Msg) String() string {
return fmt.Sprintf(
`OP_MSG{MsgHeader: %v, FlagBits: %d, Sections: %v, Checksum: %d}`,
m.MsgHeader, m.FlagBits, m.Sections, m.Checksum,
)
}
// Len implements the WireMessage interface.
func (m Msg) Len() int {
// Header + Flags + len of each section + optional checksum
totalLen := 16 + 4 // header and flag
for _, section := range m.Sections {
totalLen += section.Len()
}
if m.FlagBits&ChecksumPresent > 0 {
totalLen += 4
}
return totalLen
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (m *Msg) UnmarshalWireMessage(b []byte) error {
var err error
m.MsgHeader, err = ReadHeader(b, 0)
if err != nil {
return err
}
if len(b) < int(m.MsgHeader.MessageLength) {
return Error{
Type: ErrOpMsg,
Message: "[]byte too small",
}
}
m.FlagBits = MsgFlag(readInt32(b, 16))
// read each section
sectionBytes := m.MsgHeader.MessageLength - 16 - 4 // number of bytes taken up by sections
hasChecksum := m.FlagBits&ChecksumPresent > 0
if hasChecksum {
sectionBytes -= 4 // 4 bytes at end for checksum
}
m.Sections = make([]Section, 0)
position := 20 // position to read from
for sectionBytes > 0 {
sectionType := SectionType(b[position])
position++
switch sectionType {
case SingleDocument:
rdr, size, err := readDocument(b, int32(position))
if err.Message != "" {
err.Type = ErrOpMsg
return err
}
position += size
sb := SectionBody{
Document: rdr,
}
sb.PayloadType = sb.Kind()
sectionBytes -= int32(sb.Len())
m.Sections = append(m.Sections, sb)
case DocumentSequence:
sds := SectionDocumentSequence{}
sds.Size = readInt32(b, int32(position))
position += 4
identifier, err := readCString(b, int32(position))
if err != nil {
return err
}
sds.Identifier = identifier
position += len(identifier) + 1 // +1 for \0
sds.PayloadType = sds.Kind()
// length of documents to read
// sequenceLen - 4 bytes for size field - identifierLength (including \0)
docsLen := int(sds.Size) - 4 - len(identifier) - 1
for docsLen > 0 {
rdr, size, err := readDocument(b, int32(position))
if err.Message != "" {
err.Type = ErrOpMsg
return err
}
position += size
sds.Documents = append(sds.Documents, rdr)
docsLen -= size
}
sectionBytes -= int32(sds.Len())
m.Sections = append(m.Sections, sds)
}
}
if hasChecksum {
m.Checksum = uint32(readInt32(b, int32(position)))
}
return nil
}
// GetMainDocument returns the document containing the message to send.
func (m *Msg) GetMainDocument() (bsonx.Doc, error) {
return bsonx.ReadDoc(m.Sections[0].(SectionBody).Document)
}
// GetSequenceArray returns this message's document sequence as a BSON array along with the array identifier.
// If this message has no associated document sequence, a nil array is returned.
func (m *Msg) GetSequenceArray() (bsonx.Arr, string, error) {
if len(m.Sections) == 1 {
return nil, "", nil
}
arr := bsonx.Arr{}
sds := m.Sections[1].(SectionDocumentSequence)
for _, rdr := range sds.Documents {
doc, err := bsonx.ReadDoc([]byte(rdr))
if err != nil {
return nil, "", err
}
arr = append(arr, bsonx.Document(doc))
}
return arr, sds.Identifier, nil
}
// AcknowledgedWrite returns true if this msg represents an acknowledged write command.
func (m *Msg) AcknowledgedWrite() bool {
return m.FlagBits&MoreToCome == 0
}
// MsgFlag represents the flags on an OP_MSG message.
type MsgFlag = wiremessage.MsgFlag
// These constants represent the individual flags on an OP_MSG message.
const (
ChecksumPresent MsgFlag = 1 << iota
MoreToCome
ExhaustAllowed MsgFlag = 1 << 16
)
// Section represents a section on an OP_MSG message.
type Section interface {
Kind() SectionType
Len() int
AppendSection([]byte) []byte
}
// SectionBody represents the kind body of an OP_MSG message.
type SectionBody struct {
PayloadType SectionType
Document bson.Raw
}
// Kind implements the Section interface.
func (sb SectionBody) Kind() SectionType {
return SingleDocument
}
// Len implements the Section interface
func (sb SectionBody) Len() int {
return 1 + len(sb.Document) // 1 for PayloadType
}
// AppendSection implements the Section interface.
func (sb SectionBody) AppendSection(dest []byte) []byte {
dest = append(dest, byte(SingleDocument))
dest = append(dest, sb.Document...)
return dest
}
// SectionDocumentSequence represents the kind document sequence of an OP_MSG message.
type SectionDocumentSequence struct {
PayloadType SectionType
Size int32
Identifier string
Documents []bson.Raw
}
// Kind implements the Section interface.
func (sds SectionDocumentSequence) Kind() SectionType {
return DocumentSequence
}
// Len implements the Section interface
func (sds SectionDocumentSequence) Len() int {
// PayloadType + Size + Identifier + 1 (null terminator) + totalDocLen
totalDocLen := 0
for _, doc := range sds.Documents {
totalDocLen += len(doc)
}
return 1 + 4 + len(sds.Identifier) + 1 + totalDocLen
}
// PayloadLen returns the length of the payload
func (sds SectionDocumentSequence) PayloadLen() int {
// 4 bytes for size field, len identifier (including \0), and total docs len
return sds.Len() - 1
}
// AppendSection implements the Section interface
func (sds SectionDocumentSequence) AppendSection(dest []byte) []byte {
dest = append(dest, byte(DocumentSequence))
dest = appendInt32(dest, sds.Size)
dest = appendCString(dest, sds.Identifier)
for _, doc := range sds.Documents {
dest = append(dest, doc...)
}
return dest
}
// SectionType represents the type for 1 section in an OP_MSG
type SectionType = wiremessage.SectionType
// These constants represent the individual section types for a section in an OP_MSG
const (
SingleDocument SectionType = iota
DocumentSequence
)
// OpmsgWireVersion is the minimum wire version needed to use OP_MSG
const OpmsgWireVersion = 6

View File

@@ -0,0 +1,307 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Query represents the OP_QUERY message of the MongoDB wire protocol.
type Query struct {
MsgHeader Header
Flags QueryFlag
FullCollectionName string
NumberToSkip int32
NumberToReturn int32
Query bson.Raw
ReturnFieldsSelector bson.Raw
SkipSet bool
Limit *int32
BatchSize *int32
}
var optionsMap = map[string]string{
"$orderby": "sort",
"$hint": "hint",
"$comment": "comment",
"$maxScan": "maxScan",
"$max": "max",
"$min": "min",
"$returnKey": "returnKey",
"$showDiskLoc": "showRecordId",
"$maxTimeMS": "maxTimeMS",
"$snapshot": "snapshot",
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
//
// See AppendWireMessage for a description of the rules this method follows.
func (q Query) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, q.Len())
return q.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (q Query) ValidateWireMessage() error {
if int(q.MsgHeader.MessageLength) != q.Len() {
return errors.New("incorrect header: message length is not correct")
}
if q.MsgHeader.OpCode != OpQuery {
return errors.New("incorrect header: op code is not OpQuery")
}
if strings.Index(q.FullCollectionName, ".") == -1 {
return errors.New("incorrect header: collection name does not contain a dot")
}
if q.Query != nil && len(q.Query) > 0 {
err := q.Query.Validate()
if err != nil {
return err
}
}
if q.ReturnFieldsSelector != nil && len(q.ReturnFieldsSelector) > 0 {
err := q.ReturnFieldsSelector.Validate()
if err != nil {
return err
}
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
//
// AppendWireMessage will set the MessageLength property of the MsgHeader
// if it is zero. It will also set the OpCode to OpQuery if the OpCode is
// zero. If either of these properties are non-zero and not correct, this
// method will return both the []byte with the wire message appended to it
// and an invalid header error.
func (q Query) AppendWireMessage(b []byte) ([]byte, error) {
var err error
err = q.MsgHeader.SetDefaults(q.Len(), OpQuery)
b = q.MsgHeader.AppendHeader(b)
b = appendInt32(b, int32(q.Flags))
b = appendCString(b, q.FullCollectionName)
b = appendInt32(b, q.NumberToSkip)
b = appendInt32(b, q.NumberToReturn)
b = append(b, q.Query...)
b = append(b, q.ReturnFieldsSelector...)
return b, err
}
// String implements the fmt.Stringer interface.
func (q Query) String() string {
return fmt.Sprintf(
`OP_QUERY{MsgHeader: %s, Flags: %s, FullCollectionname: %s, NumberToSkip: %d, NumberToReturn: %d, Query: %s, ReturnFieldsSelector: %s}`,
q.MsgHeader, q.Flags, q.FullCollectionName, q.NumberToSkip, q.NumberToReturn, q.Query, q.ReturnFieldsSelector,
)
}
// Len implements the WireMessage interface.
func (q Query) Len() int {
// Header + Flags + CollectionName + Null Byte + Skip + Return + Query + ReturnFieldsSelector
return 16 + 4 + len(q.FullCollectionName) + 1 + 4 + 4 + len(q.Query) + len(q.ReturnFieldsSelector)
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (q *Query) UnmarshalWireMessage(b []byte) error {
var err error
q.MsgHeader, err = ReadHeader(b, 0)
if err != nil {
return err
}
if len(b) < int(q.MsgHeader.MessageLength) {
return Error{Type: ErrOpQuery, Message: "[]byte too small"}
}
q.Flags = QueryFlag(readInt32(b, 16))
q.FullCollectionName, err = readCString(b, 20)
if err != nil {
return err
}
pos := 20 + len(q.FullCollectionName) + 1
q.NumberToSkip = readInt32(b, int32(pos))
pos += 4
q.NumberToReturn = readInt32(b, int32(pos))
pos += 4
var size int
var wmerr Error
q.Query, size, wmerr = readDocument(b, int32(pos))
if wmerr.Message != "" {
wmerr.Type = ErrOpQuery
return wmerr
}
pos += size
if pos < len(b) {
q.ReturnFieldsSelector, size, wmerr = readDocument(b, int32(pos))
if wmerr.Message != "" {
wmerr.Type = ErrOpQuery
return wmerr
}
pos += size
}
return nil
}
// AcknowledgedWrite returns true if this command represents an acknowledged write
func (q *Query) AcknowledgedWrite() bool {
wcElem, err := q.Query.LookupErr("writeConcern")
if err != nil {
// no wc --> ack
return true
}
return writeconcern.AcknowledgedValue(wcElem)
}
// Legacy returns true if the query represents a legacy find operation.
func (q Query) Legacy() bool {
return !strings.Contains(q.FullCollectionName, "$cmd")
}
// DatabaseName returns the database name for the query.
func (q Query) DatabaseName() string {
if q.Legacy() {
return strings.Split(q.FullCollectionName, ".")[0]
}
return q.FullCollectionName[:len(q.FullCollectionName)-5] // remove .$cmd
}
// CollectionName returns the collection name for the query.
func (q Query) CollectionName() string {
parts := strings.Split(q.FullCollectionName, ".")
return parts[len(parts)-1]
}
// CommandDocument creates a BSON document representing this command.
func (q Query) CommandDocument() (bsonx.Doc, error) {
if q.Legacy() {
return q.legacyCommandDocument()
}
cmd, err := bsonx.ReadDoc([]byte(q.Query))
if err != nil {
return nil, err
}
cmdElem := cmd[0]
if cmdElem.Key == "$query" {
cmd = cmdElem.Value.Document()
}
return cmd, nil
}
func (q Query) legacyCommandDocument() (bsonx.Doc, error) {
doc, err := bsonx.ReadDoc(q.Query)
if err != nil {
return nil, err
}
parts := strings.Split(q.FullCollectionName, ".")
collName := parts[len(parts)-1]
doc = append(bsonx.Doc{{"find", bsonx.String(collName)}}, doc...)
var filter bsonx.Doc
var queryIndex int
for i, elem := range doc {
if newKey, ok := optionsMap[elem.Key]; ok {
doc[i].Key = newKey
continue
}
if elem.Key == "$query" {
filter = elem.Value.Document()
} else {
// the element is the filter
filter = filter.Append(elem.Key, elem.Value)
}
queryIndex = i
}
doc = append(doc[:queryIndex], doc[queryIndex+1:]...) // remove $query
if len(filter) != 0 {
doc = doc.Append("filter", bsonx.Document(filter))
}
doc, err = q.convertLegacyParams(doc)
if err != nil {
return nil, err
}
return doc, nil
}
func (q Query) convertLegacyParams(doc bsonx.Doc) (bsonx.Doc, error) {
if q.ReturnFieldsSelector != nil {
projDoc, err := bsonx.ReadDoc(q.ReturnFieldsSelector)
if err != nil {
return nil, err
}
doc = doc.Append("projection", bsonx.Document(projDoc))
}
if q.Limit != nil {
limit := *q.Limit
if limit < 0 {
limit *= -1
doc = doc.Append("singleBatch", bsonx.Boolean(true))
}
doc = doc.Append("limit", bsonx.Int32(*q.Limit))
}
if q.BatchSize != nil {
doc = doc.Append("batchSize", bsonx.Int32(*q.BatchSize))
}
if q.SkipSet {
doc = doc.Append("skip", bsonx.Int32(q.NumberToSkip))
}
if q.Flags&TailableCursor > 0 {
doc = doc.Append("tailable", bsonx.Boolean(true))
}
if q.Flags&OplogReplay > 0 {
doc = doc.Append("oplogReplay", bsonx.Boolean(true))
}
if q.Flags&NoCursorTimeout > 0 {
doc = doc.Append("noCursorTimeout", bsonx.Boolean(true))
}
if q.Flags&AwaitData > 0 {
doc = doc.Append("awaitData", bsonx.Boolean(true))
}
if q.Flags&Partial > 0 {
doc = doc.Append("allowPartialResults", bsonx.Boolean(true))
}
return doc, nil
}
// QueryFlag represents the flags on an OP_QUERY message.
type QueryFlag = wiremessage.QueryFlag
// These constants represent the individual flags on an OP_QUERY message.
const (
_ QueryFlag = 1 << iota
TailableCursor
SlaveOK
OplogReplay
NoCursorTimeout
AwaitData
Exhaust
Partial
)

View File

@@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"bytes"
"errors"
"go.mongodb.org/mongo-driver/bson"
)
func readInt32(b []byte, pos int32) int32 {
return (int32(b[pos+0])) | (int32(b[pos+1]) << 8) | (int32(b[pos+2]) << 16) | (int32(b[pos+3]) << 24)
}
func readCString(b []byte, pos int32) (string, error) {
null := bytes.IndexByte(b[pos:], 0x00)
if null == -1 {
return "", errors.New("invalid cstring")
}
return string(b[pos : int(pos)+null]), nil
}
func readInt64(b []byte, pos int32) int64 {
return (int64(b[pos+0])) | (int64(b[pos+1]) << 8) | (int64(b[pos+2]) << 16) | (int64(b[pos+3]) << 24) | (int64(b[pos+4]) << 32) |
(int64(b[pos+5]) << 40) | (int64(b[pos+6]) << 48) | (int64(b[pos+7]) << 56)
}
// readDocument will attempt to read a bson.Reader from the given slice of bytes
// from the given position.
func readDocument(b []byte, pos int32) (bson.Raw, int, Error) {
if int(pos)+4 > len(b) {
return nil, 0, Error{Message: "document too small to be valid"}
}
size := int(readInt32(b, int32(pos)))
if int(pos)+size > len(b) {
return nil, 0, Error{Message: "document size is larger than available bytes"}
}
if b[int(pos)+size-1] != 0x00 {
return nil, 0, Error{Message: "document invalid, last byte is not null"}
}
// TODO(GODRIVER-138): When we add 3.0 support, alter this so we either do one larger make or use a pool.
rdr := make(bson.Raw, size)
copy(rdr, b[pos:int(pos)+size])
return rdr, size, Error{Type: ErrNil}
}

View File

@@ -0,0 +1,180 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import (
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Reply represents the OP_REPLY message of the MongoDB wire protocol.
type Reply struct {
MsgHeader Header
ResponseFlags ReplyFlag
CursorID int64
StartingFrom int32
NumberReturned int32
Documents []bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
//
// See AppendWireMessage for a description of the rules this method follows.
func (r Reply) MarshalWireMessage() ([]byte, error) {
b := make([]byte, 0, r.Len())
return r.AppendWireMessage(b)
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (r Reply) ValidateWireMessage() error {
if int(r.MsgHeader.MessageLength) != r.Len() {
return errors.New("incorrect header: message length is not correct")
}
if r.MsgHeader.OpCode != OpReply {
return errors.New("incorrect header: op code is not OpReply")
}
return nil
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
//
// AppendWireMessage will set the MessageLength property of the MsgHeader
// if it is zero. It will also set the OpCode to OpQuery if the OpCode is
// zero. If either of these properties are non-zero and not correct, this
// method will return both the []byte with the wire message appended to it
// and an invalid header error.
func (r Reply) AppendWireMessage(b []byte) ([]byte, error) {
var err error
err = r.MsgHeader.SetDefaults(r.Len(), OpReply)
b = r.MsgHeader.AppendHeader(b)
b = appendInt32(b, int32(r.ResponseFlags))
b = appendInt64(b, r.CursorID)
b = appendInt32(b, r.StartingFrom)
b = appendInt32(b, r.NumberReturned)
for _, d := range r.Documents {
b = append(b, d...)
}
return b, err
}
// String implements the fmt.Stringer interface.
func (r Reply) String() string {
return fmt.Sprintf(
`OP_REPLY{MsgHeader: %s, ResponseFlags: %s, CursorID: %d, StartingFrom: %d, NumberReturned: %d, Documents: %v}`,
r.MsgHeader, r.ResponseFlags, r.CursorID, r.StartingFrom, r.NumberReturned, r.Documents,
)
}
// Len implements the WireMessage interface.
func (r Reply) Len() int {
// Header + Flags + CursorID + StartingFrom + NumberReturned + Length of Length of Documents
docsLen := 0
for _, d := range r.Documents {
docsLen += len(d)
}
return 16 + 4 + 8 + 4 + 4 + docsLen
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (r *Reply) UnmarshalWireMessage(b []byte) error {
var err error
r.MsgHeader, err = ReadHeader(b, 0)
if err != nil {
return err
}
if r.MsgHeader.MessageLength < 36 {
return errors.New("invalid OP_REPLY: header length too small")
}
if len(b) < int(r.MsgHeader.MessageLength) {
return errors.New("invalid OP_REPLY: []byte too small")
}
r.ResponseFlags = ReplyFlag(readInt32(b, 16))
r.CursorID = readInt64(b, 20)
r.StartingFrom = readInt32(b, 28)
r.NumberReturned = readInt32(b, 32)
pos := 36
for pos < len(b) {
rdr, size, err := readDocument(b, int32(pos))
if err.Message != "" {
err.Type = ErrOpReply
return err
}
r.Documents = append(r.Documents, rdr)
pos += size
}
return nil
}
// GetMainLegacyDocument constructs and returns a BSON document for this reply.
func (r *Reply) GetMainLegacyDocument(fullCollectionName string) (bsonx.Doc, error) {
if r.ResponseFlags&CursorNotFound > 0 {
fmt.Println("cursor not found err")
return bsonx.Doc{
{"ok", bsonx.Int32(0)},
}, nil
}
if r.ResponseFlags&QueryFailure > 0 {
firstDoc := r.Documents[0]
return bsonx.Doc{
{"ok", bsonx.Int32(0)},
{"errmsg", bsonx.String(firstDoc.Lookup("$err").StringValue())},
{"code", bsonx.Int32(firstDoc.Lookup("code").Int32())},
}, nil
}
doc := bsonx.Doc{
{"ok", bsonx.Int32(1)},
}
batchStr := "firstBatch"
if r.StartingFrom != 0 {
batchStr = "nextBatch"
}
batchArr := make([]bsonx.Val, len(r.Documents))
for i, docRaw := range r.Documents {
doc, err := bsonx.ReadDoc(docRaw)
if err != nil {
return nil, err
}
batchArr[i] = bsonx.Document(doc)
}
cursorDoc := bsonx.Doc{
{"id", bsonx.Int64(r.CursorID)},
{"ns", bsonx.String(fullCollectionName)},
{batchStr, bsonx.Array(batchArr)},
}
doc = doc.Append("cursor", bsonx.Document(cursorDoc))
return doc, nil
}
// GetMainDocument returns the main BSON document for this reply.
func (r *Reply) GetMainDocument() (bsonx.Doc, error) {
return bsonx.ReadDoc([]byte(r.Documents[0]))
}
// ReplyFlag represents the flags of an OP_REPLY message.
type ReplyFlag = wiremessage.ReplyFlag
// These constants represent the individual flags of an OP_REPLY message.
const (
CursorNotFound ReplyFlag = 1 << iota
QueryFailure
ShardConfigStale
AwaitCapable
)

View File

@@ -0,0 +1,57 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package wiremessage
import "go.mongodb.org/mongo-driver/bson"
// Update represents the OP_UPDATE message of the MongoDB wire protocol.
type Update struct {
MsgHeader Header
FullCollectionName string
Flags UpdateFlag
Selector bson.Raw
Update bson.Raw
}
// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
func (u Update) MarshalWireMessage() ([]byte, error) {
panic("not implemented")
}
// ValidateWireMessage implements the Validator and WireMessage interfaces.
func (u Update) ValidateWireMessage() error {
panic("not implemented")
}
// AppendWireMessage implements the Appender and WireMessage interfaces.
func (u Update) AppendWireMessage([]byte) ([]byte, error) {
panic("not implemented")
}
// String implements the fmt.Stringer interface.
func (u Update) String() string {
panic("not implemented")
}
// Len implements the WireMessage interface.
func (u Update) Len() int {
panic("not implemented")
}
// UnmarshalWireMessage implements the Unmarshaler interface.
func (u *Update) UnmarshalWireMessage([]byte) error {
panic("not implemented")
}
// UpdateFlag represents the flags on an OP_UPDATE message.
type UpdateFlag int32
// These constants represent the individual flags on an OP_UPDATE message.
const (
Upsert UpdateFlag = 1 << iota
MultiUpdate
)

View File

@@ -0,0 +1,178 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package wiremessage contains types for speaking the MongoDB Wire Protocol. Since this low
// level library is meant to be used in the context of a driver and in the context of a server
// all of the flags and types of the wire protocol are implemented. For each op there are two
// corresponding implementations. One prefixed with Immutable which can be created by casting a
// []byte to the type, and another prefixed with Mutable that is a struct with methods to mutate
// the op.
package wiremessage // import "go.mongodb.org/mongo-driver/x/network/wiremessage"
import (
"context"
"errors"
"fmt"
"io"
"sync/atomic"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// ErrInvalidMessageLength is returned when the provided message length is too small to be valid.
var ErrInvalidMessageLength = errors.New("the message length is too small, it must be at least 16")
// ErrUnknownOpCode is returned when the provided opcode is not a valid opcode.
var ErrUnknownOpCode = errors.New("the opcode is unknown")
var globalRequestID int32
// CurrentRequestID returns the current request ID.
func CurrentRequestID() int32 { return atomic.LoadInt32(&globalRequestID) }
// NextRequestID returns the next request ID.
func NextRequestID() int32 { return atomic.AddInt32(&globalRequestID, 1) }
// Error represents an error related to wire protocol messages.
type Error struct {
Type ErrorType
Message string
}
// Error implements the err interface.
func (e Error) Error() string {
return e.Message
}
// ErrorType is the type of error, which indicates from which part of the code
// the error originated.
type ErrorType uint16
// These constants are the types of errors exposed by this package.
const (
ErrNil ErrorType = iota
ErrHeader
ErrOpQuery
ErrOpReply
ErrOpCompressed
ErrOpMsg
ErrRead
)
// OpCode represents a MongoDB wire protocol opcode.
type OpCode = wiremessage.OpCode
// These constants are the valid opcodes for the version of the wireprotocol
// supported by this library. The skipped OpCodes are historical OpCodes that
// are no longer used.
const (
OpReply OpCode = 1
_ OpCode = 1001
OpUpdate OpCode = 2001
OpInsert OpCode = 2002
_ OpCode = 2003
OpQuery OpCode = 2004
OpGetMore OpCode = 2005
OpDelete OpCode = 2006
OpKillCursors OpCode = 2007
OpCommand OpCode = 2010
OpCommandReply OpCode = 2011
OpCompressed OpCode = 2012
OpMsg OpCode = 2013
)
// WireMessage represents a message in the MongoDB wire protocol.
type WireMessage interface {
Marshaler
Validator
Appender
fmt.Stringer
// Len returns the length in bytes of this WireMessage.
Len() int
}
// Validator is the interface implemented by types that can validate
// themselves as a MongoDB wire protocol message.
type Validator interface {
ValidateWireMessage() error
}
// Marshaler is the interface implemented by types that can marshal
// themselves into a valid MongoDB wire protocol message.
type Marshaler interface {
MarshalWireMessage() ([]byte, error)
}
// Appender is the interface implemented by types that can append themselves, as
// a MongoDB wire protocol message, to the provided slice of bytes.
type Appender interface {
AppendWireMessage([]byte) ([]byte, error)
}
// Unmarshaler is the interface implemented by types that can unmarshal a
// MongoDB wire protocol message version of themselves. The input can be
// assumed to be a valid MongoDB wire protocol message. UnmarshalWireMessage
// must copy the data if it wishes to retain the data after returning.
type Unmarshaler interface {
UnmarshalWireMessage([]byte) error
}
// Writer is the interface implemented by types that can have WireMessages
// written to them.
//
// Implementation must obey the cancellation, timeouts, and deadlines of the
// provided context.Context object.
type Writer interface {
WriteWireMessage(context.Context, WireMessage) error
}
// Reader is the interface implemented by types that can have WireMessages
// read from them.
//
// Implementation must obey the cancellation, timeouts, and deadlines of the
// provided context.Context object.
type Reader interface {
ReadWireMessage(context.Context) (WireMessage, error)
}
// ReadWriter is the interface implemented by types that can both read and write
// WireMessages.
type ReadWriter interface {
Reader
Writer
}
// ReadWriteCloser is the interface implemented by types that can read and write
// WireMessages and can also be closed.
type ReadWriteCloser interface {
Reader
Writer
io.Closer
}
// Transformer is the interface implemented by types that can alter a WireMessage.
// Implementations should not directly alter the provided WireMessage and instead
// make a copy of the message, alter it, and returned the new message.
type Transformer interface {
TransformWireMessage(WireMessage) (WireMessage, error)
}
// ReadFrom will read a single WireMessage from the given io.Reader. This function will
// validate the WireMessage. If the WireMessage is not valid, this method will
// return both the error and the invalid WireMessage. If another type of processing
// error occurs, WireMessage will be nil.
//
// This function will return the immutable versions of wire protocol messages. The
// Convert function can be used to retrieve a mutable version of wire protocol
// messages.
func ReadFrom(io.Reader) (WireMessage, error) { return nil, nil }
// Unmarshal will unmarshal data into a WireMessage.
func Unmarshal([]byte) (WireMessage, error) { return nil, nil }
// Validate will validate that data is a valid MongoDB wire protocol message.
func Validate([]byte) error { return nil }