newline battles continue

This commit is contained in:
bel
2020-01-19 20:41:30 +00:00
parent 98adb53caf
commit 573696774e
1456 changed files with 501133 additions and 6 deletions

97
vendor/go.mongodb.org/mongo-driver/x/bsonx/array.go generated vendored Executable file
View File

@@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx // import "go.mongodb.org/mongo-driver/x/bsonx"
import (
"bytes"
"errors"
"fmt"
"strconv"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilArray indicates that an operation was attempted on a nil *Array.
var ErrNilArray = errors.New("array is nil")
// Arr represents an array in BSON.
type Arr []Val
// String implements the fmt.Stringer interface.
func (a Arr) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Array["))
for idx, val := range a {
if idx > 0 {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%s", val)
}
buf.WriteByte(']')
return buf.String()
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
func (a Arr) MarshalBSONValue() (bsontype.Type, []byte, error) {
if a == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
idx, dst := bsoncore.ReserveLength(nil)
for idx, value := range a {
t, data, _ := value.MarshalBSONValue() // marshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, strconv.Itoa(idx)...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return bsontype.Array, dst, nil
}
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
func (a *Arr) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if a == nil {
return ErrNilArray
}
*a = (*a)[:0]
elements, err := bsoncore.Document(data).Elements()
if err != nil {
return err
}
for _, elem := range elements {
var val Val
rawval := elem.Value()
err = val.UnmarshalBSONValue(rawval.Type, rawval.Data)
if err != nil {
return err
}
*a = append(*a, val)
}
return nil
}
// Equal compares this document to another, returning true if they are equal.
func (a Arr) Equal(a2 Arr) bool {
if len(a) != len(a2) {
return false
}
for idx := range a {
if !a[idx].Equal(a2[idx]) {
return false
}
}
return true
}
func (Arr) idoc() {}

View File

@@ -0,0 +1,827 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncore contains functions that can be used to encode and decode BSON
// elements and values to or from a slice of bytes. These functions are aimed at
// allowing low level manipulation of BSON and can be used to build a higher
// level BSON library.
//
// The Read* functions within this package return the values of the element and
// a boolean indicating if the values are valid. A boolean was used instead of
// an error because any error that would be returned would be the same: not
// enough bytes. This library attempts to do no validation, it will only return
// false if there are not enough bytes for an item to be read. For example, the
// ReadDocument function checks the length, if that length is larger than the
// number of bytes availble, it will return false, if there are enough bytes, it
// will return those bytes and true. It is the consumers responsibility to
// validate those bytes.
//
// The Append* functions within this package will append the type value to the
// given dst slice. If the slice has enough capacity, it will not grow the
// slice. The Append*Element functions within this package operate in the same
// way, but additionally append the BSON type and the key before the value.
package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
import (
"bytes"
"fmt"
"math"
"strconv"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// AppendType will append t to dst and return the extended buffer.
func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) }
// AppendKey will append key to dst and return the extended buffer.
func AppendKey(dst []byte, key string) []byte { return append(dst, key+string(0x00)...) }
// AppendHeader will append Type t and key to dst and return the extended
// buffer.
func AppendHeader(dst []byte, t bsontype.Type, key string) []byte {
dst = AppendType(dst, t)
dst = append(dst, key...)
return append(dst, 0x00)
// return append(AppendType(dst, t), key+string(0x00)...)
}
// TODO(skriptble): All of the Read* functions should return src resliced to start just after what
// was read.
// ReadType will return the first byte of the provided []byte as a type. If
// there is no availble byte, false is returned.
func ReadType(src []byte) (bsontype.Type, []byte, bool) {
if len(src) < 1 {
return 0, src, false
}
return bsontype.Type(src[0]), src[1:], true
}
// ReadKey will read a key from src. The 0x00 byte will not be present
// in the returned string. If there are not enough bytes available, false is
// returned.
func ReadKey(src []byte) (string, []byte, bool) { return readcstring(src) }
// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will
// not be present in the returned string. If there are not enough bytes
// available, false is returned.
func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return readcstringbytes(src) }
// ReadHeader will read a type byte and a key from src. If both of these
// values cannot be read, false is returned.
func ReadHeader(src []byte) (t bsontype.Type, key string, rem []byte, ok bool) {
t, rem, ok = ReadType(src)
if !ok {
return 0, "", src, false
}
key, rem, ok = ReadKey(rem)
if !ok {
return 0, "", src, false
}
return t, key, rem, true
}
// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes
// are returned as rem. If either the type or key cannot be red, ok will be false.
func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) {
if len(src) < 1 {
return nil, src, false
}
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
// ReadElement reads the next full element from src. It returns the element, the remaining bytes in
// the slice, and a boolean indicating if the read was successful.
func ReadElement(src []byte) (Element, []byte, bool) {
if len(src) < 1 {
return nil, src, false
}
t := bsontype.Type(src[0])
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
length, ok := valueLength(src[idx+2:], t) // We add 2 here because we called IndexByte with src[1:]
if !ok {
return nil, src, false
}
elemLength := 1 + idx + 1 + int(length)
if elemLength > len(src) {
return nil, src, false
}
return src[:elemLength], src[elemLength:], true
}
// AppendValueElement appends value to dst as an element using key as the element's key.
func AppendValueElement(dst []byte, key string, value Value) []byte {
dst = AppendHeader(dst, value.Type, key)
dst = append(dst, value.Data...)
return dst
}
// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes,
// and a boolean indicating if the read was successful.
func ReadValue(src []byte, t bsontype.Type) (Value, []byte, bool) {
data, rem, ok := readValue(src, t)
if !ok {
return Value{}, src, false
}
return Value{Type: t, Data: data}, rem, true
}
// AppendDouble will append f to dst and return the extended buffer.
func AppendDouble(dst []byte, f float64) []byte {
return appendu64(dst, math.Float64bits(f))
}
// AppendDoubleElement will append a BSON double element using key and f to dst
// and return the extended buffer.
func AppendDoubleElement(dst []byte, key string, f float64) []byte {
return AppendDouble(AppendHeader(dst, bsontype.Double, key), f)
}
// ReadDouble will read a float64 from src. If there are not enough bytes it
// will return false.
func ReadDouble(src []byte) (float64, []byte, bool) {
bits, src, ok := readu64(src)
if !ok {
return 0, src, false
}
return math.Float64frombits(bits), src, true
}
// AppendString will append s to dst and return the extended buffer.
func AppendString(dst []byte, s string) []byte {
return appendstring(dst, s)
}
// AppendStringElement will append a BSON string element using key and val to dst
// and return the extended buffer.
func AppendStringElement(dst []byte, key, val string) []byte {
return AppendString(AppendHeader(dst, bsontype.String, key), val)
}
// ReadString will read a string from src. If there are not enough bytes it
// will return false.
func ReadString(src []byte) (string, []byte, bool) {
return readstring(src)
}
// AppendDocumentStart reserves a document's length and returns the index where the length begins.
// This index can later be used to write the length of the document.
//
// TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd.
// AppendDocumentStart would handle calling ReserveLength and providing the index of the start of
// the document. AppendDocumentEnd would handle taking that start index, adding the null byte,
// calculating the length, and filling in the length at the start of the document.
func AppendDocumentStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the
// index int32 which allows this function to be used inline.
func AppendDocumentStartInline(dst []byte, index *int32) []byte {
idx, doc := AppendDocumentStart(dst)
*index = idx
return doc
}
// AppendDocumentElementStart writes a document element header and then reserves the length bytes.
func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendDocumentStart(AppendHeader(dst, bsontype.EmbeddedDocument, key))
}
// AppendDocumentEnd writes the null byte for a document and updates the length of the document.
// The index should be the beginning of the document's length bytes.
func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) {
if int(index) > len(dst)-4 {
return dst, fmt.Errorf("not enough bytes available after index to write length")
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, index, int32(len(dst[index:])))
return dst, nil
}
// AppendDocument will append doc to dst and return the extended buffer.
func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) }
// AppendDocumentElement will append a BSON embeded document element using key
// and doc to dst and return the extended buffer.
func AppendDocumentElement(dst []byte, key string, doc []byte) []byte {
return AppendDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), doc)
}
// BuildDocument will create a document with the given elements and will append it to dst.
func BuildDocument(dst []byte, elems []byte) []byte {
idx, dst := ReserveLength(dst)
dst = append(dst, elems...)
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildDocumentFromElements will create a document with the given slice of elements and will append
// it to dst and return the extended buffer.
func BuildDocumentFromElements(dst []byte, elems ...[]byte) []byte {
idx, dst := ReserveLength(dst)
for _, elem := range elems {
dst = append(dst, elem...)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// ReadDocument will read a document from src. If there are not enough bytes it
// will return false.
func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendArrayStart appends the length bytes to an array and then returns the index of the start
// of those length bytes.
func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
// AppendArrayElementStart appends an array element header and then the length bytes for an array,
// returning the index where the length starts.
func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendArrayStart(AppendHeader(dst, bsontype.Array, key))
}
// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that
// calculated length starting at index.
func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) }
// AppendArray will append arr to dst and return the extended buffer.
func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) }
// AppendArrayElement will append a BSON array element using key and arr to dst
// and return the extended buffer.
func AppendArrayElement(dst []byte, key string, arr []byte) []byte {
return AppendArray(AppendHeader(dst, bsontype.Array, key), arr)
}
// BuildArray will append a BSON array to dst built from values.
func BuildArray(dst []byte, values ...Value) []byte {
idx, dst := ReserveLength(dst)
for pos, val := range values {
dst = AppendValueElement(dst, strconv.Itoa(pos), val)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildArrayElement will create an array element using the provided values.
func BuildArrayElement(dst []byte, key string, values ...Value) []byte {
return BuildArray(AppendHeader(dst, bsontype.Array, key), values...)
}
// ReadArray will read an array from src. If there are not enough bytes it
// will return false.
func ReadArray(src []byte) (arr Document, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendBinary will append subtype and b to dst and return the extended buffer.
func AppendBinary(dst []byte, subtype byte, b []byte) []byte {
if subtype == 0x02 {
return appendBinarySubtype2(dst, subtype, b)
}
dst = append(appendLength(dst, int32(len(b))), subtype)
return append(dst, b...)
}
// AppendBinaryElement will append a BSON binary element using key, subtype, and
// b to dst and return the extended buffer.
func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte {
return AppendBinary(AppendHeader(dst, bsontype.Binary, key), subtype, b)
}
// ReadBinary will read a subtype and bin from src. If there are not enough bytes it
// will return false.
func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok {
return 0x00, nil, src, false
}
if len(rem) < 1 { // subtype
return 0x00, nil, src, false
}
subtype, rem = rem[0], rem[1:]
if len(rem) < int(length) {
return 0x00, nil, src, false
}
if subtype == 0x02 {
length, rem, ok = ReadLength(rem)
if !ok || len(rem) < int(length) {
return 0x00, nil, src, false
}
}
return subtype, rem[:length], rem[length:], true
}
// AppendUndefinedElement will append a BSON undefined element using key to dst
// and return the extended buffer.
func AppendUndefinedElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.Undefined, key)
}
// AppendObjectID will append oid to dst and return the extended buffer.
func AppendObjectID(dst []byte, oid primitive.ObjectID) []byte { return append(dst, oid[:]...) }
// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst
// and return the extended buffer.
func AppendObjectIDElement(dst []byte, key string, oid primitive.ObjectID) []byte {
return AppendObjectID(AppendHeader(dst, bsontype.ObjectID, key), oid)
}
// ReadObjectID will read an ObjectID from src. If there are not enough bytes it
// will return false.
func ReadObjectID(src []byte) (primitive.ObjectID, []byte, bool) {
if len(src) < 12 {
return primitive.ObjectID{}, src, false
}
var oid primitive.ObjectID
copy(oid[:], src[0:12])
return oid, src[12:], true
}
// AppendBoolean will append b to dst and return the extended buffer.
func AppendBoolean(dst []byte, b bool) []byte {
if b {
return append(dst, 0x01)
}
return append(dst, 0x00)
}
// AppendBooleanElement will append a BSON boolean element using key and b to dst
// and return the extended buffer.
func AppendBooleanElement(dst []byte, key string, b bool) []byte {
return AppendBoolean(AppendHeader(dst, bsontype.Boolean, key), b)
}
// ReadBoolean will read a bool from src. If there are not enough bytes it
// will return false.
func ReadBoolean(src []byte) (bool, []byte, bool) {
if len(src) < 1 {
return false, src, false
}
return src[0] == 0x01, src[1:], true
}
// AppendDateTime will append dt to dst and return the extended buffer.
func AppendDateTime(dst []byte, dt int64) []byte { return appendi64(dst, dt) }
// AppendDateTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendDateTimeElement(dst []byte, key string, dt int64) []byte {
return AppendDateTime(AppendHeader(dst, bsontype.DateTime, key), dt)
}
// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it
// will return false.
func ReadDateTime(src []byte) (int64, []byte, bool) { return readi64(src) }
// AppendTime will append time as a BSON DateTime to dst and return the extended buffer.
func AppendTime(dst []byte, t time.Time) []byte {
return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6))
}
// AppendTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendTimeElement(dst []byte, key string, t time.Time) []byte {
return AppendTime(AppendHeader(dst, bsontype.DateTime, key), t)
}
// ReadTime will read an time.Time datetime from src. If there are not enough bytes it
// will return false.
func ReadTime(src []byte) (time.Time, []byte, bool) {
dt, rem, ok := readi64(src)
return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok
}
// AppendNullElement will append a BSON null element using key to dst
// and return the extended buffer.
func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, bsontype.Null, key) }
// AppendRegex will append pattern and options to dst and return the extended buffer.
func AppendRegex(dst []byte, pattern, options string) []byte {
return append(dst, pattern+string(0x00)+options+string(0x00)...)
}
// AppendRegexElement will append a BSON regex element using key, pattern, and
// options to dst and return the extended buffer.
func AppendRegexElement(dst []byte, key, pattern, options string) []byte {
return AppendRegex(AppendHeader(dst, bsontype.Regex, key), pattern, options)
}
// ReadRegex will read a pattern and options from src. If there are not enough bytes it
// will return false.
func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) {
pattern, rem, ok = readcstring(src)
if !ok {
return "", "", src, false
}
options, rem, ok = readcstring(rem)
if !ok {
return "", "", src, false
}
return pattern, options, rem, true
}
// AppendDBPointer will append ns and oid to dst and return the extended buffer.
func AppendDBPointer(dst []byte, ns string, oid primitive.ObjectID) []byte {
return append(appendstring(dst, ns), oid[:]...)
}
// AppendDBPointerElement will append a BSON DBPointer element using key, ns,
// and oid to dst and return the extended buffer.
func AppendDBPointerElement(dst []byte, key, ns string, oid primitive.ObjectID) []byte {
return AppendDBPointer(AppendHeader(dst, bsontype.DBPointer, key), ns, oid)
}
// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it
// will return false.
func ReadDBPointer(src []byte) (ns string, oid primitive.ObjectID, rem []byte, ok bool) {
ns, rem, ok = readstring(src)
if !ok {
return "", primitive.ObjectID{}, src, false
}
oid, rem, ok = ReadObjectID(rem)
if !ok {
return "", primitive.ObjectID{}, src, false
}
return ns, oid, rem, true
}
// AppendJavaScript will append js to dst and return the extended buffer.
func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) }
// AppendJavaScriptElement will append a BSON JavaScript element using key and
// js to dst and return the extended buffer.
func AppendJavaScriptElement(dst []byte, key, js string) []byte {
return AppendJavaScript(AppendHeader(dst, bsontype.JavaScript, key), js)
}
// ReadJavaScript will read a js string from src. If there are not enough bytes it
// will return false.
func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) }
// AppendSymbol will append symbol to dst and return the extended buffer.
func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) }
// AppendSymbolElement will append a BSON symbol element using key and symbol to dst
// and return the extended buffer.
func AppendSymbolElement(dst []byte, key, symbol string) []byte {
return AppendSymbol(AppendHeader(dst, bsontype.Symbol, key), symbol)
}
// ReadSymbol will read a symbol string from src. If there are not enough bytes it
// will return false.
func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) }
// AppendCodeWithScope will append code and scope to dst and return the extended buffer.
func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte {
length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope
dst = appendLength(dst, length)
return append(appendstring(dst, code), scope...)
}
// AppendCodeWithScopeElement will append a BSON code with scope element using
// key, code, and scope to dst
// and return the extended buffer.
func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte {
return AppendCodeWithScope(AppendHeader(dst, bsontype.CodeWithScope, key), code, scope)
}
// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it
// will return false.
func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok || len(src) < int(length) {
return "", nil, src, false
}
code, rem, ok = readstring(rem)
if !ok {
return "", nil, src, false
}
scope, rem, ok = ReadDocument(rem)
if !ok {
return "", nil, src, false
}
return code, scope, rem, true
}
// AppendInt32 will append i32 to dst and return the extended buffer.
func AppendInt32(dst []byte, i32 int32) []byte { return appendi32(dst, i32) }
// AppendInt32Element will append a BSON int32 element using key and i32 to dst
// and return the extended buffer.
func AppendInt32Element(dst []byte, key string, i32 int32) []byte {
return AppendInt32(AppendHeader(dst, bsontype.Int32, key), i32)
}
// ReadInt32 will read an int32 from src. If there are not enough bytes it
// will return false.
func ReadInt32(src []byte) (int32, []byte, bool) { return readi32(src) }
// AppendTimestamp will append t and i to dst and return the extended buffer.
func AppendTimestamp(dst []byte, t, i uint32) []byte {
return appendu32(appendu32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes
}
// AppendTimestampElement will append a BSON timestamp element using key, t, and
// i to dst and return the extended buffer.
func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte {
return AppendTimestamp(AppendHeader(dst, bsontype.Timestamp, key), t, i)
}
// ReadTimestamp will read t and i from src. If there are not enough bytes it
// will return false.
func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) {
i, rem, ok = readu32(src)
if !ok {
return 0, 0, src, false
}
t, rem, ok = readu32(rem)
if !ok {
return 0, 0, src, false
}
return t, i, rem, true
}
// AppendInt64 will append i64 to dst and return the extended buffer.
func AppendInt64(dst []byte, i64 int64) []byte { return appendi64(dst, i64) }
// AppendInt64Element will append a BSON int64 element using key and i64 to dst
// and return the extended buffer.
func AppendInt64Element(dst []byte, key string, i64 int64) []byte {
return AppendInt64(AppendHeader(dst, bsontype.Int64, key), i64)
}
// ReadInt64 will read an int64 from src. If there are not enough bytes it
// will return false.
func ReadInt64(src []byte) (int64, []byte, bool) { return readi64(src) }
// AppendDecimal128 will append d128 to dst and return the extended buffer.
func AppendDecimal128(dst []byte, d128 primitive.Decimal128) []byte {
high, low := d128.GetBytes()
return appendu64(appendu64(dst, low), high)
}
// AppendDecimal128Element will append a BSON primitive.28 element using key and
// d128 to dst and return the extended buffer.
func AppendDecimal128Element(dst []byte, key string, d128 primitive.Decimal128) []byte {
return AppendDecimal128(AppendHeader(dst, bsontype.Decimal128, key), d128)
}
// ReadDecimal128 will read a primitive.Decimal128 from src. If there are not enough bytes it
// will return false.
func ReadDecimal128(src []byte) (primitive.Decimal128, []byte, bool) {
l, rem, ok := readu64(src)
if !ok {
return primitive.Decimal128{}, src, false
}
h, rem, ok := readu64(rem)
if !ok {
return primitive.Decimal128{}, src, false
}
return primitive.NewDecimal128(h, l), rem, true
}
// AppendMaxKeyElement will append a BSON max key element using key to dst
// and return the extended buffer.
func AppendMaxKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.MaxKey, key)
}
// AppendMinKeyElement will append a BSON min key element using key to dst
// and return the extended buffer.
func AppendMinKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.MinKey, key)
}
// EqualValue will return true if the two values are equal.
func EqualValue(t1, t2 bsontype.Type, v1, v2 []byte) bool {
if t1 != t2 {
return false
}
v1, _, ok := readValue(v1, t1)
if !ok {
return false
}
v2, _, ok = readValue(v2, t2)
if !ok {
return false
}
return bytes.Equal(v1, v2)
}
// valueLength will determine the length of the next value contained in src as if it
// is type t. The returned bool will be false if there are not enough bytes in src for
// a value of type t.
func valueLength(src []byte, t bsontype.Type) (int32, bool) {
var length int32
ok := true
switch t {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
length, _, ok = ReadLength(src)
case bsontype.Binary:
length, _, ok = ReadLength(src)
length += 4 + 1 // binary length + subtype byte
case bsontype.Boolean:
length = 1
case bsontype.DBPointer:
length, _, ok = ReadLength(src)
length += 4 + 12 // string length + ObjectID length
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
length = 8
case bsontype.Decimal128:
length = 16
case bsontype.Int32:
length = 4
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
length, _, ok = ReadLength(src)
length += 4
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
length = 0
case bsontype.ObjectID:
length = 12
case bsontype.Regex:
regex := bytes.IndexByte(src, 0x00)
if regex < 0 {
ok = false
break
}
pattern := bytes.IndexByte(src[regex+1:], 0x00)
if pattern < 0 {
ok = false
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
ok = false
}
return length, ok
}
func readValue(src []byte, t bsontype.Type) ([]byte, []byte, bool) {
length, ok := valueLength(src, t)
if !ok || int(length) > len(src) {
return nil, src, false
}
return src[:length], src[length:], true
}
// ReserveLength reserves the space required for length and returns the index where to write the length
// and the []byte with reserved space.
func ReserveLength(dst []byte) (int32, []byte) {
index := len(dst)
return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00)
}
// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
dst[index] = byte(length)
dst[index+1] = byte(length >> 8)
dst[index+2] = byte(length >> 16)
dst[index+3] = byte(length >> 24)
return dst
}
func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }
func appendi32(dst []byte, i32 int32) []byte {
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
}
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
// there aren't enough bytes to read a valid length, src is returned unomdified and the returned
// bool will be false.
func ReadLength(src []byte) (int32, []byte, bool) { return readi32(src) }
func readi32(src []byte) (int32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
}
func appendi64(dst []byte, i64 int64) []byte {
return append(dst,
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
)
}
func readi64(src []byte) (int64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
return i64, src[8:], true
}
func appendu32(dst []byte, u32 uint32) []byte {
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
}
func readu32(src []byte) (uint32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
}
func appendu64(dst []byte, u64 uint64) []byte {
return append(dst,
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
)
}
func readu64(src []byte) (uint64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
return u64, src[8:], true
}
// keep in sync with readcstringbytes
func readcstring(src []byte) (string, []byte, bool) {
idx := bytes.IndexByte(src, 0x00)
if idx < 0 {
return "", src, false
}
return string(src[:idx]), src[idx+1:], true
}
// keep in sync with readcstring
func readcstringbytes(src []byte) ([]byte, []byte, bool) {
idx := bytes.IndexByte(src, 0x00)
if idx < 0 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
func appendstring(dst []byte, s string) []byte {
l := int32(len(s) + 1)
dst = appendLength(dst, l)
dst = append(dst, s...)
return append(dst, 0x00)
}
func readstring(src []byte) (string, []byte, bool) {
l, rem, ok := ReadLength(src)
if !ok {
return "", src, false
}
if len(src[4:]) < int(l) {
return "", src, false
}
return string(rem[:l-1]), rem[l:], true
}
// readLengthBytes attempts to read a length and that number of bytes. This
// function requires that the length include the four bytes for itself.
func readLengthBytes(src []byte) ([]byte, []byte, bool) {
l, _, ok := ReadLength(src)
if !ok {
return nil, src, false
}
if len(src) < int(l) {
return nil, src, false
}
return src[:l], src[l:], true
}
func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes
dst = append(dst, subtype)
dst = appendLength(dst, int32(len(b)))
return append(dst, b...)
}

View File

@@ -0,0 +1,399 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"github.com/go-stack/stack"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// DocumentValidationError is an error type returned when attempting to validate a document.
type DocumentValidationError string
func (dve DocumentValidationError) Error() string { return string(dve) }
// NewDocumentLengthError creates and returns an error for when the length of a document exceeds the
// bytes available.
func NewDocumentLengthError(length, rem int) error {
return DocumentValidationError(
fmt.Sprintf("document length exceeds available bytes. length=%d remainingBytes=%d", length, rem),
)
}
// InsufficientBytesError indicates that there were not enough bytes to read the next component.
type InsufficientBytesError struct {
Source []byte
Remaining []byte
Stack stack.CallStack
}
// NewInsufficientBytesError creates a new InsufficientBytesError with the given Document, remaining
// bytes, and the current stack.
func NewInsufficientBytesError(src, rem []byte) InsufficientBytesError {
return InsufficientBytesError{Source: src, Remaining: rem, Stack: stack.Trace().TrimRuntime()}
}
// Error implements the error interface.
func (ibe InsufficientBytesError) Error() string {
return "too few bytes to read next component"
}
// ErrorStack returns a string representing the stack at the point where the error occurred.
func (ibe InsufficientBytesError) ErrorStack() string {
s := bytes.NewBufferString("too few bytes to read next component: [")
for i, call := range ibe.Stack {
if i != 0 {
s.WriteString(", ")
}
// go vet doesn't like %k even though it's part of stack's API, so we move the format
// string so it doesn't complain. (We also can't make it a constant, or go vet still
// complains.)
callFormat := "%k.%n %v"
s.WriteString(fmt.Sprintf(callFormat, call, call, call))
}
s.WriteRune(']')
return s.String()
}
// Equal checks that err2 also is an ErrTooSmall.
func (ibe InsufficientBytesError) Equal(err2 error) bool {
switch err2.(type) {
case InsufficientBytesError:
return true
default:
return false
}
}
// InvalidDepthTraversalError is returned when attempting a recursive Lookup when one component of
// the path is neither an embedded document nor an array.
type InvalidDepthTraversalError struct {
Key string
Type bsontype.Type
}
func (idte InvalidDepthTraversalError) Error() string {
return fmt.Sprintf(
"attempt to traverse into %s, but it's type is %s, not %s nor %s",
idte.Key, idte.Type, bsontype.EmbeddedDocument, bsontype.Array,
)
}
// ErrMissingNull is returned when a document's last byte is not null.
const ErrMissingNull DocumentValidationError = "document end is missing null byte"
// ErrNilReader indicates that an operation was attempted on a nil io.Reader.
var ErrNilReader = errors.New("nil reader")
// ErrInvalidLength indicates that a length in a binary representation of a BSON document is invalid.
var ErrInvalidLength = errors.New("document length is invalid")
// ErrEmptyKey indicates that no key was provided to a Lookup method.
var ErrEmptyKey = errors.New("empty key provided")
// ErrElementNotFound indicates that an Element matching a certain condition does not exist.
var ErrElementNotFound = errors.New("element not found")
// ErrOutOfBounds indicates that an index provided to access something was invalid.
var ErrOutOfBounds = errors.New("out of bounds")
// Document is a raw bytes representation of a BSON document.
type Document []byte
// Array is a raw bytes representation of a BSON array.
type Array = Document
// NewDocumentFromReader reads a document from r. This function will only validate the length is
// correct and that the document ends with a null byte.
func NewDocumentFromReader(r io.Reader) (Document, error) {
if r == nil {
return nil, ErrNilReader
}
var lengthBytes [4]byte
// ReadFull guarantees that we will have read at least len(lengthBytes) if err == nil
_, err := io.ReadFull(r, lengthBytes[:])
if err != nil {
return nil, err
}
length, _, _ := readi32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length
if length < 0 {
return nil, ErrInvalidLength
}
document := make([]byte, length)
copy(document, lengthBytes[:])
_, err = io.ReadFull(r, document[4:])
if err != nil {
return nil, err
}
if document[length-1] != 0x00 {
return nil, ErrMissingNull
}
return document, nil
}
// Lookup searches the document, potentially recursively, for the given key. If there are multiple
// keys provided, this method will recurse down, as long as the top and intermediate nodes are
// either documents or arrays. If an error occurs or if the value doesn't exist, an empty Value is
// returned.
func (d Document) Lookup(key ...string) Value {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr is the same as Lookup, except it returns an error in addition to an empty Value.
func (d Document) LookupErr(key ...string) (Value, error) {
if len(key) < 1 {
return Value{}, ErrEmptyKey
}
length, rem, ok := ReadLength(d)
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
if elem.Key() != key[0] {
continue
}
if len(key) > 1 {
tt := bsontype.Type(elem[0])
switch tt {
case bsontype.EmbeddedDocument:
val, err := elem.Value().Document().LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
case bsontype.Array:
val, err := elem.Value().Array().LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
default:
return Value{}, InvalidDepthTraversalError{Key: elem.Key(), Type: tt}
}
}
return elem.ValueErr()
}
return Value{}, ErrElementNotFound
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (d Document) Index(index uint) Element {
elem, err := d.IndexErr(index)
if err != nil {
panic(err)
}
return elem
}
// IndexErr searches for and retrieves the element at the given index.
func (d Document) IndexErr(index uint) (Element, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var current uint
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
if current != index {
current++
continue
}
return elem, nil
}
return nil, ErrOutOfBounds
}
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (d Document) DebugString() string {
if len(d) < 5 {
return "<malformed>"
}
var buf bytes.Buffer
buf.WriteString("Document")
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString("){")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
fmt.Fprintf(&buf, "%s ", elem.DebugString())
}
buf.WriteByte('}')
return buf.String()
}
// String outputs an ExtendedJSON version of Document. If the document is not valid, this method
// returns an empty string.
func (d Document) String() string {
if len(d) < 5 {
return ""
}
var buf bytes.Buffer
buf.WriteByte('{')
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
length -= 4
var elem Element
var ok bool
first := true
for length > 1 {
if !first {
buf.WriteByte(',')
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return ""
}
fmt.Fprintf(&buf, "%s", elem.String())
first = false
}
buf.WriteByte('}')
return buf.String()
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (d Document) Elements() ([]Element, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
var elems []Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return elems, NewInsufficientBytesError(d, rem)
}
if err := elem.Validate(); err != nil {
return elems, err
}
elems = append(elems, elem)
}
return elems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (d Document) Values() ([]Value, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
var vals []Value
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return vals, NewInsufficientBytesError(d, rem)
}
if err := elem.Value().Validate(); err != nil {
return vals, err
}
vals = append(vals, elem.Value())
}
return vals, nil
}
// Validate validates the document and ensures the elements contained within are valid.
func (d Document) Validate() error {
length, rem, ok := ReadLength(d)
if !ok {
return NewInsufficientBytesError(d, rem)
}
if int(length) > len(d) {
return d.lengtherror(int(length), len(d))
}
if d[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(d, rem)
}
err := elem.Validate()
if err != nil {
return err
}
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}
func (Document) lengtherror(length, rem int) error {
return DocumentValidationError(fmt.Sprintf("document length exceeds available bytes. length=%d remainingBytes=%d", length, rem))
}

View File

@@ -0,0 +1,167 @@
package bsoncore
import (
"errors"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// DocumentSequenceStyle is used to represent how a document sequence is laid out in a slice of
// bytes.
type DocumentSequenceStyle uint32
// These constants are the valid styles for a DocumentSequence.
const (
_ DocumentSequenceStyle = iota
SequenceStyle
ArrayStyle
)
// DocumentSequence represents a sequence of documents. The Style field indicates how the documents
// are laid out inside of the Data field.
type DocumentSequence struct {
Style DocumentSequenceStyle
Data []byte
Pos int
}
// ErrCorruptedDocument is returned when a full document couldn't be read from the sequence.
var ErrCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document")
// ErrNonDocument is returned when a DocumentSequence contains a non-document BSON value.
var ErrNonDocument = errors.New("invalid DocumentSequence: a non-document value was found in sequence")
// ErrInvalidDocumentSequenceStyle is returned when an unknown DocumentSequenceStyle is set on a
// DocumentSequence.
var ErrInvalidDocumentSequenceStyle = errors.New("invalid DocumentSequenceStyle")
// DocumentCount returns the number of documents in the sequence.
func (ds *DocumentSequence) DocumentCount() int {
if ds == nil {
return 0
}
switch ds.Style {
case SequenceStyle:
var count int
var ok bool
rem := ds.Data
for len(rem) > 0 {
_, rem, ok = ReadDocument(rem)
if !ok {
return 0
}
count++
}
return count
case ArrayStyle:
_, rem, ok := ReadLength(ds.Data)
if !ok {
return 0
}
var count int
for len(rem) > 1 {
_, rem, ok = ReadElement(rem)
if !ok {
return 0
}
count++
}
return count
default:
return 0
}
}
//ResetIterator resets the iteration point for the Next method to the beginning of the document
//sequence.
func (ds *DocumentSequence) ResetIterator() {
if ds == nil {
return
}
ds.Pos = 0
}
// Documents returns a slice of the documents. If nil either the Data field is also nil or could not
// be properly read.
func (ds *DocumentSequence) Documents() ([]Document, error) {
if ds == nil {
return nil, nil
}
switch ds.Style {
case SequenceStyle:
rem := ds.Data
var docs []Document
var doc Document
var ok bool
for {
doc, rem, ok = ReadDocument(rem)
if !ok {
if len(rem) == 0 {
break
}
return nil, ErrCorruptedDocument
}
docs = append(docs, doc)
}
return docs, nil
case ArrayStyle:
if len(ds.Data) == 0 {
return nil, nil
}
vals, err := Document(ds.Data).Values()
if err != nil {
return nil, ErrCorruptedDocument
}
docs := make([]Document, 0, len(vals))
for _, v := range vals {
if v.Type != bsontype.EmbeddedDocument {
return nil, ErrNonDocument
}
docs = append(docs, v.Data)
}
return docs, nil
default:
return nil, ErrInvalidDocumentSequenceStyle
}
}
// Next retrieves the next document from this sequence and returns it. This method will return
// io.EOF when it has reached the end of the sequence.
func (ds *DocumentSequence) Next() (Document, error) {
if ds == nil || ds.Pos >= len(ds.Data) {
return nil, io.EOF
}
switch ds.Style {
case SequenceStyle:
doc, _, ok := ReadDocument(ds.Data[ds.Pos:])
if !ok {
return nil, ErrCorruptedDocument
}
ds.Pos += len(doc)
return doc, nil
case ArrayStyle:
if ds.Pos < 4 {
if len(ds.Data) < 4 {
return nil, ErrCorruptedDocument
}
ds.Pos = 4 // Skip the length of the document
}
if len(ds.Data[ds.Pos:]) == 1 && ds.Data[ds.Pos] == 0x00 {
return nil, io.EOF // At the end of the document
}
elem, _, ok := ReadElement(ds.Data[ds.Pos:])
if !ok {
return nil, ErrCorruptedDocument
}
ds.Pos += len(elem)
val := elem.Value()
if val.Type != bsontype.EmbeddedDocument {
return nil, ErrNonDocument
}
return val.Data, nil
default:
return nil, ErrInvalidDocumentSequenceStyle
}
}

View File

@@ -0,0 +1,152 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// MalformedElementError represents a class of errors that RawElement methods return.
type MalformedElementError string
func (mee MalformedElementError) Error() string { return string(mee) }
// ErrElementMissingKey is returned when a RawElement is missing a key.
const ErrElementMissingKey MalformedElementError = "element is missing key"
// ErrElementMissingType is returned when a RawElement is missing a type.
const ErrElementMissingType MalformedElementError = "element is missing type"
// Element is a raw bytes representation of a BSON element.
type Element []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (e Element) Key() string {
key, _ := e.KeyErr()
return key
}
// KeyBytes returns the key for this element as a []byte. If the element is not valid, this method
// returns an empty string. If knowing if the element is valid is important, use KeyErr. This method
// will not include the null byte at the end of the key in the slice of bytes.
func (e Element) KeyBytes() []byte {
key, _ := e.KeyBytesErr()
return key
}
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (e Element) KeyErr() (string, error) {
key, err := e.KeyBytesErr()
return string(key), err
}
// KeyBytesErr returns the key for this element as a []byte, returning an error if the element is
// not valid.
func (e Element) KeyBytesErr() ([]byte, error) {
if len(e) <= 0 {
return nil, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return nil, ErrElementMissingKey
}
return e[1 : idx+1], nil
}
// Validate ensures the element is a valid BSON element.
func (e Element) Validate() error {
if len(e) < 1 {
return ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ErrElementMissingKey
}
return Value{Type: bsontype.Type(e[0]), Data: e[idx+2:]}.Validate()
}
// CompareKey will compare this element's key to key. This method makes it easy to compare keys
// without needing to allocate a string. The key may be null terminated. If a valid key cannot be
// read this method will return false.
func (e Element) CompareKey(key []byte) bool {
if len(e) < 2 {
return false
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return false
}
if index := bytes.IndexByte(key, 0x00); index > -1 {
key = key[:index]
}
return bytes.Equal(e[1:idx+1], key)
}
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (e Element) Value() Value {
val, _ := e.ValueErr()
return val
}
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (e Element) ValueErr() (Value, error) {
if len(e) <= 0 {
return Value{}, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return Value{}, ErrElementMissingKey
}
val, rem, exists := ReadValue(e[idx+2:], bsontype.Type(e[0]))
if !exists {
return Value{}, NewInsufficientBytesError(e, rem)
}
return val, nil
}
// String implements the fmt.String interface. The output will be in extended JSON format.
func (e Element) String() string {
if len(e) <= 0 {
return ""
}
t := bsontype.Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ""
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return ""
}
return fmt.Sprintf(`"%s": %v`, key, val)
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (e Element) DebugString() string {
if len(e) <= 0 {
return "<malformed>"
}
t := bsontype.Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return fmt.Sprintf(`bson.Element{[%s]<malformed>}`, t)
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return fmt.Sprintf(`bson.Element{[%s]"%s": <malformed>}`, t, key)
}
return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, t, key, val)
}

223
vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/tables.go generated vendored Executable file
View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsoncore
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

1015
vendor/go.mongodb.org/mongo-driver/x/bsonx/bsoncore/value.go generated vendored Executable file

File diff suppressed because it is too large Load Diff

166
vendor/go.mongodb.org/mongo-driver/x/bsonx/constructor.go generated vendored Executable file
View File

@@ -0,0 +1,166 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"encoding/binary"
"math"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// IDoc is the interface implemented by Doc and MDoc. It allows either of these types to be provided
// to the Document function to create a Value.
type IDoc interface {
idoc()
}
// Double constructs a BSON double Value.
func Double(f64 float64) Val {
v := Val{t: bsontype.Double}
binary.LittleEndian.PutUint64(v.bootstrap[0:8], math.Float64bits(f64))
return v
}
// String constructs a BSON string Value.
func String(str string) Val { return Val{t: bsontype.String}.writestring(str) }
// Document constructs a Value from the given IDoc. If nil is provided, a BSON Null value will be
// returned.
func Document(doc IDoc) Val {
var v Val
switch tt := doc.(type) {
case Doc:
if tt == nil {
v.t = bsontype.Null
break
}
v.t = bsontype.EmbeddedDocument
v.primitive = tt
case MDoc:
if tt == nil {
v.t = bsontype.Null
break
}
v.t = bsontype.EmbeddedDocument
v.primitive = tt
default:
v.t = bsontype.Null
}
return v
}
// Array constructs a Value from arr. If arr is nil, a BSON Null value is returned.
func Array(arr Arr) Val {
if arr == nil {
return Val{t: bsontype.Null}
}
return Val{t: bsontype.Array, primitive: arr}
}
// Binary constructs a BSON binary Value.
func Binary(subtype byte, data []byte) Val {
return Val{t: bsontype.Binary, primitive: primitive.Binary{Subtype: subtype, Data: data}}
}
// Undefined constructs a BSON binary Value.
func Undefined() Val { return Val{t: bsontype.Undefined} }
// ObjectID constructs a BSON objectid Value.
func ObjectID(oid primitive.ObjectID) Val {
v := Val{t: bsontype.ObjectID}
copy(v.bootstrap[0:12], oid[:])
return v
}
// Boolean constructs a BSON boolean Value.
func Boolean(b bool) Val {
v := Val{t: bsontype.Boolean}
if b {
v.bootstrap[0] = 0x01
}
return v
}
// DateTime constructs a BSON datetime Value.
func DateTime(dt int64) Val { return Val{t: bsontype.DateTime}.writei64(dt) }
// Time constructs a BSON datetime Value.
func Time(t time.Time) Val {
return Val{t: bsontype.DateTime}.writei64(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
}
// Null constructs a BSON binary Value.
func Null() Val { return Val{t: bsontype.Null} }
// Regex constructs a BSON regex Value.
func Regex(pattern, options string) Val {
regex := primitive.Regex{Pattern: pattern, Options: options}
return Val{t: bsontype.Regex, primitive: regex}
}
// DBPointer constructs a BSON dbpointer Value.
func DBPointer(ns string, ptr primitive.ObjectID) Val {
dbptr := primitive.DBPointer{DB: ns, Pointer: ptr}
return Val{t: bsontype.DBPointer, primitive: dbptr}
}
// JavaScript constructs a BSON javascript Value.
func JavaScript(js string) Val {
return Val{t: bsontype.JavaScript}.writestring(js)
}
// Symbol constructs a BSON symbol Value.
func Symbol(symbol string) Val {
return Val{t: bsontype.Symbol}.writestring(symbol)
}
// CodeWithScope constructs a BSON code with scope Value.
func CodeWithScope(code string, scope IDoc) Val {
cws := primitive.CodeWithScope{Code: primitive.JavaScript(code), Scope: scope}
return Val{t: bsontype.CodeWithScope, primitive: cws}
}
// Int32 constructs a BSON int32 Value.
func Int32(i32 int32) Val {
v := Val{t: bsontype.Int32}
v.bootstrap[0] = byte(i32)
v.bootstrap[1] = byte(i32 >> 8)
v.bootstrap[2] = byte(i32 >> 16)
v.bootstrap[3] = byte(i32 >> 24)
return v
}
// Timestamp constructs a BSON timestamp Value.
func Timestamp(t, i uint32) Val {
v := Val{t: bsontype.Timestamp}
v.bootstrap[0] = byte(i)
v.bootstrap[1] = byte(i >> 8)
v.bootstrap[2] = byte(i >> 16)
v.bootstrap[3] = byte(i >> 24)
v.bootstrap[4] = byte(t)
v.bootstrap[5] = byte(t >> 8)
v.bootstrap[6] = byte(t >> 16)
v.bootstrap[7] = byte(t >> 24)
return v
}
// Int64 constructs a BSON int64 Value.
func Int64(i64 int64) Val { return Val{t: bsontype.Int64}.writei64(i64) }
// Decimal128 constructs a BSON decimal128 Value.
func Decimal128(d128 primitive.Decimal128) Val {
return Val{t: bsontype.Decimal128, primitive: d128}
}
// MinKey constructs a BSON minkey Value.
func MinKey() Val { return Val{t: bsontype.MinKey} }
// MaxKey constructs a BSON maxkey Value.
func MaxKey() Val { return Val{t: bsontype.MaxKey} }

305
vendor/go.mongodb.org/mongo-driver/x/bsonx/document.go generated vendored Executable file
View File

@@ -0,0 +1,305 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilDocument indicates that an operation was attempted on a nil *bson.Document.
var ErrNilDocument = errors.New("document is nil")
// KeyNotFound is an error type returned from the Lookup methods on Document. This type contains
// information about which key was not found and if it was actually not found or if a component of
// the key except the last was not a document nor array.
type KeyNotFound struct {
Key []string // The keys that were searched for.
Depth uint // Which key either was not found or was an incorrect type.
Type bsontype.Type // The type of the key that was found but was an incorrect type.
}
func (knf KeyNotFound) Error() string {
depth := knf.Depth
if depth >= uint(len(knf.Key)) {
depth = uint(len(knf.Key)) - 1
}
if len(knf.Key) == 0 {
return "no keys were provided for lookup"
}
if knf.Type != bsontype.Type(0) {
return fmt.Sprintf(`key "%s" was found but was not valid to traverse BSON type %s`, knf.Key[depth], knf.Type)
}
return fmt.Sprintf(`key "%s" was not found`, knf.Key[depth])
}
// Doc is a type safe, concise BSON document representation.
type Doc []Elem
// ReadDoc will create a Document using the provided slice of bytes. If the
// slice of bytes is not a valid BSON document, this method will return an error.
func ReadDoc(b []byte) (Doc, error) {
doc := make(Doc, 0)
err := doc.UnmarshalBSON(b)
if err != nil {
return nil, err
}
return doc, nil
}
// Copy makes a shallow copy of this document.
func (d Doc) Copy() Doc {
d2 := make(Doc, len(d))
copy(d2, d)
return d2
}
// Append adds an element to the end of the document, creating it from the key and value provided.
func (d Doc) Append(key string, val Val) Doc {
return append(d, Elem{Key: key, Value: val})
}
// Prepend adds an element to the beginning of the document, creating it from the key and value provided.
func (d Doc) Prepend(key string, val Val) Doc {
// TODO: should we just modify d itself instead of doing an alloc here?
return append(Doc{{Key: key, Value: val}}, d...)
}
// Set replaces an element of a document. If an element with a matching key is
// found, the element will be replaced with the one provided. If the document
// does not have an element with that key, the element is appended to the
// document instead.
func (d Doc) Set(key string, val Val) Doc {
idx := d.IndexOf(key)
if idx == -1 {
return append(d, Elem{Key: key, Value: val})
}
d[idx] = Elem{Key: key, Value: val}
return d
}
// IndexOf returns the index of the first element with a key of key, or -1 if no element with a key
// was found.
func (d Doc) IndexOf(key string) int {
for i, e := range d {
if e.Key == key {
return i
}
}
return -1
}
// Delete removes the element with key if it exists and returns the updated Doc.
func (d Doc) Delete(key string) Doc {
idx := d.IndexOf(key)
if idx == -1 {
return d
}
return append(d[:idx], d[idx+1:]...)
}
// Lookup searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Value if they key does not exist. To know if they key actually
// exists, use LookupErr.
func (d Doc) Lookup(key ...string) Val {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (d Doc) LookupErr(key ...string) (Val, error) {
elem, err := d.LookupElementErr(key...)
return elem.Value, err
}
// LookupElement searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Element if they key does not exist. To know if they key actually
// exists, use LookupElementErr.
func (d Doc) LookupElement(key ...string) Elem {
elem, _ := d.LookupElementErr(key...)
return elem
}
// LookupElementErr searches the document and potentially subdocuments for the
// provided key. Each key provided to this method represents a layer of depth.
func (d Doc) LookupElementErr(key ...string) (Elem, error) {
// KeyNotFound operates by being created where the error happens and then the depth is
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
// the proper keys.
if len(key) == 0 {
return Elem{}, KeyNotFound{Key: key}
}
var elem Elem
var err error
idx := d.IndexOf(key[0])
if idx == -1 {
return Elem{}, KeyNotFound{Key: key}
}
elem = d[idx]
if len(key) == 1 {
return elem, nil
}
switch elem.Value.Type() {
case bsontype.EmbeddedDocument:
switch tt := elem.Value.primitive.(type) {
case Doc:
elem, err = tt.LookupElementErr(key[1:]...)
case MDoc:
elem, err = tt.LookupElementErr(key[1:]...)
}
default:
return Elem{}, KeyNotFound{Type: elem.Value.Type()}
}
switch tt := err.(type) {
case KeyNotFound:
tt.Depth++
tt.Key = key
return Elem{}, tt
case nil:
return elem, nil
default:
return Elem{}, err // We can't actually hit this.
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
//
// This method will never return an error.
func (d Doc) MarshalBSONValue() (bsontype.Type, []byte, error) {
if d == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
data, _ := d.MarshalBSON()
return bsontype.EmbeddedDocument, data, nil
}
// MarshalBSON implements the Marshaler interface.
//
// This method will never return an error.
func (d Doc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
//
// This method will never return an error.
func (d Doc) AppendMarshalBSON(dst []byte) ([]byte, error) {
idx, dst := bsoncore.ReserveLength(dst)
for _, elem := range d {
t, data, _ := elem.Value.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, elem.Key...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst, nil
}
// UnmarshalBSON implements the Unmarshaler interface.
func (d *Doc) UnmarshalBSON(b []byte) error {
if d == nil {
return ErrNilDocument
}
if err := bsoncore.Document(b).Validate(); err != nil {
return err
}
elems, err := bsoncore.Document(b).Elements()
if err != nil {
return err
}
var val Val
for _, elem := range elems {
rawv := elem.Value()
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
if err != nil {
return err
}
*d = d.Append(elem.Key(), val)
}
return nil
}
// UnmarshalBSONValue implements the bson.ValueUnmarshaler interface.
func (d *Doc) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if t != bsontype.EmbeddedDocument {
return fmt.Errorf("cannot unmarshal %s into a bsonx.Doc", t)
}
return d.UnmarshalBSON(data)
}
// Equal compares this document to another, returning true if they are equal.
func (d Doc) Equal(id IDoc) bool {
switch tt := id.(type) {
case Doc:
d2 := tt
if len(d) != len(d2) {
return false
}
for idx := range d {
if !d[idx].Equal(d2[idx]) {
return false
}
}
case MDoc:
unique := make(map[string]struct{}, 0)
for _, elem := range d {
unique[elem.Key] = struct{}{}
val, ok := tt[elem.Key]
if !ok {
return false
}
if !val.Equal(elem.Value) {
return false
}
}
if len(unique) != len(tt) {
return false
}
case nil:
return d == nil
default:
return false
}
return true
}
// String implements the fmt.Stringer interface.
func (d Doc) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Document{"))
for idx, elem := range d {
if idx > 0 {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%v", elem)
}
buf.WriteByte('}')
return buf.String()
}
func (Doc) idoc() {}

53
vendor/go.mongodb.org/mongo-driver/x/bsonx/element.go generated vendored Executable file
View File

@@ -0,0 +1,53 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
const validateMaxDepthDefault = 2048
// ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value.
//
// TODO: rename this ValueTypeError.
type ElementTypeError struct {
Method string
Type bsontype.Type
}
// Error implements the error interface.
func (ete ElementTypeError) Error() string {
return "Call of " + ete.Method + " on " + ete.Type.String() + " type"
}
// Elem represents a BSON element.
//
// NOTE: Element cannot be the value of a map nor a property of a struct without special handling.
// The default encoders and decoders will not process Element correctly. To do so would require
// information loss since an Element contains a key, but the keys used when encoding a struct are
// the struct field names. Instead of using an Element, use a Value as a value in a map or a
// property of a struct.
type Elem struct {
Key string
Value Val
}
// Equal compares e and e2 and returns true if they are equal.
func (e Elem) Equal(e2 Elem) bool {
if e.Key != e2.Key {
return false
}
return e.Value.Equal(e2.Value)
}
func (e Elem) String() string {
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
return fmt.Sprintf(`bson.Element{"%s": %v}`, e.Key, e.Value)
}

231
vendor/go.mongodb.org/mongo-driver/x/bsonx/mdocument.go generated vendored Executable file
View File

@@ -0,0 +1,231 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// MDoc is an unordered, type safe, concise BSON document representation. This type should not be
// used if you require ordering of values or duplicate keys.
type MDoc map[string]Val
// ReadMDoc will create a Doc using the provided slice of bytes. If the
// slice of bytes is not a valid BSON document, this method will return an error.
func ReadMDoc(b []byte) (MDoc, error) {
doc := make(MDoc, 0)
err := doc.UnmarshalBSON(b)
if err != nil {
return nil, err
}
return doc, nil
}
// Copy makes a shallow copy of this document.
func (d MDoc) Copy() MDoc {
d2 := make(MDoc, len(d))
for k, v := range d {
d2[k] = v
}
return d2
}
// Lookup searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Value if they key does not exist. To know if they key actually
// exists, use LookupErr.
func (d MDoc) Lookup(key ...string) Val {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (d MDoc) LookupErr(key ...string) (Val, error) {
elem, err := d.LookupElementErr(key...)
return elem.Value, err
}
// LookupElement searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Element if they key does not exist. To know if they key actually
// exists, use LookupElementErr.
func (d MDoc) LookupElement(key ...string) Elem {
elem, _ := d.LookupElementErr(key...)
return elem
}
// LookupElementErr searches the document and potentially subdocuments for the
// provided key. Each key provided to this method represents a layer of depth.
func (d MDoc) LookupElementErr(key ...string) (Elem, error) {
// KeyNotFound operates by being created where the error happens and then the depth is
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
// the proper keys.
if len(key) == 0 {
return Elem{}, KeyNotFound{Key: key}
}
var elem Elem
var err error
val, ok := d[key[0]]
if !ok {
return Elem{}, KeyNotFound{Key: key}
}
if len(key) == 1 {
return Elem{Key: key[0], Value: val}, nil
}
switch val.Type() {
case bsontype.EmbeddedDocument:
switch tt := val.primitive.(type) {
case Doc:
elem, err = tt.LookupElementErr(key[1:]...)
case MDoc:
elem, err = tt.LookupElementErr(key[1:]...)
}
default:
return Elem{}, KeyNotFound{Type: val.Type()}
}
switch tt := err.(type) {
case KeyNotFound:
tt.Depth++
tt.Key = key
return Elem{}, tt
case nil:
return elem, nil
default:
return Elem{}, err // We can't actually hit this.
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
//
// This method will never return an error.
func (d MDoc) MarshalBSONValue() (bsontype.Type, []byte, error) {
if d == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
data, _ := d.MarshalBSON()
return bsontype.EmbeddedDocument, data, nil
}
// MarshalBSON implements the Marshaler interface.
//
// This method will never return an error.
func (d MDoc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
//
// This method will never return an error.
func (d MDoc) AppendMarshalBSON(dst []byte) ([]byte, error) {
idx, dst := bsoncore.ReserveLength(dst)
for k, v := range d {
t, data, _ := v.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, k...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst, nil
}
// UnmarshalBSON implements the Unmarshaler interface.
func (d *MDoc) UnmarshalBSON(b []byte) error {
if d == nil {
return ErrNilDocument
}
if err := bsoncore.Document(b).Validate(); err != nil {
return err
}
elems, err := bsoncore.Document(b).Elements()
if err != nil {
return err
}
var val Val
for _, elem := range elems {
rawv := elem.Value()
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
if err != nil {
return err
}
(*d)[elem.Key()] = val
}
return nil
}
// Equal compares this document to another, returning true if they are equal.
func (d MDoc) Equal(id IDoc) bool {
switch tt := id.(type) {
case MDoc:
d2 := tt
if len(d) != len(d2) {
return false
}
for key, value := range d {
value2, ok := d2[key]
if !ok {
return false
}
if !value.Equal(value2) {
return false
}
}
case Doc:
unique := make(map[string]struct{}, 0)
for _, elem := range tt {
unique[elem.Key] = struct{}{}
val, ok := d[elem.Key]
if !ok {
return false
}
if !val.Equal(elem.Value) {
return false
}
}
if len(unique) != len(d) {
return false
}
case nil:
return d == nil
default:
return false
}
return true
}
// String implements the fmt.Stringer interface.
func (d MDoc) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Document{"))
first := true
for key, value := range d {
if !first {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%v", Elem{Key: key, Value: value})
first = false
}
buf.WriteByte('}')
return buf.String()
}
func (MDoc) idoc() {}

View File

@@ -0,0 +1,638 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"errors"
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var primitiveCodecs PrimitiveCodecs
var tDocument = reflect.TypeOf((Doc)(nil))
var tMDoc = reflect.TypeOf((MDoc)(nil))
var tArray = reflect.TypeOf((Arr)(nil))
var tValue = reflect.TypeOf(Val{})
var tElementSlice = reflect.TypeOf(([]Elem)(nil))
// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
// defined in this package.
type PrimitiveCodecs struct{}
// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
}
rb.
RegisterEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)).
RegisterEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)).
RegisterEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)).
RegisterEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)).
RegisterDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)).
RegisterDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)).
RegisterDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)).
RegisterDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue))
}
// DocumentEncodeValue is the ValueEncoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDocument {
return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
doc := val.Interface().(Doc)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return pc.encodeDocument(ec, dw, doc)
}
// DocumentDecodeValue is the ValueDecoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tDocument {
return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
}
return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc))
}
func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error {
dr, err := vr.ReadDocument()
if err != nil {
return err
}
return pc.decodeDocument(dctx, dr, doc)
}
// ArrayEncodeValue is the ValueEncoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tArray {
return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
arr := val.Interface().(Arr)
aw, err := vw.WriteArray()
if err != nil {
return err
}
for _, val := range arr {
dvw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = pc.encodeValue(ec, dvw, val)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// ArrayDecodeValue is the ValueDecoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tArray {
return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
}
ar, err := vr.ReadArray()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(tArray, 0, 0))
}
val.SetLen(0)
for {
vr, err := ar.ReadValue()
if err == bsonrw.ErrEOA {
break
}
if err != nil {
return err
}
var elem Val
err = pc.valueDecodeValue(dc, vr, &elem)
if err != nil {
return err
}
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
}
// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tElementSlice {
return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument))
}
// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tElementSlice {
return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
dr, err := vr.ReadDocument()
if err != nil {
return err
}
elems := make([]reflect.Value, 0)
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
var elem Elem
err = pc.elementDecodeValue(dc, vr, key, &elem)
if err != nil {
return err
}
elems = append(elems, reflect.ValueOf(elem))
}
val.Set(reflect.Append(val, elems...))
return nil
}
// ValueEncodeValue is the ValueEncoderFunc for *Value.
func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tValue {
return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val}
}
v := val.Interface().(Val)
return pc.encodeValue(ec, vw, v)
}
// ValueDecodeValue is the ValueDecoderFunc for *Value.
func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tValue {
return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val}
}
return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val))
}
// encodeDocument is a separate function that we use because CodeWithScope
// returns us a DocumentWriter and we need to do the same logic that we would do
// for a document but cannot use a Codec.
func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error {
for _, elem := range doc {
dvw, err := dw.WriteDocumentElement(elem.Key)
if err != nil {
return err
}
err = pc.encodeValue(ec, dvw, elem.Value)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader.
func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
return pc.decodeDocument(dctx, dr, pdoc)
}
func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
if *pdoc == nil {
*pdoc = make(Doc, 0)
}
*pdoc = (*pdoc)[:0]
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
var elem Elem
err = pc.elementDecodeValue(dctx, vr, key, &elem)
if err != nil {
return err
}
*pdoc = append(*pdoc, elem)
}
return nil
}
func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error {
var val Val
switch vr.Type() {
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
val = Double(f64)
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return err
}
val = String(str)
case bsontype.EmbeddedDocument:
var embeddedDoc Doc
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
if err != nil {
return err
}
val = Document(embeddedDoc)
case bsontype.Array:
arr := reflect.New(tArray).Elem()
err := pc.ArrayDecodeValue(dc, vr, arr)
if err != nil {
return err
}
val = Array(arr.Interface().(Arr))
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
val = Binary(subtype, data)
case bsontype.Undefined:
err := vr.ReadUndefined()
if err != nil {
return err
}
val = Undefined()
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
val = ObjectID(oid)
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
val = Boolean(b)
case bsontype.DateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return err
}
val = DateTime(dt)
case bsontype.Null:
err := vr.ReadNull()
if err != nil {
return err
}
val = Null()
case bsontype.Regex:
pattern, options, err := vr.ReadRegex()
if err != nil {
return err
}
val = Regex(pattern, options)
case bsontype.DBPointer:
ns, pointer, err := vr.ReadDBPointer()
if err != nil {
return err
}
val = DBPointer(ns, pointer)
case bsontype.JavaScript:
js, err := vr.ReadJavascript()
if err != nil {
return err
}
val = JavaScript(js)
case bsontype.Symbol:
symbol, err := vr.ReadSymbol()
if err != nil {
return err
}
val = Symbol(symbol)
case bsontype.CodeWithScope:
code, scope, err := vr.ReadCodeWithScope()
if err != nil {
return err
}
var doc Doc
err = pc.decodeDocument(dc, scope, &doc)
if err != nil {
return err
}
val = CodeWithScope(code, doc)
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
val = Int32(i32)
case bsontype.Timestamp:
t, i, err := vr.ReadTimestamp()
if err != nil {
return err
}
val = Timestamp(t, i)
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
val = Int64(i64)
case bsontype.Decimal128:
d128, err := vr.ReadDecimal128()
if err != nil {
return err
}
val = Decimal128(d128)
case bsontype.MinKey:
err := vr.ReadMinKey()
if err != nil {
return err
}
val = MinKey()
case bsontype.MaxKey:
err := vr.ReadMaxKey()
if err != nil {
return err
}
val = MaxKey()
default:
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
}
*elem = Elem{Key: key, Value: val}
return nil
}
// encodeValue does not validation, and the callers must perform validation on val before calling
// this method.
func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error {
var err error
switch val.Type() {
case bsontype.Double:
err = vw.WriteDouble(val.Double())
case bsontype.String:
err = vw.WriteString(val.StringValue())
case bsontype.EmbeddedDocument:
var encoder bsoncodec.ValueEncoder
encoder, err = ec.LookupEncoder(tDocument)
if err != nil {
break
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document()))
case bsontype.Array:
var encoder bsoncodec.ValueEncoder
encoder, err = ec.LookupEncoder(tArray)
if err != nil {
break
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array()))
case bsontype.Binary:
// TODO: FIX THIS (╯°□°)╯︵ ┻━┻
subtype, data := val.Binary()
err = vw.WriteBinaryWithSubtype(data, subtype)
case bsontype.Undefined:
err = vw.WriteUndefined()
case bsontype.ObjectID:
err = vw.WriteObjectID(val.ObjectID())
case bsontype.Boolean:
err = vw.WriteBoolean(val.Boolean())
case bsontype.DateTime:
err = vw.WriteDateTime(val.DateTime())
case bsontype.Null:
err = vw.WriteNull()
case bsontype.Regex:
err = vw.WriteRegex(val.Regex())
case bsontype.DBPointer:
err = vw.WriteDBPointer(val.DBPointer())
case bsontype.JavaScript:
err = vw.WriteJavascript(val.JavaScript())
case bsontype.Symbol:
err = vw.WriteSymbol(val.Symbol())
case bsontype.CodeWithScope:
code, scope := val.CodeWithScope()
var cwsw bsonrw.DocumentWriter
cwsw, err = vw.WriteCodeWithScope(code)
if err != nil {
break
}
err = pc.encodeDocument(ec, cwsw, scope)
case bsontype.Int32:
err = vw.WriteInt32(val.Int32())
case bsontype.Timestamp:
err = vw.WriteTimestamp(val.Timestamp())
case bsontype.Int64:
err = vw.WriteInt64(val.Int64())
case bsontype.Decimal128:
err = vw.WriteDecimal128(val.Decimal128())
case bsontype.MinKey:
err = vw.WriteMinKey()
case bsontype.MaxKey:
err = vw.WriteMaxKey()
default:
err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type())
}
return err
}
func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error {
switch vr.Type() {
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
*val = Double(f64)
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return err
}
*val = String(str)
case bsontype.EmbeddedDocument:
var embeddedDoc Doc
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
if err != nil {
return err
}
*val = Document(embeddedDoc)
case bsontype.Array:
arr := reflect.New(tArray).Elem()
err := pc.ArrayDecodeValue(dc, vr, arr)
if err != nil {
return err
}
*val = Array(arr.Interface().(Arr))
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
*val = Binary(subtype, data)
case bsontype.Undefined:
err := vr.ReadUndefined()
if err != nil {
return err
}
*val = Undefined()
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
*val = ObjectID(oid)
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
*val = Boolean(b)
case bsontype.DateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return err
}
*val = DateTime(dt)
case bsontype.Null:
err := vr.ReadNull()
if err != nil {
return err
}
*val = Null()
case bsontype.Regex:
pattern, options, err := vr.ReadRegex()
if err != nil {
return err
}
*val = Regex(pattern, options)
case bsontype.DBPointer:
ns, pointer, err := vr.ReadDBPointer()
if err != nil {
return err
}
*val = DBPointer(ns, pointer)
case bsontype.JavaScript:
js, err := vr.ReadJavascript()
if err != nil {
return err
}
*val = JavaScript(js)
case bsontype.Symbol:
symbol, err := vr.ReadSymbol()
if err != nil {
return err
}
*val = Symbol(symbol)
case bsontype.CodeWithScope:
code, scope, err := vr.ReadCodeWithScope()
if err != nil {
return err
}
var scopeDoc Doc
err = pc.decodeDocument(dc, scope, &scopeDoc)
if err != nil {
return err
}
*val = CodeWithScope(code, scopeDoc)
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
*val = Int32(i32)
case bsontype.Timestamp:
t, i, err := vr.ReadTimestamp()
if err != nil {
return err
}
*val = Timestamp(t, i)
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
*val = Int64(i64)
case bsontype.Decimal128:
d128, err := vr.ReadDecimal128()
if err != nil {
return err
}
*val = Decimal128(d128)
case bsontype.MinKey:
err := vr.ReadMinKey()
if err != nil {
return err
}
*val = MinKey()
case bsontype.MaxKey:
err := vr.ReadMaxKey()
if err != nil {
return err
}
*val = MaxKey()
default:
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
}
return nil
}

22
vendor/go.mongodb.org/mongo-driver/x/bsonx/registry.go generated vendored Executable file
View File

@@ -0,0 +1,22 @@
package bsonx
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
)
// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the
// primitive codecs.
var DefaultRegistry = NewRegistryBuilder().Build()
// NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and
// deocders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the
// PrimitiveCodecs type in this package.
func NewRegistryBuilder() *bsoncodec.RegistryBuilder {
rb := bsoncodec.NewRegistryBuilder()
bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)
primitiveCodecs.RegisterPrimitiveCodecs(rb)
return rb
}

899
vendor/go.mongodb.org/mongo-driver/x/bsonx/value.go generated vendored Executable file
View File

@@ -0,0 +1,899 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// Val represents a BSON value.
type Val struct {
// NOTE: The bootstrap is a small amount of space that'll be on the stack. At 15 bytes this
// doesn't make this type any larger, since there are 7 bytes of padding and we want an int64 to
// store small values (e.g. boolean, double, int64, etc...). The primitive property is where all
// of the larger values go. They will use either Go primitives or the primitive.* types.
t bsontype.Type
bootstrap [15]byte
primitive interface{}
}
func (v Val) reset() Val {
v.primitive = nil // clear out any pointers so we don't accidentally stop them from being garbage collected.
v.t = bsontype.Type(0)
v.bootstrap[0] = 0x00
v.bootstrap[1] = 0x00
v.bootstrap[2] = 0x00
v.bootstrap[3] = 0x00
v.bootstrap[4] = 0x00
v.bootstrap[5] = 0x00
v.bootstrap[6] = 0x00
v.bootstrap[7] = 0x00
v.bootstrap[8] = 0x00
v.bootstrap[9] = 0x00
v.bootstrap[10] = 0x00
v.bootstrap[11] = 0x00
v.bootstrap[12] = 0x00
v.bootstrap[13] = 0x00
v.bootstrap[14] = 0x00
return v
}
func (v Val) string() string {
if v.primitive != nil {
return v.primitive.(string)
}
// The string will either end with a null byte or it fills the entire bootstrap space.
length := uint8(v.bootstrap[0])
return string(v.bootstrap[1 : length+1])
}
func (v Val) writestring(str string) Val {
switch {
case len(str) < 15:
v.bootstrap[0] = uint8(len(str))
copy(v.bootstrap[1:], str)
default:
v.primitive = str
}
return v
}
func (v Val) i64() int64 {
return int64(v.bootstrap[0]) | int64(v.bootstrap[1])<<8 | int64(v.bootstrap[2])<<16 |
int64(v.bootstrap[3])<<24 | int64(v.bootstrap[4])<<32 | int64(v.bootstrap[5])<<40 |
int64(v.bootstrap[6])<<48 | int64(v.bootstrap[7])<<56
}
func (v Val) writei64(i64 int64) Val {
v.bootstrap[0] = byte(i64)
v.bootstrap[1] = byte(i64 >> 8)
v.bootstrap[2] = byte(i64 >> 16)
v.bootstrap[3] = byte(i64 >> 24)
v.bootstrap[4] = byte(i64 >> 32)
v.bootstrap[5] = byte(i64 >> 40)
v.bootstrap[6] = byte(i64 >> 48)
v.bootstrap[7] = byte(i64 >> 56)
return v
}
// IsZero returns true if this value is zero or a BSON null.
func (v Val) IsZero() bool { return v.t == bsontype.Type(0) || v.t == bsontype.Null }
func (v Val) String() string {
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
return fmt.Sprintf("%v", v.Interface())
}
// Interface returns the Go value of this Value as an empty interface.
//
// This method will return nil if it is empty, otherwise it will return a Go primitive or a
// primitive.* instance.
func (v Val) Interface() interface{} {
switch v.Type() {
case bsontype.Double:
return v.Double()
case bsontype.String:
return v.StringValue()
case bsontype.EmbeddedDocument:
switch v.primitive.(type) {
case Doc:
return v.primitive.(Doc)
case MDoc:
return v.primitive.(MDoc)
default:
return primitive.Null{}
}
case bsontype.Array:
return v.Array()
case bsontype.Binary:
return v.primitive.(primitive.Binary)
case bsontype.Undefined:
return primitive.Undefined{}
case bsontype.ObjectID:
return v.ObjectID()
case bsontype.Boolean:
return v.Boolean()
case bsontype.DateTime:
return v.DateTime()
case bsontype.Null:
return primitive.Null{}
case bsontype.Regex:
return v.primitive.(primitive.Regex)
case bsontype.DBPointer:
return v.primitive.(primitive.DBPointer)
case bsontype.JavaScript:
return v.JavaScript()
case bsontype.Symbol:
return v.Symbol()
case bsontype.CodeWithScope:
return v.primitive.(primitive.CodeWithScope)
case bsontype.Int32:
return v.Int32()
case bsontype.Timestamp:
t, i := v.Timestamp()
return primitive.Timestamp{T: t, I: i}
case bsontype.Int64:
return v.Int64()
case bsontype.Decimal128:
return v.Decimal128()
case bsontype.MinKey:
return primitive.MinKey{}
case bsontype.MaxKey:
return primitive.MaxKey{}
default:
return primitive.Null{}
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
func (v Val) MarshalBSONValue() (bsontype.Type, []byte, error) {
return v.MarshalAppendBSONValue(nil)
}
// MarshalAppendBSONValue is similar to MarshalBSONValue, but allows the caller to specify a slice
// to add the bytes to.
func (v Val) MarshalAppendBSONValue(dst []byte) (bsontype.Type, []byte, error) {
t := v.Type()
switch v.Type() {
case bsontype.Double:
dst = bsoncore.AppendDouble(dst, v.Double())
case bsontype.String:
dst = bsoncore.AppendString(dst, v.String())
case bsontype.EmbeddedDocument:
switch v.primitive.(type) {
case Doc:
t, dst, _ = v.primitive.(Doc).MarshalBSONValue() // Doc.MarshalBSONValue never returns an error.
case MDoc:
t, dst, _ = v.primitive.(MDoc).MarshalBSONValue() // MDoc.MarshalBSONValue never returns an error.
}
case bsontype.Array:
t, dst, _ = v.Array().MarshalBSONValue() // Arr.MarshalBSON never returns an error.
case bsontype.Binary:
subtype, bindata := v.Binary()
dst = bsoncore.AppendBinary(dst, subtype, bindata)
case bsontype.Undefined:
case bsontype.ObjectID:
dst = bsoncore.AppendObjectID(dst, v.ObjectID())
case bsontype.Boolean:
dst = bsoncore.AppendBoolean(dst, v.Boolean())
case bsontype.DateTime:
dst = bsoncore.AppendDateTime(dst, int64(v.DateTime()))
case bsontype.Null:
case bsontype.Regex:
pattern, options := v.Regex()
dst = bsoncore.AppendRegex(dst, pattern, options)
case bsontype.DBPointer:
ns, ptr := v.DBPointer()
dst = bsoncore.AppendDBPointer(dst, ns, ptr)
case bsontype.JavaScript:
dst = bsoncore.AppendJavaScript(dst, string(v.JavaScript()))
case bsontype.Symbol:
dst = bsoncore.AppendSymbol(dst, string(v.Symbol()))
case bsontype.CodeWithScope:
code, doc := v.CodeWithScope()
var scope []byte
scope, _ = doc.MarshalBSON() // Doc.MarshalBSON never returns an error.
dst = bsoncore.AppendCodeWithScope(dst, code, scope)
case bsontype.Int32:
dst = bsoncore.AppendInt32(dst, v.Int32())
case bsontype.Timestamp:
t, i := v.Timestamp()
dst = bsoncore.AppendTimestamp(dst, t, i)
case bsontype.Int64:
dst = bsoncore.AppendInt64(dst, v.Int64())
case bsontype.Decimal128:
dst = bsoncore.AppendDecimal128(dst, v.Decimal128())
case bsontype.MinKey:
case bsontype.MaxKey:
default:
panic(fmt.Errorf("invalid BSON type %v", t))
}
return t, dst, nil
}
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
func (v *Val) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if v == nil {
return errors.New("cannot unmarshal into nil Value")
}
var err error
var ok = true
var rem []byte
switch t {
case bsontype.Double:
var f64 float64
f64, rem, ok = bsoncore.ReadDouble(data)
*v = Double(f64)
case bsontype.String:
var str string
str, rem, ok = bsoncore.ReadString(data)
*v = String(str)
case bsontype.EmbeddedDocument:
var raw []byte
var doc Doc
raw, rem, ok = bsoncore.ReadDocument(data)
doc, err = ReadDoc(raw)
*v = Document(doc)
case bsontype.Array:
var raw []byte
arr := make(Arr, 0)
raw, rem, ok = bsoncore.ReadArray(data)
err = arr.UnmarshalBSONValue(t, raw)
*v = Array(arr)
case bsontype.Binary:
var subtype byte
var bindata []byte
subtype, bindata, rem, ok = bsoncore.ReadBinary(data)
*v = Binary(subtype, bindata)
case bsontype.Undefined:
*v = Undefined()
case bsontype.ObjectID:
var oid primitive.ObjectID
oid, rem, ok = bsoncore.ReadObjectID(data)
*v = ObjectID(oid)
case bsontype.Boolean:
var b bool
b, rem, ok = bsoncore.ReadBoolean(data)
*v = Boolean(b)
case bsontype.DateTime:
var dt int64
dt, rem, ok = bsoncore.ReadDateTime(data)
*v = DateTime(dt)
case bsontype.Null:
*v = Null()
case bsontype.Regex:
var pattern, options string
pattern, options, rem, ok = bsoncore.ReadRegex(data)
*v = Regex(pattern, options)
case bsontype.DBPointer:
var ns string
var ptr primitive.ObjectID
ns, ptr, rem, ok = bsoncore.ReadDBPointer(data)
*v = DBPointer(ns, ptr)
case bsontype.JavaScript:
var js string
js, rem, ok = bsoncore.ReadJavaScript(data)
*v = JavaScript(js)
case bsontype.Symbol:
var symbol string
symbol, rem, ok = bsoncore.ReadSymbol(data)
*v = Symbol(symbol)
case bsontype.CodeWithScope:
var raw []byte
var code string
var scope Doc
code, raw, rem, ok = bsoncore.ReadCodeWithScope(data)
scope, err = ReadDoc(raw)
*v = CodeWithScope(code, scope)
case bsontype.Int32:
var i32 int32
i32, rem, ok = bsoncore.ReadInt32(data)
*v = Int32(i32)
case bsontype.Timestamp:
var i, t uint32
t, i, rem, ok = bsoncore.ReadTimestamp(data)
*v = Timestamp(t, i)
case bsontype.Int64:
var i64 int64
i64, rem, ok = bsoncore.ReadInt64(data)
*v = Int64(i64)
case bsontype.Decimal128:
var d128 primitive.Decimal128
d128, rem, ok = bsoncore.ReadDecimal128(data)
*v = Decimal128(d128)
case bsontype.MinKey:
*v = MinKey()
case bsontype.MaxKey:
*v = MaxKey()
default:
err = fmt.Errorf("invalid BSON type %v", t)
}
if !ok && err == nil {
err = bsoncore.NewInsufficientBytesError(data, rem)
}
return err
}
// Type returns the BSON type of this value.
func (v Val) Type() bsontype.Type {
if v.t == bsontype.Type(0) {
return bsontype.Null
}
return v.t
}
// IsNumber returns true if the type of v is a numberic BSON type.
func (v Val) IsNumber() bool {
switch v.Type() {
case bsontype.Double, bsontype.Int32, bsontype.Int64, bsontype.Decimal128:
return true
default:
return false
}
}
// Double returns the BSON double value the Value represents. It panics if the value is a BSON type
// other than double.
func (v Val) Double() float64 {
if v.t != bsontype.Double {
panic(ElementTypeError{"bson.Value.Double", v.t})
}
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8]))
}
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (v Val) DoubleOK() (float64, bool) {
if v.t != bsontype.Double {
return 0, false
}
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8])), true
}
// StringValue returns the BSON string the Value represents. It panics if the value is a BSON type
// other than string.
//
// NOTE: This method is called StringValue to avoid it implementing the
// fmt.Stringer interface.
func (v Val) StringValue() string {
if v.t != bsontype.String {
panic(ElementTypeError{"bson.Value.StringValue", v.t})
}
return v.string()
}
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (v Val) StringValueOK() (string, bool) {
if v.t != bsontype.String {
return "", false
}
return v.string(), true
}
func (v Val) asDoc() Doc {
doc, ok := v.primitive.(Doc)
if ok {
return doc
}
mdoc := v.primitive.(MDoc)
for k, v := range mdoc {
doc = append(doc, Elem{k, v})
}
return doc
}
func (v Val) asMDoc() MDoc {
mdoc, ok := v.primitive.(MDoc)
if ok {
return mdoc
}
doc := v.primitive.(Doc)
for _, elem := range doc {
mdoc[elem.Key] = elem.Value
}
return mdoc
}
// Document returns the BSON embedded document value the Value represents. It panics if the value
// is a BSON type other than embedded document.
func (v Val) Document() Doc {
if v.t != bsontype.EmbeddedDocument {
panic(ElementTypeError{"bson.Value.Document", v.t})
}
return v.asDoc()
}
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (v Val) DocumentOK() (Doc, bool) {
if v.t != bsontype.EmbeddedDocument {
return nil, false
}
return v.asDoc(), true
}
// MDocument returns the BSON embedded document value the Value represents. It panics if the value
// is a BSON type other than embedded document.
func (v Val) MDocument() MDoc {
if v.t != bsontype.EmbeddedDocument {
panic(ElementTypeError{"bson.Value.MDocument", v.t})
}
return v.asMDoc()
}
// MDocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (v Val) MDocumentOK() (MDoc, bool) {
if v.t != bsontype.EmbeddedDocument {
return nil, false
}
return v.asMDoc(), true
}
// Array returns the BSON array value the Value represents. It panics if the value is a BSON type
// other than array.
func (v Val) Array() Arr {
if v.t != bsontype.Array {
panic(ElementTypeError{"bson.Value.Array", v.t})
}
return v.primitive.(Arr)
}
// ArrayOK is the same as Array, except it returns a boolean
// instead of panicking.
func (v Val) ArrayOK() (Arr, bool) {
if v.t != bsontype.Array {
return nil, false
}
return v.primitive.(Arr), true
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Binary() (byte, []byte) {
if v.t != bsontype.Binary {
panic(ElementTypeError{"bson.Value.Binary", v.t})
}
bin := v.primitive.(primitive.Binary)
return bin.Subtype, bin.Data
}
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (v Val) BinaryOK() (byte, []byte, bool) {
if v.t != bsontype.Binary {
return 0x00, nil, false
}
bin := v.primitive.(primitive.Binary)
return bin.Subtype, bin.Data, true
}
// Undefined returns the BSON undefined the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Undefined() {
if v.t != bsontype.Undefined {
panic(ElementTypeError{"bson.Value.Undefined", v.t})
}
return
}
// UndefinedOK is the same as Undefined, except it returns a boolean instead of
// panicking.
func (v Val) UndefinedOK() bool {
if v.t != bsontype.Undefined {
return false
}
return true
}
// ObjectID returns the BSON ObjectID the Value represents. It panics if the value is a BSON type
// other than ObjectID.
func (v Val) ObjectID() primitive.ObjectID {
if v.t != bsontype.ObjectID {
panic(ElementTypeError{"bson.Value.ObjectID", v.t})
}
var oid primitive.ObjectID
copy(oid[:], v.bootstrap[:12])
return oid
}
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (v Val) ObjectIDOK() (primitive.ObjectID, bool) {
if v.t != bsontype.ObjectID {
return primitive.ObjectID{}, false
}
var oid primitive.ObjectID
copy(oid[:], v.bootstrap[:12])
return oid, true
}
// Boolean returns the BSON boolean the Value represents. It panics if the value is a BSON type
// other than boolean.
func (v Val) Boolean() bool {
if v.t != bsontype.Boolean {
panic(ElementTypeError{"bson.Value.Boolean", v.t})
}
return v.bootstrap[0] == 0x01
}
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (v Val) BooleanOK() (bool, bool) {
if v.t != bsontype.Boolean {
return false, false
}
return v.bootstrap[0] == 0x01, true
}
// DateTime returns the BSON datetime the Value represents. It panics if the value is a BSON type
// other than datetime.
func (v Val) DateTime() int64 {
if v.t != bsontype.DateTime {
panic(ElementTypeError{"bson.Value.DateTime", v.t})
}
return v.i64()
}
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (v Val) DateTimeOK() (int64, bool) {
if v.t != bsontype.DateTime {
return 0, false
}
return v.i64(), true
}
// Time returns the BSON datetime the Value represents as time.Time. It panics if the value is a BSON
// type other than datetime.
func (v Val) Time() time.Time {
if v.t != bsontype.DateTime {
panic(ElementTypeError{"bson.Value.Time", v.t})
}
i := v.i64()
return time.Unix(int64(i)/1000, int64(i)%1000*1000000)
}
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (v Val) TimeOK() (time.Time, bool) {
if v.t != bsontype.DateTime {
return time.Time{}, false
}
i := v.i64()
return time.Unix(int64(i)/1000, int64(i)%1000*1000000), true
}
// Null returns the BSON undefined the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Null() {
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
panic(ElementTypeError{"bson.Value.Null", v.t})
}
return
}
// NullOK is the same as Null, except it returns a boolean instead of
// panicking.
func (v Val) NullOK() bool {
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
return false
}
return true
}
// Regex returns the BSON regex the Value represents. It panics if the value is a BSON type
// other than regex.
func (v Val) Regex() (pattern, options string) {
if v.t != bsontype.Regex {
panic(ElementTypeError{"bson.Value.Regex", v.t})
}
regex := v.primitive.(primitive.Regex)
return regex.Pattern, regex.Options
}
// RegexOK is the same as Regex, except that it returns a boolean
// instead of panicking.
func (v Val) RegexOK() (pattern, options string, ok bool) {
if v.t != bsontype.Regex {
return "", "", false
}
regex := v.primitive.(primitive.Regex)
return regex.Pattern, regex.Options, true
}
// DBPointer returns the BSON dbpointer the Value represents. It panics if the value is a BSON type
// other than dbpointer.
func (v Val) DBPointer() (string, primitive.ObjectID) {
if v.t != bsontype.DBPointer {
panic(ElementTypeError{"bson.Value.DBPointer", v.t})
}
dbptr := v.primitive.(primitive.DBPointer)
return dbptr.DB, dbptr.Pointer
}
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (v Val) DBPointerOK() (string, primitive.ObjectID, bool) {
if v.t != bsontype.DBPointer {
return "", primitive.ObjectID{}, false
}
dbptr := v.primitive.(primitive.DBPointer)
return dbptr.DB, dbptr.Pointer, true
}
// JavaScript returns the BSON JavaScript the Value represents. It panics if the value is a BSON type
// other than JavaScript.
func (v Val) JavaScript() string {
if v.t != bsontype.JavaScript {
panic(ElementTypeError{"bson.Value.JavaScript", v.t})
}
return v.string()
}
// JavaScriptOK is the same as Javascript, except that it returns a boolean
// instead of panicking.
func (v Val) JavaScriptOK() (string, bool) {
if v.t != bsontype.JavaScript {
return "", false
}
return v.string(), true
}
// Symbol returns the BSON symbol the Value represents. It panics if the value is a BSON type
// other than symbol.
func (v Val) Symbol() string {
if v.t != bsontype.Symbol {
panic(ElementTypeError{"bson.Value.Symbol", v.t})
}
return v.string()
}
// SymbolOK is the same as Javascript, except that it returns a boolean
// instead of panicking.
func (v Val) SymbolOK() (string, bool) {
if v.t != bsontype.Symbol {
return "", false
}
return v.string(), true
}
// CodeWithScope returns the BSON code with scope value the Value represents. It panics if the
// value is a BSON type other than code with scope.
func (v Val) CodeWithScope() (string, Doc) {
if v.t != bsontype.CodeWithScope {
panic(ElementTypeError{"bson.Value.CodeWithScope", v.t})
}
cws := v.primitive.(primitive.CodeWithScope)
return string(cws.Code), cws.Scope.(Doc)
}
// CodeWithScopeOK is the same as JavascriptWithScope,
// except that it returns a boolean instead of panicking.
func (v Val) CodeWithScopeOK() (string, Doc, bool) {
if v.t != bsontype.CodeWithScope {
return "", nil, false
}
cws := v.primitive.(primitive.CodeWithScope)
return string(cws.Code), cws.Scope.(Doc), true
}
// Int32 returns the BSON int32 the Value represents. It panics if the value is a BSON type
// other than int32.
func (v Val) Int32() int32 {
if v.t != bsontype.Int32 {
panic(ElementTypeError{"bson.Value.Int32", v.t})
}
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24
}
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (v Val) Int32OK() (int32, bool) {
if v.t != bsontype.Int32 {
return 0, false
}
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24,
true
}
// Timestamp returns the BSON timestamp the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (v Val) Timestamp() (t, i uint32) {
if v.t != bsontype.Timestamp {
panic(ElementTypeError{"bson.Value.Timestamp", v.t})
}
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24
}
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (v Val) TimestampOK() (t uint32, i uint32, ok bool) {
if v.t != bsontype.Timestamp {
return 0, 0, false
}
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24,
true
}
// Int64 returns the BSON int64 the Value represents. It panics if the value is a BSON type
// other than int64.
func (v Val) Int64() int64 {
if v.t != bsontype.Int64 {
panic(ElementTypeError{"bson.Value.Int64", v.t})
}
return v.i64()
}
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (v Val) Int64OK() (int64, bool) {
if v.t != bsontype.Int64 {
return 0, false
}
return v.i64(), true
}
// Decimal128 returns the BSON decimal128 value the Value represents. It panics if the value is a
// BSON type other than decimal128.
func (v Val) Decimal128() primitive.Decimal128 {
if v.t != bsontype.Decimal128 {
panic(ElementTypeError{"bson.Value.Decimal128", v.t})
}
return v.primitive.(primitive.Decimal128)
}
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (v Val) Decimal128OK() (primitive.Decimal128, bool) {
if v.t != bsontype.Decimal128 {
return primitive.Decimal128{}, false
}
return v.primitive.(primitive.Decimal128), true
}
// MinKey returns the BSON minkey the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) MinKey() {
if v.t != bsontype.MinKey {
panic(ElementTypeError{"bson.Value.MinKey", v.t})
}
return
}
// MinKeyOK is the same as MinKey, except it returns a boolean instead of
// panicking.
func (v Val) MinKeyOK() bool {
if v.t != bsontype.MinKey {
return false
}
return true
}
// MaxKey returns the BSON maxkey the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) MaxKey() {
if v.t != bsontype.MaxKey {
panic(ElementTypeError{"bson.Value.MaxKey", v.t})
}
return
}
// MaxKeyOK is the same as MaxKey, except it returns a boolean instead of
// panicking.
func (v Val) MaxKeyOK() bool {
if v.t != bsontype.MaxKey {
return false
}
return true
}
// Equal compares v to v2 and returns true if they are equal. Unknown BSON types are
// never equal. Two empty values are equal.
func (v Val) Equal(v2 Val) bool {
if v.Type() != v2.Type() {
return false
}
if v.IsZero() && v2.IsZero() {
return true
}
switch v.Type() {
case bsontype.Double, bsontype.DateTime, bsontype.Timestamp, bsontype.Int64:
return bytes.Equal(v.bootstrap[0:8], v2.bootstrap[0:8])
case bsontype.String:
return v.string() == v2.string()
case bsontype.EmbeddedDocument:
return v.equalDocs(v2)
case bsontype.Array:
return v.Array().Equal(v2.Array())
case bsontype.Binary:
return v.primitive.(primitive.Binary).Equal(v2.primitive.(primitive.Binary))
case bsontype.Undefined:
return true
case bsontype.ObjectID:
return bytes.Equal(v.bootstrap[0:12], v2.bootstrap[0:12])
case bsontype.Boolean:
return v.bootstrap[0] == v2.bootstrap[0]
case bsontype.Null:
return true
case bsontype.Regex:
return v.primitive.(primitive.Regex).Equal(v2.primitive.(primitive.Regex))
case bsontype.DBPointer:
return v.primitive.(primitive.DBPointer).Equal(v2.primitive.(primitive.DBPointer))
case bsontype.JavaScript:
return v.JavaScript() == v2.JavaScript()
case bsontype.Symbol:
return v.Symbol() == v2.Symbol()
case bsontype.CodeWithScope:
code1, scope1 := v.primitive.(primitive.CodeWithScope).Code, v.primitive.(primitive.CodeWithScope).Scope
code2, scope2 := v2.primitive.(primitive.CodeWithScope).Code, v2.primitive.(primitive.CodeWithScope).Scope
return code1 == code2 && v.equalInterfaceDocs(scope1, scope2)
case bsontype.Int32:
return v.Int32() == v2.Int32()
case bsontype.Decimal128:
h, l := v.Decimal128().GetBytes()
h2, l2 := v2.Decimal128().GetBytes()
return h == h2 && l == l2
case bsontype.MinKey:
return true
case bsontype.MaxKey:
return true
default:
return false
}
}
func (v Val) equalDocs(v2 Val) bool {
_, ok1 := v.primitive.(MDoc)
_, ok2 := v2.primitive.(MDoc)
if ok1 || ok2 {
return v.asMDoc().Equal(v2.asMDoc())
}
return v.asDoc().Equal(v2.asDoc())
}
func (Val) equalInterfaceDocs(i, i2 interface{}) bool {
switch d := i.(type) {
case MDoc:
d2, ok := i2.(IDoc)
if !ok {
return false
}
return d.Equal(d2)
case Doc:
d2, ok := i2.(IDoc)
if !ok {
return false
}
return d.Equal(d2)
case nil:
return i2 == nil
default:
return false
}
}

23
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/DESIGN.md generated vendored Executable file
View File

@@ -0,0 +1,23 @@
# Driver Library Design
This document outlines the design for this package.
## Deployment, Server, and Connection
Acquiring a `Connection` from a `Server` selected from a `Deployment` enables sending and receiving
wire messages. A `Deployment` represents an set of MongoDB servers and a `Server` represents a
member of that set. These three types form the operation execution stack.
### Compression
Compression is handled by Connection type while uncompression is handled automatically by the
Operation type. This is done because the compressor to use for compressing a wire message is
chosen by the connection during handshake, while uncompression can be performed without this
information. This does make the design of compression non-symmetric, but it makes the design simpler
to implement and more consistent.
## Operation
The `Operation` type handles executing a series of commands using a `Deployment`. For most uses
`Operation` will only execute a single command, but the main use case for a series of commands is
batch split write commands, such as insert. The type itself is heavily documented, so reading the
code and comments together should provide an understanding of how the type works.
This type is not meant to be used directly by callers. Instead an wrapping type should be defined
using the IDL and an implementation generated using `operationgen`.

View File

@@ -0,0 +1,49 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package address // import "go.mongodb.org/mongo-driver/x/mongo/driver/address"
import (
"net"
"strings"
)
const defaultPort = "27017"
// Address is a network address. It can either be an IP address or a DNS name.
type Address string
// Network is the network protocol for this address. In most cases this will be
// "tcp" or "unix".
func (a Address) Network() string {
if strings.HasSuffix(string(a), "sock") {
return "unix"
}
return "tcp"
}
// String is the canonical version of this address, e.g. localhost:27017,
// 1.2.3.4:27017, example.com:27017.
func (a Address) String() string {
// TODO: unicode case folding?
s := strings.ToLower(string(a))
if len(s) == 0 {
return ""
}
if a.Network() != "unix" {
_, _, err := net.SplitHostPort(s)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
s += ":" + defaultPort
}
}
return s
}
// Canonicalize creates a canonicalized address.
func (a Address) Canonicalize() Address {
return Address(a.String())
}

View File

@@ -0,0 +1,136 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
var authFactories = make(map[string]AuthenticatorFactory)
func init() {
RegisterAuthenticatorFactory("", newDefaultAuthenticator)
RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
}
// CreateAuthenticator creates an authenticator.
func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
if f, ok := authFactories[name]; ok {
return f(cred)
}
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
}
// RegisterAuthenticatorFactory registers the authenticator factory.
func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
authFactories[name] = factory
}
// HandshakeOptions packages options that can be passed to the Handshaker()
// function. DBUser is optional but must be of the form <dbname.username>;
// if non-empty, then the connection will do SASL mechanism negotiation.
type HandshakeOptions struct {
AppName string
Authenticator Authenticator
Compressors []string
DBUser string
PerformAuthentication func(description.Server) bool
}
// Handshaker creates a connection handshaker for the given authenticator.
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
return driver.HandshakerFunc(func(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) {
desc, err := operation.NewIsMaster().
AppName(options.AppName).
Compressors(options.Compressors).
SASLSupportedMechs(options.DBUser).
Handshake(ctx, addr, conn)
if err != nil {
return description.Server{}, newAuthError("handshake failure", err)
}
performAuth := options.PerformAuthentication
if performAuth == nil {
performAuth = func(serv description.Server) bool {
return serv.Kind == description.RSPrimary ||
serv.Kind == description.RSSecondary ||
serv.Kind == description.Mongos ||
serv.Kind == description.Standalone
}
}
if performAuth(desc) && options.Authenticator != nil {
err = options.Authenticator.Auth(ctx, desc, conn)
if err != nil {
return description.Server{}, newAuthError("auth error", err)
}
}
if h == nil {
return desc, nil
}
return h.Handshake(ctx, addr, conn)
})
}
// Authenticator handles authenticating a connection.
type Authenticator interface {
// Auth authenticates the connection.
Auth(context.Context, description.Server, driver.Connection) error
}
func newAuthError(msg string, inner error) error {
return &Error{
message: msg,
inner: inner,
}
}
func newError(err error, mech string) error {
return &Error{
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
inner: err,
}
}
// Error is an error that occurred during authentication.
type Error struct {
message string
inner error
}
func (e *Error) Error() string {
if e.inner == nil {
return e.message
}
return fmt.Sprintf("%s: %s", e.message, e.inner)
}
// Inner returns the wrapped error.
func (e *Error) Inner() error {
return e.inner
}
// Message returns the message.
func (e *Error) Message() string {
return e.message
}

View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
// Cred is a user's credential.
type Cred struct {
Source string
Username string
Password string
PasswordSet bool
Props map[string]string
}

View File

@@ -0,0 +1,67 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
func newDefaultAuthenticator(cred *Cred) (Authenticator, error) {
return &DefaultAuthenticator{
Cred: cred,
}, nil
}
// DefaultAuthenticator uses SCRAM-SHA-1 or MONGODB-CR depending
// on the server version.
type DefaultAuthenticator struct {
Cred *Cred
}
// Auth authenticates the connection.
func (a *DefaultAuthenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error {
var actual Authenticator
var err error
switch chooseAuthMechanism(desc) {
case SCRAMSHA256:
actual, err = newScramSHA256Authenticator(a.Cred)
case SCRAMSHA1:
actual, err = newScramSHA1Authenticator(a.Cred)
default:
actual, err = newMongoDBCRAuthenticator(a.Cred)
}
if err != nil {
return newAuthError("error creating authenticator", err)
}
return actual.Auth(ctx, desc, conn)
}
// If a server provides a list of supported mechanisms, we choose
// SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1.
// Otherwise, we decide based on what is supported.
func chooseAuthMechanism(desc description.Server) string {
if desc.SaslSupportedMechs != nil {
for _, v := range desc.SaslSupportedMechs {
if v == SCRAMSHA256 {
return v
}
}
return SCRAMSHA1
}
if err := description.ScramSHA1Supported(desc.WireVersion); err == nil {
return SCRAMSHA1
}
return MONGODBCR
}

View File

@@ -0,0 +1,23 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package auth is not for public use.
//
// The API for packages in the 'private' directory have no stability
// guarantee.
//
// The packages within the 'private' directory would normally be put into an
// 'internal' directory to prohibit their use outside the 'mongo' directory.
// However, some MongoDB tools require very low-level access to the building
// blocks of a driver, so we have placed them under 'private' to allow these
// packages to be imported by projects that need them.
//
// These package APIs may be modified in backwards-incompatible ways at any
// time.
//
// You are strongly discouraged from directly using any packages
// under 'private'.
package auth

View File

@@ -0,0 +1,60 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build windows linux darwin
package auth
import (
"context"
"fmt"
"net"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}
return &GSSAPIAuthenticator{
Username: cred.Username,
Password: cred.Password,
PasswordSet: cred.PasswordSet,
Props: cred.Props,
}, nil
}
// GSSAPIAuthenticator uses the GSSAPI algorithm over SASL to authenticate a connection.
type GSSAPIAuthenticator struct {
Username string
Password string
PasswordSet bool
Props map[string]string
}
// Auth authenticates the connection.
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error {
target := desc.Addr.String()
hostname, _, err := net.SplitHostPort(target)
if err != nil {
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
}
client, err := gssapi.New(hostname, a.Username, a.Password, a.PasswordSet, a.Props)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, conn, "$external", client)
}

View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build !gssapi
package auth
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
}

View File

@@ -0,0 +1,21 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,!windows,!linux,!darwin
package auth
import (
"fmt"
"runtime"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
}

View File

@@ -0,0 +1,166 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
package gssapi
/*
#cgo linux CFLAGS: -DGOOS_linux
#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
#cgo darwin CFLAGS: -DGOOS_darwin
#cgo darwin LDFLAGS: -framework GSS
#include "gss_wrapper.h"
*/
import "C"
import (
"fmt"
"runtime"
"strings"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
serviceName := "mongodb"
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_REALM":
return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
target = value
default:
return nil, fmt.Errorf("unknown mechanism property %s", key)
}
}
servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, target)
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.gssapi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.gssapi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
if status != C.GSSAPI_OK {
return mechName, nil, sc.getError("unable to initialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var buf unsafe.Pointer
var bufLen C.size_t
var outBuf unsafe.Pointer
var outBufLen C.size_t
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.gssapi_client_username(&sc.state, &cusername)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf = unsafe.Pointer(&bytes[0])
bufLen = C.size_t(len(bytes))
status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
if len(challenge) > 0 {
buf = unsafe.Pointer(&challenge[0])
bufLen = C.size_t(len(challenge))
}
status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.GSSAPI_OK:
sc.contextComplete = true
case C.GSSAPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != nil {
defer C.free(outBuf)
}
return C.GoBytes(outBuf, C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
var desc *C.char
status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
if status != C.GSSAPI_OK {
if desc != nil {
C.free(unsafe.Pointer(desc))
}
return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
}
defer C.free(unsafe.Pointer(desc))
return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
}

View File

@@ -0,0 +1,248 @@
//+build gssapi
//+build linux darwin
#include <string.h>
#include <stdio.h>
#include "gss_wrapper.h"
OM_uint32 gssapi_canonicalize_name(
OM_uint32* minor_status,
char *input_name,
gss_OID input_name_type,
gss_name_t *output_name
)
{
OM_uint32 major_status;
gss_name_t imported_name = GSS_C_NO_NAME;
gss_buffer_desc buffer = GSS_C_EMPTY_BUFFER;
buffer.value = input_name;
buffer.length = strlen(input_name);
major_status = gss_import_name(minor_status, &buffer, input_name_type, &imported_name);
if (GSS_ERROR(major_status)) {
return major_status;
}
major_status = gss_canonicalize_name(minor_status, imported_name, (gss_OID)gss_mech_krb5, output_name);
if (imported_name != GSS_C_NO_NAME) {
OM_uint32 ignored;
gss_release_name(&ignored, &imported_name);
}
return major_status;
}
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
)
{
OM_uint32 stat = maj_stat;
int stat_type = GSS_C_GSS_CODE;
if (min_stat != 0) {
stat = min_stat;
stat_type = GSS_C_MECH_CODE;
}
OM_uint32 local_maj_stat, local_min_stat;
OM_uint32 msg_ctx = 0;
gss_buffer_desc desc_buffer;
do
{
local_maj_stat = gss_display_status(
&local_min_stat,
stat,
stat_type,
GSS_C_NO_OID,
&msg_ctx,
&desc_buffer
);
if (GSS_ERROR(local_maj_stat)) {
return GSSAPI_ERROR;
}
if (*desc) {
free(*desc);
}
*desc = malloc(desc_buffer.length+1);
memcpy(*desc, desc_buffer.value, desc_buffer.length+1);
gss_release_buffer(&local_min_stat, &desc_buffer);
}
while(msg_ctx != 0);
return GSSAPI_OK;
}
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
)
{
client->cred = GSS_C_NO_CREDENTIAL;
client->ctx = GSS_C_NO_CONTEXT;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, spn, GSS_C_NT_HOSTBASED_SERVICE, &client->spn);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (username) {
gss_name_t name;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, username, GSS_C_NT_USER_NAME, &name);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (password) {
gss_buffer_desc password_buffer;
password_buffer.value = password;
password_buffer.length = strlen(password);
client->maj_stat = gss_acquire_cred_with_password(&client->min_stat, name, &password_buffer, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
} else {
client->maj_stat = gss_acquire_cred(&client->min_stat, name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
OM_uint32 ignored;
gss_release_name(&ignored, &name);
}
return GSSAPI_OK;
}
int gssapi_client_username(
gssapi_client_state *client,
char** username
)
{
OM_uint32 ignored;
gss_name_t name = GSS_C_NO_NAME;
client->maj_stat = gss_inquire_context(&client->min_stat, client->ctx, &name, NULL, NULL, NULL, NULL, NULL, NULL);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
gss_buffer_desc name_buffer;
client->maj_stat = gss_display_name(&client->min_stat, name, &name_buffer, NULL);
if (GSS_ERROR(client->maj_stat)) {
gss_release_name(&ignored, &name);
return GSSAPI_ERROR;
}
*username = malloc(name_buffer.length+1);
memcpy(*username, name_buffer.value, name_buffer.length+1);
gss_release_buffer(&ignored, &name_buffer);
gss_release_name(&ignored, &name);
return GSSAPI_OK;
}
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
if (input) {
input_buffer.value = input;
input_buffer.length = input_length;
}
client->maj_stat = gss_init_sec_context(
&client->min_stat,
client->cred,
&client->ctx,
client->spn,
GSS_C_NO_OID,
GSS_C_MUTUAL_FLAG | GSS_C_SEQUENCE_FLAG,
0,
GSS_C_NO_CHANNEL_BINDINGS,
&input_buffer,
NULL,
&output_buffer,
NULL,
NULL
);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
OM_uint32 ignored;
gss_release_buffer(&ignored, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
} else if (client->maj_stat == GSS_S_CONTINUE_NEEDED) {
return GSSAPI_CONTINUE;
}
return GSSAPI_OK;
}
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
input_buffer.value = input;
input_buffer.length = input_length;
client->maj_stat = gss_wrap(&client->min_stat, client->ctx, 0, GSS_C_QOP_DEFAULT, &input_buffer, NULL, &output_buffer);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
gss_release_buffer(&client->min_stat, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
return GSSAPI_OK;
}
int gssapi_client_destroy(
gssapi_client_state *client
)
{
OM_uint32 ignored;
if (client->ctx != GSS_C_NO_CONTEXT) {
gss_delete_sec_context(&ignored, &client->ctx, GSS_C_NO_BUFFER);
}
if (client->spn != GSS_C_NO_NAME) {
gss_release_name(&ignored, &client->spn);
}
if (client->cred != GSS_C_NO_CREDENTIAL) {
gss_release_cred(&ignored, &client->cred);
}
return GSSAPI_OK;
}

View File

@@ -0,0 +1,66 @@
//+build gssapi
//+build linux darwin
#ifndef GSS_WRAPPER_H
#define GSS_WRAPPER_H
#include <stdlib.h>
#ifdef GOOS_linux
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_krb5.h>
#endif
#ifdef GOOS_darwin
#include <GSS/GSS.h>
#endif
#define GSSAPI_OK 0
#define GSSAPI_CONTINUE 1
#define GSSAPI_ERROR 2
typedef struct {
gss_name_t spn;
gss_cred_id_t cred;
gss_ctx_id_t ctx;
OM_uint32 maj_stat;
OM_uint32 min_stat;
} gssapi_client_state;
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
);
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
);
int gssapi_client_username(
gssapi_client_state *client,
char** username
);
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_destroy(
gssapi_client_state *client
);
#endif

View File

@@ -0,0 +1,352 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
package gssapi
// #include "sspi_wrapper.h"
import "C"
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
initOnce.Do(initSSPI)
if initError != nil {
return nil, initError
}
var err error
serviceName := "mongodb"
serviceRealm := ""
canonicalizeHostName := false
var serviceHostSet bool
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
canonicalizeHostName, err = strconv.ParseBool(value)
if err != nil {
return nil, fmt.Errorf("%s must be a boolean (true, false, 0, 1) but got '%s'", key, value)
}
case "SERVICE_REALM":
serviceRealm = value
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
serviceHostSet = true
target = value
}
}
if canonicalizeHostName {
// Should not canonicalize the SERVICE_HOST
if serviceHostSet {
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME and SERVICE_HOST canonot both be specified")
}
names, err := net.LookupAddr(target)
if err != nil || len(names) == 0 {
return nil, fmt.Errorf("unable to canonicalize hostname: %s", err)
}
target = names[0]
if target[len(target)-1] == '.' {
target = target[:len(target)-1]
}
}
servicePrincipalName := fmt.Sprintf("%s/%s", serviceName, target)
if serviceRealm != "" {
servicePrincipalName += "@" + serviceRealm
}
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.sspi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.sspi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.sspi_client_init(&sc.state, cusername, cpassword)
if status != C.SSPI_OK {
return mechName, nil, sc.getError("unable to intitialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var outBuf C.PVOID
var outBufLen C.ULONG
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.sspi_client_username(&sc.state, &cusername)
if status != C.SSPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf := (C.PVOID)(unsafe.Pointer(&bytes[0]))
bufLen := C.ULONG(len(bytes))
status := C.sspi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.SSPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
var buf C.PVOID
var bufLen C.ULONG
if len(challenge) > 0 {
buf = (C.PVOID)(unsafe.Pointer(&challenge[0]))
bufLen = C.ULONG(len(challenge))
}
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
status := C.sspi_client_negotiate(&sc.state, cservicePrincipalName, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.SSPI_OK:
sc.contextComplete = true
case C.SSPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != C.PVOID(nil) {
defer C.free(unsafe.Pointer(outBuf))
}
return C.GoBytes(unsafe.Pointer(outBuf), C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
return getError(prefix, sc.state.status)
}
var initOnce sync.Once
var initError error
func initSSPI() {
rc := C.sspi_init()
if rc != 0 {
initError = fmt.Errorf("error initializing sspi: %v", rc)
}
}
func getError(prefix string, status C.SECURITY_STATUS) error {
var s string
switch status {
case C.SEC_E_ALGORITHM_MISMATCH:
s = "The client and server cannot communicate because they do not possess a common algorithm."
case C.SEC_E_BAD_BINDINGS:
s = "The SSPI channel bindings supplied by the client are incorrect."
case C.SEC_E_BAD_PKGID:
s = "The requested package identifier does not exist."
case C.SEC_E_BUFFER_TOO_SMALL:
s = "The buffers supplied to the function are not large enough to contain the information."
case C.SEC_E_CANNOT_INSTALL:
s = "The security package cannot initialize successfully and should not be installed."
case C.SEC_E_CANNOT_PACK:
s = "The package is unable to pack the context."
case C.SEC_E_CERT_EXPIRED:
s = "The received certificate has expired."
case C.SEC_E_CERT_UNKNOWN:
s = "An unknown error occurred while processing the certificate."
case C.SEC_E_CERT_WRONG_USAGE:
s = "The certificate is not valid for the requested usage."
case C.SEC_E_CONTEXT_EXPIRED:
s = "The application is referencing a context that has already been closed. A properly written application should not receive this error."
case C.SEC_E_CROSSREALM_DELEGATION_FAILURE:
s = "The server attempted to make a Kerberos-constrained delegation request for a target outside the server's realm."
case C.SEC_E_CRYPTO_SYSTEM_INVALID:
s = "The cryptographic system or checksum function is not valid because a required function is unavailable."
case C.SEC_E_DECRYPT_FAILURE:
s = "The specified data could not be decrypted."
case C.SEC_E_DELEGATION_REQUIRED:
s = "The requested operation cannot be completed. The computer must be trusted for delegation"
case C.SEC_E_DOWNGRADE_DETECTED:
s = "The system detected a possible attempt to compromise security. Verify that the server that authenticated you can be contacted."
case C.SEC_E_ENCRYPT_FAILURE:
s = "The specified data could not be encrypted."
case C.SEC_E_ILLEGAL_MESSAGE:
s = "The message received was unexpected or badly formatted."
case C.SEC_E_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. The context could not be initialized."
case C.SEC_E_INCOMPLETE_MESSAGE:
s = "The message supplied was incomplete. The signature was not verified."
case C.SEC_E_INSUFFICIENT_MEMORY:
s = "Not enough memory is available to complete the request."
case C.SEC_E_INTERNAL_ERROR:
s = "An error occurred that did not map to an SSPI error code."
case C.SEC_E_INVALID_HANDLE:
s = "The handle passed to the function is not valid."
case C.SEC_E_INVALID_TOKEN:
s = "The token passed to the function is not valid."
case C.SEC_E_ISSUING_CA_UNTRUSTED:
s = "An untrusted certification authority (CA) was detected while processing the smart card certificate used for authentication."
case C.SEC_E_ISSUING_CA_UNTRUSTED_KDC:
s = "An untrusted CA was detected while processing the domain controller certificate used for authentication. The system event log contains additional information."
case C.SEC_E_KDC_CERT_EXPIRED:
s = "The domain controller certificate used for smart card logon has expired."
case C.SEC_E_KDC_CERT_REVOKED:
s = "The domain controller certificate used for smart card logon has been revoked."
case C.SEC_E_KDC_INVALID_REQUEST:
s = "A request that is not valid was sent to the KDC."
case C.SEC_E_KDC_UNABLE_TO_REFER:
s = "The KDC was unable to generate a referral for the service requested."
case C.SEC_E_KDC_UNKNOWN_ETYPE:
s = "The requested encryption type is not supported by the KDC."
case C.SEC_E_LOGON_DENIED:
s = "The logon has been denied"
case C.SEC_E_MAX_REFERRALS_EXCEEDED:
s = "The number of maximum ticket referrals has been exceeded."
case C.SEC_E_MESSAGE_ALTERED:
s = "The message supplied for verification has been altered."
case C.SEC_E_MULTIPLE_ACCOUNTS:
s = "The received certificate was mapped to multiple accounts."
case C.SEC_E_MUST_BE_KDC:
s = "The local computer must be a Kerberos domain controller (KDC)"
case C.SEC_E_NO_AUTHENTICATING_AUTHORITY:
s = "No authority could be contacted for authentication."
case C.SEC_E_NO_CREDENTIALS:
s = "No credentials are available."
case C.SEC_E_NO_IMPERSONATION:
s = "No impersonation is allowed for this context."
case C.SEC_E_NO_IP_ADDRESSES:
s = "Unable to accomplish the requested task because the local computer does not have any IP addresses."
case C.SEC_E_NO_KERB_KEY:
s = "No Kerberos key was found."
case C.SEC_E_NO_PA_DATA:
s = "Policy administrator (PA) data is needed to determine the encryption type"
case C.SEC_E_NO_S4U_PROT_SUPPORT:
s = "The Kerberos subsystem encountered an error. A service for user protocol request was made against a domain controller which does not support service for a user."
case C.SEC_E_NO_TGT_REPLY:
s = "The client is trying to negotiate a context and the server requires a user-to-user connection"
case C.SEC_E_NOT_OWNER:
s = "The caller of the function does not own the credentials."
case C.SEC_E_OK:
s = "The operation completed successfully."
case C.SEC_E_OUT_OF_SEQUENCE:
s = "The message supplied for verification is out of sequence."
case C.SEC_E_PKINIT_CLIENT_FAILURE:
s = "The smart card certificate used for authentication is not trusted."
case C.SEC_E_PKINIT_NAME_MISMATCH:
s = "The client certificate does not contain a valid UPN or does not match the client name in the logon request."
case C.SEC_E_QOP_NOT_SUPPORTED:
s = "The quality of protection attribute is not supported by this package."
case C.SEC_E_REVOCATION_OFFLINE_C:
s = "The revocation status of the smart card certificate used for authentication could not be determined."
case C.SEC_E_REVOCATION_OFFLINE_KDC:
s = "The revocation status of the domain controller certificate used for smart card authentication could not be determined. The system event log contains additional information."
case C.SEC_E_SECPKG_NOT_FOUND:
s = "The security package was not recognized."
case C.SEC_E_SECURITY_QOS_FAILED:
s = "The security context could not be established due to a failure in the requested quality of service (for example"
case C.SEC_E_SHUTDOWN_IN_PROGRESS:
s = "A system shutdown is in progress."
case C.SEC_E_SMARTCARD_CERT_EXPIRED:
s = "The smart card certificate used for authentication has expired."
case C.SEC_E_SMARTCARD_CERT_REVOKED:
s = "The smart card certificate used for authentication has been revoked. Additional information may exist in the event log."
case C.SEC_E_SMARTCARD_LOGON_REQUIRED:
s = "Smart card logon is required and was not used."
case C.SEC_E_STRONG_CRYPTO_NOT_SUPPORTED:
s = "The other end of the security negotiation requires strong cryptography"
case C.SEC_E_TARGET_UNKNOWN:
s = "The target was not recognized."
case C.SEC_E_TIME_SKEW:
s = "The clocks on the client and server computers do not match."
case C.SEC_E_TOO_MANY_PRINCIPALS:
s = "The KDC reply contained more than one principal name."
case C.SEC_E_UNFINISHED_CONTEXT_DELETED:
s = "A security context was deleted before the context was completed. This is considered a logon failure."
case C.SEC_E_UNKNOWN_CREDENTIALS:
s = "The credentials provided were not recognized."
case C.SEC_E_UNSUPPORTED_FUNCTION:
s = "The requested function is not supported."
case C.SEC_E_UNSUPPORTED_PREAUTH:
s = "An unsupported preauthentication mechanism was presented to the Kerberos package."
case C.SEC_E_UNTRUSTED_ROOT:
s = "The certificate chain was issued by an authority that is not trusted."
case C.SEC_E_WRONG_CREDENTIAL_HANDLE:
s = "The supplied credential handle does not match the credential associated with the security context."
case C.SEC_E_WRONG_PRINCIPAL:
s = "The target principal name is incorrect."
case C.SEC_I_COMPLETE_AND_CONTINUE:
s = "The function completed successfully"
case C.SEC_I_COMPLETE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_CONTEXT_EXPIRED:
s = "The message sender has finished using the connection and has initiated a shutdown. For information about initiating or recognizing a shutdown"
case C.SEC_I_CONTINUE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. Additional information can be returned from the context."
case C.SEC_I_LOCAL_LOGON:
s = "The logon was completed"
case C.SEC_I_NO_LSA_CONTEXT:
s = "There is no LSA mode context associated with this context."
case C.SEC_I_RENEGOTIATE:
s = "The context data must be renegotiated with the peer."
default:
return fmt.Errorf("%s: 0x%x", prefix, uint32(status))
}
return fmt.Errorf("%s: %s(0x%x)", prefix, s, uint32(status))
}

View File

@@ -0,0 +1,218 @@
//+build gssapi,windows
#include "sspi_wrapper.h"
static HINSTANCE sspi_secur32_dll = NULL;
static PSecurityFunctionTable sspi_functions = NULL;
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
int sspi_init(
)
{
sspi_secur32_dll = LoadLibrary("secur32.dll");
if (!sspi_secur32_dll) {
return GetLastError();
}
INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
if (!init_security_interface) {
return -1;
}
sspi_functions = (*init_security_interface)();
if (!sspi_functions) {
return -2;
}
return SSPI_OK;
}
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
)
{
TimeStamp timestamp;
if (username) {
if (password) {
SEC_WINNT_AUTH_IDENTITY auth_identity;
#ifdef _UNICODE
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
#else
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
#endif
auth_identity.User = (LPSTR) username;
auth_identity.UserLength = strlen(username);
auth_identity.Password = (LPSTR) password;
auth_identity.PasswordLength = strlen(password);
auth_identity.Domain = NULL;
auth_identity.DomainLength = 0;
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, &timestamp);
} else {
client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
} else {
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
return SSPI_OK;
}
int sspi_client_username(
sspi_client_state *client,
char** username
)
{
SecPkgCredentials_Names names;
client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
int len = strlen(names.sUserName) + 1;
*username = malloc(len);
memcpy(*username, names.sUserName, len);
sspi_functions->FreeContextBuffer(names.sUserName);
return SSPI_OK;
}
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecBufferDesc inbuf;
SecBuffer in_bufs[1];
SecBufferDesc outbuf;
SecBuffer out_bufs[1];
if (client->has_ctx > 0) {
inbuf.ulVersion = SECBUFFER_VERSION;
inbuf.cBuffers = 1;
inbuf.pBuffers = in_bufs;
in_bufs[0].pvBuffer = input;
in_bufs[0].cbBuffer = input_length;
in_bufs[0].BufferType = SECBUFFER_TOKEN;
}
outbuf.ulVersion = SECBUFFER_VERSION;
outbuf.cBuffers = 1;
outbuf.pBuffers = out_bufs;
out_bufs[0].pvBuffer = NULL;
out_bufs[0].cbBuffer = 0;
out_bufs[0].BufferType = SECBUFFER_TOKEN;
ULONG context_attr = 0;
client->status = sspi_functions->InitializeSecurityContext(
&client->cred,
client->has_ctx > 0 ? &client->ctx : NULL,
(LPSTR) spn,
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
0,
SECURITY_NETWORK_DREP,
client->has_ctx > 0 ? &inbuf : NULL,
0,
&client->ctx,
&outbuf,
&context_attr,
NULL);
if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
return SSPI_ERROR;
}
client->has_ctx = 1;
*output = malloc(out_bufs[0].cbBuffer);
*output_length = out_bufs[0].cbBuffer;
memcpy(*output, out_bufs[0].pvBuffer, *output_length);
sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
if (client->status == SEC_I_CONTINUE_NEEDED) {
return SSPI_CONTINUE;
}
return SSPI_OK;
}
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecPkgContext_Sizes sizes;
client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
SecBuffer wrap_bufs[3];
SecBufferDesc wrap_buf_desc;
wrap_buf_desc.cBuffers = 3;
wrap_buf_desc.pBuffers = wrap_bufs;
wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
wrap_bufs[0].pvBuffer = msg;
wrap_bufs[1].cbBuffer = input_length;
wrap_bufs[1].BufferType = SECBUFFER_DATA;
wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
wrap_bufs[2].BufferType = SECBUFFER_PADDING;
wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
if (client->status != SEC_E_OK) {
free(msg);
return SSPI_ERROR;
}
*output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
*output = malloc(*output_length);
memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
free(msg);
return SSPI_OK;
}
int sspi_client_destroy(
sspi_client_state *client
)
{
if (client->has_ctx > 0) {
sspi_functions->DeleteSecurityContext(&client->ctx);
}
sspi_functions->FreeCredentialsHandle(&client->cred);
return SSPI_OK;
}

View File

@@ -0,0 +1,58 @@
//+build gssapi,windows
#ifndef SSPI_WRAPPER_H
#define SSPI_WRAPPER_H
#define SECURITY_WIN32 1 /* Required for SSPI */
#include <windows.h>
#include <sspi.h>
#define SSPI_OK 0
#define SSPI_CONTINUE 1
#define SSPI_ERROR 2
typedef struct {
CredHandle cred;
CtxtHandle ctx;
int has_ctx;
SECURITY_STATUS status;
} sspi_client_state;
int sspi_init();
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
);
int sspi_client_username(
sspi_client_state *client,
char** username
);
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_destroy(
sspi_client_state *client
);
#endif

View File

@@ -0,0 +1,94 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"crypto/md5"
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MONGODBCR is the mechanism name for MONGODB-CR.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 4.0.
const MONGODBCR = "MONGODB-CR"
func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) {
return &MongoDBCRAuthenticator{
DB: cred.Source,
Username: cred.Username,
Password: cred.Password,
}, nil
}
// MongoDBCRAuthenticator uses the MONGODB-CR algorithm to authenticate a connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 4.0.
type MongoDBCRAuthenticator struct {
DB string
Username string
Password string
}
// Auth authenticates the connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 4.0.
func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error {
db := a.DB
if db == "" {
db = defaultAuthDB
}
doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
cmd := operation.NewCommand(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn})
err := cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
rdr := cmd.Result()
var getNonceResult struct {
Nonce string `bson:"nonce"`
}
err = bson.Unmarshal(rdr, &getNonceResult)
if err != nil {
return newAuthError("unmarshal error", err)
}
doc = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "user", a.Username),
bsoncore.AppendStringElement(nil, "nonce", getNonceResult.Nonce),
bsoncore.AppendStringElement(nil, "key", a.createKey(getNonceResult.Nonce)),
)
cmd = operation.NewCommand(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn})
err = cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
return nil
}
func (a *MongoDBCRAuthenticator) createKey(nonce string) string {
h := md5.New()
_, _ = io.WriteString(h, nonce)
_, _ = io.WriteString(h, a.Username)
_, _ = io.WriteString(h, mongoPasswordDigest(a.Username, a.Password))
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,56 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// PLAIN is the mechanism name for PLAIN.
const PLAIN = "PLAIN"
func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
}, nil
}
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
type PlainAuthenticator struct {
Username string
Password string
}
// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error {
return ConductSaslConversation(ctx, conn, "$external", &plainSaslClient{
username: a.Username,
password: a.Password,
})
}
type plainSaslClient struct {
username string
password string
}
func (c *plainSaslClient) Start() (string, []byte, error) {
b := []byte("\x00" + c.username + "\x00" + c.password)
return PLAIN, b, nil
}
func (c *plainSaslClient) Next(challenge []byte) ([]byte, error) {
return nil, newAuthError("unexpected server challenge", nil)
}
func (c *plainSaslClient) Completed() bool {
return true
}

View File

@@ -0,0 +1,112 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ConductSaslConversation handles running a sasl conversation with MongoDB.
func ConductSaslConversation(ctx context.Context, conn driver.Connection, db string, client SaslClient) error {
if db == "" {
db = defaultAuthDB
}
if closer, ok := client.(SaslClientCloser); ok {
defer closer.Close()
}
mech, payload, err := client.Start()
if err != nil {
return newError(err, mech)
}
doc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", mech),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
saslStartCmd := operation.NewCommand(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn})
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
var saslResp saslResponse
err = saslStartCmd.Execute(ctx)
if err != nil {
return newError(err, mech)
}
rdr := saslStartCmd.Result()
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshall error", err)
}
cid := saslResp.ConversationID
for {
if saslResp.Code != 0 {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
payload, err = client.Next(saslResp.Payload)
if err != nil {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
doc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslContinue", 1),
bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
saslContinueCmd := operation.NewCommand(doc).Database(db).Deployment(driver.SingleConnectionDeployment{conn})
err = saslContinueCmd.Execute(ctx)
if err != nil {
return newError(err, mech)
}
rdr = saslContinueCmd.Result()
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshal error", err)
}
}
}

View File

@@ -0,0 +1,102 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Copyright (C) MongoDB, Inc. 2018-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"github.com/xdg/scram"
"github.com/xdg/stringprep"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
const SCRAMSHA1 = "SCRAM-SHA-1"
// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
const SCRAMSHA256 = "SCRAM-SHA-256"
func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) {
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
client: client,
}, nil
}
func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) {
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError(fmt.Sprintf("error SASLprepping password '%s'", cred.Password), err)
}
client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
client: client,
}, nil
}
// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
type ScramAuthenticator struct {
mechanism string
source string
client *scram.Client
}
// Auth authenticates the connection.
func (a *ScramAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error {
adapter := &scramSaslAdapter{conversation: a.client.NewConversation(), mechanism: a.mechanism}
err := ConductSaslConversation(ctx, conn, a.source, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
type scramSaslAdapter struct {
mechanism string
conversation *scram.ClientConversation
}
func (a *scramSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step("")
if err != nil {
return a.mechanism, nil, err
}
return a.mechanism, []byte(step), nil
}
func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(string(challenge))
if err != nil {
return nil, err
}
return []byte(step), nil
}
func (a *scramSaslAdapter) Completed() bool {
return a.conversation.Done()
}

View File

@@ -0,0 +1,23 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"crypto/md5"
"fmt"
"io"
)
const defaultAuthDB = "admin"
func mongoPasswordDigest(username, password string) string {
h := md5.New()
_, _ = io.WriteString(h, username)
_, _ = io.WriteString(h, ":mongo:")
_, _ = io.WriteString(h, password)
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,49 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MongoDBX509 is the mechanism name for MongoDBX509.
const MongoDBX509 = "MONGODB-X509"
func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) {
return &MongoDBX509Authenticator{User: cred.Username}, nil
}
// MongoDBX509Authenticator uses X.509 certificates over TLS to authenticate a connection.
type MongoDBX509Authenticator struct {
User string
}
// Auth implements the Authenticator interface.
func (a *MongoDBX509Authenticator) Auth(ctx context.Context, desc description.Server, conn driver.Connection) error {
requestDoc := bsoncore.AppendInt32Element(nil, "authenticate", 1)
requestDoc = bsoncore.AppendStringElement(requestDoc, "mechanism", MongoDBX509)
if desc.WireVersion == nil || desc.WireVersion.Max < 5 {
requestDoc = bsoncore.AppendStringElement(requestDoc, "user", a.User)
}
authCmd := operation.
NewCommand(bsoncore.BuildDocument(nil, requestDoc)).
Database("$external").
Deployment(driver.SingleConnectionDeployment{conn})
err := authCmd.Execute(ctx)
if err != nil {
return newAuthError("round trip error", err)
}
return nil
}

View File

@@ -0,0 +1,325 @@
package driver
import (
"context"
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
// of one at a time. An individual document cursor can be built on top of this batch cursor.
type BatchCursor struct {
clientSession *session.Client
clock *session.ClusterClock
database string
collection string
id int64
err error
server Server
batchSize int32
maxTimeMS int64
currentBatch *bsoncore.DocumentSequence
firstBatch bool
cmdMonitor *event.CommandMonitor
postBatchResumeToken bsoncore.Document
// legacy server (< 3.2) fields
legacy bool // This field is provided for ListCollectionsBatchCursor.
limit int32
numReturned int32 // number of docs returned by server
}
// CursorResponse represents the response from a command the results in a cursor. A BatchCursor can
// be constructed from a CursorResponse.
type CursorResponse struct {
Server Server
Desc description.Server
FirstBatch *bsoncore.DocumentSequence
Database string
Collection string
ID int64
postBatchResumeToken bsoncore.Document
}
// NewCursorResponse constructs a cursor response from the given response and server. This method
// can be used within the ProcessResponse method for an operation.
func NewCursorResponse(response bsoncore.Document, server Server, desc description.Server) (CursorResponse, error) {
cur, ok := response.Lookup("cursor").DocumentOK()
if !ok {
return CursorResponse{}, fmt.Errorf("cursor should be an embedded document but is of BSON type %s", response.Lookup("cursor").Type)
}
elems, err := cur.Elements()
if err != nil {
return CursorResponse{}, err
}
curresp := CursorResponse{Server: server, Desc: desc}
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok := elem.Value().ArrayOK()
if !ok {
return CursorResponse{}, fmt.Errorf("firstBatch should be an array but is a BSON %s", elem.Value().Type)
}
curresp.FirstBatch = &bsoncore.DocumentSequence{Style: bsoncore.ArrayStyle, Data: arr}
case "ns":
ns, ok := elem.Value().StringValueOK()
if !ok {
return CursorResponse{}, fmt.Errorf("ns should be a string but is a BSON %s", elem.Value().Type)
}
index := strings.Index(ns, ".")
if index == -1 {
return CursorResponse{}, errors.New("ns field must contain a valid namespace, but is missing '.'")
}
curresp.Database = ns[:index]
curresp.Collection = ns[index+1:]
case "id":
curresp.ID, ok = elem.Value().Int64OK()
if !ok {
return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
case "postBatchResumeToken":
curresp.postBatchResumeToken, ok = elem.Value().DocumentOK()
if !ok {
return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type)
}
}
}
return curresp, nil
}
// CursorOptions are extra options that are required to construct a BatchCursor.
type CursorOptions struct {
BatchSize int32
MaxTimeMS int64
Limit int32
CommandMonitor *event.CommandMonitor
}
// NewBatchCursor creates a new BatchCursor from the provided parameters.
func NewBatchCursor(cr CursorResponse, clientSession *session.Client, clock *session.ClusterClock, opts CursorOptions) (*BatchCursor, error) {
ds := cr.FirstBatch
bc := &BatchCursor{
clientSession: clientSession,
clock: clock,
database: cr.Database,
collection: cr.Collection,
id: cr.ID,
server: cr.Server,
batchSize: opts.BatchSize,
maxTimeMS: opts.MaxTimeMS,
cmdMonitor: opts.CommandMonitor,
firstBatch: true,
postBatchResumeToken: cr.postBatchResumeToken,
}
if ds != nil {
bc.numReturned = int32(ds.DocumentCount())
}
if cr.Desc.WireVersion == nil || cr.Desc.WireVersion.Max < 4 {
bc.legacy = true
bc.limit = opts.Limit
// Take as many documents from the batch as needed.
if bc.limit != 0 && bc.limit < bc.numReturned {
for i := int32(0); i < bc.limit; i++ {
_, err := ds.Next()
if err != nil {
return nil, err
}
}
ds.Data = ds.Data[:ds.Pos]
ds.ResetIterator()
}
}
bc.currentBatch = ds
return bc, nil
}
// NewEmptyBatchCursor returns a batch cursor that is empty.
func NewEmptyBatchCursor() *BatchCursor {
return &BatchCursor{currentBatch: new(bsoncore.DocumentSequence)}
}
// ID returns the cursor ID for this batch cursor.
func (bc *BatchCursor) ID() int64 {
return bc.id
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (bc *BatchCursor) Next(ctx context.Context) bool {
if ctx == nil {
ctx = context.Background()
}
if bc.firstBatch {
bc.firstBatch = false
return true
}
if bc.id == 0 || bc.server == nil {
return false
}
bc.getMore(ctx)
switch bc.currentBatch.Style {
case bsoncore.SequenceStyle:
return len(bc.currentBatch.Data) > 0
case bsoncore.ArrayStyle:
return len(bc.currentBatch.Data) > 5
default:
return false
}
}
// Batch will return a DocumentSequence for the current batch of documents. The returned
// DocumentSequence is only valid until the next call to Next or Close.
func (bc *BatchCursor) Batch() *bsoncore.DocumentSequence { return bc.currentBatch }
// Err returns the latest error encountered.
func (bc *BatchCursor) Err() error { return bc.err }
// Close closes this batch cursor.
func (bc *BatchCursor) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
err := bc.KillCursor(ctx)
bc.id = 0
bc.currentBatch.Data = nil
bc.currentBatch.Style = 0
bc.currentBatch.ResetIterator()
return err
}
// Server returns the server for this cursor.
func (bc *BatchCursor) Server() Server {
return bc.server
}
func (bc *BatchCursor) clearBatch() {
bc.currentBatch.Data = bc.currentBatch.Data[:0]
}
// KillCursor kills cursor on server without closing batch cursor
func (bc *BatchCursor) KillCursor(ctx context.Context) error {
if bc.server == nil || bc.id == 0 {
return nil
}
return Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "killCursors", bc.collection)
dst = bsoncore.BuildArrayElement(dst, "cursors", bsoncore.Value{Type: bsontype.Int64, Data: bsoncore.AppendInt64(nil, bc.id)})
return dst, nil
},
Database: bc.database,
Deployment: SingleServerDeployment{Server: bc.server},
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyKillCursors,
CommandMonitor: bc.cmdMonitor,
}.Execute(ctx, nil)
}
func (bc *BatchCursor) getMore(ctx context.Context) {
bc.clearBatch()
if bc.id == 0 {
return
}
// Required for legacy operations which don't support limit.
numToReturn := bc.batchSize
if bc.limit != 0 && bc.numReturned+bc.batchSize > bc.limit {
numToReturn = bc.limit - bc.numReturned
if numToReturn <= 0 {
err := bc.Close(ctx)
if err != nil {
bc.err = err
}
return
}
}
bc.err = Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
if numToReturn > 0 {
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
}
if bc.maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", bc.maxTimeMS)
}
return dst, nil
},
Database: bc.database,
Deployment: SingleServerDeployment{Server: bc.server},
ProcessResponseFn: func(response bsoncore.Document, srvr Server, desc description.Server) error {
id, ok := response.Lookup("cursor", "id").Int64OK()
if !ok {
return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type)
}
bc.id = id
batch, ok := response.Lookup("cursor", "nextBatch").ArrayOK()
if !ok {
return fmt.Errorf("cursor.nextBatch should be an array but is a BSON %s", response.Lookup("cursor", "nextBatch").Type)
}
bc.currentBatch.Style = bsoncore.ArrayStyle
bc.currentBatch.Data = batch
bc.currentBatch.ResetIterator()
bc.numReturned += int32(bc.currentBatch.DocumentCount()) // Required for legacy operations which don't support limit.
pbrt, err := response.LookupErr("cursor", "postBatchResumeToken")
if err != nil {
// I don't really understand why we don't set bc.err here
return nil
}
pbrtDoc, ok := pbrt.DocumentOK()
if !ok {
bc.err = fmt.Errorf("expected BSON type for post batch resume token to be EmbeddedDocument but got %s", pbrt.Type)
return nil
}
bc.postBatchResumeToken = bsoncore.Document(pbrtDoc)
return nil
},
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyGetMore,
CommandMonitor: bc.cmdMonitor,
}.Execute(ctx, nil)
// Required for legacy operations which don't support limit.
if bc.limit != 0 && bc.numReturned >= bc.limit {
// call KillCursor instead of Close because Close will clear out the data for the current batch.
err := bc.KillCursor(ctx)
if err != nil && bc.err == nil {
bc.err = err
}
}
return
}
// PostBatchResumeToken returns the latest seen post batch resume token.
func (bc *BatchCursor) PostBatchResumeToken() bsoncore.Document {
return bc.postBatchResumeToken
}

69
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/batches.go generated vendored Executable file
View File

@@ -0,0 +1,69 @@
package driver
import (
"errors"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// this is the amount of reserved buffer space in a message that the
// driver reserves for command overhead.
const reservedCommandBufferBytes = 16 * 10 * 10 * 10
// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a
// server is passed to an insert command.
var ErrDocumentTooLarge = errors.New("an inserted document is too large")
// Batches contains the necessary information to batch split an operation. This is only used for write
// oeprations.
type Batches struct {
Identifier string
Documents []bsoncore.Document
Current []bsoncore.Document
Ordered *bool
}
// Valid returns true if Batches contains both an identifier and the length of Documents is greater
// than zero.
func (b *Batches) Valid() bool { return b != nil && b.Identifier != "" && len(b.Documents) > 0 }
// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the
// next batch.
func (b *Batches) ClearBatch() { b.Current = b.Current[:0] }
// AdvanceBatch splits the next batch using maxCount and targetBatchSize. This method will do nothing if
// the current batch has not been cleared. We do this so that when this is called during execute we
// can call it without first needing to check if we already have a batch, which makes the code
// simpler and makes retrying easier.
func (b *Batches) AdvanceBatch(maxCount, targetBatchSize int) error {
if len(b.Current) > 0 {
return nil
}
if targetBatchSize > reservedCommandBufferBytes {
targetBatchSize -= reservedCommandBufferBytes
}
if maxCount <= 0 {
maxCount = 1
}
splitAfter := 0
size := 1
for i, doc := range b.Documents {
if i == maxCount {
break
}
if len(doc) > targetBatchSize {
return ErrDocumentTooLarge
}
if size+len(doc) > targetBatchSize {
break
}
size += len(doc)
splitAfter++
}
b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:]
return nil
}

View File

@@ -0,0 +1,699 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package connstring // import "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
// Parse parses the provided uri and returns a URI object.
func Parse(s string) (ConnString, error) {
p := parser{dnsResolver: dns.DefaultResolver}
err := p.parse(s)
if err != nil {
err = internal.WrapErrorf(err, "error parsing uri")
}
return p.ConnString, err
}
// ConnString represents a connection string to mongodb.
type ConnString struct {
Original string
AppName string
AuthMechanism string
AuthMechanismProperties map[string]string
AuthSource string
Compressors []string
Connect ConnectMode
ConnectSet bool
ConnectTimeout time.Duration
ConnectTimeoutSet bool
Database string
HeartbeatInterval time.Duration
HeartbeatIntervalSet bool
Hosts []string
J bool
JSet bool
LocalThreshold time.Duration
LocalThresholdSet bool
MaxConnIdleTime time.Duration
MaxConnIdleTimeSet bool
MaxPoolSize uint16
MaxPoolSizeSet bool
Password string
PasswordSet bool
ReadConcernLevel string
ReadPreference string
ReadPreferenceTagSets []map[string]string
RetryWrites bool
RetryWritesSet bool
MaxStaleness time.Duration
MaxStalenessSet bool
ReplicaSet string
Scheme string
ServerSelectionTimeout time.Duration
ServerSelectionTimeoutSet bool
SocketTimeout time.Duration
SocketTimeoutSet bool
SSL bool
SSLSet bool
SSLClientCertificateKeyFile string
SSLClientCertificateKeyFileSet bool
SSLClientCertificateKeyPassword func() string
SSLClientCertificateKeyPasswordSet bool
SSLInsecure bool
SSLInsecureSet bool
SSLCaFile string
SSLCaFileSet bool
WString string
WNumber int
WNumberSet bool
Username string
ZlibLevel int
ZlibLevelSet bool
WTimeout time.Duration
WTimeoutSet bool
WTimeoutSetFromOption bool
Options map[string][]string
UnknownOptions map[string][]string
}
func (u *ConnString) String() string {
return u.Original
}
// ConnectMode informs the driver on how to connect
// to the server.
type ConnectMode uint8
// ConnectMode constants.
const (
AutoConnect ConnectMode = iota
SingleConnect
)
// Scheme constants
const (
SchemeMongoDB = "mongodb"
SchemeMongoDBSRV = "mongodb+srv"
)
type parser struct {
ConnString
dnsResolver *dns.Resolver
}
func (p *parser) parse(original string) error {
p.Original = original
uri := original
var err error
if strings.HasPrefix(uri, SchemeMongoDBSRV+"://") {
p.Scheme = SchemeMongoDBSRV
// remove the scheme
uri = uri[len(SchemeMongoDBSRV)+3:]
} else if strings.HasPrefix(uri, SchemeMongoDB+"://") {
p.Scheme = SchemeMongoDB
// remove the scheme
uri = uri[len(SchemeMongoDB)+3:]
} else {
return fmt.Errorf("scheme must be \"mongodb\" or \"mongodb+srv\"")
}
if idx := strings.Index(uri, "@"); idx != -1 {
userInfo := uri[:idx]
uri = uri[idx+1:]
username := userInfo
var password string
if idx := strings.Index(userInfo, ":"); idx != -1 {
username = userInfo[:idx]
password = userInfo[idx+1:]
p.PasswordSet = true
}
if len(username) > 1 {
if strings.Contains(username, "/") {
return fmt.Errorf("unescaped slash in username")
}
}
p.Username, err = url.QueryUnescape(username)
if err != nil {
return internal.WrapErrorf(err, "invalid username")
}
if len(password) > 1 {
if strings.Contains(password, ":") {
return fmt.Errorf("unescaped colon in password")
}
if strings.Contains(password, "/") {
return fmt.Errorf("unescaped slash in password")
}
p.Password, err = url.QueryUnescape(password)
if err != nil {
return internal.WrapErrorf(err, "invalid password")
}
}
}
// fetch the hosts field
hosts := uri
if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
if uri[idx] == '@' {
return fmt.Errorf("unescaped @ sign in user info")
}
if uri[idx] == '?' {
return fmt.Errorf("must have a / before the query ?")
}
hosts = uri[:idx]
}
var connectionArgsFromTXT []string
parsedHosts := strings.Split(hosts, ",")
if p.Scheme == SchemeMongoDBSRV {
parsedHosts, err = p.dnsResolver.ParseHosts(hosts, true)
if err != nil {
return err
}
connectionArgsFromTXT, err = p.dnsResolver.GetConnectionArgsFromTXT(hosts)
if err != nil {
return err
}
// SSL is enabled by default for SRV, but can be manually disabled with "ssl=false".
p.SSL = true
p.SSLSet = true
}
for _, host := range parsedHosts {
err = p.addHost(host)
if err != nil {
return internal.WrapErrorf(err, "invalid host \"%s\"", host)
}
}
if len(p.Hosts) == 0 {
return fmt.Errorf("must have at least 1 host")
}
uri = uri[len(hosts):]
extractedDatabase, err := extractDatabaseFromURI(uri)
if err != nil {
return err
}
uri = extractedDatabase.uri
p.Database = extractedDatabase.db
connectionArgsFromQueryString, err := extractQueryArgsFromURI(uri)
connectionArgPairs := append(connectionArgsFromTXT, connectionArgsFromQueryString...)
for _, pair := range connectionArgPairs {
err = p.addOption(pair)
if err != nil {
return err
}
}
err = p.setDefaultAuthParams(extractedDatabase.db)
if err != nil {
return err
}
err = p.validateAuth()
if err != nil {
return err
}
// Check for invalid write concern (i.e. w=0 and j=true)
if p.WNumberSet && p.WNumber == 0 && p.JSet && p.J {
return writeconcern.ErrInconsistent
}
// If WTimeout was set from manual options passed in, set WTImeoutSet to true.
if p.WTimeoutSetFromOption {
p.WTimeoutSet = true
}
return nil
}
func (p *parser) setDefaultAuthParams(dbName string) error {
switch strings.ToLower(p.AuthMechanism) {
case "plain":
if p.AuthSource == "" {
p.AuthSource = dbName
if p.AuthSource == "" {
p.AuthSource = "$external"
}
}
case "gssapi":
if p.AuthMechanismProperties == nil {
p.AuthMechanismProperties = map[string]string{
"SERVICE_NAME": "mongodb",
}
} else if v, ok := p.AuthMechanismProperties["SERVICE_NAME"]; !ok || v == "" {
p.AuthMechanismProperties["SERVICE_NAME"] = "mongodb"
}
fallthrough
case "mongodb-x509":
if p.AuthSource == "" {
p.AuthSource = "$external"
} else if p.AuthSource != "$external" {
return fmt.Errorf("auth source must be $external")
}
case "mongodb-cr":
fallthrough
case "scram-sha-1":
fallthrough
case "scram-sha-256":
if p.AuthSource == "" {
p.AuthSource = dbName
if p.AuthSource == "" {
p.AuthSource = "admin"
}
}
case "":
if p.AuthSource == "" && (p.AuthMechanismProperties != nil || p.Username != "" || p.PasswordSet) {
p.AuthSource = dbName
if p.AuthSource == "" {
p.AuthSource = "admin"
}
}
default:
return fmt.Errorf("invalid auth mechanism")
}
return nil
}
func (p *parser) validateAuth() error {
switch strings.ToLower(p.AuthMechanism) {
case "mongodb-cr":
if p.Username == "" {
return fmt.Errorf("username required for MONGO-CR")
}
if p.Password == "" {
return fmt.Errorf("password required for MONGO-CR")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("MONGO-CR cannot have mechanism properties")
}
case "mongodb-x509":
if p.Password != "" {
return fmt.Errorf("password cannot be specified for MONGO-X509")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("MONGO-X509 cannot have mechanism properties")
}
case "gssapi":
if p.Username == "" {
return fmt.Errorf("username required for GSSAPI")
}
for k := range p.AuthMechanismProperties {
if k != "SERVICE_NAME" && k != "CANONICALIZE_HOST_NAME" && k != "SERVICE_REALM" {
return fmt.Errorf("invalid auth property for GSSAPI")
}
}
case "plain":
if p.Username == "" {
return fmt.Errorf("username required for PLAIN")
}
if p.Password == "" {
return fmt.Errorf("password required for PLAIN")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("PLAIN cannot have mechanism properties")
}
case "scram-sha-1":
if p.Username == "" {
return fmt.Errorf("username required for SCRAM-SHA-1")
}
if p.Password == "" {
return fmt.Errorf("password required for SCRAM-SHA-1")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("SCRAM-SHA-1 cannot have mechanism properties")
}
case "scram-sha-256":
if p.Username == "" {
return fmt.Errorf("username required for SCRAM-SHA-256")
}
if p.Password == "" {
return fmt.Errorf("password required for SCRAM-SHA-256")
}
if p.AuthMechanismProperties != nil {
return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties")
}
case "":
if p.Username == "" && p.AuthSource != "" {
return fmt.Errorf("authsource without username is invalid")
}
default:
return fmt.Errorf("invalid auth mechanism")
}
return nil
}
func (p *parser) addHost(host string) error {
if host == "" {
return nil
}
host, err := url.QueryUnescape(host)
if err != nil {
return internal.WrapErrorf(err, "invalid host \"%s\"", host)
}
_, port, err := net.SplitHostPort(host)
// this is unfortunate that SplitHostPort actually requires
// a port to exist.
if err != nil {
if addrError, ok := err.(*net.AddrError); !ok || addrError.Err != "missing port in address" {
return err
}
}
if port != "" {
d, err := strconv.Atoi(port)
if err != nil {
return internal.WrapErrorf(err, "port must be an integer")
}
if d <= 0 || d >= 65536 {
return fmt.Errorf("port must be in the range [1, 65535]")
}
}
p.Hosts = append(p.Hosts, host)
return nil
}
func (p *parser) addOption(pair string) error {
kv := strings.SplitN(pair, "=", 2)
if len(kv) != 2 || kv[0] == "" {
return fmt.Errorf("invalid option")
}
key, err := url.QueryUnescape(kv[0])
if err != nil {
return internal.WrapErrorf(err, "invalid option key \"%s\"", kv[0])
}
value, err := url.QueryUnescape(kv[1])
if err != nil {
return internal.WrapErrorf(err, "invalid option value \"%s\"", kv[1])
}
lowerKey := strings.ToLower(key)
switch lowerKey {
case "appname":
p.AppName = value
case "authmechanism":
p.AuthMechanism = value
case "authmechanismproperties":
p.AuthMechanismProperties = make(map[string]string)
pairs := strings.Split(value, ",")
for _, pair := range pairs {
kv := strings.SplitN(pair, ":", 2)
if len(kv) != 2 || kv[0] == "" {
return fmt.Errorf("invalid authMechanism property")
}
p.AuthMechanismProperties[kv[0]] = kv[1]
}
case "authsource":
p.AuthSource = value
case "compressors":
compressors := strings.Split(value, ",")
if len(compressors) < 1 {
return fmt.Errorf("must have at least 1 compressor")
}
p.Compressors = compressors
case "connect":
switch strings.ToLower(value) {
case "automatic":
case "direct":
p.Connect = SingleConnect
default:
return fmt.Errorf("invalid 'connect' value: %s", value)
}
p.ConnectSet = true
case "connecttimeoutms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.ConnectTimeout = time.Duration(n) * time.Millisecond
p.ConnectTimeoutSet = true
case "heartbeatintervalms", "heartbeatfrequencyms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.HeartbeatInterval = time.Duration(n) * time.Millisecond
p.HeartbeatIntervalSet = true
case "journal":
switch value {
case "true":
p.J = true
case "false":
p.J = false
default:
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.JSet = true
case "localthresholdms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.LocalThreshold = time.Duration(n) * time.Millisecond
p.LocalThresholdSet = true
case "maxidletimems":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.MaxConnIdleTime = time.Duration(n) * time.Millisecond
p.MaxConnIdleTimeSet = true
case "maxpoolsize":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.MaxPoolSize = uint16(n)
p.MaxPoolSizeSet = true
case "readconcernlevel":
p.ReadConcernLevel = value
case "readpreference":
p.ReadPreference = value
case "readpreferencetags":
if value == "" {
// for when readPreferenceTags= at end of URI
break
}
tags := make(map[string]string)
items := strings.Split(value, ",")
for _, item := range items {
parts := strings.Split(item, ":")
if len(parts) != 2 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
tags[parts[0]] = parts[1]
}
p.ReadPreferenceTagSets = append(p.ReadPreferenceTagSets, tags)
case "maxstaleness":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.MaxStaleness = time.Duration(n) * time.Second
p.MaxStalenessSet = true
case "replicaset":
p.ReplicaSet = value
case "retrywrites":
p.RetryWrites = value == "true"
p.RetryWritesSet = true
case "serverselectiontimeoutms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.ServerSelectionTimeout = time.Duration(n) * time.Millisecond
p.ServerSelectionTimeoutSet = true
case "sockettimeoutms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.SocketTimeout = time.Duration(n) * time.Millisecond
p.SocketTimeoutSet = true
case "ssl":
switch value {
case "true":
p.SSL = true
case "false":
p.SSL = false
default:
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.SSLSet = true
case "sslclientcertificatekeyfile":
p.SSL = true
p.SSLSet = true
p.SSLClientCertificateKeyFile = value
p.SSLClientCertificateKeyFileSet = true
case "sslclientcertificatekeypassword":
p.SSLClientCertificateKeyPassword = func() string { return value }
p.SSLClientCertificateKeyPasswordSet = true
case "sslinsecure":
switch value {
case "true":
p.SSLInsecure = true
case "false":
p.SSLInsecure = false
default:
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.SSLInsecureSet = true
case "sslcertificateauthorityfile":
p.SSL = true
p.SSLSet = true
p.SSLCaFile = value
p.SSLCaFileSet = true
case "w":
if w, err := strconv.Atoi(value); err == nil {
if w < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.WNumber = w
p.WNumberSet = true
p.WString = ""
break
}
p.WString = value
p.WNumberSet = false
case "wtimeoutms":
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.WTimeout = time.Duration(n) * time.Millisecond
p.WTimeoutSet = true
case "wtimeout":
// Defer to wtimeoutms, but not to a manually-set option.
if p.WTimeoutSet {
break
}
n, err := strconv.Atoi(value)
if err != nil || n < 0 {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
p.WTimeout = time.Duration(n) * time.Millisecond
case "zlibcompressionlevel":
level, err := strconv.Atoi(value)
if err != nil || (level < -1 || level > 9) {
return fmt.Errorf("invalid value for %s: %s", key, value)
}
if level == -1 {
level = wiremessage.DefaultZlibLevel
}
p.ZlibLevel = level
p.ZlibLevelSet = true
default:
if p.UnknownOptions == nil {
p.UnknownOptions = make(map[string][]string)
}
p.UnknownOptions[lowerKey] = append(p.UnknownOptions[lowerKey], value)
}
if p.Options == nil {
p.Options = make(map[string][]string)
}
p.Options[lowerKey] = append(p.Options[lowerKey], value)
return nil
}
func extractQueryArgsFromURI(uri string) ([]string, error) {
if len(uri) == 0 {
return nil, nil
}
if uri[0] != '?' {
return nil, errors.New("must have a ? separator between path and query")
}
uri = uri[1:]
if len(uri) == 0 {
return nil, nil
}
return strings.FieldsFunc(uri, func(r rune) bool { return r == ';' || r == '&' }), nil
}
type extractedDatabase struct {
uri string
db string
}
// extractDatabaseFromURI is a helper function to retrieve information about
// the database from the passed in URI. It accepts as an argument the currently
// parsed URI and returns the remainder of the uri, the database it found,
// and any error it encounters while parsing.
func extractDatabaseFromURI(uri string) (extractedDatabase, error) {
if len(uri) == 0 {
return extractedDatabase{}, nil
}
if uri[0] != '/' {
return extractedDatabase{}, errors.New("must have a / separator between hosts and path")
}
uri = uri[1:]
if len(uri) == 0 {
return extractedDatabase{}, nil
}
database := uri
if idx := strings.IndexRune(uri, '?'); idx != -1 {
database = uri[:idx]
}
escapedDatabase, err := url.QueryUnescape(database)
if err != nil {
return extractedDatabase{}, internal.WrapErrorf(err, "invalid database \"%s\"", database)
}
uri = uri[len(database):]
return extractedDatabase{
uri: uri,
db: escapedDatabase,
}, nil
}

View File

@@ -0,0 +1,10 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description // import "go.mongodb.org/mongo-driver/x/mongo/driver/description"
// Unknown is an unknown server or topology kind.
const Unknown = 0

View File

@@ -0,0 +1,36 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import (
"fmt"
)
// MaxStalenessSupported returns an error if the given server version
// does not support max staleness.
func MaxStalenessSupported(wireVersion *VersionRange) error {
if wireVersion != nil && wireVersion.Max < 5 {
return fmt.Errorf("max staleness is only supported for servers 3.4 or newer")
}
return nil
}
// ScramSHA1Supported returns an error if the given server version
// does not support scram-sha-1.
func ScramSHA1Supported(wireVersion *VersionRange) error {
if wireVersion != nil && wireVersion.Max < 3 {
return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer")
}
return nil
}
// SessionsSupported returns true of the given server version indicates that it supports sessions.
func SessionsSupported(wireVersion *VersionRange) bool {
return wireVersion != nil && wireVersion.Max >= 6
}

View File

@@ -0,0 +1,154 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import (
"fmt"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/tag"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/network/result"
)
// UnsetRTT is the unset value for a round trip time.
const UnsetRTT = -1 * time.Millisecond
// SelectedServer represents a selected server that is a member of a topology.
type SelectedServer struct {
Server
Kind TopologyKind
}
// Server represents a description of a server. This is created from an isMaster
// command.
type Server struct {
Addr address.Address
AverageRTT time.Duration
AverageRTTSet bool
Compression []string // compression methods returned by server
CanonicalAddr address.Address
ElectionID primitive.ObjectID
HeartbeatInterval time.Duration
LastError error
LastUpdateTime time.Time
LastWriteTime time.Time
MaxBatchCount uint32
MaxDocumentSize uint32
MaxMessageSize uint32
Members []address.Address
ReadOnly bool
SessionTimeoutMinutes uint32
SetName string
SetVersion uint32
Tags tag.Set
Kind ServerKind
WireVersion *VersionRange
SaslSupportedMechs []string // user-specific from server handshake
}
// NewServer creates a new server description from the given parameters.
func NewServer(addr address.Address, isMaster result.IsMaster) Server {
i := Server{
Addr: addr,
CanonicalAddr: address.Address(isMaster.Me).Canonicalize(),
Compression: isMaster.Compression,
ElectionID: isMaster.ElectionID,
LastUpdateTime: time.Now().UTC(),
LastWriteTime: isMaster.LastWriteTimestamp,
MaxBatchCount: isMaster.MaxWriteBatchSize,
MaxDocumentSize: isMaster.MaxBSONObjectSize,
MaxMessageSize: isMaster.MaxMessageSizeBytes,
SaslSupportedMechs: isMaster.SaslSupportedMechs,
SessionTimeoutMinutes: isMaster.LogicalSessionTimeoutMinutes,
SetName: isMaster.SetName,
SetVersion: isMaster.SetVersion,
Tags: tag.NewTagSetFromMap(isMaster.Tags),
}
if i.CanonicalAddr == "" {
i.CanonicalAddr = addr
}
if isMaster.OK != 1 {
i.LastError = fmt.Errorf("not ok")
return i
}
for _, host := range isMaster.Hosts {
i.Members = append(i.Members, address.Address(host).Canonicalize())
}
for _, passive := range isMaster.Passives {
i.Members = append(i.Members, address.Address(passive).Canonicalize())
}
for _, arbiter := range isMaster.Arbiters {
i.Members = append(i.Members, address.Address(arbiter).Canonicalize())
}
i.Kind = Standalone
if isMaster.IsReplicaSet {
i.Kind = RSGhost
} else if isMaster.SetName != "" {
if isMaster.IsMaster {
i.Kind = RSPrimary
} else if isMaster.Hidden {
i.Kind = RSMember
} else if isMaster.Secondary {
i.Kind = RSSecondary
} else if isMaster.ArbiterOnly {
i.Kind = RSArbiter
} else {
i.Kind = RSMember
}
} else if isMaster.Msg == "isdbgrid" {
i.Kind = Mongos
}
i.WireVersion = &VersionRange{
Min: isMaster.MinWireVersion,
Max: isMaster.MaxWireVersion,
}
return i
}
// SetAverageRTT sets the average round trip time for this server description.
func (s Server) SetAverageRTT(rtt time.Duration) Server {
s.AverageRTT = rtt
if rtt == UnsetRTT {
s.AverageRTTSet = false
} else {
s.AverageRTTSet = true
}
return s
}
// DataBearing returns true if the server is a data bearing server.
func (s Server) DataBearing() bool {
return s.Kind == RSPrimary ||
s.Kind == RSSecondary ||
s.Kind == Mongos ||
s.Kind == Standalone
}
// SelectServer selects this server if it is in the list of given candidates.
func (s Server) SelectServer(_ Topology, candidates []Server) ([]Server, error) {
for _, candidate := range candidates {
if candidate.Addr == s.Addr {
return []Server{candidate}, nil
}
}
return nil, nil
}

View File

@@ -0,0 +1,43 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
// ServerKind represents the type of a server.
type ServerKind uint32
// These constants are the possible types of servers.
const (
Standalone ServerKind = 1
RSMember ServerKind = 2
RSPrimary ServerKind = 4 + RSMember
RSSecondary ServerKind = 8 + RSMember
RSArbiter ServerKind = 16 + RSMember
RSGhost ServerKind = 32 + RSMember
Mongos ServerKind = 256
)
// String implements the fmt.Stringer interface.
func (kind ServerKind) String() string {
switch kind {
case Standalone:
return "Standalone"
case RSMember:
return "RSOther"
case RSPrimary:
return "RSPrimary"
case RSSecondary:
return "RSSecondary"
case RSArbiter:
return "RSArbiter"
case RSGhost:
return "RSGhost"
case Mongos:
return "Mongos"
}
return "Unknown"
}

View File

@@ -0,0 +1,279 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import (
"fmt"
"math"
"time"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/tag"
)
// ServerSelector is an interface implemented by types that can select a server given a
// topology description.
type ServerSelector interface {
SelectServer(Topology, []Server) ([]Server, error)
}
// ServerSelectorFunc is a function that can be used as a ServerSelector.
type ServerSelectorFunc func(Topology, []Server) ([]Server, error)
// SelectServer implements the ServerSelector interface.
func (ssf ServerSelectorFunc) SelectServer(t Topology, s []Server) ([]Server, error) {
return ssf(t, s)
}
type compositeSelector struct {
selectors []ServerSelector
}
// CompositeSelector combines multiple selectors into a single selector.
func CompositeSelector(selectors []ServerSelector) ServerSelector {
return &compositeSelector{selectors: selectors}
}
func (cs *compositeSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) {
var err error
for _, sel := range cs.selectors {
candidates, err = sel.SelectServer(t, candidates)
if err != nil {
return nil, err
}
}
return candidates, nil
}
type latencySelector struct {
latency time.Duration
}
// LatencySelector creates a ServerSelector which selects servers based on their latency.
func LatencySelector(latency time.Duration) ServerSelector {
return &latencySelector{latency: latency}
}
func (ls *latencySelector) SelectServer(t Topology, candidates []Server) ([]Server, error) {
if ls.latency < 0 {
return candidates, nil
}
switch len(candidates) {
case 0, 1:
return candidates, nil
default:
min := time.Duration(math.MaxInt64)
for _, candidate := range candidates {
if candidate.AverageRTTSet {
if candidate.AverageRTT < min {
min = candidate.AverageRTT
}
}
}
if min == math.MaxInt64 {
return candidates, nil
}
max := min + ls.latency
var result []Server
for _, candidate := range candidates {
if candidate.AverageRTTSet {
if candidate.AverageRTT <= max {
result = append(result, candidate)
}
}
}
return result, nil
}
}
// WriteSelector selects all the writable servers.
func WriteSelector() ServerSelector {
return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) {
switch t.Kind {
case Single:
return candidates, nil
default:
result := []Server{}
for _, candidate := range candidates {
switch candidate.Kind {
case Mongos, RSPrimary, Standalone:
result = append(result, candidate)
}
}
return result, nil
}
})
}
// ReadPrefSelector selects servers based on the provided read preference.
func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector {
return ServerSelectorFunc(func(t Topology, candidates []Server) ([]Server, error) {
if _, set := rp.MaxStaleness(); set {
for _, s := range candidates {
if s.Kind != Unknown {
if err := MaxStalenessSupported(s.WireVersion); err != nil {
return nil, err
}
}
}
}
switch t.Kind {
case Single:
return candidates, nil
case ReplicaSetNoPrimary, ReplicaSetWithPrimary:
return selectForReplicaSet(rp, t, candidates)
case Sharded:
return selectByKind(candidates, Mongos), nil
}
return nil, nil
})
}
func selectForReplicaSet(rp *readpref.ReadPref, t Topology, candidates []Server) ([]Server, error) {
if err := verifyMaxStaleness(rp, t); err != nil {
return nil, err
}
switch rp.Mode() {
case readpref.PrimaryMode:
return selectByKind(candidates, RSPrimary), nil
case readpref.PrimaryPreferredMode:
selected := selectByKind(candidates, RSPrimary)
if len(selected) == 0 {
selected = selectSecondaries(rp, candidates)
return selectByTagSet(selected, rp.TagSets()), nil
}
return selected, nil
case readpref.SecondaryPreferredMode:
selected := selectSecondaries(rp, candidates)
selected = selectByTagSet(selected, rp.TagSets())
if len(selected) > 0 {
return selected, nil
}
return selectByKind(candidates, RSPrimary), nil
case readpref.SecondaryMode:
selected := selectSecondaries(rp, candidates)
return selectByTagSet(selected, rp.TagSets()), nil
case readpref.NearestMode:
selected := selectByKind(candidates, RSPrimary)
selected = append(selected, selectSecondaries(rp, candidates)...)
return selectByTagSet(selected, rp.TagSets()), nil
}
return nil, fmt.Errorf("unsupported mode: %d", rp.Mode())
}
func selectSecondaries(rp *readpref.ReadPref, candidates []Server) []Server {
secondaries := selectByKind(candidates, RSSecondary)
if len(secondaries) == 0 {
return secondaries
}
if maxStaleness, set := rp.MaxStaleness(); set {
primaries := selectByKind(candidates, RSPrimary)
if len(primaries) == 0 {
baseTime := secondaries[0].LastWriteTime
for i := 1; i < len(secondaries); i++ {
if secondaries[i].LastWriteTime.After(baseTime) {
baseTime = secondaries[i].LastWriteTime
}
}
var selected []Server
for _, secondary := range secondaries {
estimatedStaleness := baseTime.Sub(secondary.LastWriteTime) + secondary.HeartbeatInterval
if estimatedStaleness <= maxStaleness {
selected = append(selected, secondary)
}
}
return selected
}
primary := primaries[0]
var selected []Server
for _, secondary := range secondaries {
estimatedStaleness := secondary.LastUpdateTime.Sub(secondary.LastWriteTime) - primary.LastUpdateTime.Sub(primary.LastWriteTime) + secondary.HeartbeatInterval
if estimatedStaleness <= maxStaleness {
selected = append(selected, secondary)
}
}
return selected
}
return secondaries
}
func selectByTagSet(candidates []Server, tagSets []tag.Set) []Server {
if len(tagSets) == 0 {
return candidates
}
for _, ts := range tagSets {
var results []Server
for _, s := range candidates {
if len(s.Tags) > 0 && s.Tags.ContainsAll(ts) {
results = append(results, s)
}
}
if len(results) > 0 {
return results
}
}
return []Server{}
}
func selectByKind(candidates []Server, kind ServerKind) []Server {
var result []Server
for _, s := range candidates {
if s.Kind == kind {
result = append(result, s)
}
}
return result
}
func verifyMaxStaleness(rp *readpref.ReadPref, t Topology) error {
maxStaleness, set := rp.MaxStaleness()
if !set {
return nil
}
if maxStaleness < 90*time.Second {
return fmt.Errorf("max staleness (%s) must be greater than or equal to 90s", maxStaleness)
}
if len(t.Servers) < 1 {
// Maybe we should return an error here instead?
return nil
}
// we'll assume all candidates have the same heartbeat interval.
s := t.Servers[0]
idleWritePeriod := 10 * time.Second
if maxStaleness < s.HeartbeatInterval+idleWritePeriod {
return fmt.Errorf(
"max staleness (%s) must be greater than or equal to the heartbeat interval (%s) plus idle write period (%s)",
maxStaleness, s.HeartbeatInterval, idleWritePeriod,
)
}
return nil
}

View File

@@ -0,0 +1,136 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import (
"sort"
"strings"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
)
// Topology represents a description of a mongodb topology
type Topology struct {
Servers []Server
Kind TopologyKind
SessionTimeoutMinutes uint32
}
// Server returns the server for the given address. Returns false if the server
// could not be found.
func (t Topology) Server(addr address.Address) (Server, bool) {
for _, server := range t.Servers {
if server.Addr.String() == addr.String() {
return server, true
}
}
return Server{}, false
}
// TopologyDiff is the difference between two different topology descriptions.
type TopologyDiff struct {
Added []Server
Removed []Server
}
// DiffTopology compares the two topology descriptions and returns the difference.
func DiffTopology(old, new Topology) TopologyDiff {
var diff TopologyDiff
// TODO: do this without sorting...
oldServers := serverSorter(old.Servers)
newServers := serverSorter(new.Servers)
sort.Sort(oldServers)
sort.Sort(newServers)
i := 0
j := 0
for {
if i < len(oldServers) && j < len(newServers) {
comp := strings.Compare(oldServers[i].Addr.String(), newServers[j].Addr.String())
switch comp {
case 1:
//left is bigger than
diff.Added = append(diff.Added, newServers[j])
j++
case -1:
// right is bigger
diff.Removed = append(diff.Removed, oldServers[i])
i++
case 0:
i++
j++
}
} else if i < len(oldServers) {
diff.Removed = append(diff.Removed, oldServers[i])
i++
} else if j < len(newServers) {
diff.Added = append(diff.Added, newServers[j])
j++
} else {
break
}
}
return diff
}
// HostlistDiff is the difference between a topology and a host list.
type HostlistDiff struct {
Added []string
Removed []string
}
// DiffHostlist compares the topology description and host list and returns the difference.
func (t Topology) DiffHostlist(hostlist []string) HostlistDiff {
var diff HostlistDiff
oldServers := serverSorter(t.Servers)
sort.Sort(oldServers)
sort.Strings(hostlist)
i := 0
j := 0
for {
if i < len(oldServers) && j < len(hostlist) {
oldServer := oldServers[i].Addr.String()
comp := strings.Compare(oldServer, hostlist[j])
switch comp {
case 1:
// oldServers[i] is bigger
diff.Added = append(diff.Added, hostlist[j])
j++
case -1:
// hostlist[j] is bigger
diff.Removed = append(diff.Removed, oldServer)
i++
case 0:
i++
j++
}
} else if i < len(oldServers) {
diff.Removed = append(diff.Removed, oldServers[i].Addr.String())
i++
} else if j < len(hostlist) {
diff.Added = append(diff.Added, hostlist[j])
j++
} else {
break
}
}
return diff
}
type serverSorter []Server
func (ss serverSorter) Len() int { return len(ss) }
func (ss serverSorter) Swap(i, j int) { ss[i], ss[j] = ss[j], ss[i] }
func (ss serverSorter) Less(i, j int) bool {
return strings.Compare(ss[i].Addr.String(), ss[j].Addr.String()) < 0
}

View File

@@ -0,0 +1,37 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
// TopologyKind represents a specific topology configuration.
type TopologyKind uint32
// These constants are the available topology configurations.
const (
Single TopologyKind = 1
ReplicaSet TopologyKind = 2
ReplicaSetNoPrimary TopologyKind = 4 + ReplicaSet
ReplicaSetWithPrimary TopologyKind = 8 + ReplicaSet
Sharded TopologyKind = 256
)
// String implements the fmt.Stringer interface.
func (kind TopologyKind) String() string {
switch kind {
case Single:
return "Single"
case ReplicaSet:
return "ReplicaSet"
case ReplicaSetNoPrimary:
return "ReplicaSetNoPrimary"
case ReplicaSetWithPrimary:
return "ReplicaSetWithPrimary"
case Sharded:
return "Sharded"
}
return "Unknown"
}

View File

@@ -0,0 +1,44 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import "strconv"
// Version represents a software version.
type Version struct {
Desc string
Parts []uint8
}
// AtLeast ensures that the version is at least as large as the "other" version.
func (v Version) AtLeast(other ...uint8) bool {
for i := range other {
if i == len(v.Parts) {
return false
}
if v.Parts[i] < other[i] {
return false
}
}
return true
}
// String provides the string represtation of the Version.
func (v Version) String() string {
if v.Desc == "" {
var s string
for i, p := range v.Parts {
if i != 0 {
s += "."
}
s += strconv.Itoa(int(p))
}
return s
}
return v.Desc
}

View File

@@ -0,0 +1,31 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package description
import "fmt"
// VersionRange represents a range of versions.
type VersionRange struct {
Min int32
Max int32
}
// NewVersionRange creates a new VersionRange given a min and a max.
func NewVersionRange(min, max int32) VersionRange {
return VersionRange{Min: min, Max: max}
}
// Includes returns a bool indicating whether the supplied integer is included
// in the range.
func (vr VersionRange) Includes(v int32) bool {
return v >= vr.Min && v <= vr.Max
}
// String implements the fmt.Stringer interface.
func (vr VersionRange) String() string {
return fmt.Sprintf("[%d, %d]", vr.Min, vr.Max)
}

137
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/dns/dns.go generated vendored Executable file
View File

@@ -0,0 +1,137 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package dns
import (
"errors"
"fmt"
"net"
"runtime"
"strings"
)
// Resolver resolves DNS records.
type Resolver struct {
// Holds the functions to use for DNS lookups
LookupSRV func(string, string, string) (string, []*net.SRV, error)
LookupTXT func(string) ([]string, error)
}
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
// ParseHosts uses the srv string to get the hosts.
func (r *Resolver) ParseHosts(host string, stopOnErr bool) ([]string, error) {
parsedHosts := strings.Split(host, ",")
if len(parsedHosts) != 1 {
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
}
return r.fetchSeedlistFromSRV(parsedHosts[0], stopOnErr)
}
// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
var connectionArgsFromTXT []string
// error ignored because not finding a TXT record should not be
// considered an error.
recordsFromTXT, _ := r.LookupTXT(host)
// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
// It will currently incorrectly concatenate multiple TXT records to one
// on windows.
if runtime.GOOS == "windows" {
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
}
if len(recordsFromTXT) > 1 {
return nil, errors.New("multiple records from TXT not supported")
}
if len(recordsFromTXT) > 0 {
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
err := validateTXTResult(connectionArgsFromTXT)
if err != nil {
return nil, err
}
}
return connectionArgsFromTXT, nil
}
func (r *Resolver) fetchSeedlistFromSRV(host string, stopOnErr bool) ([]string, error) {
var err error
_, _, err = net.SplitHostPort(host)
if err == nil {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return nil, fmt.Errorf("URI with srv must not include a port number")
}
_, addresses, err := r.LookupSRV("mongodb", "tcp", host)
if err != nil {
return nil, err
}
var parsedHosts []string
for _, address := range addresses {
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
err := validateSRVResult(trimmedAddressTarget, host)
if err != nil {
if stopOnErr {
return nil, err
}
continue
}
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
}
return parsedHosts, nil
}
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(inputHostName, ".")
separatedRecord := strings.Split(recordFromSRV, ".")
if len(separatedRecord) < 2 {
return errors.New("DNS name must contain at least 2 labels")
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("Domain suffix from SRV record not matched input domain")
}
inputDomainSuffix := separatedInputDomain[1:]
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
if label != recordDomainSuffix[ix] {
return errors.New("Domain suffix from SRV record not matched input domain")
}
}
return nil
}
var allowedTXTOptions = map[string]struct{}{
"authsource": {},
"replicaset": {},
}
func validateTXTResult(paramsFromTXT []string) error {
for _, param := range paramsFromTXT {
kv := strings.SplitN(param, "=", 2)
if len(kv) != 2 {
return errors.New("Invalid TXT record")
}
key := strings.ToLower(kv[0])
if _, ok := allowedTXTOptions[key]; !ok {
return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
}
}
return nil
}

154
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/driver.go generated vendored Executable file
View File

@@ -0,0 +1,154 @@
package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver"
import (
"context"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// Deployment is implemented by types that can select a server from a deployment.
type Deployment interface {
SelectServer(context.Context, description.ServerSelector) (Server, error)
SupportsRetry() bool
Kind() description.TopologyKind
}
// Server represents a MongoDB server. Implementations should pool connections and handle the
// retrieving and returning of connections.
type Server interface {
Connection(context.Context) (Connection, error)
}
// Connection represents a connection to a MongoDB server.
type Connection interface {
WriteWireMessage(context.Context, []byte) error
ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error)
Description() description.Server
Close() error
ID() string
Address() address.Address
}
// Compressor is an interface used to compress wire messages. If a Connection supports compression
// it should implement this interface as well. The CompressWireMessage method will be called during
// the execution of an operation if the wire message is allowed to be compressed.
type Compressor interface {
CompressWireMessage(src, dst []byte) ([]byte, error)
}
// ErrorProcessor implementations can handle processing errors, which may modify their internal state.
// If this type is implemented by a Server, then Operation.Execute will call it's ProcessError
// method after it decodes a wire message.
type ErrorProcessor interface {
ProcessError(error)
}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker interface {
Handshake(context.Context, address.Address, Connection) (description.Server, error)
}
// HandshakerFunc is an adapter to allow the use of ordinary functions as
// connection handshakers.
type HandshakerFunc func(context.Context, address.Address, Connection) (description.Server, error)
// Handshake implements the Handshaker interface.
func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, conn Connection) (description.Server, error) {
return hf(ctx, addr, conn)
}
// SingleServerDeployment is an implementation of Deployment that always returns a single server.
type SingleServerDeployment struct{ Server }
var _ Deployment = SingleServerDeployment{}
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns the embedded Server.
func (ssd SingleServerDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return ssd.Server, nil
}
// SupportsRetry implements the Deployment interface. It always returns false, because a single
// server does not support retryability.
func (SingleServerDeployment) SupportsRetry() bool { return false }
// Kind implements the Deployment interface. It always returns description.Single.
func (SingleServerDeployment) Kind() description.TopologyKind { return description.Single }
// SingleConnectionDeployment is an implementation of Deployment that always returns the same
// Connection.
type SingleConnectionDeployment struct{ C Connection }
var _ Deployment = SingleConnectionDeployment{}
var _ Server = SingleConnectionDeployment{}
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns itself. The Connections returned from the
// Connection method have a no-op Close method.
func (ssd SingleConnectionDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return ssd, nil
}
// SupportsRetry implements the Deployment interface. It always returns false, because a single
// connection does not support retryability.
func (ssd SingleConnectionDeployment) SupportsRetry() bool { return false }
// Kind implements the Deployment interface. It always returns description.Single.
func (ssd SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single }
// Connection implements the Server interface. It always returns the embedded connection.
//
// This method returns a Connection with a no-op Close method. This ensures that a
// SingleConnectionDeployment can be used across multiple operation executions.
func (ssd SingleConnectionDeployment) Connection(context.Context) (Connection, error) {
return nopCloserConnection{ssd.C}, nil
}
// nopCloserConnection is an adapter used in a SingleConnectionDeployment. It passes through all
// functionality expcect for closing, which is a no-op. This is done so the connection can be used
// across multiple operations.
type nopCloserConnection struct{ Connection }
func (ncc nopCloserConnection) Close() error { return nil }
// TODO(GODRUVER-617): We can likely use 1 type for both the RetryType and the RetryMode by using
// 2 bits for the mode and 1 bit for the type. Although in the practical sense, we might not want to
// do that since the type of retryability is tied to the operation itself and isn't going change,
// e.g. and insert operation will always be a write, however some operations are both reads and
// writes, for instance aggregate is a read but with a $out parameter it's a write.
// RetryType specifies whether a retry is a read, write, or disabled.
type RetryType uint
// THese are the availables types of retry.
const (
_ RetryType = iota
RetryWrite
RetryRead
)
// RetryMode specifies the way that retries are handled for retryable operations.
type RetryMode uint
// These are the modes available for retrying.
const (
// RetryNone disables retrying.
RetryNone RetryMode = iota
// RetryOnce will enable retrying the entire operation once.
RetryOnce
// RetryOncePerCommand will enable retrying each command associated with an operation. For
// example, if an insert is batch split into 4 commands then each of those commands is eligible
// for one retry.
RetryOncePerCommand
// RetryContext will enable retrying until the context.Context's deadline is exceeded or it is
// cancelled.
RetryContext
)
// Enabled returns if this RetryMode enables retrying.
func (rm RetryMode) Enabled() bool {
return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext
}

345
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/errors.go generated vendored Executable file
View File

@@ -0,0 +1,345 @@
package driver
import (
"bytes"
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var (
retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001}
nodeIsRecoveringCodes = []int32{11600, 11602, 13436, 189, 91}
notMasterCodes = []int32{10107, 13435}
)
var (
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
// TransientTransactionError is an error label for transient errors with transactions.
TransientTransactionError = "TransientTransactionError"
// NetworkError is an error label for network errors.
NetworkError = "NetworkError"
// ErrCursorNotFound is the cursor not found error for legacy find operations.
ErrCursorNotFound = errors.New("cursor not found")
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
ErrUnacknowledgedWrite = errors.New("unacknowledged write")
)
// QueryFailureError is an error representing a command failure as a document.
type QueryFailureError struct {
Message string
Response bsoncore.Document
}
// Error implements the error interface.
func (e QueryFailureError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Response)
}
// ResponseError is an error parsing the response to a command.
type ResponseError struct {
Message string
Wrapped error
}
// NewCommandResponseError creates a CommandResponseError.
func NewCommandResponseError(msg string, err error) ResponseError {
return ResponseError{Message: msg, Wrapped: err}
}
// Error implements the error interface.
func (e ResponseError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
}
return fmt.Sprintf("%s", e.Message)
}
// WriteCommandError is an error for a write command.
type WriteCommandError struct {
WriteConcernError *WriteConcernError
WriteErrors WriteErrors
}
func (wce WriteCommandError) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write command error: [")
fmt.Fprintf(&buf, "{%s}, ", wce.WriteErrors)
fmt.Fprintf(&buf, "{%s}]", wce.WriteConcernError)
return buf.String()
}
// Retryable returns true if the error is retryable
func (wce WriteCommandError) Retryable() bool {
if wce.WriteConcernError == nil {
return false
}
return (*wce.WriteConcernError).Retryable()
}
// WriteConcernError is a write concern failure that occurred as a result of a
// write operation.
type WriteConcernError struct {
Name string
Code int64
Message string
Details bsoncore.Document
}
func (wce WriteConcernError) Error() string {
if wce.Name != "" {
return fmt.Sprintf("(%v) %v", wce.Name, wce.Message)
}
return wce.Message
}
// Retryable returns true if the error is retryable
func (wce WriteConcernError) Retryable() bool {
for _, code := range retryableCodes {
if wce.Code == int64(code) {
return true
}
}
if strings.Contains(wce.Message, "not master") || strings.Contains(wce.Message, "node is recovering") {
return true
}
return false
}
// WriteError is a non-write concern failure that occurred as a result of a write
// operation.
type WriteError struct {
Index int64
Code int64
Message string
}
func (we WriteError) Error() string { return we.Message }
// WriteErrors is a group of non-write concern failures that occurred as a result
// of a write operation.
type WriteErrors []WriteError
func (we WriteErrors) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write errors: [")
for idx, err := range we {
if idx != 0 {
fmt.Fprintf(&buf, ", ")
}
fmt.Fprintf(&buf, "{%s}", err)
}
fmt.Fprint(&buf, "]")
return buf.String()
}
// Error is a command execution error from the database.
type Error struct {
Code int32
Message string
Labels []string
Name string
}
// Error implements the error interface.
func (e Error) Error() string {
if e.Name != "" {
return fmt.Sprintf("(%v) %v", e.Name, e.Message)
}
return e.Message
}
// HasErrorLabel returns true if the error contains the specified label.
func (e Error) HasErrorLabel(label string) bool {
if e.Labels != nil {
for _, l := range e.Labels {
if l == label {
return true
}
}
}
return false
}
// Retryable returns true if the error is retryable
func (e Error) Retryable() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
if strings.Contains(e.Message, "not master") || strings.Contains(e.Message, "node is recovering") {
return true
}
return false
}
// NetworkError returns true if the error is a network error.
func (e Error) NetworkError() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
return false
}
// NodeIsRecovering returns true if this error is a node is recovering error.
func (e Error) NodeIsRecovering() bool {
for _, code := range nodeIsRecoveringCodes {
if e.Code == code {
return true
}
}
return strings.Contains(e.Message, "node is recovering")
}
// NotMaster returns true if this error is a not master error.
func (e Error) NotMaster() bool {
for _, code := range notMasterCodes {
if e.Code == code {
return true
}
}
return strings.Contains(e.Message, "not master")
}
// NamespaceNotFound returns true if this errors is a NamespaceNotFound error.
func (e Error) NamespaceNotFound() bool {
return e.Code == 26 || e.Message == "ns not found"
}
// helper method to extract an error from a reader if there is one; first returned item is the
// error if it exists, the second holds parsing errors
func extractError(rdr bsoncore.Document) error {
var errmsg, codeName string
var code int32
var labels []string
var ok bool
var wcError WriteCommandError
elems, err := rdr.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bson.TypeInt32:
if elem.Value().Int32() == 1 {
ok = true
}
case bson.TypeInt64:
if elem.Value().Int64() == 1 {
ok = true
}
case bson.TypeDouble:
if elem.Value().Double() == 1 {
ok = true
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
elems, err := arr.Elements()
if err != nil {
continue
}
for _, elem := range elems {
if str, ok := elem.Value().StringValueOK(); ok {
labels = append(labels, str)
}
}
}
case "writeErrors":
arr, exists := elem.Value().ArrayOK()
if !exists {
break
}
vals, err := arr.Values()
if err != nil {
continue
}
for _, val := range vals {
var we WriteError
doc, exists := val.DocumentOK()
if !exists {
continue
}
if index, exists := doc.Lookup("index").AsInt64OK(); exists {
we.Index = index
}
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
we.Code = code
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
we.Message = msg
}
wcError.WriteErrors = append(wcError.WriteErrors, we)
}
case "writeConcernError":
doc, exists := elem.Value().DocumentOK()
if !exists {
break
}
wcError.WriteConcernError = new(WriteConcernError)
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
wcError.WriteConcernError.Code = code
}
if name, exists := doc.Lookup("codeName").StringValueOK(); exists {
wcError.WriteConcernError.Name = name
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
wcError.WriteConcernError.Message = msg
}
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
wcError.WriteConcernError.Details = make([]byte, len(info))
copy(wcError.WriteConcernError.Details, info)
}
}
}
if !ok {
if errmsg == "" {
errmsg = "command failed"
}
return Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
}
}
if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
return wcError
}
return nil
}

16
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/legacy.go generated vendored Executable file
View File

@@ -0,0 +1,16 @@
package driver
// LegacyOperationKind indicates if an operation is a legacy find, getMore, or killCursors. This is used
// in Operation.Execute, which will create legacy OP_QUERY, OP_GET_MORE, or OP_KILL_CURSORS instead
// of sending them as a command.
type LegacyOperationKind uint
// These constants represent the three different kinds of legacy operations.
const (
LegacyNone LegacyOperationKind = iota
LegacyFind
LegacyGetMore
LegacyKillCursors
LegacyListCollections
LegacyListIndexes
)

View File

@@ -0,0 +1,129 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"io"
"strings"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ListCollectionsBatchCursor is a special batch cursor returned from ListCollections that properly
// handles current and legacy ListCollections operations.
type ListCollectionsBatchCursor struct {
legacy bool // server version < 3.0
bc *BatchCursor
currentBatch *bsoncore.DocumentSequence
err error
}
// NewListCollectionsBatchCursor creates a new non-legacy ListCollectionsCursor.
func NewListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
if bc == nil {
return nil, errors.New("batch cursor must not be nil")
}
return &ListCollectionsBatchCursor{bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
}
// NewLegacyListCollectionsBatchCursor creates a new legacy ListCollectionsCursor.
func NewLegacyListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
if bc == nil {
return nil, errors.New("batch cursor must not be nil")
}
return &ListCollectionsBatchCursor{legacy: true, bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
}
// ID returns the cursor ID for this batch cursor.
func (lcbc *ListCollectionsBatchCursor) ID() int64 {
return lcbc.bc.ID()
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (lcbc *ListCollectionsBatchCursor) Next(ctx context.Context) bool {
if !lcbc.bc.Next(ctx) {
return false
}
if !lcbc.legacy {
lcbc.currentBatch.Style = lcbc.bc.currentBatch.Style
lcbc.currentBatch.Data = lcbc.bc.currentBatch.Data
lcbc.currentBatch.ResetIterator()
return true
}
lcbc.currentBatch.Style = bsoncore.SequenceStyle
lcbc.currentBatch.Data = lcbc.currentBatch.Data[:0]
var doc bsoncore.Document
for {
doc, lcbc.err = lcbc.bc.currentBatch.Next()
if lcbc.err != nil {
if lcbc.err == io.EOF {
lcbc.err = nil
break
}
return false
}
doc, lcbc.err = lcbc.projectNameElement(doc)
if lcbc.err != nil {
return false
}
lcbc.currentBatch.Data = append(lcbc.currentBatch.Data, doc...)
}
return true
}
// Batch will return a DocumentSequence for the current batch of documents. The returned
// DocumentSequence is only valid until the next call to Next or Close.
func (lcbc *ListCollectionsBatchCursor) Batch() *bsoncore.DocumentSequence { return lcbc.currentBatch }
// Server returns a pointer to the cursor's server.
func (lcbc *ListCollectionsBatchCursor) Server() Server { return lcbc.bc.server }
// Err returns the latest error encountered.
func (lcbc *ListCollectionsBatchCursor) Err() error {
if lcbc.err != nil {
return lcbc.err
}
return lcbc.bc.Err()
}
// Close closes this batch cursor.
func (lcbc *ListCollectionsBatchCursor) Close(ctx context.Context) error { return lcbc.bc.Close(ctx) }
// project out the database name for a legacy server
func (*ListCollectionsBatchCursor) projectNameElement(rawDoc bsoncore.Document) (bsoncore.Document, error) {
elems, err := rawDoc.Elements()
if err != nil {
return nil, err
}
var filteredElems []byte
for _, elem := range elems {
key := elem.Key()
if key != "name" {
filteredElems = append(filteredElems, elem...)
continue
}
name := elem.Value().StringValue()
collName := name[strings.Index(name, ".")+1:]
filteredElems = bsoncore.AppendStringElement(filteredElems, "name", collName)
}
var filteredDoc []byte
filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems)
return filteredDoc, nil
}

1224
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/operation.go generated vendored Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,328 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Performs an aggregate operation
type Aggregate struct {
allowDiskUse *bool
batchSize *int32
bypassDocumentValidation *bool
collation bsoncore.Document
comment *string
hint bsoncore.Value
maxTimeMS *int64
pipeline bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result driver.CursorResponse
}
// NewAggregate constructs and returns a new Aggregate.
func NewAggregate(pipeline bsoncore.Document) *Aggregate {
return &Aggregate{
pipeline: pipeline,
}
}
// Result returns the result of executing this operation.
func (a *Aggregate) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
clientSession := a.session
clock := a.clock
return driver.NewBatchCursor(a.result, clientSession, clock, opts)
}
func (a *Aggregate) ResultCursorResponse() driver.CursorResponse {
return a.result
}
func (a *Aggregate) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
a.result, err = driver.NewCursorResponse(response, srvr, desc)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (a *Aggregate) Execute(ctx context.Context) error {
if a.deployment == nil {
return errors.New("the Aggregate operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: a.command,
ProcessResponseFn: a.processResponse,
Client: a.session,
Clock: a.clock,
CommandMonitor: a.monitor,
Database: a.database,
Deployment: a.deployment,
ReadConcern: a.readConcern,
ReadPreference: a.readPreference,
Selector: a.selector,
WriteConcern: a.writeConcern,
MinimumWriteConcernWireVersion: 5,
}.Execute(ctx, nil)
}
func (a *Aggregate) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
header := bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, a.collection)}
if a.collection == "" {
header = bsoncore.Value{Type: bsontype.Int32, Data: []byte{0x01, 0x00, 0x00, 0x00}}
}
dst = bsoncore.AppendValueElement(dst, "aggregate", header)
cursorIdx, cursorDoc := bsoncore.AppendDocumentStart(nil)
if a.allowDiskUse != nil {
dst = bsoncore.AppendBooleanElement(dst, "allowDiskUse", *a.allowDiskUse)
}
if a.batchSize != nil {
cursorDoc = bsoncore.AppendInt32Element(cursorDoc, "batchSize", *a.batchSize)
}
if a.bypassDocumentValidation != nil {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *a.bypassDocumentValidation)
}
if a.collation != nil {
if desc.WireVersion == nil || !desc.WireVersion.Includes(5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", a.collation)
}
if a.comment != nil {
dst = bsoncore.AppendStringElement(dst, "comment", *a.comment)
}
if a.hint.Type != bsontype.Type(0) {
dst = bsoncore.AppendValueElement(dst, "hint", a.hint)
}
if a.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *a.maxTimeMS)
}
if a.pipeline != nil {
dst = bsoncore.AppendArrayElement(dst, "pipeline", a.pipeline)
}
cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx)
dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc)
return dst, nil
}
// AllowDiskUse enables writing to temporary files. When true, aggregation stages can write to the dbPath/_tmp directory.
func (a *Aggregate) AllowDiskUse(allowDiskUse bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.allowDiskUse = &allowDiskUse
return a
}
// BatchSize specifies the number of documents to return in every batch.
func (a *Aggregate) BatchSize(batchSize int32) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.batchSize = &batchSize
return a
}
// BypassDocumentValidation allows the write to opt-out of document level validation. This only applies when the $out stage is specified.
func (a *Aggregate) BypassDocumentValidation(bypassDocumentValidation bool) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.bypassDocumentValidation = &bypassDocumentValidation
return a
}
// Collation specifies a collation. This option is only valid for server versions 3.4 and above.
func (a *Aggregate) Collation(collation bsoncore.Document) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.collation = collation
return a
}
// Comment specifies an arbitrary string to help trace the operation through the database profiler, currentOp, and logs.
func (a *Aggregate) Comment(comment string) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.comment = &comment
return a
}
// Hint specifies the index to use.
func (a *Aggregate) Hint(hint bsoncore.Value) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.hint = hint
return a
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (a *Aggregate) MaxTimeMS(maxTimeMS int64) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.maxTimeMS = &maxTimeMS
return a
}
// Pipeline determines how data is transformed for an aggregation.
func (a *Aggregate) Pipeline(pipeline bsoncore.Document) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.pipeline = pipeline
return a
}
// Session sets the session for this operation.
func (a *Aggregate) Session(session *session.Client) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.session = session
return a
}
// ClusterClock sets the cluster clock for this operation.
func (a *Aggregate) ClusterClock(clock *session.ClusterClock) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.clock = clock
return a
}
// Collection sets the collection that this command will run against.
func (a *Aggregate) Collection(collection string) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.collection = collection
return a
}
// CommandMonitor sets the monitor to use for APM events.
func (a *Aggregate) CommandMonitor(monitor *event.CommandMonitor) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.monitor = monitor
return a
}
// Database sets the database to run this operation against.
func (a *Aggregate) Database(database string) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.database = database
return a
}
// Deployment sets the deployment to use for this operation.
func (a *Aggregate) Deployment(deployment driver.Deployment) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.deployment = deployment
return a
}
// ReadConcern specifies the read concern for this operation.
func (a *Aggregate) ReadConcern(readConcern *readconcern.ReadConcern) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.readConcern = readConcern
return a
}
// ReadPreference set the read prefernce used with this operation.
func (a *Aggregate) ReadPreference(readPreference *readpref.ReadPref) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.readPreference = readPreference
return a
}
// ServerSelector sets the selector used to retrieve a server.
func (a *Aggregate) ServerSelector(selector description.ServerSelector) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.selector = selector
return a
}
// WriteConcern sets the write concern for this operation.
func (a *Aggregate) WriteConcern(writeConcern *writeconcern.WriteConcern) *Aggregate {
if a == nil {
a = new(Aggregate)
}
a.writeConcern = writeConcern
return a
}

View File

@@ -0,0 +1,47 @@
version = 0
name = "Aggregate"
documentation = "Performs an aggregate operation"
response.type = "batch cursor"
[properties]
enabled = ["read concern", "read preference", "write concern"]
MinimumWriteConcernWireVersion = 5
[command]
name = "aggregate"
parameter = "collection"
database = true
[request.pipeline]
type = "array"
constructor = true
documentation = "Pipeline determines how data is transformed for an aggregation."
[request.allowDiskUse]
type = "boolean"
documentation = "AllowDiskUse enables writing to temporary files. When true, aggregation stages can write to the dbPath/_tmp directory."
[request.batchSize]
type = "int32"
documentation = "BatchSize specifies the number of documents to return in every batch."
[request.bypassDocumentValidation]
type = "boolean"
documentation = "BypassDocumentValidation allows the write to opt-out of document level validation. This only applies when the $out stage is specified."
[request.collation]
type = "document"
minWireVersionRequired = 5
documentation = "Collation specifies a collation. This option is only valid for server versions 3.4 and above."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."
[request.comment]
type = "string"
documentation = "Comment specifies an arbitrary string to help trace the operation through the database profiler, currentOp, and logs."
[request.hint]
type = "value"
documentation = "Hint specifies the index to use."

View File

@@ -0,0 +1,164 @@
// NOTE: This file is maintained by hand because operationgen cannot generate it.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Command is used to run a generic operation.
type Command struct {
command bsoncore.Document
readConcern *readconcern.ReadConcern
database string
deployment driver.Deployment
selector description.ServerSelector
readPreference *readpref.ReadPref
clock *session.ClusterClock
session *session.Client
monitor *event.CommandMonitor
result bsoncore.Document
srvr driver.Server
desc description.Server
}
// NewCommand constructs and returns a new Command.
func NewCommand(command bsoncore.Document) *Command { return &Command{command: command} }
// Result returns the result of executing this operation.
func (c *Command) Result() bsoncore.Document { return c.result }
// ResultCursor parses the command response as a cursor and returns the resulting BatchCursor.
func (c *Command) ResultCursor(opts driver.CursorOptions) (*driver.BatchCursor, error) {
cursorRes, err := driver.NewCursorResponse(c.result, c.srvr, c.desc)
if err != nil {
return nil, err
}
return driver.NewBatchCursor(cursorRes, c.session, c.clock, opts)
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (c *Command) Execute(ctx context.Context) error {
if c.deployment == nil {
return errors.New("the Command operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
return append(dst, c.command[4:len(c.command)-1]...), nil
},
ProcessResponseFn: func(resp bsoncore.Document, srvr driver.Server, desc description.Server) error {
c.result = resp
c.srvr = srvr
c.desc = desc
return nil
},
Client: c.session,
Clock: c.clock,
CommandMonitor: c.monitor,
Database: c.database,
Deployment: c.deployment,
ReadPreference: c.readPreference,
Selector: c.selector,
}.Execute(ctx, nil)
}
// Command sets the command to be run.
func (c *Command) Command(command bsoncore.Document) *Command {
if c == nil {
c = new(Command)
}
c.command = command
return c
}
// Session sets the session for this operation.
func (c *Command) Session(session *session.Client) *Command {
if c == nil {
c = new(Command)
}
c.session = session
return c
}
// ClusterClock sets the cluster clock for this operation.
func (c *Command) ClusterClock(clock *session.ClusterClock) *Command {
if c == nil {
c = new(Command)
}
c.clock = clock
return c
}
// CommandMonitor sets the monitor to use for APM events.
func (c *Command) CommandMonitor(monitor *event.CommandMonitor) *Command {
if c == nil {
c = new(Command)
}
c.monitor = monitor
return c
}
// Database sets the database to run this operation against.
func (c *Command) Database(database string) *Command {
if c == nil {
c = new(Command)
}
c.database = database
return c
}
// Deployment sets the deployment to use for this operation.
func (c *Command) Deployment(deployment driver.Deployment) *Command {
if c == nil {
c = new(Command)
}
c.deployment = deployment
return c
}
// ReadConcern specifies the read concern for this operation.
func (c *Command) ReadConcern(readConcern *readconcern.ReadConcern) *Command {
if c == nil {
c = new(Command)
}
c.readConcern = readConcern
return c
}
// ReadPreference set the read prefernce used with this operation.
func (c *Command) ReadPreference(readPreference *readpref.ReadPref) *Command {
if c == nil {
c = new(Command)
}
c.readPreference = readPreference
return c
}
// ServerSelector sets the selector used to retrieve a server.
func (c *Command) ServerSelector(selector description.ServerSelector) *Command {
if c == nil {
c = new(Command)
}
c.selector = selector
return c
}

View File

@@ -0,0 +1,167 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// CommitTransaction attempts to commit a transaction.
type CommitTransaction struct {
recoveryToken bsoncore.Document
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
}
// NewCommitTransaction constructs and returns a new CommitTransaction.
func NewCommitTransaction() *CommitTransaction {
return &CommitTransaction{}
}
func (ct *CommitTransaction) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (ct *CommitTransaction) Execute(ctx context.Context) error {
if ct.deployment == nil {
return errors.New("the CommitTransaction operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ct.command,
ProcessResponseFn: ct.processResponse,
RetryMode: ct.retry,
RetryType: driver.RetryWrite,
Client: ct.session,
Clock: ct.clock,
CommandMonitor: ct.monitor,
Database: ct.database,
Deployment: ct.deployment,
Selector: ct.selector,
WriteConcern: ct.writeConcern,
}.Execute(ctx, nil)
}
func (ct *CommitTransaction) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "commitTransaction", 1)
if ct.recoveryToken != nil {
dst = bsoncore.AppendDocumentElement(dst, "recoveryToken", ct.recoveryToken)
}
return dst, nil
}
// RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction.
func (ct *CommitTransaction) RecoveryToken(recoveryToken bsoncore.Document) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.recoveryToken = recoveryToken
return ct
}
// Session sets the session for this operation.
func (ct *CommitTransaction) Session(session *session.Client) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.session = session
return ct
}
// ClusterClock sets the cluster clock for this operation.
func (ct *CommitTransaction) ClusterClock(clock *session.ClusterClock) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.clock = clock
return ct
}
// CommandMonitor sets the monitor to use for APM events.
func (ct *CommitTransaction) CommandMonitor(monitor *event.CommandMonitor) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.monitor = monitor
return ct
}
// Database sets the database to run this operation against.
func (ct *CommitTransaction) Database(database string) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.database = database
return ct
}
// Deployment sets the deployment to use for this operation.
func (ct *CommitTransaction) Deployment(deployment driver.Deployment) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.deployment = deployment
return ct
}
// ServerSelector sets the selector used to retrieve a server.
func (ct *CommitTransaction) ServerSelector(selector description.ServerSelector) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.selector = selector
return ct
}
// WriteConcern sets the write concern for this operation.
func (ct *CommitTransaction) WriteConcern(writeConcern *writeconcern.WriteConcern) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.writeConcern = writeConcern
return ct
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (ct *CommitTransaction) Retry(retry driver.RetryMode) *CommitTransaction {
if ct == nil {
ct = new(CommitTransaction)
}
ct.retry = &retry
return ct
}

View File

@@ -0,0 +1,18 @@
version = 0
name = "CommitTransaction"
documentation = "CommitTransaction attempts to commit a transaction."
[properties]
enabled = ["write concern"]
disabled = ["collection"]
retryable = {mode = "once per command", type = "writes"}
[command]
name = "commitTransaction"
parameter = "database"
[request.recoveryToken]
type = "document"
documentation = """
RecoveryToken sets the recovery token to use when committing or aborting a sharded transaction.\
"""

View File

@@ -0,0 +1,211 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// CreateIndexes performs a createIndexes operation.
type CreateIndexes struct {
indexes bsoncore.Document
maxTimeMS *int64
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
result CreateIndexesResult
}
type CreateIndexesResult struct {
// If the collection was created automatically.
CreatedCollectionAutomatically bool
// The number of indexes existing after this command.
IndexesAfter int32
// The number of indexes existing before this command.
IndexesBefore int32
}
func buildCreateIndexesResult(response bsoncore.Document, srvr driver.Server) (CreateIndexesResult, error) {
elements, err := response.Elements()
if err != nil {
return CreateIndexesResult{}, err
}
cir := CreateIndexesResult{}
for _, element := range elements {
switch element.Key() {
case "createdCollectionAutomatically":
var ok bool
cir.CreatedCollectionAutomatically, ok = element.Value().BooleanOK()
if !ok {
err = fmt.Errorf("response field 'createdCollectionAutomatically' is type bool, but received BSON type %s", element.Value().Type)
}
case "indexesAfter":
var ok bool
cir.IndexesAfter, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'indexesAfter' is type int32, but received BSON type %s", element.Value().Type)
}
case "indexesBefore":
var ok bool
cir.IndexesBefore, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'indexesBefore' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return cir, nil
}
// NewCreateIndexes constructs and returns a new CreateIndexes.
func NewCreateIndexes(indexes bsoncore.Document) *CreateIndexes {
return &CreateIndexes{
indexes: indexes,
}
}
// Result returns the result of executing this operation.
func (ci *CreateIndexes) Result() CreateIndexesResult { return ci.result }
func (ci *CreateIndexes) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
ci.result, err = buildCreateIndexesResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (ci *CreateIndexes) Execute(ctx context.Context) error {
if ci.deployment == nil {
return errors.New("the CreateIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ci.command,
ProcessResponseFn: ci.processResponse,
Client: ci.session,
Clock: ci.clock,
CommandMonitor: ci.monitor,
Database: ci.database,
Deployment: ci.deployment,
Selector: ci.selector,
}.Execute(ctx, nil)
}
func (ci *CreateIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "createIndexes", ci.collection)
if ci.indexes != nil {
dst = bsoncore.AppendArrayElement(dst, "indexes", ci.indexes)
}
if ci.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *ci.maxTimeMS)
}
return dst, nil
}
// An array containing index specification documents for the indexes being created.
func (ci *CreateIndexes) Indexes(indexes bsoncore.Document) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.indexes = indexes
return ci
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (ci *CreateIndexes) MaxTimeMS(maxTimeMS int64) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.maxTimeMS = &maxTimeMS
return ci
}
// Session sets the session for this operation.
func (ci *CreateIndexes) Session(session *session.Client) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.session = session
return ci
}
// ClusterClock sets the cluster clock for this operation.
func (ci *CreateIndexes) ClusterClock(clock *session.ClusterClock) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.clock = clock
return ci
}
// Collection sets the collection that this command will run against.
func (ci *CreateIndexes) Collection(collection string) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.collection = collection
return ci
}
// CommandMonitor sets the monitor to use for APM events.
func (ci *CreateIndexes) CommandMonitor(monitor *event.CommandMonitor) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.monitor = monitor
return ci
}
// Database sets the database to run this operation against.
func (ci *CreateIndexes) Database(database string) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.database = database
return ci
}
// Deployment sets the deployment to use for this operation.
func (ci *CreateIndexes) Deployment(deployment driver.Deployment) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.deployment = deployment
return ci
}
// ServerSelector sets the selector used to retrieve a server.
func (ci *CreateIndexes) ServerSelector(selector description.ServerSelector) *CreateIndexes {
if ci == nil {
ci = new(CreateIndexes)
}
ci.selector = selector
return ci
}

View File

@@ -0,0 +1,31 @@
version = 0
name = "CreateIndexes"
documentation = "CreateIndexes performs a createIndexes operation."
[command]
name = "createIndexes"
parameter = "collection"
[request.indexes]
type = "array"
constructor = true
documentation = "An array containing index specification documents for the indexes being created."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."
[response]
name = "CreateIndexesResult"
[response.field.createdCollectionAutomatically]
type = "boolean"
documentation = "If the collection was created automatically."
[response.field.indexesBefore]
type = "int32"
documentation = "The number of indexes existing before this command."
[response.field.indexesAfter]
type = "int32"
documentation = "The number of indexes existing after this command."

View File

@@ -0,0 +1,229 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Delete performs a delete operation
type Delete struct {
deletes []bsoncore.Document
ordered *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
result DeleteResult
}
type DeleteResult struct {
// Number of documents successfully deleted.
N int32
}
func buildDeleteResult(response bsoncore.Document, srvr driver.Server) (DeleteResult, error) {
elements, err := response.Elements()
if err != nil {
return DeleteResult{}, err
}
dr := DeleteResult{}
for _, element := range elements {
switch element.Key() {
case "n":
var ok bool
dr.N, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'n' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return dr, nil
}
// NewDelete constructs and returns a new Delete.
func NewDelete(deletes ...bsoncore.Document) *Delete {
return &Delete{
deletes: deletes,
}
}
// Result returns the result of executing this operation.
func (d *Delete) Result() DeleteResult { return d.result }
func (d *Delete) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
d.result, err = buildDeleteResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (d *Delete) Execute(ctx context.Context) error {
if d.deployment == nil {
return errors.New("the Delete operation must have a Deployment set before Execute can be called")
}
batches := &driver.Batches{
Identifier: "deletes",
Documents: d.deletes,
Ordered: d.ordered,
}
return driver.Operation{
CommandFn: d.command,
ProcessResponseFn: d.processResponse,
Batches: batches,
RetryMode: d.retry,
RetryType: driver.RetryWrite,
Client: d.session,
Clock: d.clock,
CommandMonitor: d.monitor,
Database: d.database,
Deployment: d.deployment,
Selector: d.selector,
WriteConcern: d.writeConcern,
}.Execute(ctx, nil)
}
func (d *Delete) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "delete", d.collection)
if d.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *d.ordered)
}
return dst, nil
}
// Deletes adds documents to this operation that will be used to determine what documents to delete when this operation
// is executed. These documents should have the form {q: <query>, limit: <integer limit>, collation: <document>}. The
// collation field is optional. If limit is 0, there will be no limit on the number of documents deleted.
func (d *Delete) Deletes(deletes ...bsoncore.Document) *Delete {
if d == nil {
d = new(Delete)
}
d.deletes = deletes
return d
}
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
// false write failures do not stop execution of the operation.
func (d *Delete) Ordered(ordered bool) *Delete {
if d == nil {
d = new(Delete)
}
d.ordered = &ordered
return d
}
// Session sets the session for this operation.
func (d *Delete) Session(session *session.Client) *Delete {
if d == nil {
d = new(Delete)
}
d.session = session
return d
}
// ClusterClock sets the cluster clock for this operation.
func (d *Delete) ClusterClock(clock *session.ClusterClock) *Delete {
if d == nil {
d = new(Delete)
}
d.clock = clock
return d
}
// Collection sets the collection that this command will run against.
func (d *Delete) Collection(collection string) *Delete {
if d == nil {
d = new(Delete)
}
d.collection = collection
return d
}
// CommandMonitor sets the monitor to use for APM events.
func (d *Delete) CommandMonitor(monitor *event.CommandMonitor) *Delete {
if d == nil {
d = new(Delete)
}
d.monitor = monitor
return d
}
// Database sets the database to run this operation against.
func (d *Delete) Database(database string) *Delete {
if d == nil {
d = new(Delete)
}
d.database = database
return d
}
// Deployment sets the deployment to use for this operation.
func (d *Delete) Deployment(deployment driver.Deployment) *Delete {
if d == nil {
d = new(Delete)
}
d.deployment = deployment
return d
}
// ServerSelector sets the selector used to retrieve a server.
func (d *Delete) ServerSelector(selector description.ServerSelector) *Delete {
if d == nil {
d = new(Delete)
}
d.selector = selector
return d
}
// WriteConcern sets the write concern for this operation.
func (d *Delete) WriteConcern(writeConcern *writeconcern.WriteConcern) *Delete {
if d == nil {
d = new(Delete)
}
d.writeConcern = writeConcern
return d
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (d *Delete) Retry(retry driver.RetryMode) *Delete {
if d == nil {
d = new(Delete)
}
d.retry = &retry
return d
}

View File

@@ -0,0 +1,38 @@
version = 0
name = "Delete"
documentation = "Delete performs a delete operation"
[properties]
enabled = ["write concern"]
retryable = {mode = "once per command", type = "writes"}
batches = "deletes"
[command]
name = "delete"
parameter = "collection"
[request.deletes]
type = "document"
slice = true
constructor = true
variadic = true
required = true
documentation = """
Deletes adds documents to this operation that will be used to determine what documents to delete when this operation
is executed. These documents should have the form {q: <query>, limit: <integer limit>, collation: <document>}. The
collation field is optional. If limit is 0, there will be no limit on the number of documents deleted.\
"""
[request.ordered]
type = "boolean"
documentation = """
Ordered sets ordered. If true, when a write fails, the operation will return the error, when
false write failures do not stop execution of the operation.\
"""
[response]
name = "DeleteResult"
[response.field.n]
type = "int32"
documentation = "Number of documents successfully deleted."

View File

@@ -0,0 +1,248 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Distinct performs a distinct operation.
type Distinct struct {
collation bsoncore.Document
key *string
maxTimeMS *int64
query bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
result DistinctResult
}
type DistinctResult struct {
// The distinct values for the field.
Values bsoncore.Value
}
func buildDistinctResult(response bsoncore.Document, srvr driver.Server) (DistinctResult, error) {
elements, err := response.Elements()
if err != nil {
return DistinctResult{}, err
}
dr := DistinctResult{}
for _, element := range elements {
switch element.Key() {
case "values":
dr.Values = element.Value()
}
}
return dr, nil
}
// NewDistinct constructs and returns a new Distinct.
func NewDistinct(key string, query bsoncore.Document) *Distinct {
return &Distinct{
key: &key,
query: query,
}
}
// Result returns the result of executing this operation.
func (d *Distinct) Result() DistinctResult { return d.result }
func (d *Distinct) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
d.result, err = buildDistinctResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (d *Distinct) Execute(ctx context.Context) error {
if d.deployment == nil {
return errors.New("the Distinct operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: d.command,
ProcessResponseFn: d.processResponse,
Client: d.session,
Clock: d.clock,
CommandMonitor: d.monitor,
Database: d.database,
Deployment: d.deployment,
ReadConcern: d.readConcern,
ReadPreference: d.readPreference,
Selector: d.selector,
}.Execute(ctx, nil)
}
func (d *Distinct) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "distinct", d.collection)
if d.collation != nil {
if desc.WireVersion == nil || !desc.WireVersion.Includes(5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", d.collation)
}
if d.key != nil {
dst = bsoncore.AppendStringElement(dst, "key", *d.key)
}
if d.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *d.maxTimeMS)
}
if d.query != nil {
dst = bsoncore.AppendDocumentElement(dst, "query", d.query)
}
return dst, nil
}
// Collation specifies a collation to be used.
func (d *Distinct) Collation(collation bsoncore.Document) *Distinct {
if d == nil {
d = new(Distinct)
}
d.collation = collation
return d
}
// Key specifies which field to return distinct values for.
func (d *Distinct) Key(key string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.key = &key
return d
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (d *Distinct) MaxTimeMS(maxTimeMS int64) *Distinct {
if d == nil {
d = new(Distinct)
}
d.maxTimeMS = &maxTimeMS
return d
}
// Query specifies which documents to return distinct values from.
func (d *Distinct) Query(query bsoncore.Document) *Distinct {
if d == nil {
d = new(Distinct)
}
d.query = query
return d
}
// Session sets the session for this operation.
func (d *Distinct) Session(session *session.Client) *Distinct {
if d == nil {
d = new(Distinct)
}
d.session = session
return d
}
// ClusterClock sets the cluster clock for this operation.
func (d *Distinct) ClusterClock(clock *session.ClusterClock) *Distinct {
if d == nil {
d = new(Distinct)
}
d.clock = clock
return d
}
// Collection sets the collection that this command will run against.
func (d *Distinct) Collection(collection string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.collection = collection
return d
}
// CommandMonitor sets the monitor to use for APM events.
func (d *Distinct) CommandMonitor(monitor *event.CommandMonitor) *Distinct {
if d == nil {
d = new(Distinct)
}
d.monitor = monitor
return d
}
// Database sets the database to run this operation against.
func (d *Distinct) Database(database string) *Distinct {
if d == nil {
d = new(Distinct)
}
d.database = database
return d
}
// Deployment sets the deployment to use for this operation.
func (d *Distinct) Deployment(deployment driver.Deployment) *Distinct {
if d == nil {
d = new(Distinct)
}
d.deployment = deployment
return d
}
// ReadConcern specifies the read concern for this operation.
func (d *Distinct) ReadConcern(readConcern *readconcern.ReadConcern) *Distinct {
if d == nil {
d = new(Distinct)
}
d.readConcern = readConcern
return d
}
// ReadPreference set the read prefernce used with this operation.
func (d *Distinct) ReadPreference(readPreference *readpref.ReadPref) *Distinct {
if d == nil {
d = new(Distinct)
}
d.readPreference = readPreference
return d
}
// ServerSelector sets the selector used to retrieve a server.
func (d *Distinct) ServerSelector(selector description.ServerSelector) *Distinct {
if d == nil {
d = new(Distinct)
}
d.selector = selector
return d
}

View File

@@ -0,0 +1,36 @@
version = 0
name = "Distinct"
documentation = "Distinct performs a distinct operation."
[properties]
enabled = ["read concern", "read preference"]
[command]
name = "distinct"
parameter = "collection"
[request.key]
type = "string"
constructor = true
documentation = "Key specifies which field to return distinct values for."
[request.query]
type = "document"
constructor = true
documentation = "Query specifies which documents to return distinct values from."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."
[request.collation]
type = "document"
minWireVersionRequired = 5
documentation = "Collation specifies a collation to be used."
[response]
name = "DistinctResult"
[response.field.values]
type = "value"
documentation = "The distinct values for the field."

View File

@@ -0,0 +1,186 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// DropCollection performs a drop operation.
type DropCollection struct {
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result DropCollectionResult
}
type DropCollectionResult struct {
// The number of indexes in the dropped collection.
NIndexesWas int32
// The namespace of the dropped collection.
Ns string
}
func buildDropCollectionResult(response bsoncore.Document, srvr driver.Server) (DropCollectionResult, error) {
elements, err := response.Elements()
if err != nil {
return DropCollectionResult{}, err
}
dcr := DropCollectionResult{}
for _, element := range elements {
switch element.Key() {
case "nIndexesWas":
var ok bool
dcr.NIndexesWas, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'nIndexesWas' is type int32, but received BSON type %s", element.Value().Type)
}
case "ns":
var ok bool
dcr.Ns, ok = element.Value().StringValueOK()
if !ok {
err = fmt.Errorf("response field 'ns' is type string, but received BSON type %s", element.Value().Type)
}
}
}
return dcr, nil
}
// NewDropCollection constructs and returns a new DropCollection.
func NewDropCollection() *DropCollection {
return &DropCollection{}
}
// Result returns the result of executing this operation.
func (dc *DropCollection) Result() DropCollectionResult { return dc.result }
func (dc *DropCollection) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
dc.result, err = buildDropCollectionResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (dc *DropCollection) Execute(ctx context.Context) error {
if dc.deployment == nil {
return errors.New("the DropCollection operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: dc.command,
ProcessResponseFn: dc.processResponse,
Client: dc.session,
Clock: dc.clock,
CommandMonitor: dc.monitor,
Database: dc.database,
Deployment: dc.deployment,
Selector: dc.selector,
WriteConcern: dc.writeConcern,
}.Execute(ctx, nil)
}
func (dc *DropCollection) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "drop", dc.collection)
return dst, nil
}
// Session sets the session for this operation.
func (dc *DropCollection) Session(session *session.Client) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.session = session
return dc
}
// ClusterClock sets the cluster clock for this operation.
func (dc *DropCollection) ClusterClock(clock *session.ClusterClock) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.clock = clock
return dc
}
// Collection sets the collection that this command will run against.
func (dc *DropCollection) Collection(collection string) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.collection = collection
return dc
}
// CommandMonitor sets the monitor to use for APM events.
func (dc *DropCollection) CommandMonitor(monitor *event.CommandMonitor) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.monitor = monitor
return dc
}
// Database sets the database to run this operation against.
func (dc *DropCollection) Database(database string) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.database = database
return dc
}
// Deployment sets the deployment to use for this operation.
func (dc *DropCollection) Deployment(deployment driver.Deployment) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.deployment = deployment
return dc
}
// ServerSelector sets the selector used to retrieve a server.
func (dc *DropCollection) ServerSelector(selector description.ServerSelector) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.selector = selector
return dc
}
// WriteConcern sets the write concern for this operation.
func (dc *DropCollection) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropCollection {
if dc == nil {
dc = new(DropCollection)
}
dc.writeConcern = writeConcern
return dc
}

View File

@@ -0,0 +1,21 @@
version = 0
name = "DropCollection"
documentation = "DropCollection performs a drop operation."
[command]
name = "drop"
parameter = "collection"
[properties]
enabled = ["write concern"]
[response]
name = "DropCollectionResult"
[response.field.ns]
type = "string"
documentation = "The namespace of the dropped collection."
[response.field.nIndexesWas]
type = "int32"
documentation = "The number of indexes in the dropped collection."

View File

@@ -0,0 +1,168 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// DropDatabase performs a dropDatabase operation
type DropDatabase struct {
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result DropDatabaseResult
}
type DropDatabaseResult struct {
// The dropped database.
Dropped string
}
func buildDropDatabaseResult(response bsoncore.Document, srvr driver.Server) (DropDatabaseResult, error) {
elements, err := response.Elements()
if err != nil {
return DropDatabaseResult{}, err
}
ddr := DropDatabaseResult{}
for _, element := range elements {
switch element.Key() {
case "dropped":
var ok bool
ddr.Dropped, ok = element.Value().StringValueOK()
if !ok {
err = fmt.Errorf("response field 'dropped' is type string, but received BSON type %s", element.Value().Type)
}
}
}
return ddr, nil
}
// NewDropDatabase constructs and returns a new DropDatabase.
func NewDropDatabase() *DropDatabase {
return &DropDatabase{}
}
// Result returns the result of executing this operation.
func (dd *DropDatabase) Result() DropDatabaseResult { return dd.result }
func (dd *DropDatabase) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
dd.result, err = buildDropDatabaseResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (dd *DropDatabase) Execute(ctx context.Context) error {
if dd.deployment == nil {
return errors.New("the DropDatabase operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: dd.command,
ProcessResponseFn: dd.processResponse,
Client: dd.session,
Clock: dd.clock,
CommandMonitor: dd.monitor,
Database: dd.database,
Deployment: dd.deployment,
Selector: dd.selector,
WriteConcern: dd.writeConcern,
}.Execute(ctx, nil)
}
func (dd *DropDatabase) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "dropDatabase", 1)
return dst, nil
}
// Session sets the session for this operation.
func (dd *DropDatabase) Session(session *session.Client) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.session = session
return dd
}
// ClusterClock sets the cluster clock for this operation.
func (dd *DropDatabase) ClusterClock(clock *session.ClusterClock) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.clock = clock
return dd
}
// CommandMonitor sets the monitor to use for APM events.
func (dd *DropDatabase) CommandMonitor(monitor *event.CommandMonitor) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.monitor = monitor
return dd
}
// Database sets the database to run this operation against.
func (dd *DropDatabase) Database(database string) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.database = database
return dd
}
// Deployment sets the deployment to use for this operation.
func (dd *DropDatabase) Deployment(deployment driver.Deployment) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.deployment = deployment
return dd
}
// ServerSelector sets the selector used to retrieve a server.
func (dd *DropDatabase) ServerSelector(selector description.ServerSelector) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.selector = selector
return dd
}
// WriteConcern sets the write concern for this operation.
func (dd *DropDatabase) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropDatabase {
if dd == nil {
dd = new(DropDatabase)
}
dd.writeConcern = writeConcern
return dd
}

View File

@@ -0,0 +1,18 @@
version = 0
name = "DropDatabase"
documentation = "DropDatabase performs a dropDatabase operation"
[properties]
enabled = ["write concern"]
disabled = ["collection"]
[command]
name = "dropDatabase"
parameter = "database"
[response]
name = "DropDatabaseResult"
[response.field.dropped]
type = "string"
documentation = "The dropped database."

View File

@@ -0,0 +1,209 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// DropIndexes performs an dropIndexes operation.
type DropIndexes struct {
index *string
maxTimeMS *int64
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
result DropIndexesResult
}
type DropIndexesResult struct {
// Number of indexes that existed before the drop was executed.
NIndexesWas int32
}
func buildDropIndexesResult(response bsoncore.Document, srvr driver.Server) (DropIndexesResult, error) {
elements, err := response.Elements()
if err != nil {
return DropIndexesResult{}, err
}
dir := DropIndexesResult{}
for _, element := range elements {
switch element.Key() {
case "nIndexesWas":
var ok bool
dir.NIndexesWas, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'nIndexesWas' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return dir, nil
}
// NewDropIndexes constructs and returns a new DropIndexes.
func NewDropIndexes(index string) *DropIndexes {
return &DropIndexes{
index: &index,
}
}
// Result returns the result of executing this operation.
func (di *DropIndexes) Result() DropIndexesResult { return di.result }
func (di *DropIndexes) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
di.result, err = buildDropIndexesResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (di *DropIndexes) Execute(ctx context.Context) error {
if di.deployment == nil {
return errors.New("the DropIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: di.command,
ProcessResponseFn: di.processResponse,
Client: di.session,
Clock: di.clock,
CommandMonitor: di.monitor,
Database: di.database,
Deployment: di.deployment,
Selector: di.selector,
WriteConcern: di.writeConcern,
}.Execute(ctx, nil)
}
func (di *DropIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "dropIndexes", di.collection)
if di.index != nil {
dst = bsoncore.AppendStringElement(dst, "index", *di.index)
}
if di.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *di.maxTimeMS)
}
return dst, nil
}
// Index specifies the name of the index to drop. If '*' is specified, all indexes will be dropped.
//
func (di *DropIndexes) Index(index string) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.index = &index
return di
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (di *DropIndexes) MaxTimeMS(maxTimeMS int64) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.maxTimeMS = &maxTimeMS
return di
}
// Session sets the session for this operation.
func (di *DropIndexes) Session(session *session.Client) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.session = session
return di
}
// ClusterClock sets the cluster clock for this operation.
func (di *DropIndexes) ClusterClock(clock *session.ClusterClock) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.clock = clock
return di
}
// Collection sets the collection that this command will run against.
func (di *DropIndexes) Collection(collection string) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.collection = collection
return di
}
// CommandMonitor sets the monitor to use for APM events.
func (di *DropIndexes) CommandMonitor(monitor *event.CommandMonitor) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.monitor = monitor
return di
}
// Database sets the database to run this operation against.
func (di *DropIndexes) Database(database string) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.database = database
return di
}
// Deployment sets the deployment to use for this operation.
func (di *DropIndexes) Deployment(deployment driver.Deployment) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.deployment = deployment
return di
}
// ServerSelector sets the selector used to retrieve a server.
func (di *DropIndexes) ServerSelector(selector description.ServerSelector) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.selector = selector
return di
}
// WriteConcern sets the write concern for this operation.
func (di *DropIndexes) WriteConcern(writeConcern *writeconcern.WriteConcern) *DropIndexes {
if di == nil {
di = new(DropIndexes)
}
di.writeConcern = writeConcern
return di
}

View File

@@ -0,0 +1,28 @@
version = 0
name = "DropIndexes"
documentation = "DropIndexes performs an dropIndexes operation."
[properties]
enabled = ["write concern"]
[command]
name = "dropIndexes"
parameter = "collection"
[request.index]
type = "string"
constructor = true
documentation = """
Index specifies the name of the index to drop. If '*' is specified, all indexes will be dropped.
"""
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."
[response]
name = "DropIndexesResult"
[response.field.nIndexesWas]
type = "int32"
documentation = "Number of indexes that existed before the drop was executed."

View File

@@ -0,0 +1,469 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Find performs a find operation.
type Find struct {
allowPartialResults *bool
awaitData *bool
batchSize *int32
collation bsoncore.Document
comment *string
filter bsoncore.Document
hint bsoncore.Value
limit *int64
max bsoncore.Document
maxTimeMS *int64
min bsoncore.Document
noCursorTimeout *bool
oplogReplay *bool
projection bsoncore.Document
returnKey *bool
showRecordID *bool
singleBatch *bool
skip *int64
snapshot *bool
sort bsoncore.Document
tailable *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readConcern *readconcern.ReadConcern
readPreference *readpref.ReadPref
selector description.ServerSelector
result driver.CursorResponse
}
// NewFind constructs and returns a new Find.
func NewFind(filter bsoncore.Document) *Find {
return &Find{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (f *Find) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
return driver.NewBatchCursor(f.result, f.session, f.clock, opts)
}
func (f *Find) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
f.result, err = driver.NewCursorResponse(response, srvr, desc)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (f *Find) Execute(ctx context.Context) error {
if f.deployment == nil {
return errors.New("the Find operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: f.command,
ProcessResponseFn: f.processResponse,
Client: f.session,
Clock: f.clock,
CommandMonitor: f.monitor,
Database: f.database,
Deployment: f.deployment,
ReadConcern: f.readConcern,
ReadPreference: f.readPreference,
Selector: f.selector,
Legacy: driver.LegacyFind,
}.Execute(ctx, nil)
}
func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "find", f.collection)
if f.allowPartialResults != nil {
dst = bsoncore.AppendBooleanElement(dst, "allowPartialResults", *f.allowPartialResults)
}
if f.awaitData != nil {
dst = bsoncore.AppendBooleanElement(dst, "awaitData", *f.awaitData)
}
if f.batchSize != nil {
dst = bsoncore.AppendInt32Element(dst, "batchSize", *f.batchSize)
}
if f.collation != nil {
if desc.WireVersion == nil || !desc.WireVersion.Includes(5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", f.collation)
}
if f.comment != nil {
dst = bsoncore.AppendStringElement(dst, "comment", *f.comment)
}
if f.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", f.filter)
}
if f.hint.Type != bsontype.Type(0) {
dst = bsoncore.AppendValueElement(dst, "hint", f.hint)
}
if f.limit != nil {
dst = bsoncore.AppendInt64Element(dst, "limit", *f.limit)
}
if f.max != nil {
dst = bsoncore.AppendDocumentElement(dst, "max", f.max)
}
if f.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *f.maxTimeMS)
}
if f.min != nil {
dst = bsoncore.AppendDocumentElement(dst, "min", f.min)
}
if f.noCursorTimeout != nil {
dst = bsoncore.AppendBooleanElement(dst, "noCursorTimeout", *f.noCursorTimeout)
}
if f.oplogReplay != nil {
dst = bsoncore.AppendBooleanElement(dst, "oplogReplay", *f.oplogReplay)
}
if f.projection != nil {
dst = bsoncore.AppendDocumentElement(dst, "projection", f.projection)
}
if f.returnKey != nil {
dst = bsoncore.AppendBooleanElement(dst, "returnKey", *f.returnKey)
}
if f.showRecordID != nil {
dst = bsoncore.AppendBooleanElement(dst, "showRecordId", *f.showRecordID)
}
if f.singleBatch != nil {
dst = bsoncore.AppendBooleanElement(dst, "singleBatch", *f.singleBatch)
}
if f.skip != nil {
dst = bsoncore.AppendInt64Element(dst, "skip", *f.skip)
}
if f.snapshot != nil {
dst = bsoncore.AppendBooleanElement(dst, "snapshot", *f.snapshot)
}
if f.sort != nil {
dst = bsoncore.AppendDocumentElement(dst, "sort", f.sort)
}
if f.tailable != nil {
dst = bsoncore.AppendBooleanElement(dst, "tailable", *f.tailable)
}
return dst, nil
}
// AllowPartialResults when true allows partial results to be returned if some shards are down.
func (f *Find) AllowPartialResults(allowPartialResults bool) *Find {
if f == nil {
f = new(Find)
}
f.allowPartialResults = &allowPartialResults
return f
}
// AwaitData when true makes a cursor block before returning when no data is available.
func (f *Find) AwaitData(awaitData bool) *Find {
if f == nil {
f = new(Find)
}
f.awaitData = &awaitData
return f
}
// BatchSize specifies the number of documents to return in every batch.
func (f *Find) BatchSize(batchSize int32) *Find {
if f == nil {
f = new(Find)
}
f.batchSize = &batchSize
return f
}
// Collation specifies a collation to be used.
func (f *Find) Collation(collation bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.collation = collation
return f
}
// Comment sets a string to help trace an operation.
func (f *Find) Comment(comment string) *Find {
if f == nil {
f = new(Find)
}
f.comment = &comment
return f
}
// Filter determines what results are returned from find.
func (f *Find) Filter(filter bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.filter = filter
return f
}
// Hint specifies the index to use.
func (f *Find) Hint(hint bsoncore.Value) *Find {
if f == nil {
f = new(Find)
}
f.hint = hint
return f
}
// Limit sets a limit on the number of documents to return.
func (f *Find) Limit(limit int64) *Find {
if f == nil {
f = new(Find)
}
f.limit = &limit
return f
}
// Max sets an exclusive upper bound for a specific index.
func (f *Find) Max(max bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.max = max
return f
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (f *Find) MaxTimeMS(maxTimeMS int64) *Find {
if f == nil {
f = new(Find)
}
f.maxTimeMS = &maxTimeMS
return f
}
// Min sets an inclusive lower bound for a specific index.
func (f *Find) Min(min bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.min = min
return f
}
// NoCursorTimeout when true prevents cursor from timing out after an inactivity period.
func (f *Find) NoCursorTimeout(noCursorTimeout bool) *Find {
if f == nil {
f = new(Find)
}
f.noCursorTimeout = &noCursorTimeout
return f
}
// OplogReplay when true replays a replica set's oplog.
func (f *Find) OplogReplay(oplogReplay bool) *Find {
if f == nil {
f = new(Find)
}
f.oplogReplay = &oplogReplay
return f
}
// Project limits the fields returned for all documents.
func (f *Find) Projection(projection bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.projection = projection
return f
}
// ReturnKey when true returns index keys for all result documents.
func (f *Find) ReturnKey(returnKey bool) *Find {
if f == nil {
f = new(Find)
}
f.returnKey = &returnKey
return f
}
// ShowRecordID when true adds a $recordId field with the record identifier to returned documents.
func (f *Find) ShowRecordID(showRecordID bool) *Find {
if f == nil {
f = new(Find)
}
f.showRecordID = &showRecordID
return f
}
// SingleBatch specifies whether the results should be returned in a single batch.
func (f *Find) SingleBatch(singleBatch bool) *Find {
if f == nil {
f = new(Find)
}
f.singleBatch = &singleBatch
return f
}
// Skip specifies the number of documents to skip before returning.
func (f *Find) Skip(skip int64) *Find {
if f == nil {
f = new(Find)
}
f.skip = &skip
return f
}
// Snapshot prevents the cursor from returning a document more than once because of an intervening write operation.
func (f *Find) Snapshot(snapshot bool) *Find {
if f == nil {
f = new(Find)
}
f.snapshot = &snapshot
return f
}
// Sort specifies the order in which to return results.
func (f *Find) Sort(sort bsoncore.Document) *Find {
if f == nil {
f = new(Find)
}
f.sort = sort
return f
}
// Tailable keeps a cursor open and resumable after the last data has been retrieved.
func (f *Find) Tailable(tailable bool) *Find {
if f == nil {
f = new(Find)
}
f.tailable = &tailable
return f
}
// Session sets the session for this operation.
func (f *Find) Session(session *session.Client) *Find {
if f == nil {
f = new(Find)
}
f.session = session
return f
}
// ClusterClock sets the cluster clock for this operation.
func (f *Find) ClusterClock(clock *session.ClusterClock) *Find {
if f == nil {
f = new(Find)
}
f.clock = clock
return f
}
// Collection sets the collection that this command will run against.
func (f *Find) Collection(collection string) *Find {
if f == nil {
f = new(Find)
}
f.collection = collection
return f
}
// CommandMonitor sets the monitor to use for APM events.
func (f *Find) CommandMonitor(monitor *event.CommandMonitor) *Find {
if f == nil {
f = new(Find)
}
f.monitor = monitor
return f
}
// Database sets the database to run this operation against.
func (f *Find) Database(database string) *Find {
if f == nil {
f = new(Find)
}
f.database = database
return f
}
// Deployment sets the deployment to use for this operation.
func (f *Find) Deployment(deployment driver.Deployment) *Find {
if f == nil {
f = new(Find)
}
f.deployment = deployment
return f
}
// ReadConcern specifies the read concern for this operation.
func (f *Find) ReadConcern(readConcern *readconcern.ReadConcern) *Find {
if f == nil {
f = new(Find)
}
f.readConcern = readConcern
return f
}
// ReadPreference set the read prefernce used with this operation.
func (f *Find) ReadPreference(readPreference *readpref.ReadPref) *Find {
if f == nil {
f = new(Find)
}
f.readPreference = readPreference
return f
}
// ServerSelector sets the selector used to retrieve a server.
func (f *Find) ServerSelector(selector description.ServerSelector) *Find {
if f == nil {
f = new(Find)
}
f.selector = selector
return f
}

View File

@@ -0,0 +1,99 @@
version = 0
name = "Find"
documentation = "Find performs a find operation."
response.type = "batch cursor"
[properties]
enabled = ["collection", "read concern", "read preference", "command monitor", "client session", "cluster clock"]
legacy = "find"
[command]
name = "find"
parameter = "collection"
[request.filter]
type = "document"
constructor = true
documentation = "Filter determines what results are returned from find."
[request.sort]
type = "document"
documentation = "Sort specifies the order in which to return results."
[request.projection]
type = "document"
documentation = "Project limits the fields returned for all documents."
[request.hint]
type = "value"
documentation = "Hint specifies the index to use."
[request.skip]
type = "int64"
documentation = "Skip specifies the number of documents to skip before returning."
[request.limit]
type = "int64"
documentation = "Limit sets a limit on the number of documents to return."
[request.batchSize]
type = "int32"
documentation = "BatchSize specifies the number of documents to return in every batch."
[request.singleBatch]
type = "boolean"
documentation = "SingleBatch specifies whether the results should be returned in a single batch."
[request.comment]
type = "string"
documentation = "Comment sets a string to help trace an operation."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."
[request.max]
type = "document"
documentation = "Max sets an exclusive upper bound for a specific index."
[request.min]
type = "document"
documentation = "Min sets an inclusive lower bound for a specific index."
[request.returnKey]
type = "boolean"
documentation = "ReturnKey when true returns index keys for all result documents."
[request.showRecordID]
type = "boolean"
documentation = "ShowRecordID when true adds a $recordId field with the record identifier to returned documents."
keyName = "showRecordId"
[request.oplogReplay]
type = "boolean"
documentation = "OplogReplay when true replays a replica set's oplog."
[request.noCursorTimeout]
type = "boolean"
documentation = "NoCursorTimeout when true prevents cursor from timing out after an inactivity period."
[request.tailable]
type = "boolean"
documentation = "Tailable keeps a cursor open and resumable after the last data has been retrieved."
[request.awaitData]
type = "boolean"
documentation = "AwaitData when true makes a cursor block before returning when no data is available."
[request.allowPartialResults]
type = "boolean"
documentation = "AllowPartialResults when true allows partial results to be returned if some shards are down."
[request.collation]
type = "document"
minWireVersionRequired = 5
documentation = "Collation specifies a collation to be used."
[request.snapshot]
type = "boolean"
documentation = "Snapshot prevents the cursor from returning a document more than once because of an intervening write operation."

View File

@@ -0,0 +1,395 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// FindAndModify performs a findAndModify operation.
type FindAndModify struct {
arrayFilters bsoncore.Document
bypassDocumentValidation *bool
collation bsoncore.Document
fields bsoncore.Document
maxTimeMS *int64
newDocument *bool
query bsoncore.Document
remove *bool
sort bsoncore.Document
update bsoncore.Document
upsert *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
result FindAndModifyResult
}
type LastErrorObject struct {
// True if an update modified an existing document
UpdatedExisting bool
// Object ID of the upserted document.
Upserted interface{}
}
type FindAndModifyResult struct {
// Either the old or modified document, depending on the value of the new parameter.
Value bsoncore.Document
// Contains information about updates and upserts.
LastErrorObject LastErrorObject
}
func buildFindAndModifyResult(response bsoncore.Document, srvr driver.Server) (FindAndModifyResult, error) {
elements, err := response.Elements()
if err != nil {
return FindAndModifyResult{}, err
}
famr := FindAndModifyResult{}
for _, element := range elements {
switch element.Key() {
case "value":
var ok bool
famr.Value, ok = element.Value().DocumentOK()
if !ok {
err = fmt.Errorf("response field 'value' is type document, but received BSON type %s", element.Value().Type)
}
case "lastErrorObject":
valDoc, ok := element.Value().DocumentOK()
if !ok {
err = fmt.Errorf("response field 'lastErrorObject' is type document, but received BSON type %s", element.Value().Type)
break
}
var leo LastErrorObject
if err = bson.Unmarshal(valDoc, &leo); err != nil {
break
}
famr.LastErrorObject = leo
}
}
return famr, nil
}
// NewFindAndModify constructs and returns a new FindAndModify.
func NewFindAndModify(query bsoncore.Document) *FindAndModify {
return &FindAndModify{
query: query,
}
}
// Result returns the result of executing this operation.
func (fam *FindAndModify) Result() FindAndModifyResult { return fam.result }
func (fam *FindAndModify) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
fam.result, err = buildFindAndModifyResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (fam *FindAndModify) Execute(ctx context.Context) error {
if fam.deployment == nil {
return errors.New("the FindAndModify operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: fam.command,
ProcessResponseFn: fam.processResponse,
RetryMode: fam.retry,
RetryType: driver.RetryWrite,
Client: fam.session,
Clock: fam.clock,
CommandMonitor: fam.monitor,
Database: fam.database,
Deployment: fam.deployment,
Selector: fam.selector,
WriteConcern: fam.writeConcern,
}.Execute(ctx, nil)
}
func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "findAndModify", fam.collection)
if fam.arrayFilters != nil {
if desc.WireVersion == nil || !desc.WireVersion.Includes(6) {
return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6")
}
dst = bsoncore.AppendArrayElement(dst, "arrayFilters", fam.arrayFilters)
}
if fam.bypassDocumentValidation != nil {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *fam.bypassDocumentValidation)
}
if fam.collation != nil {
if desc.WireVersion == nil || !desc.WireVersion.Includes(5) {
return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5")
}
dst = bsoncore.AppendDocumentElement(dst, "collation", fam.collation)
}
if fam.fields != nil {
dst = bsoncore.AppendDocumentElement(dst, "fields", fam.fields)
}
if fam.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *fam.maxTimeMS)
}
if fam.newDocument != nil {
dst = bsoncore.AppendBooleanElement(dst, "new", *fam.newDocument)
}
if fam.query != nil {
dst = bsoncore.AppendDocumentElement(dst, "query", fam.query)
}
if fam.remove != nil {
dst = bsoncore.AppendBooleanElement(dst, "remove", *fam.remove)
}
if fam.sort != nil {
dst = bsoncore.AppendDocumentElement(dst, "sort", fam.sort)
}
if fam.update != nil {
dst = bsoncore.AppendDocumentElement(dst, "update", fam.update)
}
if fam.upsert != nil {
dst = bsoncore.AppendBooleanElement(dst, "upsert", *fam.upsert)
}
return dst, nil
}
// ArrayFilters specifies an array of filter documents that determines which array elements to modify for an update operation on an array field.
func (fam *FindAndModify) ArrayFilters(arrayFilters bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.arrayFilters = arrayFilters
return fam
}
// BypassDocumentValidation specifies if document validation can be skipped when executing the operation.
func (fam *FindAndModify) BypassDocumentValidation(bypassDocumentValidation bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.bypassDocumentValidation = &bypassDocumentValidation
return fam
}
// Collation specifies a collation to be used.
func (fam *FindAndModify) Collation(collation bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.collation = collation
return fam
}
// Fields specifies a subset of fields to return.
func (fam *FindAndModify) Fields(fields bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.fields = fields
return fam
}
// MaxTimeMS specifies the maximum amount of time to allow the operation to run.
func (fam *FindAndModify) MaxTimeMS(maxTimeMS int64) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.maxTimeMS = &maxTimeMS
return fam
}
// NewDocument specifies whether to return the modified document or the original. Defaults to false (return original).
func (fam *FindAndModify) NewDocument(newDocument bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.newDocument = &newDocument
return fam
}
// Query specifies the selection criteria for the modification.
func (fam *FindAndModify) Query(query bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.query = query
return fam
}
// Remove specifies that the matched document should be removed. Defaults to false.
func (fam *FindAndModify) Remove(remove bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.remove = &remove
return fam
}
// Sort determines which document the operation modifies if the query matches multiple documents.The first document matched by the sort order will be modified.
//
func (fam *FindAndModify) Sort(sort bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.sort = sort
return fam
}
// Update specifies the update document to perform on the matched document.
func (fam *FindAndModify) Update(update bsoncore.Document) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.update = update
return fam
}
// Upsert specifies whether or not to create a new document if no documents match the query when doing an update. Defaults to false.
func (fam *FindAndModify) Upsert(upsert bool) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.upsert = &upsert
return fam
}
// Session sets the session for this operation.
func (fam *FindAndModify) Session(session *session.Client) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.session = session
return fam
}
// ClusterClock sets the cluster clock for this operation.
func (fam *FindAndModify) ClusterClock(clock *session.ClusterClock) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.clock = clock
return fam
}
// Collection sets the collection that this command will run against.
func (fam *FindAndModify) Collection(collection string) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.collection = collection
return fam
}
// CommandMonitor sets the monitor to use for APM events.
func (fam *FindAndModify) CommandMonitor(monitor *event.CommandMonitor) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.monitor = monitor
return fam
}
// Database sets the database to run this operation against.
func (fam *FindAndModify) Database(database string) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.database = database
return fam
}
// Deployment sets the deployment to use for this operation.
func (fam *FindAndModify) Deployment(deployment driver.Deployment) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.deployment = deployment
return fam
}
// ServerSelector sets the selector used to retrieve a server.
func (fam *FindAndModify) ServerSelector(selector description.ServerSelector) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.selector = selector
return fam
}
// WriteConcern sets the write concern for this operation.
func (fam *FindAndModify) WriteConcern(writeConcern *writeconcern.WriteConcern) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.writeConcern = writeConcern
return fam
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (fam *FindAndModify) Retry(retry driver.RetryMode) *FindAndModify {
if fam == nil {
fam = new(FindAndModify)
}
fam.retry = &retry
return fam
}

View File

@@ -0,0 +1,68 @@
version = 0
name = "FindAndModify"
documentation = "FindAndModify performs a findAndModify operation."
[properties]
enabled = ["write concern"]
retryable = {mode = "once", type = "writes"}
[command]
name = "findAndModify"
parameter = "collection"
[request.query]
type = "document"
constructor = true
documentation = "Query specifies the selection criteria for the modification."
[request.sort]
type = "document"
documentation = """
Sort determines which document the operation modifies if the query matches multiple documents.\
The first document matched by the sort order will be modified.
"""
[request.remove]
type = "boolean"
documentation = "Remove specifies that the matched document should be removed. Defaults to false."
[request.update]
type = "document"
documentation = "Update specifies the update document to perform on the matched document."
[request.newDocument]
type = "boolean"
documentation = "NewDocument specifies whether to return the modified document or the original. Defaults to false (return original)."
[request.fields]
type = "document"
documentation = "Fields specifies a subset of fields to return."
[request.upsert]
type = "boolean"
documentation = "Upsert specifies whether or not to create a new document if no documents match the query when doing an update. Defaults to false."
[request.bypassDocumentValidation]
type = "boolean"
documentation = "BypassDocumentValidation specifies if document validation can be skipped when executing the operation."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the operation to run."
[request.collation]
type = "document"
minWireVersionRequired = 5
documentation = "Collation specifies a collation to be used."
[request.arrayFilters]
type = "array"
minWireVersionRequired = 6
documentation = "ArrayFilters specifies an array of filter documents that determines which array elements to modify for an update operation on an array field."
[response]
name = "FindAndModifyResult"
[response.field.value]
type = "document"
documentation = "Either the old or modified document, depending on the value of the new parameter."

View File

@@ -0,0 +1,243 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Insert performs an insert operation.
type Insert struct {
bypassDocumentValidation *bool
documents []bsoncore.Document
ordered *bool
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
result InsertResult
}
type InsertResult struct {
// Number of documents successfully inserted.
N int32
}
func buildInsertResult(response bsoncore.Document, srvr driver.Server) (InsertResult, error) {
elements, err := response.Elements()
if err != nil {
return InsertResult{}, err
}
ir := InsertResult{}
for _, element := range elements {
switch element.Key() {
case "n":
var ok bool
ir.N, ok = element.Value().AsInt32OK()
if !ok {
err = fmt.Errorf("response field 'n' is type int32, but received BSON type %s", element.Value().Type)
}
}
}
return ir, nil
}
// NewInsert constructs and returns a new Insert.
func NewInsert(documents ...bsoncore.Document) *Insert {
return &Insert{
documents: documents,
}
}
// Result returns the result of executing this operation.
func (i *Insert) Result() InsertResult { return i.result }
func (i *Insert) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
i.result, err = buildInsertResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (i *Insert) Execute(ctx context.Context) error {
if i.deployment == nil {
return errors.New("the Insert operation must have a Deployment set before Execute can be called")
}
batches := &driver.Batches{
Identifier: "documents",
Documents: i.documents,
Ordered: i.ordered,
}
return driver.Operation{
CommandFn: i.command,
ProcessResponseFn: i.processResponse,
Batches: batches,
RetryMode: i.retry,
RetryType: driver.RetryWrite,
Client: i.session,
Clock: i.clock,
CommandMonitor: i.monitor,
Database: i.database,
Deployment: i.deployment,
Selector: i.selector,
WriteConcern: i.writeConcern,
}.Execute(ctx, nil)
}
func (i *Insert) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "insert", i.collection)
if i.bypassDocumentValidation != nil && (desc.WireVersion != nil && desc.WireVersion.Includes(4)) {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *i.bypassDocumentValidation)
}
if i.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *i.ordered)
}
return dst, nil
}
// BypassDocumentValidation allows the operation to opt-out of document level validation. Valid
// for server versions >= 3.2. For servers < 3.2, this setting is ignored.
func (i *Insert) BypassDocumentValidation(bypassDocumentValidation bool) *Insert {
if i == nil {
i = new(Insert)
}
i.bypassDocumentValidation = &bypassDocumentValidation
return i
}
// Documents adds documents to this operation that will be inserted when this operation is
// executed.
func (i *Insert) Documents(documents ...bsoncore.Document) *Insert {
if i == nil {
i = new(Insert)
}
i.documents = documents
return i
}
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
// false write failures do not stop execution of the operation.
func (i *Insert) Ordered(ordered bool) *Insert {
if i == nil {
i = new(Insert)
}
i.ordered = &ordered
return i
}
// Session sets the session for this operation.
func (i *Insert) Session(session *session.Client) *Insert {
if i == nil {
i = new(Insert)
}
i.session = session
return i
}
// ClusterClock sets the cluster clock for this operation.
func (i *Insert) ClusterClock(clock *session.ClusterClock) *Insert {
if i == nil {
i = new(Insert)
}
i.clock = clock
return i
}
// Collection sets the collection that this command will run against.
func (i *Insert) Collection(collection string) *Insert {
if i == nil {
i = new(Insert)
}
i.collection = collection
return i
}
// CommandMonitor sets the monitor to use for APM events.
func (i *Insert) CommandMonitor(monitor *event.CommandMonitor) *Insert {
if i == nil {
i = new(Insert)
}
i.monitor = monitor
return i
}
// Database sets the database to run this operation against.
func (i *Insert) Database(database string) *Insert {
if i == nil {
i = new(Insert)
}
i.database = database
return i
}
// Deployment sets the deployment to use for this operation.
func (i *Insert) Deployment(deployment driver.Deployment) *Insert {
if i == nil {
i = new(Insert)
}
i.deployment = deployment
return i
}
// ServerSelector sets the selector used to retrieve a server.
func (i *Insert) ServerSelector(selector description.ServerSelector) *Insert {
if i == nil {
i = new(Insert)
}
i.selector = selector
return i
}
// WriteConcern sets the write concern for this operation.
func (i *Insert) WriteConcern(writeConcern *writeconcern.WriteConcern) *Insert {
if i == nil {
i = new(Insert)
}
i.writeConcern = writeConcern
return i
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (i *Insert) Retry(retry driver.RetryMode) *Insert {
if i == nil {
i = new(Insert)
}
i.retry = &retry
return i
}

View File

@@ -0,0 +1,45 @@
version = 0
name = "Insert"
documentation = "Insert performs an insert operation."
[properties]
enabled = ["write concern"]
retryable = {mode = "once per command", type = "writes"}
batches = "documents"
[command]
name = "insert"
parameter = "collection"
[request.documents]
type = "document"
slice = true
constructor = true
variadic = true
required = true
documentation = """
Documents adds documents to this operation that will be inserted when this operation is
executed.\
"""
[request.ordered]
type = "boolean"
documentation = """
Ordered sets ordered. If true, when a write fails, the operation will return the error, when
false write failures do not stop execution of the operation.\
"""
[request.bypassDocumentValidation]
type = "boolean"
minWireVersion = 4
documentation = """
BypassDocumentValidation allows the operation to opt-out of document level validation. Valid
for server versions >= 3.2. For servers < 3.2, this setting is ignored.\
"""
[response]
name = "InsertResult"
[response.field.n]
type = "int32"
documentation = "Number of documents successfully inserted."

View File

@@ -0,0 +1,412 @@
package operation
import (
"context"
"errors"
"fmt"
"runtime"
"strconv"
"time"
"go.mongodb.org/mongo-driver/tag"
"go.mongodb.org/mongo-driver/version"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// IsMaster is used to run the isMaster handshake operation.
type IsMaster struct {
appname string
compressors []string
saslSupportedMechs string
d driver.Deployment
clock *session.ClusterClock
res bsoncore.Document
}
// NewIsMaster constructs an IsMaster.
func NewIsMaster() *IsMaster { return &IsMaster{} }
// AppName sets the application name in the client metadata sent in this operation.
func (im *IsMaster) AppName(appname string) *IsMaster {
im.appname = appname
return im
}
// ClusterClock sets the cluster clock for this operation.
func (im *IsMaster) ClusterClock(clock *session.ClusterClock) *IsMaster {
if im == nil {
im = new(IsMaster)
}
im.clock = clock
return im
}
// Compressors sets the compressors that can be used.
func (im *IsMaster) Compressors(compressors []string) *IsMaster {
im.compressors = compressors
return im
}
// SASLSupportedMechs retrieves the supported SASL mechanism for the given user when this operation
// is run.
func (im *IsMaster) SASLSupportedMechs(username string) *IsMaster {
im.saslSupportedMechs = username
return im
}
// Deployment sets the Deployment for this operation.
func (im *IsMaster) Deployment(d driver.Deployment) *IsMaster {
im.d = d
return im
}
// Result returns the result of executing this operaiton.
func (im *IsMaster) Result(addr address.Address) description.Server {
desc := description.Server{Addr: addr, CanonicalAddr: addr, LastUpdateTime: time.Now().UTC()}
elements, err := im.res.Elements()
if err != nil {
desc.LastError = err
return desc
}
var ok bool
var isReplicaSet, isMaster, hidden, secondary, arbiterOnly bool
var msg string
var version description.VersionRange
var hosts, passives, arbiters []string
for _, element := range elements {
switch element.Key() {
case "arbiters":
var err error
arbiters, err = im.decodeStringSlice(element, "arbiters")
if err != nil {
desc.LastError = err
return desc
}
case "arbiterOnly":
arbiterOnly, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'arbiterOnly' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "compression":
var err error
desc.Compression, err = im.decodeStringSlice(element, "compression")
if err != nil {
desc.LastError = err
return desc
}
case "electionId":
desc.ElectionID, ok = element.Value().ObjectIDOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'electionId' to be a objectID but it's a BSON %s", element.Value().Type)
return desc
}
case "hidden":
hidden, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'hidden' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "hosts":
var err error
hosts, err = im.decodeStringSlice(element, "hosts")
if err != nil {
desc.LastError = err
return desc
}
case "ismaster":
isMaster, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'isMaster' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "isreplicaset":
isReplicaSet, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'isreplicaset' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "lastWriteDate":
dt, ok := element.Value().DateTimeOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'lastWriteDate' to be a datetime but it's a BSON %s", element.Value().Type)
return desc
}
desc.LastWriteTime = time.Unix(dt/1000, dt%1000*1000000).UTC()
case "logicalSessionTimeoutMinutes":
i64, ok := element.Value().AsInt64OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'logicalSessionTimeoutMinutes' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
desc.SessionTimeoutMinutes = uint32(i64)
case "maxBsonObjectSize":
i64, ok := element.Value().AsInt64OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'maxBsonObjectSize' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
desc.MaxDocumentSize = uint32(i64)
case "maxMessageSizeBytes":
i64, ok := element.Value().AsInt64OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'maxMessageSizeBytes' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
desc.MaxMessageSize = uint32(i64)
case "maxWriteBatchSize":
i64, ok := element.Value().AsInt64OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'maxWriteBatchSize' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
desc.MaxBatchCount = uint32(i64)
case "me":
me, ok := element.Value().StringValueOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'me' to be a string but it's a BSON %s", element.Value().Type)
return desc
}
desc.CanonicalAddr = address.Address(me).Canonicalize()
case "maxWireVersion":
version.Max, ok = element.Value().AsInt32OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'maxWireVersion' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
case "minWireVersion":
version.Min, ok = element.Value().AsInt32OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'minWireVersion' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
case "msg":
msg, ok = element.Value().StringValueOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'msg' to be a string but it's a BSON %s", element.Value().Type)
return desc
}
case "ok":
okay, ok := element.Value().AsInt32OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'ok' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
if okay != 1 {
desc.LastError = errors.New("not ok")
return desc
}
case "passives":
var err error
passives, err = im.decodeStringSlice(element, "passives")
if err != nil {
desc.LastError = err
return desc
}
case "readOnly":
desc.ReadOnly, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'readOnly' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "saslSupportedMechs":
var err error
desc.SaslSupportedMechs, err = im.decodeStringSlice(element, "saslSupportedMechs")
if err != nil {
desc.LastError = err
return desc
}
case "secondary":
secondary, ok = element.Value().BooleanOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'secondary' to be a boolean but it's a BSON %s", element.Value().Type)
return desc
}
case "setName":
desc.SetName, ok = element.Value().StringValueOK()
if !ok {
desc.LastError = fmt.Errorf("expected 'setName' to be a string but it's a BSON %s", element.Value().Type)
return desc
}
case "setVersion":
i64, ok := element.Value().AsInt64OK()
if !ok {
desc.LastError = fmt.Errorf("expected 'setVersion' to be an integer but it's a BSON %s", element.Value().Type)
return desc
}
desc.SetVersion = uint32(i64)
case "tags":
m, err := im.decodeStringMap(element, "tags")
if err != nil {
desc.LastError = err
return desc
}
desc.Tags = tag.NewTagSetFromMap(m)
}
}
for _, host := range hosts {
desc.Members = append(desc.Members, address.Address(host).Canonicalize())
}
for _, passive := range passives {
desc.Members = append(desc.Members, address.Address(passive).Canonicalize())
}
for _, arbiter := range arbiters {
desc.Members = append(desc.Members, address.Address(arbiter).Canonicalize())
}
desc.Kind = description.Standalone
if isReplicaSet {
desc.Kind = description.RSGhost
} else if desc.SetName != "" {
if isMaster {
desc.Kind = description.RSPrimary
} else if hidden {
desc.Kind = description.RSMember
} else if secondary {
desc.Kind = description.RSSecondary
} else if arbiterOnly {
desc.Kind = description.RSArbiter
} else {
desc.Kind = description.RSMember
}
} else if msg == "isdbgrid" {
desc.Kind = description.Mongos
}
desc.WireVersion = &version
return desc
}
func (im *IsMaster) decodeStringSlice(element bsoncore.Element, name string) ([]string, error) {
arr, ok := element.Value().ArrayOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be an array but it's a BSON %s", name, element.Value().Type)
}
vals, err := arr.Values()
if err != nil {
return nil, err
}
var strs []string
for _, val := range vals {
str, ok := val.StringValueOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be an array of strings, but found a BSON %s", name, val.Type)
}
strs = append(strs, str)
}
return strs, nil
}
func (im *IsMaster) decodeStringMap(element bsoncore.Element, name string) (map[string]string, error) {
doc, ok := element.Value().DocumentOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be a document but it's a BSON %s", name, element.Value().Type)
}
elements, err := doc.Elements()
if err != nil {
return nil, err
}
m := make(map[string]string)
for _, element := range elements {
key := element.Key()
value, ok := element.Value().StringValueOK()
if !ok {
return nil, fmt.Errorf("expected '%s' to be a document of strings, but found a BSON %s", name, element.Value().Type)
}
m[key] = value
}
return m, nil
}
// handshakeCommand appends all necessary command fields as well as client metadata, SASL supported mechs, and compression.
func (im *IsMaster) handshakeCommand(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst, err := im.command(dst, desc)
if err != nil {
return dst, err
}
if im.saslSupportedMechs != "" {
dst = bsoncore.AppendStringElement(dst, "saslSupportedMechs", im.saslSupportedMechs)
}
var idx int32
idx, dst = bsoncore.AppendArrayElementStart(dst, "compression")
for i, compressor := range im.compressors {
dst = bsoncore.AppendStringElement(dst, strconv.Itoa(i), compressor)
}
dst, _ = bsoncore.AppendArrayEnd(dst, idx)
// append client metadata
idx, dst = bsoncore.AppendDocumentElementStart(dst, "client")
didx, dst := bsoncore.AppendDocumentElementStart(dst, "driver")
dst = bsoncore.AppendStringElement(dst, "name", "mongo-go-driver")
dst = bsoncore.AppendStringElement(dst, "version", version.Driver)
dst, _ = bsoncore.AppendDocumentEnd(dst, didx)
didx, dst = bsoncore.AppendDocumentElementStart(dst, "os")
dst = bsoncore.AppendStringElement(dst, "type", runtime.GOOS)
dst = bsoncore.AppendStringElement(dst, "architecture", runtime.GOARCH)
dst, _ = bsoncore.AppendDocumentEnd(dst, didx)
dst = bsoncore.AppendStringElement(dst, "platform", runtime.Version())
if im.appname != "" {
didx, dst = bsoncore.AppendDocumentElementStart(dst, "application")
dst = bsoncore.AppendStringElement(dst, "name", im.appname)
dst, _ = bsoncore.AppendDocumentEnd(dst, didx)
}
dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
return dst, nil
}
// command appends all necessary command fields.
func (im *IsMaster) command(dst []byte, _ description.SelectedServer) ([]byte, error) {
return bsoncore.AppendInt32Element(dst, "isMaster", 1), nil
}
// Execute runs this operation.
func (im *IsMaster) Execute(ctx context.Context) error {
if im.d == nil {
return errors.New("an IsMaster must have a Deployment set before Execute can be called")
}
return driver.Operation{
Clock: im.clock,
CommandFn: im.command,
Database: "admin",
Deployment: im.d,
ProcessResponseFn: func(response bsoncore.Document, _ driver.Server, _ description.Server) error {
im.res = response
return nil
},
}.Execute(ctx, nil)
}
// Handshake implements the Handshaker interface.
func (im *IsMaster) Handshake(ctx context.Context, _ address.Address, c driver.Connection) (description.Server, error) {
err := driver.Operation{
Clock: im.clock,
CommandFn: im.handshakeCommand,
Deployment: driver.SingleConnectionDeployment{c},
Database: "admin",
ProcessResponseFn: func(response bsoncore.Document, _ driver.Server, _ description.Server) error {
im.res = response
return nil
},
}.Execute(ctx, nil)
if err != nil {
return description.Server{}, err
}
return im.Result(c.Address()), nil
}

View File

@@ -0,0 +1,274 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// ListDatabases performs a listDatabases operation.
type ListDatabases struct {
filter bsoncore.Document
nameOnly *bool
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readPreference *readpref.ReadPref
selector description.ServerSelector
result ListDatabasesResult
}
type ListDatabasesResult struct {
// An array of documents, one document for each database
Databases []databaseRecord
// The sum of the size of all the database files on disk in bytes.
TotalSize int64
}
type databaseRecord struct {
Name string
SizeOnDisk int64 `bson:"sizeOnDisk"`
Empty bool
}
func buildListDatabasesResult(response bsoncore.Document, srvr driver.Server) (ListDatabasesResult, error) {
elements, err := response.Elements()
if err != nil {
return ListDatabasesResult{}, err
}
ir := ListDatabasesResult{}
for _, element := range elements {
switch element.Key() {
case "totalSize":
var ok bool
ir.TotalSize, ok = element.Value().AsInt64OK()
if !ok {
err = fmt.Errorf("response field 'totalSize' is type int64, but received BSON type %s: %s", element.Value().Type, element.Value())
}
case "databases":
// TODO: Make operationgen handle array results.
arr, ok := element.Value().ArrayOK()
if !ok {
err = fmt.Errorf("response field 'databases' is type array, but received BSON type %s", element.Value().Type)
continue
}
var tmp bsoncore.Document
marshalErr := bson.Unmarshal(arr, &tmp)
if marshalErr != nil {
err = marshalErr
continue
}
records, marshalErr := tmp.Elements()
if marshalErr != nil {
err = marshalErr
continue
}
ir.Databases = make([]databaseRecord, len(records))
for i, val := range records {
valueDoc, ok := val.Value().DocumentOK()
if !ok {
err = fmt.Errorf("'databases' element is type document, but received BSON type %s", val.Value().Type)
continue
}
elems, marshalErr := valueDoc.Elements()
if marshalErr != nil {
err = marshalErr
continue
}
for _, elem := range elems {
switch elem.Key() {
case "name":
ir.Databases[i].Name, ok = elem.Value().StringValueOK()
if !ok {
err = fmt.Errorf("response field 'name' is type string, but received BSON type %s", elem.Value().Type)
continue
}
case "sizeOnDisk":
ir.Databases[i].SizeOnDisk, ok = elem.Value().AsInt64OK()
if !ok {
err = fmt.Errorf("response field 'sizeOnDisk' is type int64, but received BSON type %s", elem.Value().Type)
continue
}
case "empty":
ir.Databases[i].Empty, ok = elem.Value().BooleanOK()
if !ok {
err = fmt.Errorf("response field 'empty' is type bool, but received BSON type %s", elem.Value().Type)
continue
}
}
}
}
}
}
return ir, err
}
// NewListDatabases constructs and returns a new ListDatabases.
func NewListDatabases(filter bsoncore.Document) *ListDatabases {
return &ListDatabases{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (ld *ListDatabases) Result() ListDatabasesResult { return ld.result }
func (ld *ListDatabases) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
ld.result, err = buildListDatabasesResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (ld *ListDatabases) Execute(ctx context.Context) error {
if ld.deployment == nil {
return errors.New("the ListDatabases operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: ld.command,
ProcessResponseFn: ld.processResponse,
Client: ld.session,
Clock: ld.clock,
CommandMonitor: ld.monitor,
Database: ld.database,
Deployment: ld.deployment,
ReadPreference: ld.readPreference,
Selector: ld.selector,
}.Execute(ctx, nil)
}
func (ld *ListDatabases) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "listDatabases", 1)
if ld.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", ld.filter)
}
if ld.nameOnly != nil {
dst = bsoncore.AppendBooleanElement(dst, "nameOnly", *ld.nameOnly)
}
return dst, nil
}
// Filter determines what results are returned from listDatabases.
func (ld *ListDatabases) Filter(filter bsoncore.Document) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.filter = filter
return ld
}
// NameOnly specifies whether to only return database names.
func (ld *ListDatabases) NameOnly(nameOnly bool) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.nameOnly = &nameOnly
return ld
}
// Session sets the session for this operation.
func (ld *ListDatabases) Session(session *session.Client) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.session = session
return ld
}
// ClusterClock sets the cluster clock for this operation.
func (ld *ListDatabases) ClusterClock(clock *session.ClusterClock) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.clock = clock
return ld
}
// CommandMonitor sets the monitor to use for APM events.
func (ld *ListDatabases) CommandMonitor(monitor *event.CommandMonitor) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.monitor = monitor
return ld
}
// Database sets the database to run this operation against.
func (ld *ListDatabases) Database(database string) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.database = database
return ld
}
// Deployment sets the deployment to use for this operation.
func (ld *ListDatabases) Deployment(deployment driver.Deployment) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.deployment = deployment
return ld
}
// ReadPreference set the read prefernce used with this operation.
func (ld *ListDatabases) ReadPreference(readPreference *readpref.ReadPref) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.readPreference = readPreference
return ld
}
// ServerSelector sets the selector used to retrieve a server.
func (ld *ListDatabases) ServerSelector(selector description.ServerSelector) *ListDatabases {
if ld == nil {
ld = new(ListDatabases)
}
ld.selector = selector
return ld
}

View File

@@ -0,0 +1,32 @@
version = 0
name = "ListDatabases"
documentation = "ListDatabases performs a listDatabases operation."
[properties]
enabled = ["read preference"]
disabled = ["collection"]
[command]
name = "listDatabases"
parameter = "database"
[request.filter]
type = "document"
constructor = true
documentation = "Filter determines what results are returned from listDatabases."
[request.nameOnly]
type = "boolean"
documentation = "NameOnly specifies whether to only return database names."
[response]
name = "ListDatabasesResult"
[response.field.totalSize]
type = "int64"
documentation = "The sum of the size of all the database files on disk in bytes."
[response.field.databases]
type = "value"
documentation = "An array of documents, one document for each database"

View File

@@ -0,0 +1,184 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// ListCollections performs a listCollections operation.
type ListCollections struct {
filter bsoncore.Document
nameOnly *bool
session *session.Client
clock *session.ClusterClock
monitor *event.CommandMonitor
database string
deployment driver.Deployment
readPreference *readpref.ReadPref
selector description.ServerSelector
result driver.CursorResponse
}
// NewListCollections constructs and returns a new ListCollections.
func NewListCollections(filter bsoncore.Document) *ListCollections {
return &ListCollections{
filter: filter,
}
}
// Result returns the result of executing this operation.
func (lc *ListCollections) Result(opts driver.CursorOptions) (*driver.ListCollectionsBatchCursor, error) {
bc, err := driver.NewBatchCursor(lc.result, lc.session, lc.clock, opts)
if err != nil {
return nil, err
}
desc := lc.result.Desc
if desc.WireVersion == nil || desc.WireVersion.Max < 3 {
return driver.NewLegacyListCollectionsBatchCursor(bc)
}
return driver.NewListCollectionsBatchCursor(bc)
}
func (lc *ListCollections) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
lc.result, err = driver.NewCursorResponse(response, srvr, desc)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (lc *ListCollections) Execute(ctx context.Context) error {
if lc.deployment == nil {
return errors.New("the ListCollections operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: lc.command,
ProcessResponseFn: lc.processResponse,
Client: lc.session,
Clock: lc.clock,
CommandMonitor: lc.monitor,
Database: lc.database,
Deployment: lc.deployment,
ReadPreference: lc.readPreference,
Selector: lc.selector,
Legacy: driver.LegacyListCollections,
}.Execute(ctx, nil)
}
func (lc *ListCollections) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt32Element(dst, "listCollections", 1)
if lc.filter != nil {
dst = bsoncore.AppendDocumentElement(dst, "filter", lc.filter)
}
if lc.nameOnly != nil {
dst = bsoncore.AppendBooleanElement(dst, "nameOnly", *lc.nameOnly)
}
return dst, nil
}
// Filter determines what results are returned from listCollections.
func (lc *ListCollections) Filter(filter bsoncore.Document) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.filter = filter
return lc
}
// NameOnly specifies whether to only return collection names.
func (lc *ListCollections) NameOnly(nameOnly bool) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.nameOnly = &nameOnly
return lc
}
// Session sets the session for this operation.
func (lc *ListCollections) Session(session *session.Client) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.session = session
return lc
}
// ClusterClock sets the cluster clock for this operation.
func (lc *ListCollections) ClusterClock(clock *session.ClusterClock) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.clock = clock
return lc
}
// CommandMonitor sets the monitor to use for APM events.
func (lc *ListCollections) CommandMonitor(monitor *event.CommandMonitor) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.monitor = monitor
return lc
}
// Database sets the database to run this operation against.
func (lc *ListCollections) Database(database string) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.database = database
return lc
}
// Deployment sets the deployment to use for this operation.
func (lc *ListCollections) Deployment(deployment driver.Deployment) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.deployment = deployment
return lc
}
// ReadPreference set the read prefernce used with this operation.
func (lc *ListCollections) ReadPreference(readPreference *readpref.ReadPref) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.readPreference = readPreference
return lc
}
// ServerSelector sets the selector used to retrieve a server.
func (lc *ListCollections) ServerSelector(selector description.ServerSelector) *ListCollections {
if lc == nil {
lc = new(ListCollections)
}
lc.selector = selector
return lc
}

View File

@@ -0,0 +1,22 @@
version = 0
name = "ListCollections"
documentation = "ListCollections performs a listCollections operation."
response.type = "list collections batch cursor"
[properties]
enabled = ["read preference"]
disabled = ["collection"]
legacy = "listCollections"
[command]
name = "listCollections"
parameter = "database"
[request.filter]
type = "document"
constructor = true
documentation = "Filter determines what results are returned from listCollections."
[request.nameOnly]
type = "boolean"
documentation = "NameOnly specifies whether to only return collection names."

View File

@@ -0,0 +1,186 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Code generated by operationgen. DO NOT EDIT.
package operation
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// ListIndexes performs a listIndexes operation.
type ListIndexes struct {
batchSize *int32
maxTimeMS *int64
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
result driver.CursorResponse
}
// NewListIndexes constructs and returns a new ListIndexes.
func NewListIndexes() *ListIndexes {
return &ListIndexes{}
}
// Result returns the result of executing this operation.
func (li *ListIndexes) Result(opts driver.CursorOptions) (*driver.BatchCursor, error) {
clientSession := li.session
clock := li.clock
return driver.NewBatchCursor(li.result, clientSession, clock, opts)
}
func (li *ListIndexes) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
li.result, err = driver.NewCursorResponse(response, srvr, desc)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (li *ListIndexes) Execute(ctx context.Context) error {
if li.deployment == nil {
return errors.New("the ListIndexes operation must have a Deployment set before Execute can be called")
}
return driver.Operation{
CommandFn: li.command,
ProcessResponseFn: li.processResponse,
Client: li.session,
Clock: li.clock,
CommandMonitor: li.monitor,
Database: li.database,
Deployment: li.deployment,
Selector: li.selector,
Legacy: driver.LegacyListIndexes,
}.Execute(ctx, nil)
}
func (li *ListIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "listIndexes", li.collection)
cursorIdx, cursorDoc := bsoncore.AppendDocumentStart(nil)
if li.batchSize != nil {
cursorDoc = bsoncore.AppendInt32Element(cursorDoc, "batchSize", *li.batchSize)
}
if li.maxTimeMS != nil {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", *li.maxTimeMS)
}
cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx)
dst = bsoncore.AppendDocumentElement(dst, "cursor", cursorDoc)
return dst, nil
}
// BatchSize specifies the number of documents to return in every batch.
func (li *ListIndexes) BatchSize(batchSize int32) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.batchSize = &batchSize
return li
}
// MaxTimeMS specifies the maximum amount of time to allow the query to run.
func (li *ListIndexes) MaxTimeMS(maxTimeMS int64) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.maxTimeMS = &maxTimeMS
return li
}
// Session sets the session for this operation.
func (li *ListIndexes) Session(session *session.Client) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.session = session
return li
}
// ClusterClock sets the cluster clock for this operation.
func (li *ListIndexes) ClusterClock(clock *session.ClusterClock) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.clock = clock
return li
}
// Collection sets the collection that this command will run against.
func (li *ListIndexes) Collection(collection string) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.collection = collection
return li
}
// CommandMonitor sets the monitor to use for APM events.
func (li *ListIndexes) CommandMonitor(monitor *event.CommandMonitor) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.monitor = monitor
return li
}
// Database sets the database to run this operation against.
func (li *ListIndexes) Database(database string) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.database = database
return li
}
// Deployment sets the deployment to use for this operation.
func (li *ListIndexes) Deployment(deployment driver.Deployment) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.deployment = deployment
return li
}
// ServerSelector sets the selector used to retrieve a server.
func (li *ListIndexes) ServerSelector(selector description.ServerSelector) *ListIndexes {
if li == nil {
li = new(ListIndexes)
}
li.selector = selector
return li
}

View File

@@ -0,0 +1,19 @@
version = 0
name = "ListIndexes"
documentation = "ListIndexes performs a listIndexes operation."
response.type = "batch cursor"
[properties]
legacy = "listIndexes"
[command]
name = "listIndexes"
parameter = "collection"
[request.batchSize]
type = "int32"
documentation = "BatchSize specifies the number of documents to return in every batch."
[request.maxTimeMS]
type = "int64"
documentation = "MaxTimeMS specifies the maximum amount of time to allow the query to run."

View File

@@ -0,0 +1,12 @@
package operation
//go:generate operationgen insert.toml operation insert.go
//go:generate operationgen find.toml operation find.go
//go:generate operationgen list_collections.toml operation list_collections.go
//go:generate operationgen createIndexes.toml operation createIndexes.go
//go:generate operationgen drop_collection.toml operation drop_collection.go
//go:generate operationgen distinct.toml operation distinct.go
//go:generate operationgen delete.toml operation delete.go
//go:generate operationgen drop_indexes.toml operation drop_indexes.go
//go:generate operationgen drop_database.toml operation drop_database.go
//go:generate operationgen commit_transaction.toml operation commit_transaction.go

View File

@@ -0,0 +1,295 @@
// Copyright (C) MongoDB, Inc. 2019-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// NOTE: This file is maintained by hand because operationgen cannot generate it.
package operation
import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Update performs an update operation.
type Update struct {
bypassDocumentValidation *bool
ordered *bool
updates []bsoncore.Document
session *session.Client
clock *session.ClusterClock
collection string
monitor *event.CommandMonitor
database string
deployment driver.Deployment
selector description.ServerSelector
writeConcern *writeconcern.WriteConcern
retry *driver.RetryMode
result UpdateResult
}
// Upsert contains the information for an upsert in an Update operation.
type Upsert struct {
Index int64
ID interface{} `bson:"_id"`
}
// UpdateResult contains information for the result of an Update operation.
type UpdateResult struct {
// Number of documents matched.
N int32
// Number of documents modified.
NModified int32
// Information about upserted documents.
Upserted []Upsert
}
func buildUpdateResult(response bsoncore.Document, srvr driver.Server) (UpdateResult, error) {
elements, err := response.Elements()
if err != nil {
return UpdateResult{}, err
}
ur := UpdateResult{}
for _, element := range elements {
switch element.Key() {
case "nModified":
var ok bool
ur.NModified, ok = element.Value().Int32OK()
if !ok {
err = fmt.Errorf("response field 'nModified' is type int32, but received BSON type %s", element.Value().Type)
}
case "n":
var ok bool
ur.N, ok = element.Value().Int32OK()
if !ok {
err = fmt.Errorf("response field 'n' is type int32, but received BSON type %s", element.Value().Type)
}
case "upserted":
arr, ok := element.Value().ArrayOK()
if !ok {
err = fmt.Errorf("response field 'upserted' is type array, but received BSON type %s", element.Value().Type)
break
}
var values []bsoncore.Value
values, err = arr.Values()
if err != nil {
break
}
for _, val := range values {
valDoc, ok := val.DocumentOK()
if !ok {
err = fmt.Errorf("upserted value is type document, but received BSON type %s", val.Type)
break
}
var upsert Upsert
if err = bson.Unmarshal(valDoc, &upsert); err != nil {
break
}
ur.Upserted = append(ur.Upserted, upsert)
}
}
}
return ur, nil
}
// NewUpdate constructs and returns a new Update.
func NewUpdate(updates ...bsoncore.Document) *Update {
return &Update{
updates: updates,
}
}
// Result returns the result of executing this operation.
func (u *Update) Result() UpdateResult { return u.result }
func (u *Update) processResponse(response bsoncore.Document, srvr driver.Server, desc description.Server) error {
var err error
u.result, err = buildUpdateResult(response, srvr)
return err
}
// Execute runs this operations and returns an error if the operaiton did not execute successfully.
func (u *Update) Execute(ctx context.Context) error {
if u.deployment == nil {
return errors.New("the Update operation must have a Deployment set before Execute can be called")
}
batches := &driver.Batches{
Identifier: "updates",
Documents: u.updates,
Ordered: u.ordered,
}
return driver.Operation{
CommandFn: u.command,
ProcessResponseFn: u.processResponse,
Batches: batches,
RetryMode: u.retry,
RetryType: driver.RetryWrite,
Client: u.session,
Clock: u.clock,
CommandMonitor: u.monitor,
Database: u.database,
Deployment: u.deployment,
Selector: u.selector,
WriteConcern: u.writeConcern,
}.Execute(ctx, nil)
}
func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "update", u.collection)
if u.bypassDocumentValidation != nil &&
(desc.WireVersion != nil && desc.WireVersion.Includes(4)) {
dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *u.bypassDocumentValidation)
}
if u.ordered != nil {
dst = bsoncore.AppendBooleanElement(dst, "ordered", *u.ordered)
}
return dst, nil
}
// BypassDocumentValidation allows the operation to opt-out of document level validation. Valid
// for server versions >= 3.2. For servers < 3.2, this setting is ignored.
func (u *Update) BypassDocumentValidation(bypassDocumentValidation bool) *Update {
if u == nil {
u = new(Update)
}
u.bypassDocumentValidation = &bypassDocumentValidation
return u
}
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
// false write failures do not stop execution of the operation.
func (u *Update) Ordered(ordered bool) *Update {
if u == nil {
u = new(Update)
}
u.ordered = &ordered
return u
}
// Updates specifies an array of update statements to perform when this operation is executed.
// Each update document must have the following structure: {q: <query>, u: <update>, multi: <boolean>, collation: Optional<Document>, arrayFitlers: Optional<Array>}.
func (u *Update) Updates(updates ...bsoncore.Document) *Update {
if u == nil {
u = new(Update)
}
u.updates = updates
return u
}
// Session sets the session for this operation.
func (u *Update) Session(session *session.Client) *Update {
if u == nil {
u = new(Update)
}
u.session = session
return u
}
// ClusterClock sets the cluster clock for this operation.
func (u *Update) ClusterClock(clock *session.ClusterClock) *Update {
if u == nil {
u = new(Update)
}
u.clock = clock
return u
}
// Collection sets the collection that this command will run against.
func (u *Update) Collection(collection string) *Update {
if u == nil {
u = new(Update)
}
u.collection = collection
return u
}
// CommandMonitor sets the monitor to use for APM events.
func (u *Update) CommandMonitor(monitor *event.CommandMonitor) *Update {
if u == nil {
u = new(Update)
}
u.monitor = monitor
return u
}
// Database sets the database to run this operation against.
func (u *Update) Database(database string) *Update {
if u == nil {
u = new(Update)
}
u.database = database
return u
}
// Deployment sets the deployment to use for this operation.
func (u *Update) Deployment(deployment driver.Deployment) *Update {
if u == nil {
u = new(Update)
}
u.deployment = deployment
return u
}
// ServerSelector sets the selector used to retrieve a server.
func (u *Update) ServerSelector(selector description.ServerSelector) *Update {
if u == nil {
u = new(Update)
}
u.selector = selector
return u
}
// WriteConcern sets the write concern for this operation.
func (u *Update) WriteConcern(writeConcern *writeconcern.WriteConcern) *Update {
if u == nil {
u = new(Update)
}
u.writeConcern = writeConcern
return u
}
// Retry enables retryable writes for this operation. Retries are not handled automatically,
// instead a boolean is returned from Execute and SelectAndExecute that indicates if the
// operation can be retried. Retrying is handled by calling RetryExecute.
func (u *Update) Retry(retry driver.RetryMode) *Update {
if u == nil {
u = new(Update)
}
u.retry = &retry
return u
}

View File

@@ -0,0 +1,49 @@
version = 0
name = "Update"
documentation = "Update performs an update operation."
[properties]
enabled = ["write concern"]
retryable = {mode = "once per command", type = "writes"}
batches = "updates"
[command]
name = "update"
parameter = "collection"
[request.updates]
type = "document"
slice = true
constructor = true
variadic = true
required = true
documentation = """
Updates specifies an array of update statements to perform when this operation is executed.
Each update document must have the following structure: {q: <query>, u: <update>, multi: <boolean>, collation: Optional<Document>, arrayFitlers: Optional<Array>}.\
"""
[request.ordered]
type = "boolean"
documentation = """
Ordered sets ordered. If true, when a write fails, the operation will return the error, when
false write failures do not stop execution of the operation.\
"""
[request.bypassDocumentValidation]
type = "boolean"
minWireVersion = 4
documentation = """
BypassDocumentValidation allows the operation to opt-out of document level validation. Valid
for server versions >= 3.2. For servers < 3.2, this setting is ignored.\
"""
[response]
name = "UpdateResult"
[response.field.n]
type = "int32"
documentation = "Number of documents matched."
[response.field.nModified]
type = "int32"
documentation = "Number of documents modified."

View File

@@ -0,0 +1,669 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"strconv"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
var (
firstBatchIdentifier = "firstBatch"
nextBatchIdentifier = "nextBatch"
listCollectionsNamespace = "system.namespaces"
listIndexesNamespace = "system.indexes"
// ErrFilterType is returned when the filter for a legacy list collections operation is of the wrong type.
ErrFilterType = errors.New("filter for list collections operation must be a string")
)
func (op Operation) getFullCollectionName(coll string) string {
return op.Database + "." + coll
}
func (op Operation) legacyFind(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error {
wm, startedInfo, collName, err := op.createLegacyFindWireMessage(dst, desc)
if err != nil {
return err
}
startedInfo.connID = conn.ID()
op.publishStartedEvent(ctx, startedInfo)
finishedInfo := finishedInformation{
cmdName: startedInfo.cmdName,
requestID: startedInfo.requestID,
startTime: time.Now(),
connID: startedInfo.connID,
}
finishedInfo.response, finishedInfo.cmdErr = op.roundTripLegacyCursor(ctx, wm, srvr, conn, collName, firstBatchIdentifier)
op.publishFinishedEvent(ctx, finishedInfo)
if op.ProcessResponseFn != nil {
return op.ProcessResponseFn(finishedInfo.response, srvr, desc.Server)
}
return nil
}
// returns wire message, collection name, error
func (op Operation) createLegacyFindWireMessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) {
info := startedInformation{
requestID: wiremessage.NextRequestID(),
cmdName: "find",
}
// call CommandFn on an empty slice rather than dst because the options will need to be converted to legacy
var cmdDoc bsoncore.Document
var cmdIndex int32
var err error
cmdIndex, cmdDoc = bsoncore.AppendDocumentStart(cmdDoc)
cmdDoc, err = op.CommandFn(cmdDoc, desc)
if err != nil {
return dst, info, "", err
}
cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIndex)
// for monitoring legacy events, the upconverted document should be captured rather than the legacy one
info.cmd = cmdDoc
cmdElems, err := cmdDoc.Elements()
if err != nil {
return dst, info, "", err
}
// take each option from the non-legacy command and convert it
// build options as a byte slice of elements rather than a bsoncore.Document because they will be appended
// to another document with $query
var optsElems []byte
flags := op.slaveOK(desc)
var numToSkip, numToReturn, batchSize, limit int32 // numToReturn calculated from batchSize and limit
var filter, returnFieldsSelector bsoncore.Document
var collName string
var singleBatch bool
for _, elem := range cmdElems {
switch elem.Key() {
case "find":
collName = elem.Value().StringValue()
case "filter":
filter = elem.Value().Data
case "sort":
optsElems = bsoncore.AppendValueElement(optsElems, "$orderby", elem.Value())
case "hint":
optsElems = bsoncore.AppendValueElement(optsElems, "$hint", elem.Value())
case "comment":
optsElems = bsoncore.AppendValueElement(optsElems, "$comment", elem.Value())
case "maxScan":
optsElems = bsoncore.AppendValueElement(optsElems, "$maxScan", elem.Value())
case "max":
optsElems = bsoncore.AppendValueElement(optsElems, "$max", elem.Value())
case "min":
optsElems = bsoncore.AppendValueElement(optsElems, "$min", elem.Value())
case "returnKey":
optsElems = bsoncore.AppendValueElement(optsElems, "$returnKey", elem.Value())
case "showRecordId":
optsElems = bsoncore.AppendValueElement(optsElems, "$showDiskLoc", elem.Value())
case "maxTimeMS":
optsElems = bsoncore.AppendValueElement(optsElems, "$maxTimeMS", elem.Value())
case "snapshot":
optsElems = bsoncore.AppendValueElement(optsElems, "$snapshot", elem.Value())
case "projection":
returnFieldsSelector = elem.Value().Data
case "skip":
// CRUD spec declares skip as int64 but numToSkip is int32 in OP_QUERY
numToSkip = int32(elem.Value().Int64())
case "batchSize":
batchSize = elem.Value().Int32()
// Not possible to use batchSize = 1 because cursor will be closed on first batch
if batchSize == 1 {
batchSize = 2
}
case "limit":
// CRUD spec declares limit as int64 but numToReturn is int32 in OP_QUERY
limit = int32(elem.Value().Int64())
case "singleBatch":
singleBatch = elem.Value().Boolean()
case "tailable":
flags |= wiremessage.TailableCursor
case "awaitData":
flags |= wiremessage.AwaitData
case "oplogReply":
flags |= wiremessage.OplogReplay
case "noCursorTimeout":
flags |= wiremessage.NoCursorTimeout
case "allowPartialResults":
flags |= wiremessage.Partial
}
}
// for non-legacy servers, a negative limit is implemented as a positive limit + singleBatch = true
if singleBatch {
limit = limit * -1
}
numToReturn = op.calculateNumberToReturn(limit, batchSize)
// add read preference if needed
rp, err := op.createReadPref(desc.Server.Kind, desc.Kind, true)
if err != nil {
return dst, info, "", err
}
if len(rp) > 0 {
optsElems = bsoncore.AppendDocumentElement(optsElems, "$readPreference", rp)
}
if len(filter) == 0 {
var fidx int32
fidx, filter = bsoncore.AppendDocumentStart(filter)
filter, _ = bsoncore.AppendDocumentEnd(filter, fidx)
}
var wmIdx int32
wmIdx, dst = wiremessagex.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessagex.AppendQueryFlags(dst, flags)
dst = wiremessagex.AppendQueryFullCollectionName(dst, op.getFullCollectionName(collName))
dst = wiremessagex.AppendQueryNumberToSkip(dst, numToSkip)
dst = wiremessagex.AppendQueryNumberToReturn(dst, numToReturn)
dst = op.appendLegacyQueryDocument(dst, filter, optsElems)
if len(returnFieldsSelector) != 0 {
// returnFieldsSelector is optional
dst = append(dst, returnFieldsSelector...)
}
return bsoncore.UpdateLength(dst, wmIdx, int32(len(dst[wmIdx:]))), info, collName, nil
}
func (op Operation) calculateNumberToReturn(limit, batchSize int32) int32 {
var numToReturn int32
if limit < 0 {
numToReturn = limit
} else if limit == 0 {
numToReturn = batchSize
} else if batchSize == 0 {
numToReturn = limit
} else if limit < batchSize {
numToReturn = limit
} else {
numToReturn = batchSize
}
return numToReturn
}
func (op Operation) legacyGetMore(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error {
wm, startedInfo, collName, err := op.createLegacyGetMoreWiremessage(dst, desc)
if err != nil {
return err
}
startedInfo.connID = conn.ID()
op.publishStartedEvent(ctx, startedInfo)
finishedInfo := finishedInformation{
cmdName: startedInfo.cmdName,
requestID: startedInfo.requestID,
startTime: time.Now(),
connID: startedInfo.connID,
}
finishedInfo.response, finishedInfo.cmdErr = op.roundTripLegacyCursor(ctx, wm, srvr, conn, collName, nextBatchIdentifier)
op.publishFinishedEvent(ctx, finishedInfo)
if op.ProcessResponseFn != nil {
return op.ProcessResponseFn(finishedInfo.response, srvr, desc.Server)
}
return nil
}
func (op Operation) createLegacyGetMoreWiremessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) {
info := startedInformation{
requestID: wiremessage.NextRequestID(),
cmdName: "getMore",
}
var cmdDoc bsoncore.Document
var cmdIdx int32
var err error
cmdIdx, cmdDoc = bsoncore.AppendDocumentStart(cmdDoc)
cmdDoc, err = op.CommandFn(cmdDoc, desc)
if err != nil {
return dst, info, "", err
}
cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIdx)
info.cmd = cmdDoc
cmdElems, err := cmdDoc.Elements()
if err != nil {
return dst, info, "", err
}
var cursorID int64
var numToReturn int32
var collName string
for _, elem := range cmdElems {
switch elem.Key() {
case "getMore":
cursorID = elem.Value().Int64()
case "collection":
collName = elem.Value().StringValue()
case "batchSize":
numToReturn = elem.Value().Int32()
}
}
var wmIdx int32
wmIdx, dst = wiremessagex.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpGetMore)
dst = wiremessagex.AppendGetMoreZero(dst)
dst = wiremessagex.AppendGetMoreFullCollectionName(dst, op.getFullCollectionName(collName))
dst = wiremessagex.AppendGetMoreNumberToReturn(dst, numToReturn)
dst = wiremessagex.AppendGetMoreCursorID(dst, cursorID)
return bsoncore.UpdateLength(dst, wmIdx, int32(len(dst[wmIdx:]))), info, collName, nil
}
func (op Operation) legacyKillCursors(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error {
wm, startedInfo, _, err := op.createLegacyKillCursorsWiremessage(dst, desc)
if err != nil {
return err
}
startedInfo.connID = conn.ID()
op.publishStartedEvent(ctx, startedInfo)
// skip startTime because OP_KILL_CURSORS does not return a response
finishedInfo := finishedInformation{
cmdName: "killCursors",
requestID: startedInfo.requestID,
connID: startedInfo.connID,
}
err = conn.WriteWireMessage(ctx, wm)
if err != nil {
err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
if ep, ok := srvr.(ErrorProcessor); ok {
ep.ProcessError(err)
}
finishedInfo.cmdErr = err
op.publishFinishedEvent(ctx, finishedInfo)
return err
}
ridx, response := bsoncore.AppendDocumentStart(nil)
response = bsoncore.AppendInt32Element(response, "ok", 1)
response = bsoncore.AppendArrayElement(response, "cursorsKilled", startedInfo.cmd.Lookup("cursors").Array())
response, _ = bsoncore.AppendDocumentEnd(response, ridx)
finishedInfo.response = response
op.publishFinishedEvent(ctx, finishedInfo)
return nil
}
func (op Operation) createLegacyKillCursorsWiremessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) {
info := startedInformation{
cmdName: "killCursors",
requestID: wiremessage.NextRequestID(),
}
var cmdDoc bsoncore.Document
var cmdIdx int32
var err error
cmdIdx, cmdDoc = bsoncore.AppendDocumentStart(cmdDoc)
cmdDoc, err = op.CommandFn(cmdDoc, desc)
if err != nil {
return nil, info, "", err
}
cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIdx)
info.cmd = cmdDoc
cmdElems, err := cmdDoc.Elements()
if err != nil {
return nil, info, "", err
}
var collName string
var cursors bsoncore.Document
for _, elem := range cmdElems {
switch elem.Key() {
case "killCursors":
collName = elem.Value().StringValue()
case "cursors":
cursors = elem.Value().Array()
}
}
var cursorIDs []int64
if cursors != nil {
cursorValues, err := cursors.Values()
if err != nil {
return nil, info, "", err
}
for _, cursorVal := range cursorValues {
cursorIDs = append(cursorIDs, cursorVal.Int64())
}
}
var wmIdx int32
wmIdx, dst = wiremessagex.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpKillCursors)
dst = wiremessagex.AppendKillCursorsZero(dst)
dst = wiremessagex.AppendKillCursorsNumberIDs(dst, int32(len(cursorIDs)))
dst = wiremessagex.AppendKillCursorsCursorIDs(dst, cursorIDs)
return bsoncore.UpdateLength(dst, wmIdx, int32(len(dst[wmIdx:]))), info, collName, nil
}
func (op Operation) legacyListCollections(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error {
wm, startedInfo, collName, err := op.createLegacyListCollectionsWiremessage(dst, desc)
if err != nil {
return err
}
startedInfo.connID = conn.ID()
op.publishStartedEvent(ctx, startedInfo)
finishedInfo := finishedInformation{
cmdName: startedInfo.cmdName,
requestID: startedInfo.requestID,
startTime: time.Now(),
connID: startedInfo.connID,
}
finishedInfo.response, finishedInfo.cmdErr = op.roundTripLegacyCursor(ctx, wm, srvr, conn, collName, firstBatchIdentifier)
op.publishFinishedEvent(ctx, finishedInfo)
if op.ProcessResponseFn != nil {
return op.ProcessResponseFn(finishedInfo.response, srvr, desc.Server)
}
return nil
}
func (op Operation) createLegacyListCollectionsWiremessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) {
info := startedInformation{
cmdName: "find",
requestID: wiremessage.NextRequestID(),
}
var cmdDoc bsoncore.Document
var cmdIdx int32
var err error
cmdIdx, cmdDoc = bsoncore.AppendDocumentStart(cmdDoc)
if cmdDoc, err = op.CommandFn(cmdDoc, desc); err != nil {
return dst, info, "", err
}
cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIdx)
info.cmd, err = op.convertCommandToFind(cmdDoc, listCollectionsNamespace)
if err != nil {
return nil, info, "", err
}
// lookup filter directly instead of calling cmdDoc.Elements() because the only listCollections option is nameOnly,
// which doesn't apply to legacy servers
var originalFilter bsoncore.Document
if filterVal, err := cmdDoc.LookupErr("filter"); err == nil {
originalFilter = filterVal.Document()
}
var optsElems []byte
filter, err := op.transformListCollectionsFilter(originalFilter)
if err != nil {
return dst, info, "", err
}
rp, err := op.createReadPref(desc.Server.Kind, desc.Kind, true)
if err != nil {
return dst, info, "", err
}
if len(rp) > 0 {
optsElems = bsoncore.AppendDocumentElement(optsElems, "$readPreference", rp)
}
var wmIdx int32
wmIdx, dst = wiremessagex.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessagex.AppendQueryFlags(dst, op.slaveOK(desc))
dst = wiremessagex.AppendQueryFullCollectionName(dst, op.getFullCollectionName(listCollectionsNamespace))
dst = wiremessagex.AppendQueryNumberToSkip(dst, 0)
dst = wiremessagex.AppendQueryNumberToReturn(dst, 0)
dst = op.appendLegacyQueryDocument(dst, filter, optsElems)
// leave out returnFieldsSelector because it is optional
return bsoncore.UpdateLength(dst, wmIdx, int32(len(dst[wmIdx:]))), info, listCollectionsNamespace, nil
}
func (op Operation) transformListCollectionsFilter(filter bsoncore.Document) (bsoncore.Document, error) {
// filter out results containing $ because those represent indexes
var regexFilter bsoncore.Document
var ridx int32
ridx, regexFilter = bsoncore.AppendDocumentStart(regexFilter)
regexFilter = bsoncore.AppendRegexElement(regexFilter, "name", "^[^$]*$", "")
regexFilter, _ = bsoncore.AppendDocumentEnd(regexFilter, ridx)
if len(filter) == 0 {
return regexFilter, nil
}
convertedIdx, convertedFilter := bsoncore.AppendDocumentStart(nil)
elems, err := filter.Elements()
if err != nil {
return nil, err
}
for _, elem := range elems {
if elem.Key() != "name" {
convertedFilter = append(convertedFilter, elem...)
continue
}
// the name value in a filter for legacy list collections must be a string and has to be prepended
// with the database name
nameVal := elem.Value()
if nameVal.Type != bsontype.String {
return nil, ErrFilterType
}
convertedFilter = bsoncore.AppendStringElement(convertedFilter, "name", op.getFullCollectionName(nameVal.StringValue()))
}
convertedFilter, _ = bsoncore.AppendDocumentEnd(convertedFilter, convertedIdx)
// combine regexFilter and convertedFilter with $and
var combinedFilter bsoncore.Document
var cidx, aidx int32
cidx, combinedFilter = bsoncore.AppendDocumentStart(combinedFilter)
aidx, combinedFilter = bsoncore.AppendArrayElementStart(combinedFilter, "$and")
combinedFilter = bsoncore.AppendDocumentElement(combinedFilter, "0", regexFilter)
combinedFilter = bsoncore.AppendDocumentElement(combinedFilter, "1", convertedFilter)
combinedFilter, _ = bsoncore.AppendArrayEnd(combinedFilter, aidx)
combinedFilter, _ = bsoncore.AppendDocumentEnd(combinedFilter, cidx)
return combinedFilter, nil
}
func (op Operation) legacyListIndexes(ctx context.Context, dst []byte, srvr Server, conn Connection, desc description.SelectedServer) error {
wm, startedInfo, collName, err := op.createLegacyListIndexesWiremessage(dst, desc)
if err != nil {
return err
}
startedInfo.connID = conn.ID()
op.publishStartedEvent(ctx, startedInfo)
finishedInfo := finishedInformation{
cmdName: startedInfo.cmdName,
requestID: startedInfo.requestID,
startTime: time.Now(),
connID: startedInfo.connID,
}
finishedInfo.response, finishedInfo.cmdErr = op.roundTripLegacyCursor(ctx, wm, srvr, conn, collName, firstBatchIdentifier)
op.publishFinishedEvent(ctx, finishedInfo)
if op.ProcessResponseFn != nil {
return op.ProcessResponseFn(finishedInfo.response, srvr, desc.Server)
}
return nil
}
func (op Operation) createLegacyListIndexesWiremessage(dst []byte, desc description.SelectedServer) ([]byte, startedInformation, string, error) {
info := startedInformation{
cmdName: "find",
requestID: wiremessage.NextRequestID(),
}
var cmdDoc bsoncore.Document
var cmdIndex int32
var err error
cmdIndex, cmdDoc = bsoncore.AppendDocumentStart(cmdDoc)
cmdDoc, err = op.CommandFn(cmdDoc, desc)
if err != nil {
return dst, info, "", err
}
cmdDoc, _ = bsoncore.AppendDocumentEnd(cmdDoc, cmdIndex)
info.cmd, err = op.convertCommandToFind(cmdDoc, listIndexesNamespace)
if err != nil {
return nil, info, "", err
}
cmdElems, err := cmdDoc.Elements()
if err != nil {
return nil, info, "", err
}
var filterCollName string
var batchSize int32
var optsElems []byte // options elements
for _, elem := range cmdElems {
switch elem.Key() {
case "listIndexes":
filterCollName = elem.Value().StringValue()
case "batchSize":
batchSize = elem.Value().Int32()
case "maxTimeMS":
optsElems = bsoncore.AppendValueElement(optsElems, "$maxTimeMS", elem.Value())
}
}
// always filter with {ns: db.collection}
fidx, filter := bsoncore.AppendDocumentStart(nil)
filter = bsoncore.AppendStringElement(filter, "ns", op.getFullCollectionName(filterCollName))
filter, _ = bsoncore.AppendDocumentEnd(filter, fidx)
rp, err := op.createReadPref(desc.Server.Kind, desc.Kind, true)
if err != nil {
return dst, info, "", err
}
if len(rp) > 0 {
optsElems = bsoncore.AppendDocumentElement(optsElems, "$readPreference", rp)
}
var wmIdx int32
wmIdx, dst = wiremessagex.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessagex.AppendQueryFlags(dst, op.slaveOK(desc))
dst = wiremessagex.AppendQueryFullCollectionName(dst, op.getFullCollectionName(listIndexesNamespace))
dst = wiremessagex.AppendQueryNumberToSkip(dst, 0)
dst = wiremessagex.AppendQueryNumberToReturn(dst, batchSize)
dst = op.appendLegacyQueryDocument(dst, filter, optsElems)
// leave out returnFieldsSelector because it is optional
return bsoncore.UpdateLength(dst, wmIdx, int32(len(dst[wmIdx:]))), info, listIndexesNamespace, nil
}
// convertCommandToFind takes a non-legacy command document for a command that needs to be run as a find on legacy
// servers and converts it to a find command document for APM.
func (op Operation) convertCommandToFind(cmdDoc bsoncore.Document, collName string) (bsoncore.Document, error) {
cidx, converted := bsoncore.AppendDocumentStart(nil)
elems, err := cmdDoc.Elements()
if err != nil {
return nil, err
}
converted = bsoncore.AppendStringElement(converted, "find", collName)
// skip the first element because that will have the old command name
for i := 1; i < len(elems); i++ {
converted = bsoncore.AppendValueElement(converted, elems[i].Key(), elems[i].Value())
}
converted, _ = bsoncore.AppendDocumentEnd(converted, cidx)
return converted, nil
}
// appendLegacyQueryDocument takes a filter and a list of options elements for a legacy find operation, creates
// a query document, and appends it to dst.
func (op Operation) appendLegacyQueryDocument(dst []byte, filter bsoncore.Document, opts []byte) []byte {
if len(opts) == 0 {
dst = append(dst, filter...)
return dst
}
// filter must be wrapped in $query if other $-modifiers are used
var qidx int32
qidx, dst = bsoncore.AppendDocumentStart(dst)
dst = bsoncore.AppendDocumentElement(dst, "$query", filter)
dst = append(dst, opts...)
dst, _ = bsoncore.AppendDocumentEnd(dst, qidx)
return dst
}
// roundTripLegacyCursor sends a wiremessage for an operation expecting a cursor result and converts the legacy
// document sequence into a cursor document.
func (op Operation) roundTripLegacyCursor(ctx context.Context, wm []byte, srvr Server, conn Connection, collName, identifier string) (bsoncore.Document, error) {
wm, err := op.roundTripLegacy(ctx, conn, wm)
if ep, ok := srvr.(ErrorProcessor); ok {
ep.ProcessError(err)
}
if err != nil {
return nil, err
}
return op.upconvertCursorResponse(wm, identifier, collName)
}
// roundTripLegacy handles writing a wire message and reading the response.
func (op Operation) roundTripLegacy(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
err := conn.WriteWireMessage(ctx, wm)
if err != nil {
return nil, Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
wm, err = conn.ReadWireMessage(ctx, wm[:0])
if err != nil {
err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}}
}
return wm, err
}
func (op Operation) upconvertCursorResponse(wm []byte, batchIdentifier string, collName string) (bsoncore.Document, error) {
reply := op.decodeOpReply(wm, true)
if reply.err != nil {
return nil, reply.err
}
cursorIdx, cursorDoc := bsoncore.AppendDocumentStart(nil)
// convert reply documents to BSON array
var arrIdx int32
arrIdx, cursorDoc = bsoncore.AppendArrayElementStart(cursorDoc, batchIdentifier)
for i, doc := range reply.documents {
cursorDoc = bsoncore.AppendDocumentElement(cursorDoc, strconv.Itoa(i), doc)
}
cursorDoc, _ = bsoncore.AppendArrayEnd(cursorDoc, arrIdx)
cursorDoc = bsoncore.AppendInt64Element(cursorDoc, "id", reply.cursorID)
cursorDoc = bsoncore.AppendStringElement(cursorDoc, "ns", op.getFullCollectionName(collName))
cursorDoc, _ = bsoncore.AppendDocumentEnd(cursorDoc, cursorIdx)
resIdx, resDoc := bsoncore.AppendDocumentStart(nil)
resDoc = bsoncore.AppendInt32Element(resDoc, "ok", 1)
resDoc = bsoncore.AppendDocumentElement(resDoc, "cursor", cursorDoc)
resDoc, _ = bsoncore.AppendDocumentEnd(resDoc, resIdx)
return resDoc, nil
}

View File

@@ -0,0 +1,391 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session"
import (
"errors"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
)
// ErrSessionEnded is returned when a client session is used after a call to endSession().
var ErrSessionEnded = errors.New("ended session was used")
// ErrNoTransactStarted is returned if a transaction operation is called when no transaction has started.
var ErrNoTransactStarted = errors.New("no transaction started")
// ErrTransactInProgress is returned if startTransaction() is called when a transaction is in progress.
var ErrTransactInProgress = errors.New("transaction already in progress")
// ErrAbortAfterCommit is returned when abort is called after a commit.
var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
// ErrAbortTwice is returned if abort is called after transaction is already aborted.
var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
// ErrCommitAfterAbort is returned if commit is called after an abort.
var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
// ErrUnackWCUnsupported is returned if an unacknowledged write concern is supported for a transaciton.
var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
// Type describes the type of the session
type Type uint8
// These constants are the valid types for a client session.
const (
Explicit Type = iota
Implicit
)
// State indicates the state of the FSM.
type state uint8
// Client Session states
const (
None state = iota
Starting
InProgress
Committed
Aborted
)
// Client is a session for clients to run commands.
type Client struct {
*Server
ClientID uuid.UUID
ClusterTime bson.Raw
Consistent bool // causal consistency
OperationTime *primitive.Timestamp
SessionType Type
Terminated bool
RetryingCommit bool
Committing bool
Aborting bool
RetryWrite bool
// options for the current transaction
// most recently set by transactionopt
CurrentRc *readconcern.ReadConcern
CurrentRp *readpref.ReadPref
CurrentWc *writeconcern.WriteConcern
// default transaction options
transactionRc *readconcern.ReadConcern
transactionRp *readpref.ReadPref
transactionWc *writeconcern.WriteConcern
pool *Pool
state state
PinnedServer *description.Server
RecoveryToken bson.Raw
}
func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
if clusterTime == nil {
return 0, 0
}
clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
if err != nil {
return 0, 0
}
timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
if err != nil {
return 0, 0
}
return timestampVal.Timestamp()
}
// MaxClusterTime compares 2 clusterTime documents and returns the document representing the highest cluster time.
func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
epoch1, ord1 := getClusterTime(ct1)
epoch2, ord2 := getClusterTime(ct2)
if epoch1 > epoch2 {
return ct1
} else if epoch1 < epoch2 {
return ct2
} else if ord1 > ord2 {
return ct1
} else if ord1 < ord2 {
return ct2
}
return ct1
}
// NewClientSession creates a Client.
func NewClientSession(pool *Pool, clientID uuid.UUID, sessionType Type, opts ...*ClientOptions) (*Client, error) {
c := &Client{
Consistent: true, // set default
ClientID: clientID,
SessionType: sessionType,
pool: pool,
}
mergedOpts := mergeClientOptions(opts...)
if mergedOpts.CausalConsistency != nil {
c.Consistent = *mergedOpts.CausalConsistency
}
if mergedOpts.DefaultReadPreference != nil {
c.transactionRp = mergedOpts.DefaultReadPreference
}
if mergedOpts.DefaultReadConcern != nil {
c.transactionRc = mergedOpts.DefaultReadConcern
}
if mergedOpts.DefaultWriteConcern != nil {
c.transactionWc = mergedOpts.DefaultWriteConcern
}
servSess, err := pool.GetSession()
if err != nil {
return nil, err
}
c.Server = servSess
return c, nil
}
// AdvanceClusterTime updates the session's cluster time.
func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
if c.Terminated {
return ErrSessionEnded
}
c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
return nil
}
// AdvanceOperationTime updates the session's operation time.
func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error {
if c.Terminated {
return ErrSessionEnded
}
if c.OperationTime == nil {
c.OperationTime = opTime
return nil
}
if opTime.T > c.OperationTime.T {
c.OperationTime = opTime
} else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
c.OperationTime = opTime
}
return nil
}
// UpdateUseTime updates the session's last used time.
// Must be called whenver this session is used to send a command to the server.
func (c *Client) UpdateUseTime() error {
if c.Terminated {
return ErrSessionEnded
}
c.updateUseTime()
return nil
}
// UpdateRecoveryToken updates the session's recovery token from the server response.
func (c *Client) UpdateRecoveryToken(response bson.Raw) {
if c == nil {
return
}
token, err := response.LookupErr("recoveryToken")
if err != nil {
return
}
c.RecoveryToken = token.Document()
}
// ClearPinnedServer sets the PinnedServer to nil.
func (c *Client) ClearPinnedServer() {
if c != nil {
c.PinnedServer = nil
}
}
// EndSession ends the session.
func (c *Client) EndSession() {
if c.Terminated {
return
}
c.Terminated = true
c.pool.ReturnSession(c.Server)
return
}
// TransactionInProgress returns true if the client session is in an active transaction.
func (c *Client) TransactionInProgress() bool {
return c.state == InProgress
}
// TransactionStarting returns true if the client session is starting a transaction.
func (c *Client) TransactionStarting() bool {
return c.state == Starting
}
// TransactionRunning returns true if the client session has started the transaction
// and it hasn't been committed or aborted
func (c *Client) TransactionRunning() bool {
return c != nil && (c.state == Starting || c.state == InProgress)
}
// TransactionCommitted returns true of the client session just committed a transaciton.
func (c *Client) TransactionCommitted() bool {
return c.state == Committed
}
// CheckStartTransaction checks to see if allowed to start transaction and returns
// an error if not allowed
func (c *Client) CheckStartTransaction() error {
if c.state == InProgress || c.state == Starting {
return ErrTransactInProgress
}
return nil
}
// StartTransaction initializes the transaction options and advances the state machine.
// It does not contact the server to start the transaction.
func (c *Client) StartTransaction(opts *TransactionOptions) error {
err := c.CheckStartTransaction()
if err != nil {
return err
}
c.IncrementTxnNumber()
c.RetryingCommit = false
if opts != nil {
c.CurrentRc = opts.ReadConcern
c.CurrentRp = opts.ReadPreference
c.CurrentWc = opts.WriteConcern
}
if c.CurrentRc == nil {
c.CurrentRc = c.transactionRc
}
if c.CurrentRp == nil {
c.CurrentRp = c.transactionRp
}
if c.CurrentWc == nil {
c.CurrentWc = c.transactionWc
}
if !writeconcern.AckWrite(c.CurrentWc) {
c.clearTransactionOpts()
return ErrUnackWCUnsupported
}
c.state = Starting
c.PinnedServer = nil
return nil
}
// CheckCommitTransaction checks to see if allowed to commit transaction and returns
// an error if not allowed.
func (c *Client) CheckCommitTransaction() error {
if c.state == None {
return ErrNoTransactStarted
} else if c.state == Aborted {
return ErrCommitAfterAbort
}
return nil
}
// CommitTransaction updates the state for a successfully committed transaction and returns
// an error if not permissible. It does not actually perform the commit.
func (c *Client) CommitTransaction() error {
err := c.CheckCommitTransaction()
if err != nil {
return err
}
c.state = Committed
return nil
}
// UpdateCommitTransactionWriteConcern will set the write concern to majority and potentially set a
// w timeout of 10 seconds. This should be called after a commit transaction operation fails with a
// retryable error or after a successful commit transaction operation.
func (c *Client) UpdateCommitTransactionWriteConcern() {
wc := c.CurrentWc
timeout := 10 * time.Second
if wc != nil && wc.GetWTimeout() != 0 {
timeout = wc.GetWTimeout()
}
c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout))
}
// CheckAbortTransaction checks to see if allowed to abort transaction and returns
// an error if not allowed.
func (c *Client) CheckAbortTransaction() error {
if c.state == None {
return ErrNoTransactStarted
} else if c.state == Committed {
return ErrAbortAfterCommit
} else if c.state == Aborted {
return ErrAbortTwice
}
return nil
}
// AbortTransaction updates the state for a successfully aborted transaction and returns
// an error if not permissible. It does not actually perform the abort.
func (c *Client) AbortTransaction() error {
err := c.CheckAbortTransaction()
if err != nil {
return err
}
c.state = Aborted
c.clearTransactionOpts()
return nil
}
// ApplyCommand advances the state machine upon command execution.
func (c *Client) ApplyCommand(desc description.Server) {
if c.Committing {
// Do not change state if committing after already committed
return
}
if c.state == Starting {
c.state = InProgress
// If this is in a transaction and the server is a mongos, pin it
if desc.Kind == description.Mongos {
c.PinnedServer = &desc
}
} else if c.state == Committed || c.state == Aborted {
c.clearTransactionOpts()
c.state = None
}
}
func (c *Client) clearTransactionOpts() {
c.RetryingCommit = false
c.Aborting = false
c.Committing = false
c.CurrentWc = nil
c.CurrentRp = nil
c.CurrentRc = nil
c.PinnedServer = nil
c.RecoveryToken = nil
}

View File

@@ -0,0 +1,36 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"sync"
"go.mongodb.org/mongo-driver/bson"
)
// ClusterClock represents a logical clock for keeping track of cluster time.
type ClusterClock struct {
clusterTime bson.Raw
lock sync.Mutex
}
// GetClusterTime returns the cluster's current time.
func (cc *ClusterClock) GetClusterTime() bson.Raw {
var ct bson.Raw
cc.lock.Lock()
ct = cc.clusterTime
cc.lock.Unlock()
return ct
}
// AdvanceClusterTime updates the cluster's current time.
func (cc *ClusterClock) AdvanceClusterTime(clusterTime bson.Raw) {
cc.lock.Lock()
cc.clusterTime = MaxClusterTime(cc.clusterTime, clusterTime)
cc.lock.Unlock()
}

View File

@@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
)
// ClientOptions represents all possible options for creating a client session.
type ClientOptions struct {
CausalConsistency *bool
DefaultReadConcern *readconcern.ReadConcern
DefaultWriteConcern *writeconcern.WriteConcern
DefaultReadPreference *readpref.ReadPref
}
// TransactionOptions represents all possible options for starting a transaction in a session.
type TransactionOptions struct {
ReadConcern *readconcern.ReadConcern
WriteConcern *writeconcern.WriteConcern
ReadPreference *readpref.ReadPref
}
func mergeClientOptions(opts ...*ClientOptions) *ClientOptions {
c := &ClientOptions{}
for _, opt := range opts {
if opt == nil {
continue
}
if opt.CausalConsistency != nil {
c.CausalConsistency = opt.CausalConsistency
}
if opt.DefaultReadConcern != nil {
c.DefaultReadConcern = opt.DefaultReadConcern
}
if opt.DefaultReadPreference != nil {
c.DefaultReadPreference = opt.DefaultReadPreference
}
if opt.DefaultWriteConcern != nil {
c.DefaultWriteConcern = opt.DefaultWriteConcern
}
}
return c
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"time"
"crypto/rand"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/uuid"
)
var rander = rand.Reader
// Server is an open session with the server.
type Server struct {
SessionID bsonx.Doc
TxnNumber int64
LastUsed time.Time
}
// returns whether or not a session has expired given a timeout in minutes
// a session is considered expired if it has less than 1 minute left before becoming stale
func (ss *Server) expired(timeoutMinutes uint32) bool {
if timeoutMinutes <= 0 {
return true
}
timeUnused := time.Since(ss.LastUsed).Minutes()
return timeUnused > float64(timeoutMinutes-1)
}
// update the last used time for this session.
// must be called whenever this server session is used to send a command to the server.
func (ss *Server) updateUseTime() {
ss.LastUsed = time.Now()
}
func newServerSession() (*Server, error) {
id, err := uuid.New()
if err != nil {
return nil, err
}
idDoc := bsonx.Doc{{"id", bsonx.Binary(UUIDSubtype, id[:])}}
return &Server{
SessionID: idDoc,
LastUsed: time.Now(),
}, nil
}
// IncrementTxnNumber increments the transaction number.
func (ss *Server) IncrementTxnNumber() {
ss.TxnNumber++
}
// UUIDSubtype is the BSON binary subtype that a UUID should be encoded as
const UUIDSubtype byte = 4

View File

@@ -0,0 +1,175 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package session
import (
"sync"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// Node represents a server session in a linked list
type Node struct {
*Server
next *Node
prev *Node
}
// Pool is a pool of server sessions that can be reused.
type Pool struct {
descChan <-chan description.Topology
head *Node
tail *Node
timeout uint32
mutex sync.Mutex // mutex to protect list and sessionTimeout
checkedOut int // number of sessions checked out of pool
}
func (p *Pool) createServerSession() (*Server, error) {
s, err := newServerSession()
if err != nil {
return nil, err
}
p.checkedOut++
return s, nil
}
// NewPool creates a new server session pool
func NewPool(descChan <-chan description.Topology) *Pool {
p := &Pool{
descChan: descChan,
}
return p
}
// assumes caller has mutex to protect the pool
func (p *Pool) updateTimeout() {
select {
case newDesc := <-p.descChan:
p.timeout = newDesc.SessionTimeoutMinutes
default:
// no new description waiting
}
}
// GetSession retrieves an unexpired session from the pool.
func (p *Pool) GetSession() (*Server, error) {
p.mutex.Lock() // prevent changing the linked list while seeing if sessions have expired
defer p.mutex.Unlock()
// empty pool
if p.head == nil && p.tail == nil {
return p.createServerSession()
}
p.updateTimeout()
for p.head != nil {
// pull session from head of queue and return if it is valid for at least 1 more minute
if p.head.expired(p.timeout) {
p.head = p.head.next
continue
}
// found unexpired session
session := p.head.Server
if p.head.next != nil {
p.head.next.prev = nil
}
if p.tail == p.head {
p.tail = nil
p.head = nil
} else {
p.head = p.head.next
}
p.checkedOut++
return session, nil
}
// no valid session found
p.tail = nil // empty list
return p.createServerSession()
}
// ReturnSession returns a session to the pool if it has not expired.
func (p *Pool) ReturnSession(ss *Server) {
if ss == nil {
return
}
p.mutex.Lock()
defer p.mutex.Unlock()
p.checkedOut--
p.updateTimeout()
// check sessions at end of queue for expired
// stop checking after hitting the first valid session
for p.tail != nil && p.tail.expired(p.timeout) {
if p.tail.prev != nil {
p.tail.prev.next = nil
}
p.tail = p.tail.prev
}
// session expired
if ss.expired(p.timeout) {
return
}
newNode := &Node{
Server: ss,
next: nil,
prev: nil,
}
// empty list
if p.tail == nil {
p.head = newNode
p.tail = newNode
return
}
// at least 1 valid session in list
newNode.next = p.head
p.head.prev = newNode
p.head = newNode
}
// IDSlice returns a slice of session IDs for each session in the pool
func (p *Pool) IDSlice() []bsonx.Doc {
p.mutex.Lock()
defer p.mutex.Unlock()
ids := []bsonx.Doc{}
for node := p.head; node != nil; node = node.next {
ids = append(ids, node.SessionID)
}
return ids
}
// String implements the Stringer interface
func (p *Pool) String() string {
p.mutex.Lock()
defer p.mutex.Unlock()
s := ""
for head := p.head; head != nil; head = head.next {
s += head.SessionID.String() + "\n"
}
return s
}
// CheckedOut returns number of sessions checked out from pool.
func (p *Pool) CheckedOut() int {
return p.checkedOut
}

View File

@@ -0,0 +1,40 @@
# Topology Package Design
This document outlines the design for this package.
## Topology
The `Topology` type handles monitoring the state of a MongoDB deployment and selecting servers.
Updating the description is handled by finite state machine which implements the server discovery
and monitoring specification. A `Topology` can be connected and fully disconnected, which enables
saving resources. The `Topology` type also handles server selection following the server selection
specification.
## Server
The `Server` type handles heartbeating a MongoDB server and holds a pool of connections.
## Connection
Connections are handled by two main types and an auxiliary type. The two main types are `connection`
and `Connection`. The first holds most of the logic required to actually read and write wire
messages. Instances can be created with the `newConnection` method. Inside the `newConnection`
method the auxiliary type, `initConnection` is used to perform the connection handshake. This is
required because the `connection` type does not fully implement `driver.Connection` which is
required during handshaking. The `Connection` type is what is actually returned to a consumer of the
`topology` package. This type does implement the `driver.Connection` type, holds a reference to a
`connection` instance, and exists mainly to prevent accidental continued usage of a connection after
closing it.
The connection implementations in this package are conduits for wire messages but they have no
ability to encode, decode, or validate wire messages. That must be handled by consumers.
## Pool
The `pool` type implements a connection pool. It handles caching idle connections and dialing
new ones, but it does not track a maximum number of connections. That is the responsibility of a
wrapping type, like `Server`.
The `pool` type has no concept of closing, instead it has concepts of connecting and disconnecting.
This allows a `Topology` to be disconnected,but keeping the memory around to be reconnected later.
There is a `close` method, but this is used to close a connection.
There are three methods related to getting and putting connections: `get`, `close`, and `put`. The
`get` method will either retrieve a connection from the cache or it will dial a new `connection`.
The `close` method will close the underlying socket of a `connection`. The `put` method will put a
connection into the pool, placing it in the cahce if there is space, otherwise it will close it.

View File

@@ -0,0 +1,458 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"compress/zlib"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"strings"
"github.com/golang/snappy"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
"go.mongodb.org/mongo-driver/x/network/command"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
var globalConnectionID uint64
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
type connection struct {
id string
nc net.Conn // When nil, the connection is closed.
addr address.Address
idleTimeout time.Duration
idleDeadline time.Time
lifetimeDeadline time.Time
readTimeout time.Duration
writeTimeout time.Duration
desc description.Server
compressor wiremessage.CompressorID
zliblevel int
// pool related fields
pool *pool
poolID uint64
generation uint64
}
// newConnection handles the creation of a connection. It will dial, configure TLS, and perform
// initialization handshakes.
func newConnection(ctx context.Context, addr address.Address, opts ...ConnectionOption) (*connection, error) {
cfg, err := newConnectionConfig(opts...)
if err != nil {
return nil, err
}
nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String())
if err != nil {
return nil, ConnectionError{Wrapped: err, init: true}
}
if cfg.tlsConfig != nil {
tlsConfig := cfg.tlsConfig.Clone()
nc, err = configureTLS(ctx, nc, addr, tlsConfig)
if err != nil {
return nil, ConnectionError{Wrapped: err, init: true}
}
}
var lifetimeDeadline time.Time
if cfg.lifeTimeout > 0 {
lifetimeDeadline = time.Now().Add(cfg.lifeTimeout)
}
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
c := &connection{
id: id,
nc: nc,
addr: addr,
idleTimeout: cfg.idleTimeout,
lifetimeDeadline: lifetimeDeadline,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
}
c.bumpIdleDeadline()
// running isMaster and authentication is handled by a handshaker on the configuration instance.
if cfg.handshaker != nil {
c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c})
if err != nil {
if c.nc != nil {
_ = c.nc.Close()
}
return nil, ConnectionError{Wrapped: err, init: true}
}
if cfg.descCallback != nil {
cfg.descCallback(c.desc)
}
if len(c.desc.Compression) > 0 {
clientMethodLoop:
for _, method := range cfg.compressors {
for _, serverMethod := range c.desc.Compression {
if method != serverMethod {
continue
}
switch strings.ToLower(method) {
case "snappy":
c.compressor = wiremessage.CompressorSnappy
case "zlib":
c.compressor = wiremessage.CompressorZLib
c.zliblevel = wiremessage.DefaultZlibLevel
if cfg.zlibLevel != nil {
c.zliblevel = *cfg.zlibLevel
}
}
break clientMethodLoop
}
}
}
}
return c, nil
}
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if c.nc == nil {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
default:
}
var deadline time.Time
if c.writeTimeout != 0 {
deadline = time.Now().Add(c.writeTimeout)
}
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
deadline = dl
}
if err := c.nc.SetWriteDeadline(deadline); err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
}
_, err = c.nc.Write(wm)
if err != nil {
c.close()
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to write wire message to network"}
}
c.bumpIdleDeadline()
return nil
}
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
if c.nc == nil {
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
// We close the connection because we don't know if there is an unread message on the wire.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
default:
}
var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
}
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
deadline = dl
}
if err := c.nc.SetReadDeadline(deadline); err != nil {
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
}
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
// We close the connection because we don't know if there are other bytes left to read.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to decode message length"}
}
// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
if int(size) > cap(dst) {
// Since we can't grow this slice without allocating, just allocate an entirely new slice.
dst = make([]byte, 0, size)
}
// We need to ensure we don't accidentally read into a subsequent wire message, so we set the
// size to read exactly this wire message.
dst = dst[:size]
copy(dst, sizeBuf[:])
_, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
// We close the connection because we don't know if there are other bytes left to read.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to read full message"}
}
c.bumpIdleDeadline()
return dst, nil
}
func (c *connection) close() error {
if c.nc == nil {
return nil
}
if c.pool == nil {
err := c.nc.Close()
c.nc = nil
return err
}
return c.pool.close(c)
}
func (c *connection) expired() bool {
now := time.Now()
if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
return true
}
if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) {
return true
}
return c.nc == nil
}
func (c *connection) bumpIdleDeadline() {
if c.idleTimeout > 0 {
c.idleDeadline = time.Now().Add(c.idleTimeout)
}
}
// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
type initConnection struct{ *connection }
var _ driver.Connection = initConnection{}
func (c initConnection) Description() description.Server { return description.Server{} }
func (c initConnection) Close() error { return nil }
func (c initConnection) ID() string { return c.id }
func (c initConnection) Address() address.Address { return c.addr }
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
return c.writeWireMessage(ctx, wm)
}
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
return c.readWireMessage(ctx, dst)
}
// Connection implements the driver.Connection interface. It allows reading and writing wire
// messages.
type Connection struct {
*connection
s *Server
mu sync.RWMutex
}
var _ driver.Connection = (*Connection)(nil)
// WriteWireMessage handles writing a wire message to the underlying connection.
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return ErrConnectionClosed
}
return c.writeWireMessage(ctx, wm)
}
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
// will be overwritten with the new wire message.
func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
return c.readWireMessage(ctx, dst)
}
// CompressWireMessage handles compressing the provided wire message using the underlying
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
// there is no compressor set on the underlying connection, then no compression will be performed.
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
if c.connection.compressor == wiremessage.CompressorNoOp {
return append(dst, src...), nil
}
_, reqid, respto, origcode, rem, ok := wiremessagex.ReadHeader(src)
if !ok {
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
}
idx, dst := wiremessagex.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
dst = wiremessagex.AppendCompressedOriginalOpCode(dst, origcode)
dst = wiremessagex.AppendCompressedUncompressedSize(dst, int32(len(rem)))
dst = wiremessagex.AppendCompressedCompressorID(dst, c.connection.compressor)
switch c.connection.compressor {
case wiremessage.CompressorSnappy:
compressed := snappy.Encode(nil, rem)
dst = wiremessagex.AppendCompressedCompressedMessage(dst, compressed)
case wiremessage.CompressorZLib:
var b bytes.Buffer
w, err := zlib.NewWriterLevel(&b, c.connection.zliblevel)
if err != nil {
return dst, err
}
_, err = w.Write(rem)
if err != nil {
return dst, err
}
err = w.Close()
if err != nil {
return dst, err
}
dst = wiremessagex.AppendCompressedCompressedMessage(dst, b.Bytes())
default:
return dst, fmt.Errorf("unknown compressor ID %v", c.connection.compressor)
}
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
}
// Description returns the server description of the server this connection is connected to.
func (c *Connection) Description() description.Server {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return description.Server{}
}
return c.desc
}
// Close returns this connection to the connection pool. This method may not close the underlying
// socket.
func (c *Connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return nil
}
if c.s != nil {
c.s.sem.Release(1)
}
err := c.pool.put(c.connection)
if err != nil {
return err
}
c.connection = nil
return nil
}
// ID returns the ID of this connection.
func (c *Connection) ID() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.id
}
// Address returns the address of this connection.
func (c *Connection) Address() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return address.Address("0.0.0.0")
}
return c.addr
}
var notMasterCodes = []int32{10107, 13435}
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
func isRecoveringError(err command.Error) bool {
for _, c := range recoveringCodes {
if c == err.Code {
return true
}
}
return strings.Contains(err.Error(), "node is recovering")
}
func isNotMasterError(err command.Error) bool {
for _, c := range notMasterCodes {
if c == err.Code {
return true
}
}
return strings.Contains(err.Error(), "not master")
}
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config) (net.Conn, error) {
if !config.InsecureSkipVerify {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}
hostname = hostname[:colonPos]
config.ServerName = hostname
}
client := tls.Client(nc, config)
errChan := make(chan error, 1)
go func() {
errChan <- client.Handshake()
}()
select {
case err := <-errChan:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, errors.New("server connection cancelled/timeout during TLS handshake")
}
return client, nil
}

View File

@@ -0,0 +1,615 @@
package topology
import (
"context"
"fmt"
"sync"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/command"
"go.mongodb.org/mongo-driver/x/network/compressor"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
var emptyDoc bson.Raw
type connectionLegacy struct {
*connection
writeBuf []byte
readBuf []byte
monitor *event.CommandMonitor
cmdMap map[int64]*commandMetadata // map for monitoring commands sent to server
wireMessageBuf []byte // buffer to store uncompressed wire message before compressing
compressBuf []byte // buffer to compress messages
uncompressBuf []byte // buffer to uncompress messages
compressor compressor.Compressor // use for compressing messages
// server can compress response with any compressor supported by driver
compressorMap map[wiremessage.CompressorID]compressor.Compressor
s *Server
sync.RWMutex
}
func newConnectionLegacy(c *connection, s *Server, opts ...ConnectionOption) (*connectionLegacy, error) {
cfg, err := newConnectionConfig(opts...)
if err != nil {
return nil, err
}
compressorMap := make(map[wiremessage.CompressorID]compressor.Compressor)
for _, comp := range cfg.compressors {
switch comp {
case "snappy":
snappyComp := compressor.CreateSnappy()
compressorMap[snappyComp.CompressorID()] = snappyComp
case "zlib":
zlibComp, err := compressor.CreateZlib(cfg.zlibLevel)
if err != nil {
return nil, err
}
compressorMap[zlibComp.CompressorID()] = zlibComp
}
}
cl := &connectionLegacy{
connection: c,
compressorMap: compressorMap,
cmdMap: make(map[int64]*commandMetadata),
compressBuf: make([]byte, 256),
readBuf: make([]byte, 256),
uncompressBuf: make([]byte, 256),
writeBuf: make([]byte, 0, 256),
wireMessageBuf: make([]byte, 256),
s: s,
}
d := c.desc
if len(d.Compression) > 0 {
clientMethodLoop:
for _, comp := range cl.compressorMap {
method := comp.Name()
for _, serverMethod := range d.Compression {
if method != serverMethod {
continue
}
cl.compressor = comp // found matching compressor
break clientMethodLoop
}
}
}
cl.monitor = cfg.cmdMonitor // Attach the command monitor later to avoid monitoring auth.
return cl, nil
}
func (c *connectionLegacy) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
// Truncate the write buffer
c.writeBuf = c.writeBuf[:0]
messageToWrite := wm
// Compress if possible
if c.compressor != nil {
compressed, err := c.compressMessage(wm)
if err != nil {
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to compress wire message",
}
}
messageToWrite = compressed
}
var err error
c.writeBuf, err = messageToWrite.AppendWireMessage(c.writeBuf)
if err != nil {
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to encode wire message",
}
}
err = c.writeWireMessage(ctx, c.writeBuf)
if c.s != nil {
c.s.ProcessError(err)
}
if err != nil {
// The error we got back was probably a ConnectionError already, so we don't really need to
// wrap it here.
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to write wire message to network",
}
}
return c.commandStartedEvent(ctx, wm)
}
func (c *connectionLegacy) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
// Truncate the write buffer
c.readBuf = c.readBuf[:0]
var err error
c.readBuf, err = c.readWireMessage(ctx, c.readBuf)
if c.s != nil {
c.s.ProcessError(err)
}
if err != nil {
// The error we got back was probably a ConnectionError already, so we don't really need to
// wrap it here.
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to read wire message from network",
}
}
hdr, err := wiremessage.ReadHeader(c.readBuf, 0)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode header",
}
}
messageToDecode := c.readBuf
opcodeToCheck := hdr.OpCode
if hdr.OpCode == wiremessage.OpCompressed {
var compressed wiremessage.Compressed
err := compressed.UnmarshalWireMessage(c.readBuf)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_COMPRESSED",
}
}
uncompressed, origOpcode, err := c.uncompressMessage(compressed)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to uncompress message",
}
}
messageToDecode = uncompressed
opcodeToCheck = origOpcode
}
var wm wiremessage.WireMessage
switch opcodeToCheck {
case wiremessage.OpReply:
var r wiremessage.Reply
err := r.UnmarshalWireMessage(messageToDecode)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_REPLY",
}
}
wm = r
case wiremessage.OpMsg:
var reply wiremessage.Msg
err := reply.UnmarshalWireMessage(messageToDecode)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_MSG",
}
}
wm = reply
default:
return nil, ConnectionError{
ConnectionID: c.id,
message: fmt.Sprintf("opcode %s not implemented", hdr.OpCode),
}
}
if c.s != nil {
c.s.ProcessError(command.DecodeError(wm))
}
// TODO: do we care if monitoring fails?
return wm, c.commandFinishedEvent(ctx, wm)
}
func (c *connectionLegacy) Close() error {
c.Lock()
defer c.Unlock()
if c.connection == nil {
return nil
}
if c.s != nil {
c.s.sem.Release(1)
}
err := c.pool.put(c.connection)
if err != nil {
return err
}
c.connection = nil
return nil
}
func (c *connectionLegacy) Expired() bool {
c.RLock()
defer c.RUnlock()
return c.connection == nil || c.expired()
}
func (c *connectionLegacy) Alive() bool {
c.RLock()
defer c.RUnlock()
return c.connection != nil
}
func (c *connectionLegacy) ID() string {
c.RLock()
defer c.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.id
}
func (c *connectionLegacy) commandStartedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
if c.monitor == nil || c.monitor.Started == nil {
return nil
}
startedEvent := &event.CommandStartedEvent{
ConnectionID: c.id,
}
var cmd bsonx.Doc
var err error
var legacy bool
var fullCollName string
var acknowledged bool
switch converted := wm.(type) {
case wiremessage.Query:
cmd, err = converted.CommandDocument()
if err != nil {
return err
}
acknowledged = converted.AcknowledgedWrite()
startedEvent.DatabaseName = converted.DatabaseName()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
legacy = converted.Legacy()
fullCollName = converted.FullCollectionName
case wiremessage.Msg:
cmd, err = converted.GetMainDocument()
if err != nil {
return err
}
acknowledged = converted.AcknowledgedWrite()
arr, identifier, err := converted.GetSequenceArray()
if err != nil {
return err
}
if arr != nil {
cmd = cmd.Copy() // make copy to avoid changing original command
cmd = append(cmd, bsonx.Elem{identifier, bsonx.Array(arr)})
}
dbVal, err := cmd.LookupErr("$db")
if err != nil {
return err
}
startedEvent.DatabaseName = dbVal.StringValue()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
case wiremessage.GetMore:
cmd = converted.CommandDocument()
startedEvent.DatabaseName = converted.DatabaseName()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
acknowledged = true
legacy = true
fullCollName = converted.FullCollectionName
case wiremessage.KillCursors:
cmd = converted.CommandDocument()
startedEvent.DatabaseName = converted.DatabaseName
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
legacy = true
}
rawcmd, _ := cmd.MarshalBSON()
startedEvent.Command = rawcmd
startedEvent.CommandName = cmd[0].Key
if !canMonitor(startedEvent.CommandName) {
startedEvent.Command = emptyDoc
}
c.monitor.Started(ctx, startedEvent)
if !acknowledged {
if c.monitor.Succeeded == nil {
return nil
}
// unack writes must provide a CommandSucceededEvent with an { ok: 1 } reply
finishedEvent := event.CommandFinishedEvent{
DurationNanos: 0,
CommandName: startedEvent.CommandName,
RequestID: startedEvent.RequestID,
ConnectionID: c.id,
}
c.monitor.Succeeded(ctx, &event.CommandSucceededEvent{
CommandFinishedEvent: finishedEvent,
Reply: bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)),
})
return nil
}
c.cmdMap[startedEvent.RequestID] = createMetadata(startedEvent.CommandName, legacy, fullCollName)
return nil
}
func (c *connectionLegacy) commandFinishedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
if c.monitor == nil {
return nil
}
var reply bsonx.Doc
var requestID int64
var err error
switch converted := wm.(type) {
case wiremessage.Reply:
requestID = int64(converted.MsgHeader.ResponseTo)
case wiremessage.Msg:
requestID = int64(converted.MsgHeader.ResponseTo)
}
cmdMetadata := c.cmdMap[requestID]
delete(c.cmdMap, requestID)
switch converted := wm.(type) {
case wiremessage.Reply:
if cmdMetadata.Legacy {
reply, err = converted.GetMainLegacyDocument(cmdMetadata.FullCollectionName)
} else {
reply, err = converted.GetMainDocument()
}
case wiremessage.Msg:
reply, err = converted.GetMainDocument()
}
if err != nil {
return err
}
success, errmsg := processReply(reply)
if (success && c.monitor.Succeeded == nil) || (!success && c.monitor.Failed == nil) {
return nil
}
finishedEvent := event.CommandFinishedEvent{
DurationNanos: cmdMetadata.TimeDifference(),
CommandName: cmdMetadata.Name,
RequestID: requestID,
ConnectionID: c.id,
}
if success {
if !canMonitor(finishedEvent.CommandName) {
successEvent := &event.CommandSucceededEvent{
Reply: emptyDoc,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Succeeded(ctx, successEvent)
return nil
}
// if response has type 1 document sequence, the sequence must be included as a BSON array in the event's reply.
if opmsg, ok := wm.(wiremessage.Msg); ok {
arr, identifier, err := opmsg.GetSequenceArray()
if err != nil {
return err
}
if arr != nil {
reply = reply.Copy() // make copy to avoid changing original command
reply = append(reply, bsonx.Elem{identifier, bsonx.Array(arr)})
}
}
replyraw, _ := reply.MarshalBSON()
successEvent := &event.CommandSucceededEvent{
Reply: replyraw,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Succeeded(ctx, successEvent)
return nil
}
failureEvent := &event.CommandFailedEvent{
Failure: errmsg,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Failed(ctx, failureEvent)
return nil
}
func canCompress(cmd string) bool {
if cmd == "isMaster" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
return false
}
return true
}
func (c *connectionLegacy) compressMessage(wm wiremessage.WireMessage) (wiremessage.WireMessage, error) {
var requestID int32
var responseTo int32
var origOpcode wiremessage.OpCode
switch converted := wm.(type) {
case wiremessage.Query:
firstElem, err := converted.Query.IndexErr(0)
if err != nil {
return wiremessage.Compressed{}, err
}
key := firstElem.Key()
if !canCompress(key) {
return wm, nil // return original message because this command can't be compressed
}
requestID = converted.MsgHeader.RequestID
origOpcode = wiremessage.OpQuery
responseTo = converted.MsgHeader.ResponseTo
case wiremessage.Msg:
firstElem, err := converted.Sections[0].(wiremessage.SectionBody).Document.IndexErr(0)
if err != nil {
return wiremessage.Compressed{}, err
}
key := firstElem.Key()
if !canCompress(key) {
return wm, nil
}
requestID = converted.MsgHeader.RequestID
origOpcode = wiremessage.OpMsg
responseTo = converted.MsgHeader.ResponseTo
}
// can compress
c.wireMessageBuf = c.wireMessageBuf[:0] // truncate
var err error
c.wireMessageBuf, err = wm.AppendWireMessage(c.wireMessageBuf)
if err != nil {
return wiremessage.Compressed{}, err
}
c.wireMessageBuf = c.wireMessageBuf[16:] // strip header
c.compressBuf = c.compressBuf[:0]
compressedBytes, err := c.compressor.CompressBytes(c.wireMessageBuf, c.compressBuf)
if err != nil {
return wiremessage.Compressed{}, err
}
compressedMessage := wiremessage.Compressed{
MsgHeader: wiremessage.Header{
// MessageLength and OpCode will be set when marshalling wire message by SetDefaults()
RequestID: requestID,
ResponseTo: responseTo,
},
OriginalOpCode: origOpcode,
UncompressedSize: int32(len(c.wireMessageBuf)), // length of uncompressed message excluding MsgHeader
CompressorID: wiremessage.CompressorID(c.compressor.CompressorID()),
CompressedMessage: compressedBytes,
}
return compressedMessage, nil
}
// returns []byte of uncompressed message with reconstructed header, original opcode, error
func (c *connectionLegacy) uncompressMessage(compressed wiremessage.Compressed) ([]byte, wiremessage.OpCode, error) {
// server doesn't guarantee the same compression method will be used each time so the CompressorID field must be
// used to find the correct method for uncompressing data
uncompressor := c.compressorMap[compressed.CompressorID]
// reset uncompressBuf
c.uncompressBuf = c.uncompressBuf[:0]
if int(compressed.UncompressedSize) > cap(c.uncompressBuf) {
c.uncompressBuf = make([]byte, 0, compressed.UncompressedSize)
}
uncompressedMessage, err := uncompressor.UncompressBytes(compressed.CompressedMessage, c.uncompressBuf[:compressed.UncompressedSize])
if err != nil {
return nil, 0, err
}
origHeader := wiremessage.Header{
MessageLength: int32(len(uncompressedMessage)) + 16, // add 16 for original header
RequestID: compressed.MsgHeader.RequestID,
ResponseTo: compressed.MsgHeader.ResponseTo,
}
switch compressed.OriginalOpCode {
case wiremessage.OpReply:
origHeader.OpCode = wiremessage.OpReply
case wiremessage.OpMsg:
origHeader.OpCode = wiremessage.OpMsg
default:
return nil, 0, fmt.Errorf("opcode %s not implemented", compressed.OriginalOpCode)
}
var fullMessage []byte
fullMessage = origHeader.AppendHeader(fullMessage)
fullMessage = append(fullMessage, uncompressedMessage...)
return fullMessage, origHeader.OpCode, nil
}
func canMonitor(cmd string) bool {
if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
return false
}
return true
}
func processReply(reply bsonx.Doc) (bool, string) {
var success bool
var errmsg string
var errCode int32
for _, elem := range reply {
switch elem.Key {
case "ok":
switch elem.Value.Type() {
case bsontype.Int32:
if elem.Value.Int32() == 1 {
success = true
}
case bsontype.Int64:
if elem.Value.Int64() == 1 {
success = true
}
case bsontype.Double:
if elem.Value.Double() == 1 {
success = true
}
}
case "errmsg":
if str, ok := elem.Value.StringValueOK(); ok {
errmsg = str
}
case "code":
if c, ok := elem.Value.Int32OK(); ok {
errCode = c
}
}
}
if success {
return true, ""
}
fullErrMsg := fmt.Sprintf("Error code %d: %s", errCode, errmsg)
return false, fullErrMsg
}

View File

@@ -0,0 +1,28 @@
package topology
import "time"
// commandMetadata contains metadata about a command sent to the server.
type commandMetadata struct {
Name string
Time time.Time
Legacy bool
FullCollectionName string
}
// createMetadata creates metadata for a command.
func createMetadata(name string, legacy bool, fullCollName string) *commandMetadata {
return &commandMetadata{
Name: name,
Time: time.Now(),
Legacy: legacy,
FullCollectionName: fullCollName,
}
}
// TimeDifference returns the difference between now and the time a command was sent in nanoseconds.
func (cm *commandMetadata) TimeDifference() int64 {
t := time.Now()
duration := t.Sub(cm.Time)
return duration.Nanoseconds()
}

View File

@@ -0,0 +1,187 @@
package topology
import (
"context"
"crypto/tls"
"net"
"time"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// Dialer is used to make network connections.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DialerFunc is a type implemented by functions that can be used as a Dialer.
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
// DialContext implements the Dialer interface.
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return df(ctx, network, address)
}
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
// will also change the Dialer used for this package. This should only be changed why all
// of the connections being made need to use a different Dialer. Most of the time, using a
// WithDialer option is more appropriate than changing this variable.
var DefaultDialer Dialer = &net.Dialer{}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker
// HandshakerFunc is an adapter to allow the use of ordinary functions as
// connection handshakers.
type HandshakerFunc = driver.HandshakerFunc
type connectionConfig struct {
appName string
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
lifeTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
compressors []string
zlibLevel *int
descCallback func(description.Server)
}
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
cfg := &connectionConfig{
connectTimeout: 30 * time.Second,
dialer: nil,
idleTimeout: 10 * time.Minute,
lifeTimeout: 30 * time.Minute,
}
for _, opt := range opts {
err := opt(cfg)
if err != nil {
return nil, err
}
}
if cfg.dialer == nil {
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
}
return cfg, nil
}
func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
return append(opts, ConnectionOption(func(c *connectionConfig) error {
c.descCallback = callback
return nil
}))
}
// ConnectionOption is used to configure a connection.
type ConnectionOption func(*connectionConfig) error
// WithAppName sets the application name which gets sent to MongoDB when it
// first connects.
func WithAppName(fn func(string) string) ConnectionOption {
return func(c *connectionConfig) error {
c.appName = fn(c.appName)
return nil
}
}
// WithCompressors sets the compressors that can be used for communication.
func WithCompressors(fn func([]string) []string) ConnectionOption {
return func(c *connectionConfig) error {
c.compressors = fn(c.compressors)
return nil
}
}
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
// connect to complete. The default is 30 seconds.
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.connectTimeout = fn(c.connectTimeout)
return nil
}
}
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
return func(c *connectionConfig) error {
c.dialer = fn(c.dialer)
return nil
}
}
// WithHandshaker configures the Handshaker that wll be used to initialize newly
// dialed connections.
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
return func(c *connectionConfig) error {
c.handshaker = fn(c.handshaker)
return nil
}
}
// WithIdleTimeout configures the maximum idle time to allow for a connection.
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.idleTimeout = fn(c.idleTimeout)
return nil
}
}
// WithLifeTimeout configures the maximum life of a connection.
func WithLifeTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.lifeTimeout = fn(c.lifeTimeout)
return nil
}
}
// WithReadTimeout configures the maximum read time for a connection.
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.readTimeout = fn(c.readTimeout)
return nil
}
}
// WithWriteTimeout configures the maximum write time for a connection.
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.writeTimeout = fn(c.writeTimeout)
return nil
}
}
// WithTLSConfig configures the TLS options for a connection.
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
return func(c *connectionConfig) error {
c.tlsConfig = fn(c.tlsConfig)
return nil
}
}
// WithMonitor configures a event for command monitoring.
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
return func(c *connectionConfig) error {
c.cmdMonitor = fn(c.cmdMonitor)
return nil
}
}
// WithZlibLevel sets the zLib compression level.
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) error {
c.zlibLevel = fn(c.zlibLevel)
return nil
}
}

View File

@@ -0,0 +1,22 @@
package topology
import "fmt"
// ConnectionError represents a connection error.
type ConnectionError struct {
ConnectionID string
Wrapped error
// init will be set to true if this error occured during connection initialization or
// during a connection handshake.
init bool
message string
}
// Error implements the error interface.
func (e ConnectionError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, e.message, e.Wrapped.Error())
}
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message)
}

View File

@@ -0,0 +1,350 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
var supportedWireVersions = description.NewVersionRange(2, 6)
var minSupportedMongoDBVersion = "2.6"
type fsm struct {
description.Topology
SetName string
maxElectionID primitive.ObjectID
maxSetVersion uint32
}
func newFSM() *fsm {
return new(fsm)
}
// apply should operate on immutable TopologyDescriptions and Descriptions. This way we don't have to
// lock for the entire time we're applying server description.
func (f *fsm) apply(s description.Server) (description.Topology, error) {
newServers := make([]description.Server, len(f.Servers))
copy(newServers, f.Servers)
oldMinutes := f.SessionTimeoutMinutes
f.Topology = description.Topology{
Kind: f.Kind,
Servers: newServers,
}
// For data bearing servers, set SessionTimeoutMinutes to the lowest among them
if oldMinutes == 0 {
// If timeout currently 0, check all servers to see if any still don't have a timeout
// If they all have timeout, pick the lowest.
timeout := s.SessionTimeoutMinutes
for _, server := range f.Servers {
if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
timeout = server.SessionTimeoutMinutes
}
}
f.SessionTimeoutMinutes = timeout
} else {
if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
} else {
f.SessionTimeoutMinutes = oldMinutes
}
}
if _, ok := f.findServer(s.Addr); !ok {
return f.Topology, nil
}
if s.WireVersion != nil {
if s.WireVersion.Max < supportedWireVersions.Min {
return description.Topology{}, fmt.Errorf(
"server at %s reports wire version %d, but this version of the Go driver requires "+
"at least %d (MongoDB %s)",
s.Addr.String(),
s.WireVersion.Max,
supportedWireVersions.Min,
minSupportedMongoDBVersion,
)
}
if s.WireVersion.Min > supportedWireVersions.Max {
return description.Topology{}, fmt.Errorf(
"server at %s requires wire version %d, but this version of the Go driver only "+
"supports up to %d",
s.Addr.String(),
s.WireVersion.Min,
supportedWireVersions.Max,
)
}
}
switch f.Kind {
case description.Unknown:
f.applyToUnknown(s)
case description.Sharded:
f.applyToSharded(s)
case description.ReplicaSetNoPrimary:
f.applyToReplicaSetNoPrimary(s)
case description.ReplicaSetWithPrimary:
f.applyToReplicaSetWithPrimary(s)
case description.Single:
f.applyToSingle(s)
}
return f.Topology, nil
}
func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithoutPrimary(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
}
func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithPrimaryFromMember(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
f.checkIfHasPrimary()
}
}
func (f *fsm) applyToSharded(s description.Server) {
switch s.Kind {
case description.Mongos, description.Unknown:
f.replaceServer(s)
case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
f.removeServerByAddr(s.Addr)
}
}
func (f *fsm) applyToSingle(s description.Server) {
switch s.Kind {
case description.Unknown:
f.replaceServer(s)
case description.Standalone, description.Mongos:
if f.SetName != "" {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
if f.SetName != "" && f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
}
}
func (f *fsm) applyToUnknown(s description.Server) {
switch s.Kind {
case description.Mongos:
f.setKind(description.Sharded)
f.replaceServer(s)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.setKind(description.ReplicaSetNoPrimary)
f.updateRSWithoutPrimary(s)
case description.Standalone:
f.updateUnknownWithStandalone(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
}
func (f *fsm) checkIfHasPrimary() {
if _, ok := f.findPrimary(); ok {
f.setKind(description.ReplicaSetWithPrimary)
} else {
f.setKind(description.ReplicaSetNoPrimary)
}
}
func (f *fsm) updateRSFromPrimary(s description.Server) {
if f.SetName == "" {
f.SetName = s.SetName
} else if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
if s.SetVersion != 0 && !bytes.Equal(s.ElectionID[:], primitive.NilObjectID[:]) {
if f.maxSetVersion > s.SetVersion || bytes.Compare(f.maxElectionID[:], s.ElectionID[:]) == 1 {
f.replaceServer(description.Server{
Addr: s.Addr,
LastError: fmt.Errorf("was a primary, but its set version or election id is stale"),
})
f.checkIfHasPrimary()
return
}
f.maxElectionID = s.ElectionID
}
if s.SetVersion > f.maxSetVersion {
f.maxSetVersion = s.SetVersion
}
if j, ok := f.findPrimary(); ok {
f.setServer(j, description.Server{
Addr: f.Servers[j].Addr,
LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
})
}
f.replaceServer(s)
for j := len(f.Servers) - 1; j >= 0; j-- {
found := false
for _, member := range s.Members {
if member == f.Servers[j].Addr {
found = true
break
}
}
if !found {
f.removeServer(j)
}
}
for _, member := range s.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
f.checkIfHasPrimary()
}
func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
f.replaceServer(s)
if _, ok := f.findPrimary(); !ok {
f.setKind(description.ReplicaSetNoPrimary)
}
}
func (f *fsm) updateRSWithoutPrimary(s description.Server) {
if f.SetName == "" {
f.SetName = s.SetName
} else if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
return
}
for _, member := range s.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
}
func (f *fsm) updateUnknownWithStandalone(s description.Server) {
if len(f.Servers) > 1 {
f.removeServerByAddr(s.Addr)
return
}
f.setKind(description.Single)
f.replaceServer(s)
}
func (f *fsm) addServer(addr address.Address) {
f.Servers = append(f.Servers, description.Server{
Addr: addr.Canonicalize(),
})
}
func (f *fsm) findPrimary() (int, bool) {
for i, s := range f.Servers {
if s.Kind == description.RSPrimary {
return i, true
}
}
return 0, false
}
func (f *fsm) findServer(addr address.Address) (int, bool) {
canon := addr.Canonicalize()
for i, s := range f.Servers {
if canon == s.Addr {
return i, true
}
}
return 0, false
}
func (f *fsm) removeServer(i int) {
f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
}
func (f *fsm) removeServerByAddr(addr address.Address) {
if i, ok := f.findServer(addr); ok {
f.removeServer(i)
}
}
func (f *fsm) replaceServer(s description.Server) bool {
if i, ok := f.findServer(s.Addr); ok {
f.setServer(i, s)
return true
}
return false
}
func (f *fsm) setServer(i int, s description.Server) {
f.Servers[i] = s
}
func (f *fsm) setKind(k description.TopologyKind) {
f.Kind = k
}

View File

@@ -0,0 +1,213 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
)
// ErrPoolConnected is returned from an attempt to connect an already connected pool
var ErrPoolConnected = PoolError("pool is connected")
// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected
// or disconnecting pool.
var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting")
// ErrConnectionClosed is returned from an attempt to use an already closed connection.
var ErrConnectionClosed = ConnectionError{ConnectionID: "<closed>", message: "connection is closed"}
// ErrWrongPool is return when a connection is returned to a pool it doesn't belong to.
var ErrWrongPool = PoolError("connection does not belong to this pool")
// PoolError is an error returned from a Pool method.
type PoolError string
// pruneInterval is the interval at which the background routine to close expired connections will be run.
var pruneInterval = time.Minute
func (pe PoolError) Error() string { return string(pe) }
type pool struct {
address address.Address
opts []ConnectionOption
conns *resourcePool // pool for idle connections
generation uint64
connected int32 // Must be accessed using the sync/atomic package.
nextid uint64
opened map[uint64]*connection // opened holds all of the currently open connections.
sync.Mutex
}
func connectionExpiredFunc(v interface{}) bool {
return v.(*connection).expired()
}
func connectionCloseFunc(v interface{}) {
c := v.(*connection)
go c.pool.close(c)
}
// newPool creates a new pool that will hold size number of idle connections. It will use the
// provided options when creating connections.
func newPool(addr address.Address, size uint64, opts ...ConnectionOption) *pool {
return &pool{
address: addr,
conns: newResourcePool(size, connectionExpiredFunc, connectionCloseFunc, pruneInterval),
generation: 0,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
}
}
// drain lazily drains the pool by increasing the generation ID.
func (p *pool) drain() { atomic.AddUint64(&p.generation, 1) }
func (p *pool) expired(generation uint64) bool { return generation < atomic.LoadUint64(&p.generation) }
// connect puts the pool into the connected state, allowing it to be used.
func (p *pool) connect() error {
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
return ErrPoolConnected
}
atomic.AddUint64(&p.generation, 1)
return nil
}
func (p *pool) disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
return ErrPoolDisconnected
}
// We first clear out the idle connections, then we wait until the context's deadline is hit or
// it's cancelled, after which we aggressively close the remaining open connections.
for {
connVal := p.conns.Get()
if connVal == nil {
break
}
_ = p.close(connVal.(*connection))
}
if dl, ok := ctx.Deadline(); ok {
// If we have a deadline then we interpret it as a request to gracefully shutdown. We wait
// until either all the connections have landed back in the pool (and have been closed) or
// until the timer is done.
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timer := time.NewTimer(time.Now().Sub(dl))
defer timer.Stop()
for {
select {
case <-timer.C:
case <-ticker.C: // Can we repalce this with an actual signal channel? We will know when p.inflight hits zero from the close method.
p.Lock()
if len(p.opened) > 0 {
p.Unlock()
continue
}
p.Unlock()
}
break
}
}
// We copy the remaining connections into a slice, then iterate it to close them. This allows us
// to use a single function to actually clean up and close connections at the expense of a
// double itertion in the worse case.
p.Lock()
toClose := make([]*connection, 0, len(p.opened))
for _, pc := range p.opened {
toClose = append(toClose, pc)
}
p.Unlock()
for _, pc := range toClose {
_ = p.close(pc) // We don't care about errors while closing the connection.
}
atomic.StoreInt32(&p.connected, disconnected)
return nil
}
func (p *pool) get(ctx context.Context) (*connection, error) {
if atomic.LoadInt32(&p.connected) != connected {
return nil, ErrPoolDisconnected
}
// try to get an unexpired idle connection
connVal := p.conns.Get()
if connVal != nil {
return connVal.(*connection), nil
}
// couldn't find an unexpired connection. create a new one.
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
c, err := newConnection(ctx, p.address, p.opts...)
if err != nil {
return nil, err
}
c.pool = p
c.poolID = atomic.AddUint64(&p.nextid, 1)
c.generation = p.generation
if atomic.LoadInt32(&p.connected) != connected {
_ = p.close(c) // The pool is disconnected or disconnecting, ignore the error from closing the connection.
return nil, ErrPoolDisconnected
}
p.Lock()
p.opened[c.poolID] = c
p.Unlock()
return c, nil
}
}
// close closes a connection, not the pool itself. This method will actually close the connection,
// making it unusable, to instead return the connection to the pool, use put.
func (p *pool) close(c *connection) error {
if c.pool != p {
return ErrWrongPool
}
p.Lock()
delete(p.opened, c.poolID)
nc := c.nc
c.nc = nil
p.Unlock()
if nc == nil {
return nil // We're closing an already closed connection.
}
err := nc.Close()
if err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to close net.Conn"}
}
return nil
}
// put returns a connection to this pool. If the pool is connected, the connection is not
// expired, and there is space in the cache, the connection is returned to the cache.
func (p *pool) put(c *connection) error {
if c.pool != p {
return ErrWrongPool
}
if atomic.LoadInt32(&p.connected) != connected || c.expired() {
return p.close(c)
}
// close the connection if the underlying pool is full
if !p.conns.Put(c) {
return p.close(c)
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More