newline battles continue
This commit is contained in:
40
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/DESIGN.md
generated
vendored
Executable file
40
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/DESIGN.md
generated
vendored
Executable file
@@ -0,0 +1,40 @@
|
||||
# Topology Package Design
|
||||
This document outlines the design for this package.
|
||||
|
||||
## Topology
|
||||
The `Topology` type handles monitoring the state of a MongoDB deployment and selecting servers.
|
||||
Updating the description is handled by finite state machine which implements the server discovery
|
||||
and monitoring specification. A `Topology` can be connected and fully disconnected, which enables
|
||||
saving resources. The `Topology` type also handles server selection following the server selection
|
||||
specification.
|
||||
|
||||
## Server
|
||||
The `Server` type handles heartbeating a MongoDB server and holds a pool of connections.
|
||||
|
||||
## Connection
|
||||
Connections are handled by two main types and an auxiliary type. The two main types are `connection`
|
||||
and `Connection`. The first holds most of the logic required to actually read and write wire
|
||||
messages. Instances can be created with the `newConnection` method. Inside the `newConnection`
|
||||
method the auxiliary type, `initConnection` is used to perform the connection handshake. This is
|
||||
required because the `connection` type does not fully implement `driver.Connection` which is
|
||||
required during handshaking. The `Connection` type is what is actually returned to a consumer of the
|
||||
`topology` package. This type does implement the `driver.Connection` type, holds a reference to a
|
||||
`connection` instance, and exists mainly to prevent accidental continued usage of a connection after
|
||||
closing it.
|
||||
|
||||
The connection implementations in this package are conduits for wire messages but they have no
|
||||
ability to encode, decode, or validate wire messages. That must be handled by consumers.
|
||||
|
||||
## Pool
|
||||
The `pool` type implements a connection pool. It handles caching idle connections and dialing
|
||||
new ones, but it does not track a maximum number of connections. That is the responsibility of a
|
||||
wrapping type, like `Server`.
|
||||
|
||||
The `pool` type has no concept of closing, instead it has concepts of connecting and disconnecting.
|
||||
This allows a `Topology` to be disconnected,but keeping the memory around to be reconnected later.
|
||||
There is a `close` method, but this is used to close a connection.
|
||||
|
||||
There are three methods related to getting and putting connections: `get`, `close`, and `put`. The
|
||||
`get` method will either retrieve a connection from the cache or it will dial a new `connection`.
|
||||
The `close` method will close the underlying socket of a `connection`. The `put` method will put a
|
||||
connection into the pool, placing it in the cahce if there is space, otherwise it will close it.
|
||||
458
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
generated
vendored
Executable file
458
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go
generated
vendored
Executable file
@@ -0,0 +1,458 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"strings"
|
||||
|
||||
"github.com/golang/snappy"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
|
||||
wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
"go.mongodb.org/mongo-driver/x/network/command"
|
||||
"go.mongodb.org/mongo-driver/x/network/wiremessage"
|
||||
)
|
||||
|
||||
var globalConnectionID uint64
|
||||
|
||||
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
|
||||
|
||||
type connection struct {
|
||||
id string
|
||||
nc net.Conn // When nil, the connection is closed.
|
||||
addr address.Address
|
||||
idleTimeout time.Duration
|
||||
idleDeadline time.Time
|
||||
lifetimeDeadline time.Time
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
desc description.Server
|
||||
compressor wiremessage.CompressorID
|
||||
zliblevel int
|
||||
|
||||
// pool related fields
|
||||
pool *pool
|
||||
poolID uint64
|
||||
generation uint64
|
||||
}
|
||||
|
||||
// newConnection handles the creation of a connection. It will dial, configure TLS, and perform
|
||||
// initialization handshakes.
|
||||
func newConnection(ctx context.Context, addr address.Address, opts ...ConnectionOption) (*connection, error) {
|
||||
cfg, err := newConnectionConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String())
|
||||
if err != nil {
|
||||
return nil, ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
|
||||
if cfg.tlsConfig != nil {
|
||||
tlsConfig := cfg.tlsConfig.Clone()
|
||||
nc, err = configureTLS(ctx, nc, addr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
}
|
||||
|
||||
var lifetimeDeadline time.Time
|
||||
if cfg.lifeTimeout > 0 {
|
||||
lifetimeDeadline = time.Now().Add(cfg.lifeTimeout)
|
||||
}
|
||||
|
||||
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
|
||||
|
||||
c := &connection{
|
||||
id: id,
|
||||
nc: nc,
|
||||
addr: addr,
|
||||
idleTimeout: cfg.idleTimeout,
|
||||
lifetimeDeadline: lifetimeDeadline,
|
||||
readTimeout: cfg.readTimeout,
|
||||
writeTimeout: cfg.writeTimeout,
|
||||
}
|
||||
|
||||
c.bumpIdleDeadline()
|
||||
|
||||
// running isMaster and authentication is handled by a handshaker on the configuration instance.
|
||||
if cfg.handshaker != nil {
|
||||
c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c})
|
||||
if err != nil {
|
||||
if c.nc != nil {
|
||||
_ = c.nc.Close()
|
||||
}
|
||||
return nil, ConnectionError{Wrapped: err, init: true}
|
||||
}
|
||||
if cfg.descCallback != nil {
|
||||
cfg.descCallback(c.desc)
|
||||
}
|
||||
if len(c.desc.Compression) > 0 {
|
||||
clientMethodLoop:
|
||||
for _, method := range cfg.compressors {
|
||||
for _, serverMethod := range c.desc.Compression {
|
||||
if method != serverMethod {
|
||||
continue
|
||||
}
|
||||
|
||||
switch strings.ToLower(method) {
|
||||
case "snappy":
|
||||
c.compressor = wiremessage.CompressorSnappy
|
||||
case "zlib":
|
||||
c.compressor = wiremessage.CompressorZLib
|
||||
c.zliblevel = wiremessage.DefaultZlibLevel
|
||||
if cfg.zlibLevel != nil {
|
||||
c.zliblevel = *cfg.zlibLevel
|
||||
}
|
||||
}
|
||||
break clientMethodLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
|
||||
var err error
|
||||
if c.nc == nil {
|
||||
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
|
||||
default:
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if c.writeTimeout != 0 {
|
||||
deadline = time.Now().Add(c.writeTimeout)
|
||||
}
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
||||
deadline = dl
|
||||
}
|
||||
|
||||
if err := c.nc.SetWriteDeadline(deadline); err != nil {
|
||||
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
|
||||
}
|
||||
|
||||
_, err = c.nc.Write(wm)
|
||||
if err != nil {
|
||||
c.close()
|
||||
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to write wire message to network"}
|
||||
}
|
||||
|
||||
c.bumpIdleDeadline()
|
||||
return nil
|
||||
}
|
||||
|
||||
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
|
||||
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
if c.nc == nil {
|
||||
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// We close the connection because we don't know if there is an unread message on the wire.
|
||||
c.close()
|
||||
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
|
||||
default:
|
||||
}
|
||||
|
||||
var deadline time.Time
|
||||
if c.readTimeout != 0 {
|
||||
deadline = time.Now().Add(c.readTimeout)
|
||||
}
|
||||
|
||||
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
||||
deadline = dl
|
||||
}
|
||||
|
||||
if err := c.nc.SetReadDeadline(deadline); err != nil {
|
||||
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
|
||||
}
|
||||
|
||||
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
|
||||
// reslice dst once instead of twice.
|
||||
var sizeBuf [4]byte
|
||||
|
||||
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
|
||||
// because there might be more than one wire message waiting to be read, for example when
|
||||
// reading messages from an exhaust cursor.
|
||||
_, err := io.ReadFull(c.nc, sizeBuf[:])
|
||||
if err != nil {
|
||||
// We close the connection because we don't know if there are other bytes left to read.
|
||||
c.close()
|
||||
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to decode message length"}
|
||||
}
|
||||
|
||||
// read the length as an int32
|
||||
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
|
||||
|
||||
if int(size) > cap(dst) {
|
||||
// Since we can't grow this slice without allocating, just allocate an entirely new slice.
|
||||
dst = make([]byte, 0, size)
|
||||
}
|
||||
// We need to ensure we don't accidentally read into a subsequent wire message, so we set the
|
||||
// size to read exactly this wire message.
|
||||
dst = dst[:size]
|
||||
copy(dst, sizeBuf[:])
|
||||
|
||||
_, err = io.ReadFull(c.nc, dst[4:])
|
||||
if err != nil {
|
||||
// We close the connection because we don't know if there are other bytes left to read.
|
||||
c.close()
|
||||
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to read full message"}
|
||||
}
|
||||
|
||||
c.bumpIdleDeadline()
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
func (c *connection) close() error {
|
||||
if c.nc == nil {
|
||||
return nil
|
||||
}
|
||||
if c.pool == nil {
|
||||
err := c.nc.Close()
|
||||
c.nc = nil
|
||||
return err
|
||||
}
|
||||
return c.pool.close(c)
|
||||
}
|
||||
|
||||
func (c *connection) expired() bool {
|
||||
now := time.Now()
|
||||
if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
|
||||
return true
|
||||
}
|
||||
|
||||
if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) {
|
||||
return true
|
||||
}
|
||||
|
||||
return c.nc == nil
|
||||
}
|
||||
|
||||
func (c *connection) bumpIdleDeadline() {
|
||||
if c.idleTimeout > 0 {
|
||||
c.idleDeadline = time.Now().Add(c.idleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// initConnection is an adapter used during connection initialization. It has the minimum
|
||||
// functionality necessary to implement the driver.Connection interface, which is required to pass a
|
||||
// *connection to a Handshaker.
|
||||
type initConnection struct{ *connection }
|
||||
|
||||
var _ driver.Connection = initConnection{}
|
||||
|
||||
func (c initConnection) Description() description.Server { return description.Server{} }
|
||||
func (c initConnection) Close() error { return nil }
|
||||
func (c initConnection) ID() string { return c.id }
|
||||
func (c initConnection) Address() address.Address { return c.addr }
|
||||
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
||||
return c.writeWireMessage(ctx, wm)
|
||||
}
|
||||
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
return c.readWireMessage(ctx, dst)
|
||||
}
|
||||
|
||||
// Connection implements the driver.Connection interface. It allows reading and writing wire
|
||||
// messages.
|
||||
type Connection struct {
|
||||
*connection
|
||||
s *Server
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var _ driver.Connection = (*Connection)(nil)
|
||||
|
||||
// WriteWireMessage handles writing a wire message to the underlying connection.
|
||||
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
return c.writeWireMessage(ctx, wm)
|
||||
}
|
||||
|
||||
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
|
||||
// will be overwritten with the new wire message.
|
||||
func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return dst, ErrConnectionClosed
|
||||
}
|
||||
return c.readWireMessage(ctx, dst)
|
||||
}
|
||||
|
||||
// CompressWireMessage handles compressing the provided wire message using the underlying
|
||||
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
|
||||
// there is no compressor set on the underlying connection, then no compression will be performed.
|
||||
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return dst, ErrConnectionClosed
|
||||
}
|
||||
if c.connection.compressor == wiremessage.CompressorNoOp {
|
||||
return append(dst, src...), nil
|
||||
}
|
||||
_, reqid, respto, origcode, rem, ok := wiremessagex.ReadHeader(src)
|
||||
if !ok {
|
||||
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
|
||||
}
|
||||
idx, dst := wiremessagex.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
|
||||
dst = wiremessagex.AppendCompressedOriginalOpCode(dst, origcode)
|
||||
dst = wiremessagex.AppendCompressedUncompressedSize(dst, int32(len(rem)))
|
||||
dst = wiremessagex.AppendCompressedCompressorID(dst, c.connection.compressor)
|
||||
switch c.connection.compressor {
|
||||
case wiremessage.CompressorSnappy:
|
||||
compressed := snappy.Encode(nil, rem)
|
||||
dst = wiremessagex.AppendCompressedCompressedMessage(dst, compressed)
|
||||
case wiremessage.CompressorZLib:
|
||||
var b bytes.Buffer
|
||||
w, err := zlib.NewWriterLevel(&b, c.connection.zliblevel)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
_, err = w.Write(rem)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
dst = wiremessagex.AppendCompressedCompressedMessage(dst, b.Bytes())
|
||||
default:
|
||||
return dst, fmt.Errorf("unknown compressor ID %v", c.connection.compressor)
|
||||
}
|
||||
|
||||
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
|
||||
}
|
||||
|
||||
// Description returns the server description of the server this connection is connected to.
|
||||
func (c *Connection) Description() description.Server {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return description.Server{}
|
||||
}
|
||||
return c.desc
|
||||
}
|
||||
|
||||
// Close returns this connection to the connection pool. This method may not close the underlying
|
||||
// socket.
|
||||
func (c *Connection) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connection == nil {
|
||||
return nil
|
||||
}
|
||||
if c.s != nil {
|
||||
c.s.sem.Release(1)
|
||||
}
|
||||
err := c.pool.put(c.connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.connection = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID returns the ID of this connection.
|
||||
func (c *Connection) ID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return "<closed>"
|
||||
}
|
||||
return c.id
|
||||
}
|
||||
|
||||
// Address returns the address of this connection.
|
||||
func (c *Connection) Address() address.Address {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.connection == nil {
|
||||
return address.Address("0.0.0.0")
|
||||
}
|
||||
return c.addr
|
||||
}
|
||||
|
||||
var notMasterCodes = []int32{10107, 13435}
|
||||
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
|
||||
|
||||
func isRecoveringError(err command.Error) bool {
|
||||
for _, c := range recoveringCodes {
|
||||
if c == err.Code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), "node is recovering")
|
||||
}
|
||||
|
||||
func isNotMasterError(err command.Error) bool {
|
||||
for _, c := range notMasterCodes {
|
||||
if c == err.Code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), "not master")
|
||||
}
|
||||
|
||||
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config) (net.Conn, error) {
|
||||
if !config.InsecureSkipVerify {
|
||||
hostname := addr.String()
|
||||
colonPos := strings.LastIndex(hostname, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(hostname)
|
||||
}
|
||||
|
||||
hostname = hostname[:colonPos]
|
||||
config.ServerName = hostname
|
||||
}
|
||||
|
||||
client := tls.Client(nc, config)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- client.Handshake()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, errors.New("server connection cancelled/timeout during TLS handshake")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
615
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy.go
generated
vendored
Executable file
615
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy.go
generated
vendored
Executable file
@@ -0,0 +1,615 @@
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/network/command"
|
||||
"go.mongodb.org/mongo-driver/x/network/compressor"
|
||||
"go.mongodb.org/mongo-driver/x/network/wiremessage"
|
||||
)
|
||||
|
||||
var emptyDoc bson.Raw
|
||||
|
||||
type connectionLegacy struct {
|
||||
*connection
|
||||
writeBuf []byte
|
||||
readBuf []byte
|
||||
monitor *event.CommandMonitor
|
||||
cmdMap map[int64]*commandMetadata // map for monitoring commands sent to server
|
||||
|
||||
wireMessageBuf []byte // buffer to store uncompressed wire message before compressing
|
||||
compressBuf []byte // buffer to compress messages
|
||||
uncompressBuf []byte // buffer to uncompress messages
|
||||
compressor compressor.Compressor // use for compressing messages
|
||||
// server can compress response with any compressor supported by driver
|
||||
compressorMap map[wiremessage.CompressorID]compressor.Compressor
|
||||
|
||||
s *Server
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func newConnectionLegacy(c *connection, s *Server, opts ...ConnectionOption) (*connectionLegacy, error) {
|
||||
cfg, err := newConnectionConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
compressorMap := make(map[wiremessage.CompressorID]compressor.Compressor)
|
||||
|
||||
for _, comp := range cfg.compressors {
|
||||
switch comp {
|
||||
case "snappy":
|
||||
snappyComp := compressor.CreateSnappy()
|
||||
compressorMap[snappyComp.CompressorID()] = snappyComp
|
||||
case "zlib":
|
||||
zlibComp, err := compressor.CreateZlib(cfg.zlibLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
compressorMap[zlibComp.CompressorID()] = zlibComp
|
||||
}
|
||||
}
|
||||
|
||||
cl := &connectionLegacy{
|
||||
connection: c,
|
||||
compressorMap: compressorMap,
|
||||
cmdMap: make(map[int64]*commandMetadata),
|
||||
compressBuf: make([]byte, 256),
|
||||
readBuf: make([]byte, 256),
|
||||
uncompressBuf: make([]byte, 256),
|
||||
writeBuf: make([]byte, 0, 256),
|
||||
wireMessageBuf: make([]byte, 256),
|
||||
|
||||
s: s,
|
||||
}
|
||||
|
||||
d := c.desc
|
||||
if len(d.Compression) > 0 {
|
||||
clientMethodLoop:
|
||||
for _, comp := range cl.compressorMap {
|
||||
method := comp.Name()
|
||||
|
||||
for _, serverMethod := range d.Compression {
|
||||
if method != serverMethod {
|
||||
continue
|
||||
}
|
||||
|
||||
cl.compressor = comp // found matching compressor
|
||||
break clientMethodLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cl.monitor = cfg.cmdMonitor // Attach the command monitor later to avoid monitoring auth.
|
||||
return cl, nil
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
|
||||
// Truncate the write buffer
|
||||
c.writeBuf = c.writeBuf[:0]
|
||||
|
||||
messageToWrite := wm
|
||||
// Compress if possible
|
||||
if c.compressor != nil {
|
||||
compressed, err := c.compressMessage(wm)
|
||||
if err != nil {
|
||||
return ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to compress wire message",
|
||||
}
|
||||
}
|
||||
messageToWrite = compressed
|
||||
}
|
||||
|
||||
var err error
|
||||
c.writeBuf, err = messageToWrite.AppendWireMessage(c.writeBuf)
|
||||
if err != nil {
|
||||
return ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to encode wire message",
|
||||
}
|
||||
}
|
||||
|
||||
err = c.writeWireMessage(ctx, c.writeBuf)
|
||||
if c.s != nil {
|
||||
c.s.ProcessError(err)
|
||||
}
|
||||
if err != nil {
|
||||
// The error we got back was probably a ConnectionError already, so we don't really need to
|
||||
// wrap it here.
|
||||
return ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to write wire message to network",
|
||||
}
|
||||
}
|
||||
|
||||
return c.commandStartedEvent(ctx, wm)
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
|
||||
// Truncate the write buffer
|
||||
c.readBuf = c.readBuf[:0]
|
||||
|
||||
var err error
|
||||
c.readBuf, err = c.readWireMessage(ctx, c.readBuf)
|
||||
if c.s != nil {
|
||||
c.s.ProcessError(err)
|
||||
}
|
||||
if err != nil {
|
||||
// The error we got back was probably a ConnectionError already, so we don't really need to
|
||||
// wrap it here.
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to read wire message from network",
|
||||
}
|
||||
}
|
||||
hdr, err := wiremessage.ReadHeader(c.readBuf, 0)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to decode header",
|
||||
}
|
||||
}
|
||||
|
||||
messageToDecode := c.readBuf
|
||||
opcodeToCheck := hdr.OpCode
|
||||
|
||||
if hdr.OpCode == wiremessage.OpCompressed {
|
||||
var compressed wiremessage.Compressed
|
||||
err := compressed.UnmarshalWireMessage(c.readBuf)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to decode OP_COMPRESSED",
|
||||
}
|
||||
}
|
||||
|
||||
uncompressed, origOpcode, err := c.uncompressMessage(compressed)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to uncompress message",
|
||||
}
|
||||
}
|
||||
messageToDecode = uncompressed
|
||||
opcodeToCheck = origOpcode
|
||||
}
|
||||
|
||||
var wm wiremessage.WireMessage
|
||||
switch opcodeToCheck {
|
||||
case wiremessage.OpReply:
|
||||
var r wiremessage.Reply
|
||||
err := r.UnmarshalWireMessage(messageToDecode)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to decode OP_REPLY",
|
||||
}
|
||||
}
|
||||
wm = r
|
||||
case wiremessage.OpMsg:
|
||||
var reply wiremessage.Msg
|
||||
err := reply.UnmarshalWireMessage(messageToDecode)
|
||||
if err != nil {
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
Wrapped: err,
|
||||
message: "unable to decode OP_MSG",
|
||||
}
|
||||
}
|
||||
wm = reply
|
||||
default:
|
||||
return nil, ConnectionError{
|
||||
ConnectionID: c.id,
|
||||
message: fmt.Sprintf("opcode %s not implemented", hdr.OpCode),
|
||||
}
|
||||
}
|
||||
|
||||
if c.s != nil {
|
||||
c.s.ProcessError(command.DecodeError(wm))
|
||||
}
|
||||
|
||||
// TODO: do we care if monitoring fails?
|
||||
return wm, c.commandFinishedEvent(ctx, wm)
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) Close() error {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
if c.connection == nil {
|
||||
return nil
|
||||
}
|
||||
if c.s != nil {
|
||||
c.s.sem.Release(1)
|
||||
}
|
||||
err := c.pool.put(c.connection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.connection = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) Expired() bool {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return c.connection == nil || c.expired()
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) Alive() bool {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
return c.connection != nil
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) ID() string {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
if c.connection == nil {
|
||||
return "<closed>"
|
||||
}
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) commandStartedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
|
||||
if c.monitor == nil || c.monitor.Started == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
startedEvent := &event.CommandStartedEvent{
|
||||
ConnectionID: c.id,
|
||||
}
|
||||
|
||||
var cmd bsonx.Doc
|
||||
var err error
|
||||
var legacy bool
|
||||
var fullCollName string
|
||||
|
||||
var acknowledged bool
|
||||
switch converted := wm.(type) {
|
||||
case wiremessage.Query:
|
||||
cmd, err = converted.CommandDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
acknowledged = converted.AcknowledgedWrite()
|
||||
startedEvent.DatabaseName = converted.DatabaseName()
|
||||
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
|
||||
legacy = converted.Legacy()
|
||||
fullCollName = converted.FullCollectionName
|
||||
case wiremessage.Msg:
|
||||
cmd, err = converted.GetMainDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
acknowledged = converted.AcknowledgedWrite()
|
||||
arr, identifier, err := converted.GetSequenceArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if arr != nil {
|
||||
cmd = cmd.Copy() // make copy to avoid changing original command
|
||||
cmd = append(cmd, bsonx.Elem{identifier, bsonx.Array(arr)})
|
||||
}
|
||||
|
||||
dbVal, err := cmd.LookupErr("$db")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startedEvent.DatabaseName = dbVal.StringValue()
|
||||
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
|
||||
case wiremessage.GetMore:
|
||||
cmd = converted.CommandDocument()
|
||||
startedEvent.DatabaseName = converted.DatabaseName()
|
||||
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
|
||||
acknowledged = true
|
||||
legacy = true
|
||||
fullCollName = converted.FullCollectionName
|
||||
case wiremessage.KillCursors:
|
||||
cmd = converted.CommandDocument()
|
||||
startedEvent.DatabaseName = converted.DatabaseName
|
||||
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
|
||||
legacy = true
|
||||
}
|
||||
|
||||
rawcmd, _ := cmd.MarshalBSON()
|
||||
startedEvent.Command = rawcmd
|
||||
startedEvent.CommandName = cmd[0].Key
|
||||
if !canMonitor(startedEvent.CommandName) {
|
||||
startedEvent.Command = emptyDoc
|
||||
}
|
||||
|
||||
c.monitor.Started(ctx, startedEvent)
|
||||
|
||||
if !acknowledged {
|
||||
if c.monitor.Succeeded == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// unack writes must provide a CommandSucceededEvent with an { ok: 1 } reply
|
||||
finishedEvent := event.CommandFinishedEvent{
|
||||
DurationNanos: 0,
|
||||
CommandName: startedEvent.CommandName,
|
||||
RequestID: startedEvent.RequestID,
|
||||
ConnectionID: c.id,
|
||||
}
|
||||
|
||||
c.monitor.Succeeded(ctx, &event.CommandSucceededEvent{
|
||||
CommandFinishedEvent: finishedEvent,
|
||||
Reply: bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.cmdMap[startedEvent.RequestID] = createMetadata(startedEvent.CommandName, legacy, fullCollName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) commandFinishedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
|
||||
if c.monitor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var reply bsonx.Doc
|
||||
var requestID int64
|
||||
var err error
|
||||
|
||||
switch converted := wm.(type) {
|
||||
case wiremessage.Reply:
|
||||
requestID = int64(converted.MsgHeader.ResponseTo)
|
||||
case wiremessage.Msg:
|
||||
requestID = int64(converted.MsgHeader.ResponseTo)
|
||||
}
|
||||
cmdMetadata := c.cmdMap[requestID]
|
||||
delete(c.cmdMap, requestID)
|
||||
|
||||
switch converted := wm.(type) {
|
||||
case wiremessage.Reply:
|
||||
if cmdMetadata.Legacy {
|
||||
reply, err = converted.GetMainLegacyDocument(cmdMetadata.FullCollectionName)
|
||||
} else {
|
||||
reply, err = converted.GetMainDocument()
|
||||
}
|
||||
case wiremessage.Msg:
|
||||
reply, err = converted.GetMainDocument()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
success, errmsg := processReply(reply)
|
||||
|
||||
if (success && c.monitor.Succeeded == nil) || (!success && c.monitor.Failed == nil) {
|
||||
return nil
|
||||
}
|
||||
|
||||
finishedEvent := event.CommandFinishedEvent{
|
||||
DurationNanos: cmdMetadata.TimeDifference(),
|
||||
CommandName: cmdMetadata.Name,
|
||||
RequestID: requestID,
|
||||
ConnectionID: c.id,
|
||||
}
|
||||
|
||||
if success {
|
||||
if !canMonitor(finishedEvent.CommandName) {
|
||||
successEvent := &event.CommandSucceededEvent{
|
||||
Reply: emptyDoc,
|
||||
CommandFinishedEvent: finishedEvent,
|
||||
}
|
||||
c.monitor.Succeeded(ctx, successEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
// if response has type 1 document sequence, the sequence must be included as a BSON array in the event's reply.
|
||||
if opmsg, ok := wm.(wiremessage.Msg); ok {
|
||||
arr, identifier, err := opmsg.GetSequenceArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if arr != nil {
|
||||
reply = reply.Copy() // make copy to avoid changing original command
|
||||
reply = append(reply, bsonx.Elem{identifier, bsonx.Array(arr)})
|
||||
}
|
||||
}
|
||||
|
||||
replyraw, _ := reply.MarshalBSON()
|
||||
successEvent := &event.CommandSucceededEvent{
|
||||
Reply: replyraw,
|
||||
CommandFinishedEvent: finishedEvent,
|
||||
}
|
||||
|
||||
c.monitor.Succeeded(ctx, successEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
failureEvent := &event.CommandFailedEvent{
|
||||
Failure: errmsg,
|
||||
CommandFinishedEvent: finishedEvent,
|
||||
}
|
||||
|
||||
c.monitor.Failed(ctx, failureEvent)
|
||||
return nil
|
||||
}
|
||||
|
||||
func canCompress(cmd string) bool {
|
||||
if cmd == "isMaster" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
|
||||
cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *connectionLegacy) compressMessage(wm wiremessage.WireMessage) (wiremessage.WireMessage, error) {
|
||||
var requestID int32
|
||||
var responseTo int32
|
||||
var origOpcode wiremessage.OpCode
|
||||
|
||||
switch converted := wm.(type) {
|
||||
case wiremessage.Query:
|
||||
firstElem, err := converted.Query.IndexErr(0)
|
||||
if err != nil {
|
||||
return wiremessage.Compressed{}, err
|
||||
}
|
||||
|
||||
key := firstElem.Key()
|
||||
if !canCompress(key) {
|
||||
return wm, nil // return original message because this command can't be compressed
|
||||
}
|
||||
requestID = converted.MsgHeader.RequestID
|
||||
origOpcode = wiremessage.OpQuery
|
||||
responseTo = converted.MsgHeader.ResponseTo
|
||||
case wiremessage.Msg:
|
||||
firstElem, err := converted.Sections[0].(wiremessage.SectionBody).Document.IndexErr(0)
|
||||
if err != nil {
|
||||
return wiremessage.Compressed{}, err
|
||||
}
|
||||
|
||||
key := firstElem.Key()
|
||||
if !canCompress(key) {
|
||||
return wm, nil
|
||||
}
|
||||
|
||||
requestID = converted.MsgHeader.RequestID
|
||||
origOpcode = wiremessage.OpMsg
|
||||
responseTo = converted.MsgHeader.ResponseTo
|
||||
}
|
||||
|
||||
// can compress
|
||||
c.wireMessageBuf = c.wireMessageBuf[:0] // truncate
|
||||
var err error
|
||||
c.wireMessageBuf, err = wm.AppendWireMessage(c.wireMessageBuf)
|
||||
if err != nil {
|
||||
return wiremessage.Compressed{}, err
|
||||
}
|
||||
|
||||
c.wireMessageBuf = c.wireMessageBuf[16:] // strip header
|
||||
c.compressBuf = c.compressBuf[:0]
|
||||
compressedBytes, err := c.compressor.CompressBytes(c.wireMessageBuf, c.compressBuf)
|
||||
if err != nil {
|
||||
return wiremessage.Compressed{}, err
|
||||
}
|
||||
|
||||
compressedMessage := wiremessage.Compressed{
|
||||
MsgHeader: wiremessage.Header{
|
||||
// MessageLength and OpCode will be set when marshalling wire message by SetDefaults()
|
||||
RequestID: requestID,
|
||||
ResponseTo: responseTo,
|
||||
},
|
||||
OriginalOpCode: origOpcode,
|
||||
UncompressedSize: int32(len(c.wireMessageBuf)), // length of uncompressed message excluding MsgHeader
|
||||
CompressorID: wiremessage.CompressorID(c.compressor.CompressorID()),
|
||||
CompressedMessage: compressedBytes,
|
||||
}
|
||||
|
||||
return compressedMessage, nil
|
||||
}
|
||||
|
||||
// returns []byte of uncompressed message with reconstructed header, original opcode, error
|
||||
func (c *connectionLegacy) uncompressMessage(compressed wiremessage.Compressed) ([]byte, wiremessage.OpCode, error) {
|
||||
// server doesn't guarantee the same compression method will be used each time so the CompressorID field must be
|
||||
// used to find the correct method for uncompressing data
|
||||
uncompressor := c.compressorMap[compressed.CompressorID]
|
||||
|
||||
// reset uncompressBuf
|
||||
c.uncompressBuf = c.uncompressBuf[:0]
|
||||
if int(compressed.UncompressedSize) > cap(c.uncompressBuf) {
|
||||
c.uncompressBuf = make([]byte, 0, compressed.UncompressedSize)
|
||||
}
|
||||
|
||||
uncompressedMessage, err := uncompressor.UncompressBytes(compressed.CompressedMessage, c.uncompressBuf[:compressed.UncompressedSize])
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
origHeader := wiremessage.Header{
|
||||
MessageLength: int32(len(uncompressedMessage)) + 16, // add 16 for original header
|
||||
RequestID: compressed.MsgHeader.RequestID,
|
||||
ResponseTo: compressed.MsgHeader.ResponseTo,
|
||||
}
|
||||
|
||||
switch compressed.OriginalOpCode {
|
||||
case wiremessage.OpReply:
|
||||
origHeader.OpCode = wiremessage.OpReply
|
||||
case wiremessage.OpMsg:
|
||||
origHeader.OpCode = wiremessage.OpMsg
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("opcode %s not implemented", compressed.OriginalOpCode)
|
||||
}
|
||||
|
||||
var fullMessage []byte
|
||||
fullMessage = origHeader.AppendHeader(fullMessage)
|
||||
fullMessage = append(fullMessage, uncompressedMessage...)
|
||||
return fullMessage, origHeader.OpCode, nil
|
||||
}
|
||||
|
||||
func canMonitor(cmd string) bool {
|
||||
if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
|
||||
cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func processReply(reply bsonx.Doc) (bool, string) {
|
||||
var success bool
|
||||
var errmsg string
|
||||
var errCode int32
|
||||
|
||||
for _, elem := range reply {
|
||||
switch elem.Key {
|
||||
case "ok":
|
||||
switch elem.Value.Type() {
|
||||
case bsontype.Int32:
|
||||
if elem.Value.Int32() == 1 {
|
||||
success = true
|
||||
}
|
||||
case bsontype.Int64:
|
||||
if elem.Value.Int64() == 1 {
|
||||
success = true
|
||||
}
|
||||
case bsontype.Double:
|
||||
if elem.Value.Double() == 1 {
|
||||
success = true
|
||||
}
|
||||
}
|
||||
case "errmsg":
|
||||
if str, ok := elem.Value.StringValueOK(); ok {
|
||||
errmsg = str
|
||||
}
|
||||
case "code":
|
||||
if c, ok := elem.Value.Int32OK(); ok {
|
||||
errCode = c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if success {
|
||||
return true, ""
|
||||
}
|
||||
|
||||
fullErrMsg := fmt.Sprintf("Error code %d: %s", errCode, errmsg)
|
||||
return false, fullErrMsg
|
||||
}
|
||||
28
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy_command_metadata.go
generated
vendored
Executable file
28
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_legacy_command_metadata.go
generated
vendored
Executable file
@@ -0,0 +1,28 @@
|
||||
package topology
|
||||
|
||||
import "time"
|
||||
|
||||
// commandMetadata contains metadata about a command sent to the server.
|
||||
type commandMetadata struct {
|
||||
Name string
|
||||
Time time.Time
|
||||
Legacy bool
|
||||
FullCollectionName string
|
||||
}
|
||||
|
||||
// createMetadata creates metadata for a command.
|
||||
func createMetadata(name string, legacy bool, fullCollName string) *commandMetadata {
|
||||
return &commandMetadata{
|
||||
Name: name,
|
||||
Time: time.Now(),
|
||||
Legacy: legacy,
|
||||
FullCollectionName: fullCollName,
|
||||
}
|
||||
}
|
||||
|
||||
// TimeDifference returns the difference between now and the time a command was sent in nanoseconds.
|
||||
func (cm *commandMetadata) TimeDifference() int64 {
|
||||
t := time.Now()
|
||||
duration := t.Sub(cm.Time)
|
||||
return duration.Nanoseconds()
|
||||
}
|
||||
187
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go
generated
vendored
Executable file
187
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_options.go
generated
vendored
Executable file
@@ -0,0 +1,187 @@
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
|
||||
)
|
||||
|
||||
// Dialer is used to make network connections.
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// DialerFunc is a type implemented by functions that can be used as a Dialer.
|
||||
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
|
||||
// DialContext implements the Dialer interface.
|
||||
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return df(ctx, network, address)
|
||||
}
|
||||
|
||||
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
|
||||
// will also change the Dialer used for this package. This should only be changed why all
|
||||
// of the connections being made need to use a different Dialer. Most of the time, using a
|
||||
// WithDialer option is more appropriate than changing this variable.
|
||||
var DefaultDialer Dialer = &net.Dialer{}
|
||||
|
||||
// Handshaker is the interface implemented by types that can perform a MongoDB
|
||||
// handshake over a provided driver.Connection. This is used during connection
|
||||
// initialization. Implementations must be goroutine safe.
|
||||
type Handshaker = driver.Handshaker
|
||||
|
||||
// HandshakerFunc is an adapter to allow the use of ordinary functions as
|
||||
// connection handshakers.
|
||||
type HandshakerFunc = driver.HandshakerFunc
|
||||
|
||||
type connectionConfig struct {
|
||||
appName string
|
||||
connectTimeout time.Duration
|
||||
dialer Dialer
|
||||
handshaker Handshaker
|
||||
idleTimeout time.Duration
|
||||
lifeTimeout time.Duration
|
||||
cmdMonitor *event.CommandMonitor
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
tlsConfig *tls.Config
|
||||
compressors []string
|
||||
zlibLevel *int
|
||||
descCallback func(description.Server)
|
||||
}
|
||||
|
||||
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
|
||||
cfg := &connectionConfig{
|
||||
connectTimeout: 30 * time.Second,
|
||||
dialer: nil,
|
||||
idleTimeout: 10 * time.Minute,
|
||||
lifeTimeout: 30 * time.Minute,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.dialer == nil {
|
||||
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
|
||||
return append(opts, ConnectionOption(func(c *connectionConfig) error {
|
||||
c.descCallback = callback
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// ConnectionOption is used to configure a connection.
|
||||
type ConnectionOption func(*connectionConfig) error
|
||||
|
||||
// WithAppName sets the application name which gets sent to MongoDB when it
|
||||
// first connects.
|
||||
func WithAppName(fn func(string) string) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.appName = fn(c.appName)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCompressors sets the compressors that can be used for communication.
|
||||
func WithCompressors(fn func([]string) []string) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.compressors = fn(c.compressors)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
|
||||
// connect to complete. The default is 30 seconds.
|
||||
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.connectTimeout = fn(c.connectTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
|
||||
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.dialer = fn(c.dialer)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHandshaker configures the Handshaker that wll be used to initialize newly
|
||||
// dialed connections.
|
||||
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.handshaker = fn(c.handshaker)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdleTimeout configures the maximum idle time to allow for a connection.
|
||||
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.idleTimeout = fn(c.idleTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithLifeTimeout configures the maximum life of a connection.
|
||||
func WithLifeTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.lifeTimeout = fn(c.lifeTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithReadTimeout configures the maximum read time for a connection.
|
||||
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.readTimeout = fn(c.readTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithWriteTimeout configures the maximum write time for a connection.
|
||||
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.writeTimeout = fn(c.writeTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig configures the TLS options for a connection.
|
||||
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.tlsConfig = fn(c.tlsConfig)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithMonitor configures a event for command monitoring.
|
||||
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.cmdMonitor = fn(c.cmdMonitor)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithZlibLevel sets the zLib compression level.
|
||||
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
|
||||
return func(c *connectionConfig) error {
|
||||
c.zlibLevel = fn(c.zlibLevel)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
22
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go
generated
vendored
Executable file
22
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/errors.go
generated
vendored
Executable file
@@ -0,0 +1,22 @@
|
||||
package topology
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ConnectionError represents a connection error.
|
||||
type ConnectionError struct {
|
||||
ConnectionID string
|
||||
Wrapped error
|
||||
|
||||
// init will be set to true if this error occured during connection initialization or
|
||||
// during a connection handshake.
|
||||
init bool
|
||||
message string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ConnectionError) Error() string {
|
||||
if e.Wrapped != nil {
|
||||
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, e.message, e.Wrapped.Error())
|
||||
}
|
||||
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message)
|
||||
}
|
||||
350
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go
generated
vendored
Executable file
350
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/fsm.go
generated
vendored
Executable file
@@ -0,0 +1,350 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
|
||||
)
|
||||
|
||||
var supportedWireVersions = description.NewVersionRange(2, 6)
|
||||
var minSupportedMongoDBVersion = "2.6"
|
||||
|
||||
type fsm struct {
|
||||
description.Topology
|
||||
SetName string
|
||||
maxElectionID primitive.ObjectID
|
||||
maxSetVersion uint32
|
||||
}
|
||||
|
||||
func newFSM() *fsm {
|
||||
return new(fsm)
|
||||
}
|
||||
|
||||
// apply should operate on immutable TopologyDescriptions and Descriptions. This way we don't have to
|
||||
// lock for the entire time we're applying server description.
|
||||
func (f *fsm) apply(s description.Server) (description.Topology, error) {
|
||||
|
||||
newServers := make([]description.Server, len(f.Servers))
|
||||
copy(newServers, f.Servers)
|
||||
|
||||
oldMinutes := f.SessionTimeoutMinutes
|
||||
f.Topology = description.Topology{
|
||||
Kind: f.Kind,
|
||||
Servers: newServers,
|
||||
}
|
||||
|
||||
// For data bearing servers, set SessionTimeoutMinutes to the lowest among them
|
||||
if oldMinutes == 0 {
|
||||
// If timeout currently 0, check all servers to see if any still don't have a timeout
|
||||
// If they all have timeout, pick the lowest.
|
||||
timeout := s.SessionTimeoutMinutes
|
||||
for _, server := range f.Servers {
|
||||
if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
|
||||
timeout = server.SessionTimeoutMinutes
|
||||
}
|
||||
}
|
||||
f.SessionTimeoutMinutes = timeout
|
||||
} else {
|
||||
if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
|
||||
f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
|
||||
} else {
|
||||
f.SessionTimeoutMinutes = oldMinutes
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := f.findServer(s.Addr); !ok {
|
||||
return f.Topology, nil
|
||||
}
|
||||
|
||||
if s.WireVersion != nil {
|
||||
if s.WireVersion.Max < supportedWireVersions.Min {
|
||||
return description.Topology{}, fmt.Errorf(
|
||||
"server at %s reports wire version %d, but this version of the Go driver requires "+
|
||||
"at least %d (MongoDB %s)",
|
||||
s.Addr.String(),
|
||||
s.WireVersion.Max,
|
||||
supportedWireVersions.Min,
|
||||
minSupportedMongoDBVersion,
|
||||
)
|
||||
}
|
||||
|
||||
if s.WireVersion.Min > supportedWireVersions.Max {
|
||||
return description.Topology{}, fmt.Errorf(
|
||||
"server at %s requires wire version %d, but this version of the Go driver only "+
|
||||
"supports up to %d",
|
||||
s.Addr.String(),
|
||||
s.WireVersion.Min,
|
||||
supportedWireVersions.Max,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch f.Kind {
|
||||
case description.Unknown:
|
||||
f.applyToUnknown(s)
|
||||
case description.Sharded:
|
||||
f.applyToSharded(s)
|
||||
case description.ReplicaSetNoPrimary:
|
||||
f.applyToReplicaSetNoPrimary(s)
|
||||
case description.ReplicaSetWithPrimary:
|
||||
f.applyToReplicaSetWithPrimary(s)
|
||||
case description.Single:
|
||||
f.applyToSingle(s)
|
||||
}
|
||||
|
||||
return f.Topology, nil
|
||||
}
|
||||
|
||||
func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) {
|
||||
switch s.Kind {
|
||||
case description.Standalone, description.Mongos:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.updateRSWithoutPrimary(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) {
|
||||
switch s.Kind {
|
||||
case description.Standalone, description.Mongos:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.updateRSWithPrimaryFromMember(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
f.checkIfHasPrimary()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) applyToSharded(s description.Server) {
|
||||
switch s.Kind {
|
||||
case description.Mongos, description.Unknown:
|
||||
f.replaceServer(s)
|
||||
case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
|
||||
f.removeServerByAddr(s.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) applyToSingle(s description.Server) {
|
||||
switch s.Kind {
|
||||
case description.Unknown:
|
||||
f.replaceServer(s)
|
||||
case description.Standalone, description.Mongos:
|
||||
if f.SetName != "" {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
|
||||
if f.SetName != "" && f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) applyToUnknown(s description.Server) {
|
||||
switch s.Kind {
|
||||
case description.Mongos:
|
||||
f.setKind(description.Sharded)
|
||||
f.replaceServer(s)
|
||||
case description.RSPrimary:
|
||||
f.updateRSFromPrimary(s)
|
||||
case description.RSSecondary, description.RSArbiter, description.RSMember:
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
f.updateRSWithoutPrimary(s)
|
||||
case description.Standalone:
|
||||
f.updateUnknownWithStandalone(s)
|
||||
case description.Unknown, description.RSGhost:
|
||||
f.replaceServer(s)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) checkIfHasPrimary() {
|
||||
if _, ok := f.findPrimary(); ok {
|
||||
f.setKind(description.ReplicaSetWithPrimary)
|
||||
} else {
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSFromPrimary(s description.Server) {
|
||||
if f.SetName == "" {
|
||||
f.SetName = s.SetName
|
||||
} else if f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
if s.SetVersion != 0 && !bytes.Equal(s.ElectionID[:], primitive.NilObjectID[:]) {
|
||||
if f.maxSetVersion > s.SetVersion || bytes.Compare(f.maxElectionID[:], s.ElectionID[:]) == 1 {
|
||||
f.replaceServer(description.Server{
|
||||
Addr: s.Addr,
|
||||
LastError: fmt.Errorf("was a primary, but its set version or election id is stale"),
|
||||
})
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
f.maxElectionID = s.ElectionID
|
||||
}
|
||||
|
||||
if s.SetVersion > f.maxSetVersion {
|
||||
f.maxSetVersion = s.SetVersion
|
||||
}
|
||||
|
||||
if j, ok := f.findPrimary(); ok {
|
||||
f.setServer(j, description.Server{
|
||||
Addr: f.Servers[j].Addr,
|
||||
LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
|
||||
})
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
|
||||
for j := len(f.Servers) - 1; j >= 0; j-- {
|
||||
found := false
|
||||
for _, member := range s.Members {
|
||||
if member == f.Servers[j].Addr {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
f.removeServer(j)
|
||||
}
|
||||
}
|
||||
|
||||
for _, member := range s.Members {
|
||||
if _, ok := f.findServer(member); !ok {
|
||||
f.addServer(member)
|
||||
}
|
||||
}
|
||||
|
||||
f.checkIfHasPrimary()
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
|
||||
if f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
if s.Addr != s.CanonicalAddr {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
f.checkIfHasPrimary()
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
|
||||
if _, ok := f.findPrimary(); !ok {
|
||||
f.setKind(description.ReplicaSetNoPrimary)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) updateRSWithoutPrimary(s description.Server) {
|
||||
if f.SetName == "" {
|
||||
f.SetName = s.SetName
|
||||
} else if f.SetName != s.SetName {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
for _, member := range s.Members {
|
||||
if _, ok := f.findServer(member); !ok {
|
||||
f.addServer(member)
|
||||
}
|
||||
}
|
||||
|
||||
if s.Addr != s.CanonicalAddr {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
func (f *fsm) updateUnknownWithStandalone(s description.Server) {
|
||||
if len(f.Servers) > 1 {
|
||||
f.removeServerByAddr(s.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
f.setKind(description.Single)
|
||||
f.replaceServer(s)
|
||||
}
|
||||
|
||||
func (f *fsm) addServer(addr address.Address) {
|
||||
f.Servers = append(f.Servers, description.Server{
|
||||
Addr: addr.Canonicalize(),
|
||||
})
|
||||
}
|
||||
|
||||
func (f *fsm) findPrimary() (int, bool) {
|
||||
for i, s := range f.Servers {
|
||||
if s.Kind == description.RSPrimary {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (f *fsm) findServer(addr address.Address) (int, bool) {
|
||||
canon := addr.Canonicalize()
|
||||
for i, s := range f.Servers {
|
||||
if canon == s.Addr {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (f *fsm) removeServer(i int) {
|
||||
f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
|
||||
}
|
||||
|
||||
func (f *fsm) removeServerByAddr(addr address.Address) {
|
||||
if i, ok := f.findServer(addr); ok {
|
||||
f.removeServer(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fsm) replaceServer(s description.Server) bool {
|
||||
if i, ok := f.findServer(s.Addr); ok {
|
||||
f.setServer(i, s)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (f *fsm) setServer(i int, s description.Server) {
|
||||
f.Servers[i] = s
|
||||
}
|
||||
|
||||
func (f *fsm) setKind(k description.TopologyKind) {
|
||||
f.Kind = k
|
||||
}
|
||||
213
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go
generated
vendored
Executable file
213
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/pool.go
generated
vendored
Executable file
@@ -0,0 +1,213 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
|
||||
)
|
||||
|
||||
// ErrPoolConnected is returned from an attempt to connect an already connected pool
|
||||
var ErrPoolConnected = PoolError("pool is connected")
|
||||
|
||||
// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected
|
||||
// or disconnecting pool.
|
||||
var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting")
|
||||
|
||||
// ErrConnectionClosed is returned from an attempt to use an already closed connection.
|
||||
var ErrConnectionClosed = ConnectionError{ConnectionID: "<closed>", message: "connection is closed"}
|
||||
|
||||
// ErrWrongPool is return when a connection is returned to a pool it doesn't belong to.
|
||||
var ErrWrongPool = PoolError("connection does not belong to this pool")
|
||||
|
||||
// PoolError is an error returned from a Pool method.
|
||||
type PoolError string
|
||||
|
||||
// pruneInterval is the interval at which the background routine to close expired connections will be run.
|
||||
var pruneInterval = time.Minute
|
||||
|
||||
func (pe PoolError) Error() string { return string(pe) }
|
||||
|
||||
type pool struct {
|
||||
address address.Address
|
||||
opts []ConnectionOption
|
||||
conns *resourcePool // pool for idle connections
|
||||
generation uint64
|
||||
|
||||
connected int32 // Must be accessed using the sync/atomic package.
|
||||
nextid uint64
|
||||
opened map[uint64]*connection // opened holds all of the currently open connections.
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func connectionExpiredFunc(v interface{}) bool {
|
||||
return v.(*connection).expired()
|
||||
}
|
||||
|
||||
func connectionCloseFunc(v interface{}) {
|
||||
c := v.(*connection)
|
||||
go c.pool.close(c)
|
||||
}
|
||||
|
||||
// newPool creates a new pool that will hold size number of idle connections. It will use the
|
||||
// provided options when creating connections.
|
||||
func newPool(addr address.Address, size uint64, opts ...ConnectionOption) *pool {
|
||||
return &pool{
|
||||
address: addr,
|
||||
conns: newResourcePool(size, connectionExpiredFunc, connectionCloseFunc, pruneInterval),
|
||||
generation: 0,
|
||||
connected: disconnected,
|
||||
opened: make(map[uint64]*connection),
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// drain lazily drains the pool by increasing the generation ID.
|
||||
func (p *pool) drain() { atomic.AddUint64(&p.generation, 1) }
|
||||
func (p *pool) expired(generation uint64) bool { return generation < atomic.LoadUint64(&p.generation) }
|
||||
|
||||
// connect puts the pool into the connected state, allowing it to be used.
|
||||
func (p *pool) connect() error {
|
||||
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
|
||||
return ErrPoolConnected
|
||||
}
|
||||
atomic.AddUint64(&p.generation, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pool) disconnect(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
|
||||
return ErrPoolDisconnected
|
||||
}
|
||||
|
||||
// We first clear out the idle connections, then we wait until the context's deadline is hit or
|
||||
// it's cancelled, after which we aggressively close the remaining open connections.
|
||||
for {
|
||||
connVal := p.conns.Get()
|
||||
if connVal == nil {
|
||||
break
|
||||
}
|
||||
|
||||
_ = p.close(connVal.(*connection))
|
||||
}
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
// If we have a deadline then we interpret it as a request to gracefully shutdown. We wait
|
||||
// until either all the connections have landed back in the pool (and have been closed) or
|
||||
// until the timer is done.
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
timer := time.NewTimer(time.Now().Sub(dl))
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-ticker.C: // Can we repalce this with an actual signal channel? We will know when p.inflight hits zero from the close method.
|
||||
p.Lock()
|
||||
if len(p.opened) > 0 {
|
||||
p.Unlock()
|
||||
continue
|
||||
}
|
||||
p.Unlock()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// We copy the remaining connections into a slice, then iterate it to close them. This allows us
|
||||
// to use a single function to actually clean up and close connections at the expense of a
|
||||
// double itertion in the worse case.
|
||||
p.Lock()
|
||||
toClose := make([]*connection, 0, len(p.opened))
|
||||
for _, pc := range p.opened {
|
||||
toClose = append(toClose, pc)
|
||||
}
|
||||
p.Unlock()
|
||||
for _, pc := range toClose {
|
||||
_ = p.close(pc) // We don't care about errors while closing the connection.
|
||||
}
|
||||
atomic.StoreInt32(&p.connected, disconnected)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pool) get(ctx context.Context) (*connection, error) {
|
||||
if atomic.LoadInt32(&p.connected) != connected {
|
||||
return nil, ErrPoolDisconnected
|
||||
}
|
||||
|
||||
// try to get an unexpired idle connection
|
||||
connVal := p.conns.Get()
|
||||
if connVal != nil {
|
||||
return connVal.(*connection), nil
|
||||
}
|
||||
|
||||
// couldn't find an unexpired connection. create a new one.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
c, err := newConnection(ctx, p.address, p.opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.pool = p
|
||||
c.poolID = atomic.AddUint64(&p.nextid, 1)
|
||||
c.generation = p.generation
|
||||
|
||||
if atomic.LoadInt32(&p.connected) != connected {
|
||||
_ = p.close(c) // The pool is disconnected or disconnecting, ignore the error from closing the connection.
|
||||
return nil, ErrPoolDisconnected
|
||||
}
|
||||
p.Lock()
|
||||
p.opened[c.poolID] = c
|
||||
p.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// close closes a connection, not the pool itself. This method will actually close the connection,
|
||||
// making it unusable, to instead return the connection to the pool, use put.
|
||||
func (p *pool) close(c *connection) error {
|
||||
if c.pool != p {
|
||||
return ErrWrongPool
|
||||
}
|
||||
p.Lock()
|
||||
delete(p.opened, c.poolID)
|
||||
nc := c.nc
|
||||
c.nc = nil
|
||||
p.Unlock()
|
||||
if nc == nil {
|
||||
return nil // We're closing an already closed connection.
|
||||
}
|
||||
err := nc.Close()
|
||||
if err != nil {
|
||||
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to close net.Conn"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// put returns a connection to this pool. If the pool is connected, the connection is not
|
||||
// expired, and there is space in the cache, the connection is returned to the cache.
|
||||
func (p *pool) put(c *connection) error {
|
||||
if c.pool != p {
|
||||
return ErrWrongPool
|
||||
}
|
||||
if atomic.LoadInt32(&p.connected) != connected || c.expired() {
|
||||
return p.close(c)
|
||||
}
|
||||
|
||||
// close the connection if the underlying pool is full
|
||||
if !p.conns.Put(c) {
|
||||
return p.close(c)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
121
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/resource_pool.go
generated
vendored
Executable file
121
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/resource_pool.go
generated
vendored
Executable file
@@ -0,0 +1,121 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// expiredFunc is the function type used for testing whether or not resources in a resourcePool have expired. It should
|
||||
// return true if the resource has expired and can be removed from the pool.
|
||||
type expiredFunc func(interface{}) bool
|
||||
|
||||
// closeFunc is the function type used to close resources in a resourcePool. The pool will always call this function
|
||||
// asynchronously.
|
||||
type closeFunc func(interface{})
|
||||
|
||||
// resourcePool is a concurrent resource pool that implements the behavior described in the sessions spec.
|
||||
type resourcePool struct {
|
||||
deque *list.List
|
||||
len, maxSize uint64
|
||||
expiredFn expiredFunc
|
||||
closeFn closeFunc
|
||||
pruneTimer *time.Timer
|
||||
pruneInterval time.Duration
|
||||
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
// NewResourcePool creates a new resourcePool instance that is capped to maxSize resources.
|
||||
// If maxSize is 0, the pool size will be unbounded.
|
||||
func newResourcePool(maxSize uint64, expiredFn expiredFunc, closeFn closeFunc, pruneInterval time.Duration) *resourcePool {
|
||||
rp := &resourcePool{
|
||||
deque: list.New(),
|
||||
maxSize: maxSize,
|
||||
expiredFn: expiredFn,
|
||||
closeFn: closeFn,
|
||||
pruneInterval: pruneInterval,
|
||||
}
|
||||
rp.Lock()
|
||||
rp.pruneTimer = time.AfterFunc(rp.pruneInterval, rp.Prune)
|
||||
rp.Unlock()
|
||||
return rp
|
||||
}
|
||||
|
||||
// Get returns the first un-expired resource from the pool. If no such resource can be found, nil is returned.
|
||||
func (rp *resourcePool) Get() interface{} {
|
||||
rp.Lock()
|
||||
defer rp.Unlock()
|
||||
|
||||
var next *list.Element
|
||||
for curr := rp.deque.Front(); curr != nil; curr = next {
|
||||
next = curr.Next()
|
||||
|
||||
// remove the current resource and return it if it is valid
|
||||
rp.deque.Remove(curr)
|
||||
rp.len--
|
||||
if !rp.expiredFn(curr.Value) {
|
||||
// found un-expired resource
|
||||
return curr.Value
|
||||
}
|
||||
|
||||
// close expired resources
|
||||
rp.closeFn(curr.Value)
|
||||
}
|
||||
|
||||
// did not find a valid resource
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put clears expired resources from the pool and then returns resource v to the pool if there is room. It returns true
|
||||
// if v was successfully added to the pool and false otherwise.
|
||||
func (rp *resourcePool) Put(v interface{}) bool {
|
||||
rp.Lock()
|
||||
defer rp.Unlock()
|
||||
|
||||
// close expired resources from the back of the pool
|
||||
rp.prune()
|
||||
if (rp.maxSize != 0 && rp.len == rp.maxSize) || rp.expiredFn(v) {
|
||||
return false
|
||||
}
|
||||
rp.deque.PushFront(v)
|
||||
rp.len++
|
||||
return true
|
||||
}
|
||||
|
||||
// Prune clears expired resources from the pool.
|
||||
func (rp *resourcePool) Prune() {
|
||||
rp.Lock()
|
||||
defer rp.Unlock()
|
||||
rp.prune()
|
||||
}
|
||||
|
||||
func (rp *resourcePool) prune() {
|
||||
// iterate over the list and stop at the first valid value
|
||||
var prev *list.Element
|
||||
for curr := rp.deque.Back(); curr != nil; curr = prev {
|
||||
prev = curr.Prev()
|
||||
if !rp.expiredFn(curr.Value) {
|
||||
// found unexpired resource
|
||||
break
|
||||
}
|
||||
|
||||
// remove and close expired resources
|
||||
rp.deque.Remove(curr)
|
||||
rp.closeFn(curr.Value)
|
||||
rp.len--
|
||||
}
|
||||
|
||||
// reset the timer for the background cleanup routine
|
||||
if !rp.pruneTimer.Stop() {
|
||||
rp.pruneTimer = time.AfterFunc(rp.pruneInterval, rp.Prune)
|
||||
return
|
||||
}
|
||||
rp.pruneTimer.Reset(rp.pruneInterval)
|
||||
}
|
||||
643
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go
generated
vendored
Executable file
643
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go
generated
vendored
Executable file
@@ -0,0 +1,643 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection"
|
||||
"go.mongodb.org/mongo-driver/x/network/result"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
const minHeartbeatInterval = 500 * time.Millisecond
|
||||
const connectionSemaphoreSize = math.MaxInt64
|
||||
|
||||
var isMasterOrRecoveringCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91}
|
||||
|
||||
// ErrServerClosed occurs when an attempt to get a connection is made after
|
||||
// the server has been closed.
|
||||
var ErrServerClosed = errors.New("server is closed")
|
||||
|
||||
// ErrServerConnected occurs when at attempt to connect is made after a server
|
||||
// has already been connected.
|
||||
var ErrServerConnected = errors.New("server is connected")
|
||||
|
||||
// SelectedServer represents a specific server that was selected during server selection.
|
||||
// It contains the kind of the topology it was selected from.
|
||||
type SelectedServer struct {
|
||||
*Server
|
||||
|
||||
Kind description.TopologyKind
|
||||
}
|
||||
|
||||
// Description returns a description of the server as of the last heartbeat.
|
||||
func (ss *SelectedServer) Description() description.SelectedServer {
|
||||
sdesc := ss.Server.Description()
|
||||
return description.SelectedServer{
|
||||
Server: sdesc,
|
||||
Kind: ss.Kind,
|
||||
}
|
||||
}
|
||||
|
||||
// These constants represent the connection states of a server.
|
||||
const (
|
||||
disconnected int32 = iota
|
||||
disconnecting
|
||||
connected
|
||||
connecting
|
||||
)
|
||||
|
||||
func connectionStateString(state int32) string {
|
||||
switch state {
|
||||
case 0:
|
||||
return "Disconnected"
|
||||
case 1:
|
||||
return "Disconnecting"
|
||||
case 2:
|
||||
return "Connected"
|
||||
case 3:
|
||||
return "Connecting"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Server is a single server within a topology.
|
||||
type Server struct {
|
||||
cfg *serverConfig
|
||||
address address.Address
|
||||
connectionstate int32
|
||||
|
||||
// connection related fields
|
||||
pool *pool
|
||||
sem *semaphore.Weighted
|
||||
|
||||
// goroutine management fields
|
||||
done chan struct{}
|
||||
checkNow chan struct{}
|
||||
closewg sync.WaitGroup
|
||||
|
||||
// description related fields
|
||||
desc atomic.Value // holds a description.Server
|
||||
updateTopologyCallback atomic.Value
|
||||
averageRTTSet bool
|
||||
averageRTT time.Duration
|
||||
|
||||
// subscriber related fields
|
||||
subLock sync.Mutex
|
||||
subscribers map[uint64]chan description.Server
|
||||
currentSubscriberID uint64
|
||||
subscriptionsClosed bool
|
||||
}
|
||||
|
||||
// ConnectServer creates a new Server and then initializes it using the
|
||||
// Connect method.
|
||||
func ConnectServer(addr address.Address, updateCallback func(description.Server), opts ...ServerOption) (*Server, error) {
|
||||
srvr, err := NewServer(addr, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = srvr.Connect(updateCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srvr, nil
|
||||
}
|
||||
|
||||
// NewServer creates a new server. The mongodb server at the address will be monitored
|
||||
// on an internal monitoring goroutine.
|
||||
func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
|
||||
cfg, err := newServerConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var maxConns = uint64(cfg.maxConns)
|
||||
if maxConns == 0 {
|
||||
maxConns = math.MaxInt64
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
address: addr,
|
||||
|
||||
sem: semaphore.NewWeighted(int64(maxConns)),
|
||||
|
||||
done: make(chan struct{}),
|
||||
checkNow: make(chan struct{}, 1),
|
||||
|
||||
subscribers: make(map[uint64]chan description.Server),
|
||||
}
|
||||
s.desc.Store(description.Server{Addr: addr})
|
||||
|
||||
callback := func(desc description.Server) { s.updateDescription(desc, false) }
|
||||
s.pool = newPool(addr, uint64(cfg.maxIdleConns), withServerDescriptionCallback(callback, cfg.connectionOpts...)...)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Connect initializes the Server by starting background monitoring goroutines.
|
||||
// This method must be called before a Server can be used.
|
||||
func (s *Server) Connect(updateCallback func(description.Server)) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
|
||||
return ErrServerConnected
|
||||
}
|
||||
s.desc.Store(description.Server{Addr: s.address})
|
||||
s.updateTopologyCallback.Store(updateCallback)
|
||||
go s.update()
|
||||
s.closewg.Add(1)
|
||||
return s.pool.connect()
|
||||
}
|
||||
|
||||
// Disconnect closes sockets to the server referenced by this Server.
|
||||
// Subscriptions to this Server will be closed. Disconnect will shutdown
|
||||
// any monitoring goroutines, close the idle connection pool, and will
|
||||
// wait until all the in use connections have been returned to the connection
|
||||
// pool and are closed before returning. If the context expires via
|
||||
// cancellation, deadline, or timeout before the in use connections have been
|
||||
// returned, the in use connections will be closed, resulting in the failure of
|
||||
// any in flight read or write operations. If this method returns with no
|
||||
// errors, all connections associated with this Server have been closed.
|
||||
func (s *Server) Disconnect(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
s.updateTopologyCallback.Store((func(description.Server))(nil))
|
||||
|
||||
// For every call to Connect there must be at least 1 goroutine that is
|
||||
// waiting on the done channel.
|
||||
s.done <- struct{}{}
|
||||
err := s.pool.disconnect(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.closewg.Wait()
|
||||
atomic.StoreInt32(&s.connectionstate, disconnected)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connection gets a connection to the server.
|
||||
func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
|
||||
if atomic.LoadInt32(&s.connectionstate) != connected {
|
||||
return nil, ErrServerClosed
|
||||
}
|
||||
|
||||
err := s.sem.Acquire(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := s.pool.get(ctx)
|
||||
if err != nil {
|
||||
s.sem.Release(1)
|
||||
connerr, ok := err.(ConnectionError)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Since the only kind of ConnectionError we receive from pool.get will be an initialization
|
||||
// error, we should set the description.Server appropriately.
|
||||
desc := description.Server{
|
||||
Kind: description.Unknown,
|
||||
LastError: connerr.Wrapped,
|
||||
}
|
||||
s.updateDescription(desc, false)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Connection{connection: conn, s: s}, nil
|
||||
}
|
||||
|
||||
// ConnectionLegacy gets a connection to the server.
|
||||
func (s *Server) ConnectionLegacy(ctx context.Context) (connectionlegacy.Connection, error) {
|
||||
if atomic.LoadInt32(&s.connectionstate) != connected {
|
||||
return nil, ErrServerClosed
|
||||
}
|
||||
|
||||
err := s.sem.Acquire(ctx, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := s.pool.get(ctx)
|
||||
if err != nil {
|
||||
s.sem.Release(1)
|
||||
connerr, ok := err.(ConnectionError)
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Since the only kind of ConnectionError we receive from pool.get will be an initialization
|
||||
// error, we should set the description.Server appropriately.
|
||||
desc := description.Server{
|
||||
Kind: description.Unknown,
|
||||
LastError: connerr.Wrapped,
|
||||
}
|
||||
s.updateDescription(desc, false)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
return newConnectionLegacy(conn, s, s.cfg.connectionOpts...)
|
||||
}
|
||||
|
||||
// Description returns a description of the server as of the last heartbeat.
|
||||
func (s *Server) Description() description.Server {
|
||||
return s.desc.Load().(description.Server)
|
||||
}
|
||||
|
||||
// SelectedDescription returns a description.SelectedServer with a Kind of
|
||||
// Single. This can be used when performing tasks like monitoring a batch
|
||||
// of servers and you want to run one off commands against those servers.
|
||||
func (s *Server) SelectedDescription() description.SelectedServer {
|
||||
sdesc := s.Description()
|
||||
return description.SelectedServer{
|
||||
Server: sdesc,
|
||||
Kind: description.Single,
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe returns a ServerSubscription which has a channel on which all
|
||||
// updated server descriptions will be sent. The channel will have a buffer
|
||||
// size of one, and will be pre-populated with the current description.
|
||||
func (s *Server) Subscribe() (*ServerSubscription, error) {
|
||||
if atomic.LoadInt32(&s.connectionstate) != connected {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
ch := make(chan description.Server, 1)
|
||||
ch <- s.desc.Load().(description.Server)
|
||||
|
||||
s.subLock.Lock()
|
||||
defer s.subLock.Unlock()
|
||||
if s.subscriptionsClosed {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
id := s.currentSubscriberID
|
||||
s.subscribers[id] = ch
|
||||
s.currentSubscriberID++
|
||||
|
||||
ss := &ServerSubscription{
|
||||
C: ch,
|
||||
s: s,
|
||||
id: id,
|
||||
}
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// RequestImmediateCheck will cause the server to send a heartbeat immediately
|
||||
// instead of waiting for the heartbeat timeout.
|
||||
func (s *Server) RequestImmediateCheck() {
|
||||
select {
|
||||
case s.checkNow <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
|
||||
func (s *Server) ProcessError(err error) {
|
||||
// Invalidate server description if not master or node recovering error occurs
|
||||
if cerr, ok := err.(driver.Error); ok && (cerr.NetworkError() || cerr.NodeIsRecovering() || cerr.NotMaster()) {
|
||||
desc := s.Description()
|
||||
desc.Kind = description.Unknown
|
||||
desc.LastError = err
|
||||
// updates description to unknown
|
||||
s.updateDescription(desc, false)
|
||||
s.RequestImmediateCheck()
|
||||
s.pool.drain()
|
||||
return
|
||||
}
|
||||
|
||||
ne, ok := err.(ConnectionError)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if netErr, ok := ne.Wrapped.(net.Error); ok && netErr.Timeout() {
|
||||
return
|
||||
}
|
||||
if ne.Wrapped == context.Canceled || ne.Wrapped == context.DeadlineExceeded {
|
||||
return
|
||||
}
|
||||
|
||||
desc := s.Description()
|
||||
desc.Kind = description.Unknown
|
||||
desc.LastError = err
|
||||
// updates description to unknown
|
||||
s.updateDescription(desc, false)
|
||||
}
|
||||
|
||||
// ProcessWriteConcernError checks if a WriteConcernError is an isNotMaster or
|
||||
// isRecovering error, and if so updates the server accordingly.
|
||||
func (s *Server) ProcessWriteConcernError(err *result.WriteConcernError) {
|
||||
if err == nil || !wceIsNotMasterOrRecovering(err) {
|
||||
return
|
||||
}
|
||||
desc := s.Description()
|
||||
desc.Kind = description.Unknown
|
||||
desc.LastError = err
|
||||
// updates description to unknown
|
||||
s.updateDescription(desc, false)
|
||||
s.RequestImmediateCheck()
|
||||
}
|
||||
|
||||
func wceIsNotMasterOrRecovering(wce *result.WriteConcernError) bool {
|
||||
for _, code := range isMasterOrRecoveringCodes {
|
||||
if int32(wce.Code) == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.Contains(wce.ErrMsg, "not master") || strings.Contains(wce.ErrMsg, "node is recovering") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// update handles performing heartbeats and updating any subscribers of the
|
||||
// newest description.Server retrieved.
|
||||
func (s *Server) update() {
|
||||
defer s.closewg.Done()
|
||||
heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
|
||||
rateLimiter := time.NewTicker(minHeartbeatInterval)
|
||||
defer heartbeatTicker.Stop()
|
||||
defer rateLimiter.Stop()
|
||||
checkNow := s.checkNow
|
||||
done := s.done
|
||||
|
||||
var doneOnce bool
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if doneOnce {
|
||||
return
|
||||
}
|
||||
// We keep this goroutine alive attempting to read from the done channel.
|
||||
<-done
|
||||
}
|
||||
}()
|
||||
|
||||
var conn *connection
|
||||
var desc description.Server
|
||||
|
||||
desc, conn = s.heartbeat(nil)
|
||||
s.updateDescription(desc, true)
|
||||
|
||||
closeServer := func() {
|
||||
doneOnce = true
|
||||
s.subLock.Lock()
|
||||
for id, c := range s.subscribers {
|
||||
close(c)
|
||||
delete(s.subscribers, id)
|
||||
}
|
||||
s.subscriptionsClosed = true
|
||||
s.subLock.Unlock()
|
||||
if conn == nil || conn.nc == nil {
|
||||
return
|
||||
}
|
||||
conn.nc.Close()
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-heartbeatTicker.C:
|
||||
case <-checkNow:
|
||||
case <-done:
|
||||
closeServer()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-rateLimiter.C:
|
||||
case <-done:
|
||||
closeServer()
|
||||
return
|
||||
}
|
||||
|
||||
desc, conn = s.heartbeat(conn)
|
||||
s.updateDescription(desc, false)
|
||||
}
|
||||
}
|
||||
|
||||
// updateDescription handles updating the description on the Server, notifying
|
||||
// subscribers, and potentially draining the connection pool. The initial
|
||||
// parameter is used to determine if this is the first description from the
|
||||
// server.
|
||||
func (s *Server) updateDescription(desc description.Server, initial bool) {
|
||||
defer func() {
|
||||
// ¯\_(ツ)_/¯
|
||||
_ = recover()
|
||||
}()
|
||||
s.desc.Store(desc)
|
||||
|
||||
callback, ok := s.updateTopologyCallback.Load().(func(description.Server))
|
||||
if ok && callback != nil {
|
||||
callback(desc)
|
||||
}
|
||||
|
||||
s.subLock.Lock()
|
||||
for _, c := range s.subscribers {
|
||||
select {
|
||||
// drain the channel if it isn't empty
|
||||
case <-c:
|
||||
default:
|
||||
}
|
||||
c <- desc
|
||||
}
|
||||
s.subLock.Unlock()
|
||||
|
||||
if initial {
|
||||
// We don't clear the pool on the first update on the description.
|
||||
return
|
||||
}
|
||||
|
||||
switch desc.Kind {
|
||||
case description.Unknown:
|
||||
s.pool.drain()
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat sends a heartbeat to the server using the given connection. The connection can be nil.
|
||||
func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
|
||||
const maxRetry = 2
|
||||
var saved error
|
||||
var desc description.Server
|
||||
var set bool
|
||||
var err error
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 1; i <= maxRetry; i++ {
|
||||
var now time.Time
|
||||
var descPtr *description.Server
|
||||
|
||||
if conn != nil && conn.expired() {
|
||||
if conn.nc != nil {
|
||||
conn.nc.Close()
|
||||
}
|
||||
conn = nil
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
opts := []ConnectionOption{
|
||||
WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
|
||||
}
|
||||
opts = append(opts, s.cfg.connectionOpts...)
|
||||
// We override whatever handshaker is currently attached to the options with a basic
|
||||
// one because need to make sure we don't do auth.
|
||||
opts = append(opts, WithHandshaker(func(h Handshaker) Handshaker {
|
||||
now = time.Now()
|
||||
return operation.NewIsMaster().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts)
|
||||
}))
|
||||
|
||||
// Override any command monitors specified in options with nil to avoid monitoring heartbeats.
|
||||
opts = append(opts, WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
|
||||
return nil
|
||||
}))
|
||||
|
||||
conn, err = newConnection(ctx, s.address, opts...)
|
||||
if err == nil {
|
||||
descPtr = &conn.desc
|
||||
}
|
||||
}
|
||||
|
||||
// do a heartbeat because a new connection wasn't created so a handshake was not performed
|
||||
if descPtr == nil && err == nil {
|
||||
now = time.Now()
|
||||
op := operation.
|
||||
NewIsMaster().
|
||||
ClusterClock(s.cfg.clock).
|
||||
Deployment(driver.SingleConnectionDeployment{initConnection{conn}})
|
||||
err = op.Execute(ctx)
|
||||
if err == nil {
|
||||
tmpDesc := op.Result(s.address)
|
||||
descPtr = &tmpDesc
|
||||
}
|
||||
}
|
||||
|
||||
// we do a retry if the server is connected, if succeed return new server desc (see below)
|
||||
if err != nil {
|
||||
saved = err
|
||||
if conn != nil && conn.nc != nil {
|
||||
conn.nc.Close()
|
||||
}
|
||||
conn = nil
|
||||
if _, ok := err.(ConnectionError); ok {
|
||||
s.pool.drain()
|
||||
// If the server is not connected, give up and exit loop
|
||||
if s.Description().Kind == description.Unknown {
|
||||
break
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
desc = *descPtr
|
||||
delay := time.Since(now)
|
||||
desc = desc.SetAverageRTT(s.updateAverageRTT(delay))
|
||||
desc.HeartbeatInterval = s.cfg.heartbeatInterval
|
||||
set = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if !set {
|
||||
desc = description.Server{
|
||||
Addr: s.address,
|
||||
LastError: saved,
|
||||
Kind: description.Unknown,
|
||||
}
|
||||
}
|
||||
|
||||
return desc, conn
|
||||
}
|
||||
|
||||
func (s *Server) updateAverageRTT(delay time.Duration) time.Duration {
|
||||
if !s.averageRTTSet {
|
||||
s.averageRTT = delay
|
||||
} else {
|
||||
alpha := 0.2
|
||||
s.averageRTT = time.Duration(alpha*float64(delay) + (1-alpha)*float64(s.averageRTT))
|
||||
}
|
||||
return s.averageRTT
|
||||
}
|
||||
|
||||
// Drain will drain the connection pool of this server. This is mainly here so the
|
||||
// pool for the server doesn't need to be directly exposed and so that when an error
|
||||
// is returned from reading or writing, a client can drain the pool for this server.
|
||||
// This is exposed here so we don't have to wrap the Connection type and sniff responses
|
||||
// for errors that would cause the pool to be drained, which can in turn centralize the
|
||||
// logic for handling errors in the Client type.
|
||||
//
|
||||
// TODO(GODRIVER-617): I don't think we actually need this method. It's likely replaced by
|
||||
// ProcessError.
|
||||
func (s *Server) Drain() error {
|
||||
s.pool.drain()
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (s *Server) String() string {
|
||||
desc := s.Description()
|
||||
connState := atomic.LoadInt32(&s.connectionstate)
|
||||
str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
|
||||
s.address, desc.Kind, connectionStateString(connState))
|
||||
if len(desc.Tags) != 0 {
|
||||
str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
|
||||
}
|
||||
if connState == connected {
|
||||
str += fmt.Sprintf(", Average RTT: %d", desc.AverageRTT)
|
||||
}
|
||||
if desc.LastError != nil {
|
||||
str += fmt.Sprintf(", Last error: %s", desc.LastError)
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// ServerSubscription represents a subscription to the description.Server updates for
|
||||
// a specific server.
|
||||
type ServerSubscription struct {
|
||||
C <-chan description.Server
|
||||
s *Server
|
||||
id uint64
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
|
||||
// subscription channel.
|
||||
func (ss *ServerSubscription) Unsubscribe() error {
|
||||
ss.s.subLock.Lock()
|
||||
defer ss.s.subLock.Unlock()
|
||||
if ss.s.subscriptionsClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, ok := ss.s.subscribers[ss.id]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(ss.s.subscribers, ss.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
120
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go
generated
vendored
Executable file
120
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_options.go
generated
vendored
Executable file
@@ -0,0 +1,120 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
var defaultRegistry = bson.NewRegistryBuilder().Build()
|
||||
|
||||
type serverConfig struct {
|
||||
clock *session.ClusterClock
|
||||
compressionOpts []string
|
||||
connectionOpts []ConnectionOption
|
||||
appname string
|
||||
heartbeatInterval time.Duration
|
||||
heartbeatTimeout time.Duration
|
||||
maxConns uint16
|
||||
maxIdleConns uint16
|
||||
registry *bsoncodec.Registry
|
||||
}
|
||||
|
||||
func newServerConfig(opts ...ServerOption) (*serverConfig, error) {
|
||||
cfg := &serverConfig{
|
||||
heartbeatInterval: 10 * time.Second,
|
||||
heartbeatTimeout: 10 * time.Second,
|
||||
maxConns: 100,
|
||||
maxIdleConns: 100,
|
||||
registry: defaultRegistry,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ServerOption configures a server.
|
||||
type ServerOption func(*serverConfig) error
|
||||
|
||||
// WithConnectionOptions configures the server's connections.
|
||||
func WithConnectionOptions(fn func(...ConnectionOption) []ConnectionOption) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.connectionOpts = fn(cfg.connectionOpts...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithCompressionOptions configures the server's compressors.
|
||||
func WithCompressionOptions(fn func(...string) []string) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.compressionOpts = fn(cfg.compressionOpts...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeartbeatInterval configures a server's heartbeat interval.
|
||||
func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.heartbeatInterval = fn(cfg.heartbeatInterval)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
|
||||
// connection.
|
||||
func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxConnections configures the maximum number of connections to allow for
|
||||
// a given server. If max is 0, then there is no upper limit to the number of
|
||||
// connections.
|
||||
func WithMaxConnections(fn func(uint16) uint16) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.maxConns = fn(cfg.maxConns)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxIdleConnections configures the maximum number of idle connections
|
||||
// allowed for the server.
|
||||
func WithMaxIdleConnections(fn func(uint16) uint16) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.maxIdleConns = fn(cfg.maxIdleConns)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithClock configures the ClusterClock for the server to use.
|
||||
func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.clock = fn(cfg.clock)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRegistry configures the registry for the server to use when creating
|
||||
// cursors.
|
||||
func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption {
|
||||
return func(cfg *serverConfig) error {
|
||||
cfg.registry = fn(cfg.registry)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
624
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go
generated
vendored
Executable file
624
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go
generated
vendored
Executable file
@@ -0,0 +1,624 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package topology contains types that handles the discovery, monitoring, and selection
|
||||
// of servers. This package is designed to expose enough inner workings of service discovery
|
||||
// and monitoring to allow low level applications to have fine grained control, while hiding
|
||||
// most of the detailed implementation of the algorithms.
|
||||
package topology // import "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
// ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
|
||||
// closed Server or Topology.
|
||||
var ErrSubscribeAfterClosed = errors.New("cannot subscribe after close")
|
||||
|
||||
// ErrTopologyClosed is returned when a user attempts to call a method on a
|
||||
// closed Topology.
|
||||
var ErrTopologyClosed = errors.New("topology is closed")
|
||||
|
||||
// ErrTopologyConnected is returned whena user attempts to connect to an
|
||||
// already connected Topology.
|
||||
var ErrTopologyConnected = errors.New("topology is connected or connecting")
|
||||
|
||||
// ErrServerSelectionTimeout is returned from server selection when the server
|
||||
// selection process took longer than allowed by the timeout.
|
||||
var ErrServerSelectionTimeout = errors.New("server selection timeout")
|
||||
|
||||
// MonitorMode represents the way in which a server is monitored.
|
||||
type MonitorMode uint8
|
||||
|
||||
// These constants are the available monitoring modes.
|
||||
const (
|
||||
AutomaticMode MonitorMode = iota
|
||||
SingleMode
|
||||
)
|
||||
|
||||
// Topology represents a MongoDB deployment.
|
||||
type Topology struct {
|
||||
connectionstate int32
|
||||
|
||||
cfg *config
|
||||
|
||||
desc atomic.Value // holds a description.Topology
|
||||
|
||||
dnsResolver *dns.Resolver
|
||||
|
||||
done chan struct{}
|
||||
|
||||
pollingDone chan struct{}
|
||||
pollingwg sync.WaitGroup
|
||||
rescanSRVInterval time.Duration
|
||||
pollHeartbeatTime atomic.Value // holds a bool
|
||||
|
||||
fsm *fsm
|
||||
|
||||
SessionPool *session.Pool
|
||||
|
||||
// This should really be encapsulated into it's own type. This will likely
|
||||
// require a redesign so we can share a minimum of data between the
|
||||
// subscribers and the topology.
|
||||
subscribers map[uint64]chan description.Topology
|
||||
currentSubscriberID uint64
|
||||
subscriptionsClosed bool
|
||||
subLock sync.Mutex
|
||||
|
||||
// We should redesign how we connect and handle individal servers. This is
|
||||
// too difficult to maintain and it's rather easy to accidentally access
|
||||
// the servers without acquiring the lock or checking if the servers are
|
||||
// closed. This lock should also be an RWMutex.
|
||||
serversLock sync.Mutex
|
||||
serversClosed bool
|
||||
servers map[address.Address]*Server
|
||||
}
|
||||
|
||||
// New creates a new topology.
|
||||
func New(opts ...Option) (*Topology, error) {
|
||||
cfg, err := newConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Topology{
|
||||
cfg: cfg,
|
||||
done: make(chan struct{}),
|
||||
pollingDone: make(chan struct{}),
|
||||
rescanSRVInterval: 60 * time.Second,
|
||||
fsm: newFSM(),
|
||||
subscribers: make(map[uint64]chan description.Topology),
|
||||
servers: make(map[address.Address]*Server),
|
||||
dnsResolver: dns.DefaultResolver,
|
||||
}
|
||||
t.desc.Store(description.Topology{})
|
||||
|
||||
if cfg.replicaSetName != "" {
|
||||
t.fsm.SetName = cfg.replicaSetName
|
||||
t.fsm.Kind = description.ReplicaSetNoPrimary
|
||||
}
|
||||
|
||||
if cfg.mode == SingleMode {
|
||||
t.fsm.Kind = description.Single
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Connect initializes a Topology and starts the monitoring process. This function
|
||||
// must be called to properly monitor the topology.
|
||||
func (t *Topology) Connect() error {
|
||||
if !atomic.CompareAndSwapInt32(&t.connectionstate, disconnected, connecting) {
|
||||
return ErrTopologyConnected
|
||||
}
|
||||
|
||||
t.desc.Store(description.Topology{})
|
||||
var err error
|
||||
t.serversLock.Lock()
|
||||
for _, a := range t.cfg.seedList {
|
||||
addr := address.Address(a).Canonicalize()
|
||||
t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr})
|
||||
err = t.addServer(addr)
|
||||
}
|
||||
t.serversLock.Unlock()
|
||||
|
||||
if srvPollingRequired(t.cfg.cs.Original) {
|
||||
go t.pollSRVRecords()
|
||||
t.pollingwg.Add(1)
|
||||
}
|
||||
|
||||
t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
|
||||
|
||||
atomic.StoreInt32(&t.connectionstate, connected)
|
||||
|
||||
// After connection, make a subscription to keep the pool updated
|
||||
sub, err := t.Subscribe()
|
||||
t.SessionPool = session.NewPool(sub.C)
|
||||
return err
|
||||
}
|
||||
|
||||
// Disconnect closes the topology. It stops the monitoring thread and
|
||||
// closes all open subscriptions.
|
||||
func (t *Topology) Disconnect(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&t.connectionstate, connected, disconnecting) {
|
||||
return ErrTopologyClosed
|
||||
}
|
||||
|
||||
servers := make(map[address.Address]*Server)
|
||||
t.serversLock.Lock()
|
||||
t.serversClosed = true
|
||||
for addr, server := range t.servers {
|
||||
servers[addr] = server
|
||||
}
|
||||
t.serversLock.Unlock()
|
||||
|
||||
for _, server := range servers {
|
||||
_ = server.Disconnect(ctx)
|
||||
}
|
||||
|
||||
t.subLock.Lock()
|
||||
for id, ch := range t.subscribers {
|
||||
close(ch)
|
||||
delete(t.subscribers, id)
|
||||
}
|
||||
t.subscriptionsClosed = true
|
||||
t.subLock.Unlock()
|
||||
|
||||
if srvPollingRequired(t.cfg.cs.Original) {
|
||||
t.pollingDone <- struct{}{}
|
||||
t.pollingwg.Wait()
|
||||
}
|
||||
|
||||
t.desc.Store(description.Topology{})
|
||||
|
||||
atomic.StoreInt32(&t.connectionstate, disconnected)
|
||||
return nil
|
||||
}
|
||||
|
||||
func srvPollingRequired(connstr string) bool {
|
||||
return strings.HasPrefix(connstr, "mongodb+srv://")
|
||||
}
|
||||
|
||||
// Description returns a description of the topology.
|
||||
func (t *Topology) Description() description.Topology {
|
||||
td, ok := t.desc.Load().(description.Topology)
|
||||
if !ok {
|
||||
td = description.Topology{}
|
||||
}
|
||||
return td
|
||||
}
|
||||
|
||||
// Kind returns the topology kind of this Topology.
|
||||
func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind }
|
||||
|
||||
// Subscribe returns a Subscription on which all updated description.Topologys
|
||||
// will be sent. The channel of the subscription will have a buffer size of one,
|
||||
// and will be pre-populated with the current description.Topology.
|
||||
func (t *Topology) Subscribe() (*Subscription, error) {
|
||||
if atomic.LoadInt32(&t.connectionstate) != connected {
|
||||
return nil, errors.New("cannot subscribe to Topology that is not connected")
|
||||
}
|
||||
ch := make(chan description.Topology, 1)
|
||||
td, ok := t.desc.Load().(description.Topology)
|
||||
if !ok {
|
||||
td = description.Topology{}
|
||||
}
|
||||
ch <- td
|
||||
|
||||
t.subLock.Lock()
|
||||
defer t.subLock.Unlock()
|
||||
if t.subscriptionsClosed {
|
||||
return nil, ErrSubscribeAfterClosed
|
||||
}
|
||||
id := t.currentSubscriberID
|
||||
t.subscribers[id] = ch
|
||||
t.currentSubscriberID++
|
||||
|
||||
return &Subscription{
|
||||
C: ch,
|
||||
t: t,
|
||||
id: id,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequestImmediateCheck will send heartbeats to all the servers in the
|
||||
// topology right away, instead of waiting for the heartbeat timeout.
|
||||
func (t *Topology) RequestImmediateCheck() {
|
||||
if atomic.LoadInt32(&t.connectionstate) != connected {
|
||||
return
|
||||
}
|
||||
t.serversLock.Lock()
|
||||
for _, server := range t.servers {
|
||||
server.RequestImmediateCheck()
|
||||
}
|
||||
t.serversLock.Unlock()
|
||||
}
|
||||
|
||||
// SupportsSessions returns true if the topology supports sessions.
|
||||
func (t *Topology) SupportsSessions() bool {
|
||||
return t.Description().SessionTimeoutMinutes != 0 && t.Description().Kind != description.Single
|
||||
}
|
||||
|
||||
// SupportsRetry returns true if the topology supports retryability, which it does if it supports sessions.
|
||||
func (t *Topology) SupportsRetry() bool { return t.SupportsSessions() }
|
||||
|
||||
// SelectServer selects a server with given a selector. SelectServer complies with the
|
||||
// server selection spec, and will time out after severSelectionTimeout or when the
|
||||
// parent context is done.
|
||||
func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) {
|
||||
if atomic.LoadInt32(&t.connectionstate) != connected {
|
||||
return nil, ErrTopologyClosed
|
||||
}
|
||||
var ssTimeoutCh <-chan time.Time
|
||||
|
||||
if t.cfg.serverSelectionTimeout > 0 {
|
||||
ssTimeout := time.NewTimer(t.cfg.serverSelectionTimeout)
|
||||
ssTimeoutCh = ssTimeout.C
|
||||
defer ssTimeout.Stop()
|
||||
}
|
||||
|
||||
sub, err := t.Subscribe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
for {
|
||||
suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selected := suitable[rand.Intn(len(suitable))]
|
||||
selectedS, err := t.FindServer(selected)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
case selectedS != nil:
|
||||
return selectedS, nil
|
||||
default:
|
||||
// We don't have an actual server for the provided description.
|
||||
// This could happen for a number of reasons, including that the
|
||||
// server has since stopped being a part of this topology, or that
|
||||
// the server selector returned no suitable servers.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SelectServerLegacy selects a server with given a selector. SelectServerLegacy complies with the
|
||||
// server selection spec, and will time out after severSelectionTimeout or when the
|
||||
// parent context is done.
|
||||
func (t *Topology) SelectServerLegacy(ctx context.Context, ss description.ServerSelector) (*SelectedServer, error) {
|
||||
if atomic.LoadInt32(&t.connectionstate) != connected {
|
||||
return nil, ErrTopologyClosed
|
||||
}
|
||||
var ssTimeoutCh <-chan time.Time
|
||||
|
||||
if t.cfg.serverSelectionTimeout > 0 {
|
||||
ssTimeout := time.NewTimer(t.cfg.serverSelectionTimeout)
|
||||
ssTimeoutCh = ssTimeout.C
|
||||
defer ssTimeout.Stop()
|
||||
}
|
||||
|
||||
sub, err := t.Subscribe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
for {
|
||||
suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selected := suitable[rand.Intn(len(suitable))]
|
||||
selectedS, err := t.FindServer(selected)
|
||||
switch {
|
||||
case err != nil:
|
||||
return nil, err
|
||||
case selectedS != nil:
|
||||
return selectedS, nil
|
||||
default:
|
||||
// We don't have an actual server for the provided description.
|
||||
// This could happen for a number of reasons, including that the
|
||||
// server has since stopped being a part of this topology, or that
|
||||
// the server selector returned no suitable servers.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FindServer will attempt to find a server that fits the given server description.
|
||||
// This method will return nil, nil if a matching server could not be found.
|
||||
func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
|
||||
if atomic.LoadInt32(&t.connectionstate) != connected {
|
||||
return nil, ErrTopologyClosed
|
||||
}
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
server, ok := t.servers[selected.Addr]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
desc := t.Description()
|
||||
return &SelectedServer{
|
||||
Server: server,
|
||||
Kind: desc.Kind,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func wrapServerSelectionError(err error, t *Topology) error {
|
||||
return fmt.Errorf("server selection error: %v\ncurrent topology: %s", err, t.String())
|
||||
}
|
||||
|
||||
// selectServer is the core piece of server selection. It handles getting
|
||||
// topology descriptions and running sever selection on those descriptions.
|
||||
func (t *Topology) selectServer(ctx context.Context, subscriptionCh <-chan description.Topology, ss description.ServerSelector, timeoutCh <-chan time.Time) ([]description.Server, error) {
|
||||
var current description.Topology
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timeoutCh:
|
||||
return nil, wrapServerSelectionError(ErrServerSelectionTimeout, t)
|
||||
case current = <-subscriptionCh:
|
||||
}
|
||||
|
||||
var allowed []description.Server
|
||||
for _, s := range current.Servers {
|
||||
if s.Kind != description.Unknown {
|
||||
allowed = append(allowed, s)
|
||||
}
|
||||
}
|
||||
|
||||
suitable, err := ss.SelectServer(current, allowed)
|
||||
if err != nil {
|
||||
return nil, wrapServerSelectionError(err, t)
|
||||
}
|
||||
|
||||
if len(suitable) > 0 {
|
||||
return suitable, nil
|
||||
}
|
||||
|
||||
t.RequestImmediateCheck()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Topology) pollSRVRecords() {
|
||||
defer t.pollingwg.Done()
|
||||
|
||||
serverConfig, _ := newServerConfig(t.cfg.serverOpts...)
|
||||
heartbeatInterval := serverConfig.heartbeatInterval
|
||||
|
||||
pollTicker := time.NewTicker(t.rescanSRVInterval)
|
||||
defer pollTicker.Stop()
|
||||
t.pollHeartbeatTime.Store(false)
|
||||
var doneOnce bool
|
||||
defer func() {
|
||||
// ¯\_(ツ)_/¯
|
||||
if r := recover(); r != nil && !doneOnce {
|
||||
<-t.pollingDone
|
||||
}
|
||||
}()
|
||||
|
||||
// remove the scheme
|
||||
uri := t.cfg.cs.Original[14:]
|
||||
hosts := uri
|
||||
if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
|
||||
hosts = uri[:idx]
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pollTicker.C:
|
||||
case <-t.pollingDone:
|
||||
doneOnce = true
|
||||
return
|
||||
}
|
||||
topoKind := t.Description().Kind
|
||||
if !(topoKind == description.Unknown || topoKind == description.Sharded) {
|
||||
break
|
||||
}
|
||||
|
||||
parsedHosts, err := t.dnsResolver.ParseHosts(hosts, false)
|
||||
// DNS problem or no verified hosts returned
|
||||
if err != nil || len(parsedHosts) == 0 {
|
||||
if !t.pollHeartbeatTime.Load().(bool) {
|
||||
pollTicker.Stop()
|
||||
pollTicker = time.NewTicker(heartbeatInterval)
|
||||
t.pollHeartbeatTime.Store(true)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if t.pollHeartbeatTime.Load().(bool) {
|
||||
pollTicker.Stop()
|
||||
pollTicker = time.NewTicker(t.rescanSRVInterval)
|
||||
t.pollHeartbeatTime.Store(false)
|
||||
}
|
||||
|
||||
cont := t.processSRVResults(parsedHosts)
|
||||
if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
<-t.pollingDone
|
||||
doneOnce = true
|
||||
}
|
||||
|
||||
func (t *Topology) processSRVResults(parsedHosts []string) bool {
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
|
||||
if t.serversClosed {
|
||||
return false
|
||||
}
|
||||
diff := t.fsm.Topology.DiffHostlist(parsedHosts)
|
||||
|
||||
if len(diff.Added) == 0 && len(diff.Removed) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, r := range diff.Removed {
|
||||
addr := address.Address(r).Canonicalize()
|
||||
s, ok := t.servers[addr]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_ = s.Disconnect(cancelCtx)
|
||||
}()
|
||||
delete(t.servers, addr)
|
||||
t.fsm.removeServerByAddr(addr)
|
||||
}
|
||||
for _, a := range diff.Added {
|
||||
addr := address.Address(a).Canonicalize()
|
||||
_ = t.addServer(addr)
|
||||
t.fsm.addServer(addr)
|
||||
}
|
||||
//store new description
|
||||
newDesc := description.Topology{
|
||||
Kind: t.fsm.Kind,
|
||||
Servers: t.fsm.Servers,
|
||||
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
|
||||
}
|
||||
t.desc.Store(newDesc)
|
||||
|
||||
t.subLock.Lock()
|
||||
for _, ch := range t.subscribers {
|
||||
// We drain the description if there's one in the channel
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
ch <- newDesc
|
||||
}
|
||||
t.subLock.Unlock()
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func (t *Topology) apply(ctx context.Context, desc description.Server) {
|
||||
var err error
|
||||
|
||||
t.serversLock.Lock()
|
||||
defer t.serversLock.Unlock()
|
||||
|
||||
if _, ok := t.servers[desc.Addr]; t.serversClosed || !ok {
|
||||
return
|
||||
}
|
||||
|
||||
prev := t.fsm.Topology
|
||||
|
||||
current, err := t.fsm.apply(desc)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
diff := description.DiffTopology(prev, current)
|
||||
|
||||
for _, removed := range diff.Removed {
|
||||
if s, ok := t.servers[removed.Addr]; ok {
|
||||
go func() {
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
_ = s.Disconnect(cancelCtx)
|
||||
}()
|
||||
delete(t.servers, removed.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
for _, added := range diff.Added {
|
||||
_ = t.addServer(added.Addr)
|
||||
}
|
||||
|
||||
t.desc.Store(current)
|
||||
|
||||
t.subLock.Lock()
|
||||
for _, ch := range t.subscribers {
|
||||
// We drain the description if there's one in the channel
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
ch <- current
|
||||
}
|
||||
t.subLock.Unlock()
|
||||
|
||||
}
|
||||
|
||||
func (t *Topology) addServer(addr address.Address) error {
|
||||
if _, ok := t.servers[addr]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
topoFunc := func(desc description.Server) {
|
||||
t.apply(context.TODO(), desc)
|
||||
}
|
||||
svr, err := ConnectServer(addr, topoFunc, t.cfg.serverOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.servers[addr] = svr
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements the Stringer interface
|
||||
func (t *Topology) String() string {
|
||||
desc := t.Description()
|
||||
str := fmt.Sprintf("Type: %s\nServers:\n", desc.Kind)
|
||||
for _, s := range t.servers {
|
||||
str += s.String() + "\n"
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Subscription is a subscription to updates to the description of the Topology that created this
|
||||
// Subscription.
|
||||
type Subscription struct {
|
||||
C <-chan description.Topology
|
||||
t *Topology
|
||||
id uint64
|
||||
}
|
||||
|
||||
// Unsubscribe unsubscribes this Subscription from updates and closes the
|
||||
// subscription channel.
|
||||
func (s *Subscription) Unsubscribe() error {
|
||||
s.t.subLock.Lock()
|
||||
defer s.t.subLock.Unlock()
|
||||
if s.t.subscriptionsClosed {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, ok := s.t.subscribers[s.id]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(ch)
|
||||
delete(s.t.subscribers, s.id)
|
||||
|
||||
return nil
|
||||
}
|
||||
386
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go
generated
vendored
Executable file
386
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options.go
generated
vendored
Executable file
@@ -0,0 +1,386 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
// Option is a configuration option for a topology.
|
||||
type Option func(*config) error
|
||||
|
||||
type config struct {
|
||||
mode MonitorMode
|
||||
replicaSetName string
|
||||
seedList []string
|
||||
serverOpts []ServerOption
|
||||
cs connstring.ConnString
|
||||
serverSelectionTimeout time.Duration
|
||||
}
|
||||
|
||||
func newConfig(opts ...Option) (*config, error) {
|
||||
cfg := &config{
|
||||
seedList: []string{"localhost:27017"},
|
||||
serverSelectionTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// WithConnString configures the topology using the connection string.
|
||||
func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option {
|
||||
return func(c *config) error {
|
||||
cs := fn(c.cs)
|
||||
c.cs = cs
|
||||
|
||||
if cs.ServerSelectionTimeoutSet {
|
||||
c.serverSelectionTimeout = cs.ServerSelectionTimeout
|
||||
}
|
||||
|
||||
var connOpts []ConnectionOption
|
||||
|
||||
if cs.AppName != "" {
|
||||
connOpts = append(connOpts, WithAppName(func(string) string { return cs.AppName }))
|
||||
}
|
||||
|
||||
switch cs.Connect {
|
||||
case connstring.SingleConnect:
|
||||
c.mode = SingleMode
|
||||
}
|
||||
|
||||
c.seedList = cs.Hosts
|
||||
|
||||
if cs.ConnectTimeout > 0 {
|
||||
c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
|
||||
connOpts = append(connOpts, WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
|
||||
}
|
||||
|
||||
if cs.SocketTimeoutSet {
|
||||
connOpts = append(
|
||||
connOpts,
|
||||
WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
|
||||
WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
|
||||
)
|
||||
}
|
||||
|
||||
if cs.HeartbeatInterval > 0 {
|
||||
c.serverOpts = append(c.serverOpts, WithHeartbeatInterval(func(time.Duration) time.Duration { return cs.HeartbeatInterval }))
|
||||
}
|
||||
|
||||
if cs.MaxConnIdleTime > 0 {
|
||||
connOpts = append(connOpts, WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime }))
|
||||
}
|
||||
|
||||
if cs.MaxPoolSizeSet {
|
||||
c.serverOpts = append(c.serverOpts, WithMaxConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
|
||||
c.serverOpts = append(c.serverOpts, WithMaxIdleConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
|
||||
}
|
||||
|
||||
if cs.ReplicaSet != "" {
|
||||
c.replicaSetName = cs.ReplicaSet
|
||||
}
|
||||
|
||||
var x509Username string
|
||||
if cs.SSL {
|
||||
tlsConfig := new(tls.Config)
|
||||
|
||||
if cs.SSLCaFileSet {
|
||||
err := addCACertFromFile(tlsConfig, cs.SSLCaFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if cs.SSLInsecure {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
if cs.SSLClientCertificateKeyFileSet {
|
||||
var keyPasswd string
|
||||
if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
|
||||
keyPasswd = cs.SSLClientCertificateKeyPassword()
|
||||
}
|
||||
s, err := addClientCertFromFile(tlsConfig, cs.SSLClientCertificateKeyFile, keyPasswd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The Go x509 package gives the subject with the pairs in reverse order that we want.
|
||||
pairs := strings.Split(s, ",")
|
||||
b := bytes.NewBufferString("")
|
||||
|
||||
for i := len(pairs) - 1; i >= 0; i-- {
|
||||
b.WriteString(pairs[i])
|
||||
|
||||
if i > 0 {
|
||||
b.WriteString(",")
|
||||
}
|
||||
}
|
||||
|
||||
x509Username = b.String()
|
||||
}
|
||||
|
||||
connOpts = append(connOpts, WithTLSConfig(func(*tls.Config) *tls.Config { return tlsConfig }))
|
||||
}
|
||||
|
||||
if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI {
|
||||
cred := &auth.Cred{
|
||||
Source: "admin",
|
||||
Username: cs.Username,
|
||||
Password: cs.Password,
|
||||
PasswordSet: cs.PasswordSet,
|
||||
Props: cs.AuthMechanismProperties,
|
||||
}
|
||||
|
||||
if cs.AuthSource != "" {
|
||||
cred.Source = cs.AuthSource
|
||||
} else {
|
||||
switch cs.AuthMechanism {
|
||||
case auth.MongoDBX509:
|
||||
if cred.Username == "" {
|
||||
cred.Username = x509Username
|
||||
}
|
||||
fallthrough
|
||||
case auth.GSSAPI, auth.PLAIN:
|
||||
cred.Source = "$external"
|
||||
default:
|
||||
cred.Source = cs.Database
|
||||
}
|
||||
}
|
||||
|
||||
authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connOpts = append(connOpts, WithHandshaker(func(h Handshaker) Handshaker {
|
||||
options := &auth.HandshakeOptions{
|
||||
AppName: cs.AppName,
|
||||
Authenticator: authenticator,
|
||||
Compressors: cs.Compressors,
|
||||
}
|
||||
if cs.AuthMechanism == "" {
|
||||
// Required for SASL mechanism negotiation during handshake
|
||||
options.DBUser = cred.Source + "." + cred.Username
|
||||
}
|
||||
return auth.Handshaker(h, options)
|
||||
}))
|
||||
} else {
|
||||
// We need to add a non-auth Handshaker to the connection options
|
||||
connOpts = append(connOpts, WithHandshaker(func(h driver.Handshaker) driver.Handshaker {
|
||||
return operation.NewIsMaster().AppName(cs.AppName).Compressors(cs.Compressors)
|
||||
}))
|
||||
}
|
||||
|
||||
if len(cs.Compressors) > 0 {
|
||||
connOpts = append(connOpts, WithCompressors(func(compressors []string) []string {
|
||||
return append(compressors, cs.Compressors...)
|
||||
}))
|
||||
|
||||
for _, comp := range cs.Compressors {
|
||||
if comp == "zlib" {
|
||||
connOpts = append(connOpts, WithZlibLevel(func(level *int) *int {
|
||||
return &cs.ZlibLevel
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
c.serverOpts = append(c.serverOpts, WithCompressionOptions(func(opts ...string) []string {
|
||||
return append(opts, cs.Compressors...)
|
||||
}))
|
||||
}
|
||||
|
||||
if len(connOpts) > 0 {
|
||||
c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
|
||||
return append(opts, connOpts...)
|
||||
}))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithMode configures the topology's monitor mode.
|
||||
func WithMode(fn func(MonitorMode) MonitorMode) Option {
|
||||
return func(cfg *config) error {
|
||||
cfg.mode = fn(cfg.mode)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithReplicaSetName configures the topology's default replica set name.
|
||||
func WithReplicaSetName(fn func(string) string) Option {
|
||||
return func(cfg *config) error {
|
||||
cfg.replicaSetName = fn(cfg.replicaSetName)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithSeedList configures a topology's seed list.
|
||||
func WithSeedList(fn func(...string) []string) Option {
|
||||
return func(cfg *config) error {
|
||||
cfg.seedList = fn(cfg.seedList...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerOptions configures a topology's server options for when a new server
|
||||
// needs to be created.
|
||||
func WithServerOptions(fn func(...ServerOption) []ServerOption) Option {
|
||||
return func(cfg *config) error {
|
||||
cfg.serverOpts = fn(cfg.serverOpts...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithServerSelectionTimeout configures a topology's server selection timeout.
|
||||
// A server selection timeout of 0 means there is no timeout for server selection.
|
||||
func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option {
|
||||
return func(cfg *config) error {
|
||||
cfg.serverSelectionTimeout = fn(cfg.serverSelectionTimeout)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// addCACertFromFile adds a root CA certificate to the configuration given a path
|
||||
// to the containing file.
|
||||
func addCACertFromFile(cfg *tls.Config, file string) error {
|
||||
data, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
certBytes, err := loadCert(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.RootCAs == nil {
|
||||
cfg.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
cfg.RootCAs.AddCert(cert)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadCert(data []byte) ([]byte, error) {
|
||||
var certBlock *pem.Block
|
||||
|
||||
for certBlock == nil {
|
||||
if data == nil || len(data) == 0 {
|
||||
return nil, errors.New(".pem file must have both a CERTIFICATE and an RSA PRIVATE KEY section")
|
||||
}
|
||||
|
||||
block, rest := pem.Decode(data)
|
||||
if block == nil {
|
||||
return nil, errors.New("invalid .pem file")
|
||||
}
|
||||
|
||||
switch block.Type {
|
||||
case "CERTIFICATE":
|
||||
if certBlock != nil {
|
||||
return nil, errors.New("multiple CERTIFICATE sections in .pem file")
|
||||
}
|
||||
|
||||
certBlock = block
|
||||
}
|
||||
|
||||
data = rest
|
||||
}
|
||||
|
||||
return certBlock.Bytes, nil
|
||||
}
|
||||
|
||||
// addClientCertFromFile adds a client certificate to the configuration given a path to the
|
||||
// containing file and returns the certificate's subject name.
|
||||
func addClientCertFromFile(cfg *tls.Config, clientFile, keyPasswd string) (string, error) {
|
||||
data, err := ioutil.ReadFile(clientFile)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var currentBlock *pem.Block
|
||||
var certBlock, certDecodedBlock, keyBlock []byte
|
||||
|
||||
remaining := data
|
||||
start := 0
|
||||
for {
|
||||
currentBlock, remaining = pem.Decode(remaining)
|
||||
if currentBlock == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if currentBlock.Type == "CERTIFICATE" {
|
||||
certBlock = data[start : len(data)-len(remaining)]
|
||||
certDecodedBlock = currentBlock.Bytes
|
||||
start += len(certBlock)
|
||||
} else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") {
|
||||
if keyPasswd != "" && x509.IsEncryptedPEMBlock(currentBlock) {
|
||||
var encoded bytes.Buffer
|
||||
buf, err := x509.DecryptPEMBlock(currentBlock, []byte(keyPasswd))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf})
|
||||
keyBlock = encoded.Bytes()
|
||||
start = len(data) - len(remaining)
|
||||
} else {
|
||||
keyBlock = data[start : len(data)-len(remaining)]
|
||||
start += len(keyBlock)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(certBlock) == 0 {
|
||||
return "", fmt.Errorf("failed to find CERTIFICATE")
|
||||
}
|
||||
if len(keyBlock) == 0 {
|
||||
return "", fmt.Errorf("failed to find PRIVATE KEY")
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(certBlock, keyBlock)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cfg.Certificates = append(cfg.Certificates, cert)
|
||||
|
||||
// The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not
|
||||
// retained.
|
||||
crt, err := x509.ParseCertificate(certDecodedBlock)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return x509CertSubject(crt), nil
|
||||
}
|
||||
9
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options_1_10.go
generated
vendored
Executable file
9
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options_1_10.go
generated
vendored
Executable file
@@ -0,0 +1,9 @@
|
||||
// +build go1.10
|
||||
|
||||
package topology
|
||||
|
||||
import "crypto/x509"
|
||||
|
||||
func x509CertSubject(cert *x509.Certificate) string {
|
||||
return cert.Subject.String()
|
||||
}
|
||||
13
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options_1_9.go
generated
vendored
Executable file
13
vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_options_1_9.go
generated
vendored
Executable file
@@ -0,0 +1,13 @@
|
||||
// +build !go1.10
|
||||
|
||||
package topology
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// We don't support version less then 1.10, but Evergreen needs to be able to compile the driver
|
||||
// using version 1.8.
|
||||
func x509CertSubject(cert *x509.Certificate) string {
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user