newline battles continue

This commit is contained in:
bel
2020-01-19 20:41:30 +00:00
parent 98adb53caf
commit 573696774e
1456 changed files with 501133 additions and 6 deletions

View File

@@ -0,0 +1,40 @@
# Topology Package Design
This document outlines the design for this package.
## Topology
The `Topology` type handles monitoring the state of a MongoDB deployment and selecting servers.
Updating the description is handled by finite state machine which implements the server discovery
and monitoring specification. A `Topology` can be connected and fully disconnected, which enables
saving resources. The `Topology` type also handles server selection following the server selection
specification.
## Server
The `Server` type handles heartbeating a MongoDB server and holds a pool of connections.
## Connection
Connections are handled by two main types and an auxiliary type. The two main types are `connection`
and `Connection`. The first holds most of the logic required to actually read and write wire
messages. Instances can be created with the `newConnection` method. Inside the `newConnection`
method the auxiliary type, `initConnection` is used to perform the connection handshake. This is
required because the `connection` type does not fully implement `driver.Connection` which is
required during handshaking. The `Connection` type is what is actually returned to a consumer of the
`topology` package. This type does implement the `driver.Connection` type, holds a reference to a
`connection` instance, and exists mainly to prevent accidental continued usage of a connection after
closing it.
The connection implementations in this package are conduits for wire messages but they have no
ability to encode, decode, or validate wire messages. That must be handled by consumers.
## Pool
The `pool` type implements a connection pool. It handles caching idle connections and dialing
new ones, but it does not track a maximum number of connections. That is the responsibility of a
wrapping type, like `Server`.
The `pool` type has no concept of closing, instead it has concepts of connecting and disconnecting.
This allows a `Topology` to be disconnected,but keeping the memory around to be reconnected later.
There is a `close` method, but this is used to close a connection.
There are three methods related to getting and putting connections: `get`, `close`, and `put`. The
`get` method will either retrieve a connection from the cache or it will dial a new `connection`.
The `close` method will close the underlying socket of a `connection`. The `put` method will put a
connection into the pool, placing it in the cahce if there is space, otherwise it will close it.

View File

@@ -0,0 +1,458 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"compress/zlib"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"strings"
"github.com/golang/snappy"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
wiremessagex "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
"go.mongodb.org/mongo-driver/x/network/command"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
var globalConnectionID uint64
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
type connection struct {
id string
nc net.Conn // When nil, the connection is closed.
addr address.Address
idleTimeout time.Duration
idleDeadline time.Time
lifetimeDeadline time.Time
readTimeout time.Duration
writeTimeout time.Duration
desc description.Server
compressor wiremessage.CompressorID
zliblevel int
// pool related fields
pool *pool
poolID uint64
generation uint64
}
// newConnection handles the creation of a connection. It will dial, configure TLS, and perform
// initialization handshakes.
func newConnection(ctx context.Context, addr address.Address, opts ...ConnectionOption) (*connection, error) {
cfg, err := newConnectionConfig(opts...)
if err != nil {
return nil, err
}
nc, err := cfg.dialer.DialContext(ctx, addr.Network(), addr.String())
if err != nil {
return nil, ConnectionError{Wrapped: err, init: true}
}
if cfg.tlsConfig != nil {
tlsConfig := cfg.tlsConfig.Clone()
nc, err = configureTLS(ctx, nc, addr, tlsConfig)
if err != nil {
return nil, ConnectionError{Wrapped: err, init: true}
}
}
var lifetimeDeadline time.Time
if cfg.lifeTimeout > 0 {
lifetimeDeadline = time.Now().Add(cfg.lifeTimeout)
}
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
c := &connection{
id: id,
nc: nc,
addr: addr,
idleTimeout: cfg.idleTimeout,
lifetimeDeadline: lifetimeDeadline,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
}
c.bumpIdleDeadline()
// running isMaster and authentication is handled by a handshaker on the configuration instance.
if cfg.handshaker != nil {
c.desc, err = cfg.handshaker.Handshake(ctx, c.addr, initConnection{c})
if err != nil {
if c.nc != nil {
_ = c.nc.Close()
}
return nil, ConnectionError{Wrapped: err, init: true}
}
if cfg.descCallback != nil {
cfg.descCallback(c.desc)
}
if len(c.desc.Compression) > 0 {
clientMethodLoop:
for _, method := range cfg.compressors {
for _, serverMethod := range c.desc.Compression {
if method != serverMethod {
continue
}
switch strings.ToLower(method) {
case "snappy":
c.compressor = wiremessage.CompressorSnappy
case "zlib":
c.compressor = wiremessage.CompressorZLib
c.zliblevel = wiremessage.DefaultZlibLevel
if cfg.zlibLevel != nil {
c.zliblevel = *cfg.zlibLevel
}
}
break clientMethodLoop
}
}
}
}
return c, nil
}
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if c.nc == nil {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
default:
}
var deadline time.Time
if c.writeTimeout != 0 {
deadline = time.Now().Add(c.writeTimeout)
}
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
deadline = dl
}
if err := c.nc.SetWriteDeadline(deadline); err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
}
_, err = c.nc.Write(wm)
if err != nil {
c.close()
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to write wire message to network"}
}
c.bumpIdleDeadline()
return nil
}
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
if c.nc == nil {
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
// We close the connection because we don't know if there is an unread message on the wire.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
default:
}
var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
}
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
deadline = dl
}
if err := c.nc.SetReadDeadline(deadline); err != nil {
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
}
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
// We close the connection because we don't know if there are other bytes left to read.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to decode message length"}
}
// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
if int(size) > cap(dst) {
// Since we can't grow this slice without allocating, just allocate an entirely new slice.
dst = make([]byte, 0, size)
}
// We need to ensure we don't accidentally read into a subsequent wire message, so we set the
// size to read exactly this wire message.
dst = dst[:size]
copy(dst, sizeBuf[:])
_, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
// We close the connection because we don't know if there are other bytes left to read.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "unable to read full message"}
}
c.bumpIdleDeadline()
return dst, nil
}
func (c *connection) close() error {
if c.nc == nil {
return nil
}
if c.pool == nil {
err := c.nc.Close()
c.nc = nil
return err
}
return c.pool.close(c)
}
func (c *connection) expired() bool {
now := time.Now()
if !c.idleDeadline.IsZero() && now.After(c.idleDeadline) {
return true
}
if !c.lifetimeDeadline.IsZero() && now.After(c.lifetimeDeadline) {
return true
}
return c.nc == nil
}
func (c *connection) bumpIdleDeadline() {
if c.idleTimeout > 0 {
c.idleDeadline = time.Now().Add(c.idleTimeout)
}
}
// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
type initConnection struct{ *connection }
var _ driver.Connection = initConnection{}
func (c initConnection) Description() description.Server { return description.Server{} }
func (c initConnection) Close() error { return nil }
func (c initConnection) ID() string { return c.id }
func (c initConnection) Address() address.Address { return c.addr }
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
return c.writeWireMessage(ctx, wm)
}
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
return c.readWireMessage(ctx, dst)
}
// Connection implements the driver.Connection interface. It allows reading and writing wire
// messages.
type Connection struct {
*connection
s *Server
mu sync.RWMutex
}
var _ driver.Connection = (*Connection)(nil)
// WriteWireMessage handles writing a wire message to the underlying connection.
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return ErrConnectionClosed
}
return c.writeWireMessage(ctx, wm)
}
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
// will be overwritten with the new wire message.
func (c *Connection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
return c.readWireMessage(ctx, dst)
}
// CompressWireMessage handles compressing the provided wire message using the underlying
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
// there is no compressor set on the underlying connection, then no compression will be performed.
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
if c.connection.compressor == wiremessage.CompressorNoOp {
return append(dst, src...), nil
}
_, reqid, respto, origcode, rem, ok := wiremessagex.ReadHeader(src)
if !ok {
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
}
idx, dst := wiremessagex.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
dst = wiremessagex.AppendCompressedOriginalOpCode(dst, origcode)
dst = wiremessagex.AppendCompressedUncompressedSize(dst, int32(len(rem)))
dst = wiremessagex.AppendCompressedCompressorID(dst, c.connection.compressor)
switch c.connection.compressor {
case wiremessage.CompressorSnappy:
compressed := snappy.Encode(nil, rem)
dst = wiremessagex.AppendCompressedCompressedMessage(dst, compressed)
case wiremessage.CompressorZLib:
var b bytes.Buffer
w, err := zlib.NewWriterLevel(&b, c.connection.zliblevel)
if err != nil {
return dst, err
}
_, err = w.Write(rem)
if err != nil {
return dst, err
}
err = w.Close()
if err != nil {
return dst, err
}
dst = wiremessagex.AppendCompressedCompressedMessage(dst, b.Bytes())
default:
return dst, fmt.Errorf("unknown compressor ID %v", c.connection.compressor)
}
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
}
// Description returns the server description of the server this connection is connected to.
func (c *Connection) Description() description.Server {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return description.Server{}
}
return c.desc
}
// Close returns this connection to the connection pool. This method may not close the underlying
// socket.
func (c *Connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return nil
}
if c.s != nil {
c.s.sem.Release(1)
}
err := c.pool.put(c.connection)
if err != nil {
return err
}
c.connection = nil
return nil
}
// ID returns the ID of this connection.
func (c *Connection) ID() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.id
}
// Address returns the address of this connection.
func (c *Connection) Address() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return address.Address("0.0.0.0")
}
return c.addr
}
var notMasterCodes = []int32{10107, 13435}
var recoveringCodes = []int32{11600, 11602, 13436, 189, 91}
func isRecoveringError(err command.Error) bool {
for _, c := range recoveringCodes {
if c == err.Code {
return true
}
}
return strings.Contains(err.Error(), "node is recovering")
}
func isNotMasterError(err command.Error) bool {
for _, c := range notMasterCodes {
if c == err.Code {
return true
}
}
return strings.Contains(err.Error(), "not master")
}
func configureTLS(ctx context.Context, nc net.Conn, addr address.Address, config *tls.Config) (net.Conn, error) {
if !config.InsecureSkipVerify {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}
hostname = hostname[:colonPos]
config.ServerName = hostname
}
client := tls.Client(nc, config)
errChan := make(chan error, 1)
go func() {
errChan <- client.Handshake()
}()
select {
case err := <-errChan:
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, errors.New("server connection cancelled/timeout during TLS handshake")
}
return client, nil
}

View File

@@ -0,0 +1,615 @@
package topology
import (
"context"
"fmt"
"sync"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/bsonx"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/network/command"
"go.mongodb.org/mongo-driver/x/network/compressor"
"go.mongodb.org/mongo-driver/x/network/wiremessage"
)
var emptyDoc bson.Raw
type connectionLegacy struct {
*connection
writeBuf []byte
readBuf []byte
monitor *event.CommandMonitor
cmdMap map[int64]*commandMetadata // map for monitoring commands sent to server
wireMessageBuf []byte // buffer to store uncompressed wire message before compressing
compressBuf []byte // buffer to compress messages
uncompressBuf []byte // buffer to uncompress messages
compressor compressor.Compressor // use for compressing messages
// server can compress response with any compressor supported by driver
compressorMap map[wiremessage.CompressorID]compressor.Compressor
s *Server
sync.RWMutex
}
func newConnectionLegacy(c *connection, s *Server, opts ...ConnectionOption) (*connectionLegacy, error) {
cfg, err := newConnectionConfig(opts...)
if err != nil {
return nil, err
}
compressorMap := make(map[wiremessage.CompressorID]compressor.Compressor)
for _, comp := range cfg.compressors {
switch comp {
case "snappy":
snappyComp := compressor.CreateSnappy()
compressorMap[snappyComp.CompressorID()] = snappyComp
case "zlib":
zlibComp, err := compressor.CreateZlib(cfg.zlibLevel)
if err != nil {
return nil, err
}
compressorMap[zlibComp.CompressorID()] = zlibComp
}
}
cl := &connectionLegacy{
connection: c,
compressorMap: compressorMap,
cmdMap: make(map[int64]*commandMetadata),
compressBuf: make([]byte, 256),
readBuf: make([]byte, 256),
uncompressBuf: make([]byte, 256),
writeBuf: make([]byte, 0, 256),
wireMessageBuf: make([]byte, 256),
s: s,
}
d := c.desc
if len(d.Compression) > 0 {
clientMethodLoop:
for _, comp := range cl.compressorMap {
method := comp.Name()
for _, serverMethod := range d.Compression {
if method != serverMethod {
continue
}
cl.compressor = comp // found matching compressor
break clientMethodLoop
}
}
}
cl.monitor = cfg.cmdMonitor // Attach the command monitor later to avoid monitoring auth.
return cl, nil
}
func (c *connectionLegacy) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
// Truncate the write buffer
c.writeBuf = c.writeBuf[:0]
messageToWrite := wm
// Compress if possible
if c.compressor != nil {
compressed, err := c.compressMessage(wm)
if err != nil {
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to compress wire message",
}
}
messageToWrite = compressed
}
var err error
c.writeBuf, err = messageToWrite.AppendWireMessage(c.writeBuf)
if err != nil {
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to encode wire message",
}
}
err = c.writeWireMessage(ctx, c.writeBuf)
if c.s != nil {
c.s.ProcessError(err)
}
if err != nil {
// The error we got back was probably a ConnectionError already, so we don't really need to
// wrap it here.
return ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to write wire message to network",
}
}
return c.commandStartedEvent(ctx, wm)
}
func (c *connectionLegacy) ReadWireMessage(ctx context.Context) (wiremessage.WireMessage, error) {
// Truncate the write buffer
c.readBuf = c.readBuf[:0]
var err error
c.readBuf, err = c.readWireMessage(ctx, c.readBuf)
if c.s != nil {
c.s.ProcessError(err)
}
if err != nil {
// The error we got back was probably a ConnectionError already, so we don't really need to
// wrap it here.
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to read wire message from network",
}
}
hdr, err := wiremessage.ReadHeader(c.readBuf, 0)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode header",
}
}
messageToDecode := c.readBuf
opcodeToCheck := hdr.OpCode
if hdr.OpCode == wiremessage.OpCompressed {
var compressed wiremessage.Compressed
err := compressed.UnmarshalWireMessage(c.readBuf)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_COMPRESSED",
}
}
uncompressed, origOpcode, err := c.uncompressMessage(compressed)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to uncompress message",
}
}
messageToDecode = uncompressed
opcodeToCheck = origOpcode
}
var wm wiremessage.WireMessage
switch opcodeToCheck {
case wiremessage.OpReply:
var r wiremessage.Reply
err := r.UnmarshalWireMessage(messageToDecode)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_REPLY",
}
}
wm = r
case wiremessage.OpMsg:
var reply wiremessage.Msg
err := reply.UnmarshalWireMessage(messageToDecode)
if err != nil {
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: err,
message: "unable to decode OP_MSG",
}
}
wm = reply
default:
return nil, ConnectionError{
ConnectionID: c.id,
message: fmt.Sprintf("opcode %s not implemented", hdr.OpCode),
}
}
if c.s != nil {
c.s.ProcessError(command.DecodeError(wm))
}
// TODO: do we care if monitoring fails?
return wm, c.commandFinishedEvent(ctx, wm)
}
func (c *connectionLegacy) Close() error {
c.Lock()
defer c.Unlock()
if c.connection == nil {
return nil
}
if c.s != nil {
c.s.sem.Release(1)
}
err := c.pool.put(c.connection)
if err != nil {
return err
}
c.connection = nil
return nil
}
func (c *connectionLegacy) Expired() bool {
c.RLock()
defer c.RUnlock()
return c.connection == nil || c.expired()
}
func (c *connectionLegacy) Alive() bool {
c.RLock()
defer c.RUnlock()
return c.connection != nil
}
func (c *connectionLegacy) ID() string {
c.RLock()
defer c.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.id
}
func (c *connectionLegacy) commandStartedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
if c.monitor == nil || c.monitor.Started == nil {
return nil
}
startedEvent := &event.CommandStartedEvent{
ConnectionID: c.id,
}
var cmd bsonx.Doc
var err error
var legacy bool
var fullCollName string
var acknowledged bool
switch converted := wm.(type) {
case wiremessage.Query:
cmd, err = converted.CommandDocument()
if err != nil {
return err
}
acknowledged = converted.AcknowledgedWrite()
startedEvent.DatabaseName = converted.DatabaseName()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
legacy = converted.Legacy()
fullCollName = converted.FullCollectionName
case wiremessage.Msg:
cmd, err = converted.GetMainDocument()
if err != nil {
return err
}
acknowledged = converted.AcknowledgedWrite()
arr, identifier, err := converted.GetSequenceArray()
if err != nil {
return err
}
if arr != nil {
cmd = cmd.Copy() // make copy to avoid changing original command
cmd = append(cmd, bsonx.Elem{identifier, bsonx.Array(arr)})
}
dbVal, err := cmd.LookupErr("$db")
if err != nil {
return err
}
startedEvent.DatabaseName = dbVal.StringValue()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
case wiremessage.GetMore:
cmd = converted.CommandDocument()
startedEvent.DatabaseName = converted.DatabaseName()
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
acknowledged = true
legacy = true
fullCollName = converted.FullCollectionName
case wiremessage.KillCursors:
cmd = converted.CommandDocument()
startedEvent.DatabaseName = converted.DatabaseName
startedEvent.RequestID = int64(converted.MsgHeader.RequestID)
legacy = true
}
rawcmd, _ := cmd.MarshalBSON()
startedEvent.Command = rawcmd
startedEvent.CommandName = cmd[0].Key
if !canMonitor(startedEvent.CommandName) {
startedEvent.Command = emptyDoc
}
c.monitor.Started(ctx, startedEvent)
if !acknowledged {
if c.monitor.Succeeded == nil {
return nil
}
// unack writes must provide a CommandSucceededEvent with an { ok: 1 } reply
finishedEvent := event.CommandFinishedEvent{
DurationNanos: 0,
CommandName: startedEvent.CommandName,
RequestID: startedEvent.RequestID,
ConnectionID: c.id,
}
c.monitor.Succeeded(ctx, &event.CommandSucceededEvent{
CommandFinishedEvent: finishedEvent,
Reply: bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)),
})
return nil
}
c.cmdMap[startedEvent.RequestID] = createMetadata(startedEvent.CommandName, legacy, fullCollName)
return nil
}
func (c *connectionLegacy) commandFinishedEvent(ctx context.Context, wm wiremessage.WireMessage) error {
if c.monitor == nil {
return nil
}
var reply bsonx.Doc
var requestID int64
var err error
switch converted := wm.(type) {
case wiremessage.Reply:
requestID = int64(converted.MsgHeader.ResponseTo)
case wiremessage.Msg:
requestID = int64(converted.MsgHeader.ResponseTo)
}
cmdMetadata := c.cmdMap[requestID]
delete(c.cmdMap, requestID)
switch converted := wm.(type) {
case wiremessage.Reply:
if cmdMetadata.Legacy {
reply, err = converted.GetMainLegacyDocument(cmdMetadata.FullCollectionName)
} else {
reply, err = converted.GetMainDocument()
}
case wiremessage.Msg:
reply, err = converted.GetMainDocument()
}
if err != nil {
return err
}
success, errmsg := processReply(reply)
if (success && c.monitor.Succeeded == nil) || (!success && c.monitor.Failed == nil) {
return nil
}
finishedEvent := event.CommandFinishedEvent{
DurationNanos: cmdMetadata.TimeDifference(),
CommandName: cmdMetadata.Name,
RequestID: requestID,
ConnectionID: c.id,
}
if success {
if !canMonitor(finishedEvent.CommandName) {
successEvent := &event.CommandSucceededEvent{
Reply: emptyDoc,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Succeeded(ctx, successEvent)
return nil
}
// if response has type 1 document sequence, the sequence must be included as a BSON array in the event's reply.
if opmsg, ok := wm.(wiremessage.Msg); ok {
arr, identifier, err := opmsg.GetSequenceArray()
if err != nil {
return err
}
if arr != nil {
reply = reply.Copy() // make copy to avoid changing original command
reply = append(reply, bsonx.Elem{identifier, bsonx.Array(arr)})
}
}
replyraw, _ := reply.MarshalBSON()
successEvent := &event.CommandSucceededEvent{
Reply: replyraw,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Succeeded(ctx, successEvent)
return nil
}
failureEvent := &event.CommandFailedEvent{
Failure: errmsg,
CommandFinishedEvent: finishedEvent,
}
c.monitor.Failed(ctx, failureEvent)
return nil
}
func canCompress(cmd string) bool {
if cmd == "isMaster" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
return false
}
return true
}
func (c *connectionLegacy) compressMessage(wm wiremessage.WireMessage) (wiremessage.WireMessage, error) {
var requestID int32
var responseTo int32
var origOpcode wiremessage.OpCode
switch converted := wm.(type) {
case wiremessage.Query:
firstElem, err := converted.Query.IndexErr(0)
if err != nil {
return wiremessage.Compressed{}, err
}
key := firstElem.Key()
if !canCompress(key) {
return wm, nil // return original message because this command can't be compressed
}
requestID = converted.MsgHeader.RequestID
origOpcode = wiremessage.OpQuery
responseTo = converted.MsgHeader.ResponseTo
case wiremessage.Msg:
firstElem, err := converted.Sections[0].(wiremessage.SectionBody).Document.IndexErr(0)
if err != nil {
return wiremessage.Compressed{}, err
}
key := firstElem.Key()
if !canCompress(key) {
return wm, nil
}
requestID = converted.MsgHeader.RequestID
origOpcode = wiremessage.OpMsg
responseTo = converted.MsgHeader.ResponseTo
}
// can compress
c.wireMessageBuf = c.wireMessageBuf[:0] // truncate
var err error
c.wireMessageBuf, err = wm.AppendWireMessage(c.wireMessageBuf)
if err != nil {
return wiremessage.Compressed{}, err
}
c.wireMessageBuf = c.wireMessageBuf[16:] // strip header
c.compressBuf = c.compressBuf[:0]
compressedBytes, err := c.compressor.CompressBytes(c.wireMessageBuf, c.compressBuf)
if err != nil {
return wiremessage.Compressed{}, err
}
compressedMessage := wiremessage.Compressed{
MsgHeader: wiremessage.Header{
// MessageLength and OpCode will be set when marshalling wire message by SetDefaults()
RequestID: requestID,
ResponseTo: responseTo,
},
OriginalOpCode: origOpcode,
UncompressedSize: int32(len(c.wireMessageBuf)), // length of uncompressed message excluding MsgHeader
CompressorID: wiremessage.CompressorID(c.compressor.CompressorID()),
CompressedMessage: compressedBytes,
}
return compressedMessage, nil
}
// returns []byte of uncompressed message with reconstructed header, original opcode, error
func (c *connectionLegacy) uncompressMessage(compressed wiremessage.Compressed) ([]byte, wiremessage.OpCode, error) {
// server doesn't guarantee the same compression method will be used each time so the CompressorID field must be
// used to find the correct method for uncompressing data
uncompressor := c.compressorMap[compressed.CompressorID]
// reset uncompressBuf
c.uncompressBuf = c.uncompressBuf[:0]
if int(compressed.UncompressedSize) > cap(c.uncompressBuf) {
c.uncompressBuf = make([]byte, 0, compressed.UncompressedSize)
}
uncompressedMessage, err := uncompressor.UncompressBytes(compressed.CompressedMessage, c.uncompressBuf[:compressed.UncompressedSize])
if err != nil {
return nil, 0, err
}
origHeader := wiremessage.Header{
MessageLength: int32(len(uncompressedMessage)) + 16, // add 16 for original header
RequestID: compressed.MsgHeader.RequestID,
ResponseTo: compressed.MsgHeader.ResponseTo,
}
switch compressed.OriginalOpCode {
case wiremessage.OpReply:
origHeader.OpCode = wiremessage.OpReply
case wiremessage.OpMsg:
origHeader.OpCode = wiremessage.OpMsg
default:
return nil, 0, fmt.Errorf("opcode %s not implemented", compressed.OriginalOpCode)
}
var fullMessage []byte
fullMessage = origHeader.AppendHeader(fullMessage)
fullMessage = append(fullMessage, uncompressedMessage...)
return fullMessage, origHeader.OpCode, nil
}
func canMonitor(cmd string) bool {
if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
return false
}
return true
}
func processReply(reply bsonx.Doc) (bool, string) {
var success bool
var errmsg string
var errCode int32
for _, elem := range reply {
switch elem.Key {
case "ok":
switch elem.Value.Type() {
case bsontype.Int32:
if elem.Value.Int32() == 1 {
success = true
}
case bsontype.Int64:
if elem.Value.Int64() == 1 {
success = true
}
case bsontype.Double:
if elem.Value.Double() == 1 {
success = true
}
}
case "errmsg":
if str, ok := elem.Value.StringValueOK(); ok {
errmsg = str
}
case "code":
if c, ok := elem.Value.Int32OK(); ok {
errCode = c
}
}
}
if success {
return true, ""
}
fullErrMsg := fmt.Sprintf("Error code %d: %s", errCode, errmsg)
return false, fullErrMsg
}

View File

@@ -0,0 +1,28 @@
package topology
import "time"
// commandMetadata contains metadata about a command sent to the server.
type commandMetadata struct {
Name string
Time time.Time
Legacy bool
FullCollectionName string
}
// createMetadata creates metadata for a command.
func createMetadata(name string, legacy bool, fullCollName string) *commandMetadata {
return &commandMetadata{
Name: name,
Time: time.Now(),
Legacy: legacy,
FullCollectionName: fullCollName,
}
}
// TimeDifference returns the difference between now and the time a command was sent in nanoseconds.
func (cm *commandMetadata) TimeDifference() int64 {
t := time.Now()
duration := t.Sub(cm.Time)
return duration.Nanoseconds()
}

View File

@@ -0,0 +1,187 @@
package topology
import (
"context"
"crypto/tls"
"net"
"time"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
// Dialer is used to make network connections.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DialerFunc is a type implemented by functions that can be used as a Dialer.
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
// DialContext implements the Dialer interface.
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return df(ctx, network, address)
}
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
// will also change the Dialer used for this package. This should only be changed why all
// of the connections being made need to use a different Dialer. Most of the time, using a
// WithDialer option is more appropriate than changing this variable.
var DefaultDialer Dialer = &net.Dialer{}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker
// HandshakerFunc is an adapter to allow the use of ordinary functions as
// connection handshakers.
type HandshakerFunc = driver.HandshakerFunc
type connectionConfig struct {
appName string
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
lifeTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
compressors []string
zlibLevel *int
descCallback func(description.Server)
}
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
cfg := &connectionConfig{
connectTimeout: 30 * time.Second,
dialer: nil,
idleTimeout: 10 * time.Minute,
lifeTimeout: 30 * time.Minute,
}
for _, opt := range opts {
err := opt(cfg)
if err != nil {
return nil, err
}
}
if cfg.dialer == nil {
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
}
return cfg, nil
}
func withServerDescriptionCallback(callback func(description.Server), opts ...ConnectionOption) []ConnectionOption {
return append(opts, ConnectionOption(func(c *connectionConfig) error {
c.descCallback = callback
return nil
}))
}
// ConnectionOption is used to configure a connection.
type ConnectionOption func(*connectionConfig) error
// WithAppName sets the application name which gets sent to MongoDB when it
// first connects.
func WithAppName(fn func(string) string) ConnectionOption {
return func(c *connectionConfig) error {
c.appName = fn(c.appName)
return nil
}
}
// WithCompressors sets the compressors that can be used for communication.
func WithCompressors(fn func([]string) []string) ConnectionOption {
return func(c *connectionConfig) error {
c.compressors = fn(c.compressors)
return nil
}
}
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
// connect to complete. The default is 30 seconds.
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.connectTimeout = fn(c.connectTimeout)
return nil
}
}
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
return func(c *connectionConfig) error {
c.dialer = fn(c.dialer)
return nil
}
}
// WithHandshaker configures the Handshaker that wll be used to initialize newly
// dialed connections.
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
return func(c *connectionConfig) error {
c.handshaker = fn(c.handshaker)
return nil
}
}
// WithIdleTimeout configures the maximum idle time to allow for a connection.
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.idleTimeout = fn(c.idleTimeout)
return nil
}
}
// WithLifeTimeout configures the maximum life of a connection.
func WithLifeTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.lifeTimeout = fn(c.lifeTimeout)
return nil
}
}
// WithReadTimeout configures the maximum read time for a connection.
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.readTimeout = fn(c.readTimeout)
return nil
}
}
// WithWriteTimeout configures the maximum write time for a connection.
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.writeTimeout = fn(c.writeTimeout)
return nil
}
}
// WithTLSConfig configures the TLS options for a connection.
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
return func(c *connectionConfig) error {
c.tlsConfig = fn(c.tlsConfig)
return nil
}
}
// WithMonitor configures a event for command monitoring.
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
return func(c *connectionConfig) error {
c.cmdMonitor = fn(c.cmdMonitor)
return nil
}
}
// WithZlibLevel sets the zLib compression level.
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) error {
c.zlibLevel = fn(c.zlibLevel)
return nil
}
}

View File

@@ -0,0 +1,22 @@
package topology
import "fmt"
// ConnectionError represents a connection error.
type ConnectionError struct {
ConnectionID string
Wrapped error
// init will be set to true if this error occured during connection initialization or
// during a connection handshake.
init bool
message string
}
// Error implements the error interface.
func (e ConnectionError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("connection(%s) %s: %s", e.ConnectionID, e.message, e.Wrapped.Error())
}
return fmt.Sprintf("connection(%s) %s", e.ConnectionID, e.message)
}

View File

@@ -0,0 +1,350 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
)
var supportedWireVersions = description.NewVersionRange(2, 6)
var minSupportedMongoDBVersion = "2.6"
type fsm struct {
description.Topology
SetName string
maxElectionID primitive.ObjectID
maxSetVersion uint32
}
func newFSM() *fsm {
return new(fsm)
}
// apply should operate on immutable TopologyDescriptions and Descriptions. This way we don't have to
// lock for the entire time we're applying server description.
func (f *fsm) apply(s description.Server) (description.Topology, error) {
newServers := make([]description.Server, len(f.Servers))
copy(newServers, f.Servers)
oldMinutes := f.SessionTimeoutMinutes
f.Topology = description.Topology{
Kind: f.Kind,
Servers: newServers,
}
// For data bearing servers, set SessionTimeoutMinutes to the lowest among them
if oldMinutes == 0 {
// If timeout currently 0, check all servers to see if any still don't have a timeout
// If they all have timeout, pick the lowest.
timeout := s.SessionTimeoutMinutes
for _, server := range f.Servers {
if server.DataBearing() && server.SessionTimeoutMinutes < timeout {
timeout = server.SessionTimeoutMinutes
}
}
f.SessionTimeoutMinutes = timeout
} else {
if s.DataBearing() && oldMinutes > s.SessionTimeoutMinutes {
f.SessionTimeoutMinutes = s.SessionTimeoutMinutes
} else {
f.SessionTimeoutMinutes = oldMinutes
}
}
if _, ok := f.findServer(s.Addr); !ok {
return f.Topology, nil
}
if s.WireVersion != nil {
if s.WireVersion.Max < supportedWireVersions.Min {
return description.Topology{}, fmt.Errorf(
"server at %s reports wire version %d, but this version of the Go driver requires "+
"at least %d (MongoDB %s)",
s.Addr.String(),
s.WireVersion.Max,
supportedWireVersions.Min,
minSupportedMongoDBVersion,
)
}
if s.WireVersion.Min > supportedWireVersions.Max {
return description.Topology{}, fmt.Errorf(
"server at %s requires wire version %d, but this version of the Go driver only "+
"supports up to %d",
s.Addr.String(),
s.WireVersion.Min,
supportedWireVersions.Max,
)
}
}
switch f.Kind {
case description.Unknown:
f.applyToUnknown(s)
case description.Sharded:
f.applyToSharded(s)
case description.ReplicaSetNoPrimary:
f.applyToReplicaSetNoPrimary(s)
case description.ReplicaSetWithPrimary:
f.applyToReplicaSetWithPrimary(s)
case description.Single:
f.applyToSingle(s)
}
return f.Topology, nil
}
func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithoutPrimary(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
}
func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) {
switch s.Kind {
case description.Standalone, description.Mongos:
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.updateRSWithPrimaryFromMember(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
f.checkIfHasPrimary()
}
}
func (f *fsm) applyToSharded(s description.Server) {
switch s.Kind {
case description.Mongos, description.Unknown:
f.replaceServer(s)
case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
f.removeServerByAddr(s.Addr)
}
}
func (f *fsm) applyToSingle(s description.Server) {
switch s.Kind {
case description.Unknown:
f.replaceServer(s)
case description.Standalone, description.Mongos:
if f.SetName != "" {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost:
if f.SetName != "" && f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
}
}
func (f *fsm) applyToUnknown(s description.Server) {
switch s.Kind {
case description.Mongos:
f.setKind(description.Sharded)
f.replaceServer(s)
case description.RSPrimary:
f.updateRSFromPrimary(s)
case description.RSSecondary, description.RSArbiter, description.RSMember:
f.setKind(description.ReplicaSetNoPrimary)
f.updateRSWithoutPrimary(s)
case description.Standalone:
f.updateUnknownWithStandalone(s)
case description.Unknown, description.RSGhost:
f.replaceServer(s)
}
}
func (f *fsm) checkIfHasPrimary() {
if _, ok := f.findPrimary(); ok {
f.setKind(description.ReplicaSetWithPrimary)
} else {
f.setKind(description.ReplicaSetNoPrimary)
}
}
func (f *fsm) updateRSFromPrimary(s description.Server) {
if f.SetName == "" {
f.SetName = s.SetName
} else if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
if s.SetVersion != 0 && !bytes.Equal(s.ElectionID[:], primitive.NilObjectID[:]) {
if f.maxSetVersion > s.SetVersion || bytes.Compare(f.maxElectionID[:], s.ElectionID[:]) == 1 {
f.replaceServer(description.Server{
Addr: s.Addr,
LastError: fmt.Errorf("was a primary, but its set version or election id is stale"),
})
f.checkIfHasPrimary()
return
}
f.maxElectionID = s.ElectionID
}
if s.SetVersion > f.maxSetVersion {
f.maxSetVersion = s.SetVersion
}
if j, ok := f.findPrimary(); ok {
f.setServer(j, description.Server{
Addr: f.Servers[j].Addr,
LastError: fmt.Errorf("was a primary, but a new primary was discovered"),
})
}
f.replaceServer(s)
for j := len(f.Servers) - 1; j >= 0; j-- {
found := false
for _, member := range s.Members {
if member == f.Servers[j].Addr {
found = true
break
}
}
if !found {
f.removeServer(j)
}
}
for _, member := range s.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
f.checkIfHasPrimary()
}
func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) {
if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
f.checkIfHasPrimary()
return
}
f.replaceServer(s)
if _, ok := f.findPrimary(); !ok {
f.setKind(description.ReplicaSetNoPrimary)
}
}
func (f *fsm) updateRSWithoutPrimary(s description.Server) {
if f.SetName == "" {
f.SetName = s.SetName
} else if f.SetName != s.SetName {
f.removeServerByAddr(s.Addr)
return
}
for _, member := range s.Members {
if _, ok := f.findServer(member); !ok {
f.addServer(member)
}
}
if s.Addr != s.CanonicalAddr {
f.removeServerByAddr(s.Addr)
return
}
f.replaceServer(s)
}
func (f *fsm) updateUnknownWithStandalone(s description.Server) {
if len(f.Servers) > 1 {
f.removeServerByAddr(s.Addr)
return
}
f.setKind(description.Single)
f.replaceServer(s)
}
func (f *fsm) addServer(addr address.Address) {
f.Servers = append(f.Servers, description.Server{
Addr: addr.Canonicalize(),
})
}
func (f *fsm) findPrimary() (int, bool) {
for i, s := range f.Servers {
if s.Kind == description.RSPrimary {
return i, true
}
}
return 0, false
}
func (f *fsm) findServer(addr address.Address) (int, bool) {
canon := addr.Canonicalize()
for i, s := range f.Servers {
if canon == s.Addr {
return i, true
}
}
return 0, false
}
func (f *fsm) removeServer(i int) {
f.Servers = append(f.Servers[:i], f.Servers[i+1:]...)
}
func (f *fsm) removeServerByAddr(addr address.Address) {
if i, ok := f.findServer(addr); ok {
f.removeServer(i)
}
}
func (f *fsm) replaceServer(s description.Server) bool {
if i, ok := f.findServer(s.Addr); ok {
f.setServer(i, s)
return true
}
return false
}
func (f *fsm) setServer(i int, s description.Server) {
f.Servers[i] = s
}
func (f *fsm) setKind(k description.TopologyKind) {
f.Kind = k
}

View File

@@ -0,0 +1,213 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
)
// ErrPoolConnected is returned from an attempt to connect an already connected pool
var ErrPoolConnected = PoolError("pool is connected")
// ErrPoolDisconnected is returned from an attempt to disconnect an already disconnected
// or disconnecting pool.
var ErrPoolDisconnected = PoolError("pool is disconnected or disconnecting")
// ErrConnectionClosed is returned from an attempt to use an already closed connection.
var ErrConnectionClosed = ConnectionError{ConnectionID: "<closed>", message: "connection is closed"}
// ErrWrongPool is return when a connection is returned to a pool it doesn't belong to.
var ErrWrongPool = PoolError("connection does not belong to this pool")
// PoolError is an error returned from a Pool method.
type PoolError string
// pruneInterval is the interval at which the background routine to close expired connections will be run.
var pruneInterval = time.Minute
func (pe PoolError) Error() string { return string(pe) }
type pool struct {
address address.Address
opts []ConnectionOption
conns *resourcePool // pool for idle connections
generation uint64
connected int32 // Must be accessed using the sync/atomic package.
nextid uint64
opened map[uint64]*connection // opened holds all of the currently open connections.
sync.Mutex
}
func connectionExpiredFunc(v interface{}) bool {
return v.(*connection).expired()
}
func connectionCloseFunc(v interface{}) {
c := v.(*connection)
go c.pool.close(c)
}
// newPool creates a new pool that will hold size number of idle connections. It will use the
// provided options when creating connections.
func newPool(addr address.Address, size uint64, opts ...ConnectionOption) *pool {
return &pool{
address: addr,
conns: newResourcePool(size, connectionExpiredFunc, connectionCloseFunc, pruneInterval),
generation: 0,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
}
}
// drain lazily drains the pool by increasing the generation ID.
func (p *pool) drain() { atomic.AddUint64(&p.generation, 1) }
func (p *pool) expired(generation uint64) bool { return generation < atomic.LoadUint64(&p.generation) }
// connect puts the pool into the connected state, allowing it to be used.
func (p *pool) connect() error {
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
return ErrPoolConnected
}
atomic.AddUint64(&p.generation, 1)
return nil
}
func (p *pool) disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&p.connected, connected, disconnecting) {
return ErrPoolDisconnected
}
// We first clear out the idle connections, then we wait until the context's deadline is hit or
// it's cancelled, after which we aggressively close the remaining open connections.
for {
connVal := p.conns.Get()
if connVal == nil {
break
}
_ = p.close(connVal.(*connection))
}
if dl, ok := ctx.Deadline(); ok {
// If we have a deadline then we interpret it as a request to gracefully shutdown. We wait
// until either all the connections have landed back in the pool (and have been closed) or
// until the timer is done.
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timer := time.NewTimer(time.Now().Sub(dl))
defer timer.Stop()
for {
select {
case <-timer.C:
case <-ticker.C: // Can we repalce this with an actual signal channel? We will know when p.inflight hits zero from the close method.
p.Lock()
if len(p.opened) > 0 {
p.Unlock()
continue
}
p.Unlock()
}
break
}
}
// We copy the remaining connections into a slice, then iterate it to close them. This allows us
// to use a single function to actually clean up and close connections at the expense of a
// double itertion in the worse case.
p.Lock()
toClose := make([]*connection, 0, len(p.opened))
for _, pc := range p.opened {
toClose = append(toClose, pc)
}
p.Unlock()
for _, pc := range toClose {
_ = p.close(pc) // We don't care about errors while closing the connection.
}
atomic.StoreInt32(&p.connected, disconnected)
return nil
}
func (p *pool) get(ctx context.Context) (*connection, error) {
if atomic.LoadInt32(&p.connected) != connected {
return nil, ErrPoolDisconnected
}
// try to get an unexpired idle connection
connVal := p.conns.Get()
if connVal != nil {
return connVal.(*connection), nil
}
// couldn't find an unexpired connection. create a new one.
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
c, err := newConnection(ctx, p.address, p.opts...)
if err != nil {
return nil, err
}
c.pool = p
c.poolID = atomic.AddUint64(&p.nextid, 1)
c.generation = p.generation
if atomic.LoadInt32(&p.connected) != connected {
_ = p.close(c) // The pool is disconnected or disconnecting, ignore the error from closing the connection.
return nil, ErrPoolDisconnected
}
p.Lock()
p.opened[c.poolID] = c
p.Unlock()
return c, nil
}
}
// close closes a connection, not the pool itself. This method will actually close the connection,
// making it unusable, to instead return the connection to the pool, use put.
func (p *pool) close(c *connection) error {
if c.pool != p {
return ErrWrongPool
}
p.Lock()
delete(p.opened, c.poolID)
nc := c.nc
c.nc = nil
p.Unlock()
if nc == nil {
return nil // We're closing an already closed connection.
}
err := nc.Close()
if err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to close net.Conn"}
}
return nil
}
// put returns a connection to this pool. If the pool is connected, the connection is not
// expired, and there is space in the cache, the connection is returned to the cache.
func (p *pool) put(c *connection) error {
if c.pool != p {
return ErrWrongPool
}
if atomic.LoadInt32(&p.connected) != connected || c.expired() {
return p.close(c)
}
// close the connection if the underlying pool is full
if !p.conns.Put(c) {
return p.close(c)
}
return nil
}

View File

@@ -0,0 +1,121 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"container/list"
"sync"
"time"
)
// expiredFunc is the function type used for testing whether or not resources in a resourcePool have expired. It should
// return true if the resource has expired and can be removed from the pool.
type expiredFunc func(interface{}) bool
// closeFunc is the function type used to close resources in a resourcePool. The pool will always call this function
// asynchronously.
type closeFunc func(interface{})
// resourcePool is a concurrent resource pool that implements the behavior described in the sessions spec.
type resourcePool struct {
deque *list.List
len, maxSize uint64
expiredFn expiredFunc
closeFn closeFunc
pruneTimer *time.Timer
pruneInterval time.Duration
sync.Mutex
}
// NewResourcePool creates a new resourcePool instance that is capped to maxSize resources.
// If maxSize is 0, the pool size will be unbounded.
func newResourcePool(maxSize uint64, expiredFn expiredFunc, closeFn closeFunc, pruneInterval time.Duration) *resourcePool {
rp := &resourcePool{
deque: list.New(),
maxSize: maxSize,
expiredFn: expiredFn,
closeFn: closeFn,
pruneInterval: pruneInterval,
}
rp.Lock()
rp.pruneTimer = time.AfterFunc(rp.pruneInterval, rp.Prune)
rp.Unlock()
return rp
}
// Get returns the first un-expired resource from the pool. If no such resource can be found, nil is returned.
func (rp *resourcePool) Get() interface{} {
rp.Lock()
defer rp.Unlock()
var next *list.Element
for curr := rp.deque.Front(); curr != nil; curr = next {
next = curr.Next()
// remove the current resource and return it if it is valid
rp.deque.Remove(curr)
rp.len--
if !rp.expiredFn(curr.Value) {
// found un-expired resource
return curr.Value
}
// close expired resources
rp.closeFn(curr.Value)
}
// did not find a valid resource
return nil
}
// Put clears expired resources from the pool and then returns resource v to the pool if there is room. It returns true
// if v was successfully added to the pool and false otherwise.
func (rp *resourcePool) Put(v interface{}) bool {
rp.Lock()
defer rp.Unlock()
// close expired resources from the back of the pool
rp.prune()
if (rp.maxSize != 0 && rp.len == rp.maxSize) || rp.expiredFn(v) {
return false
}
rp.deque.PushFront(v)
rp.len++
return true
}
// Prune clears expired resources from the pool.
func (rp *resourcePool) Prune() {
rp.Lock()
defer rp.Unlock()
rp.prune()
}
func (rp *resourcePool) prune() {
// iterate over the list and stop at the first valid value
var prev *list.Element
for curr := rp.deque.Back(); curr != nil; curr = prev {
prev = curr.Prev()
if !rp.expiredFn(curr.Value) {
// found unexpired resource
break
}
// remove and close expired resources
rp.deque.Remove(curr)
rp.closeFn(curr.Value)
rp.len--
}
// reset the timer for the background cleanup routine
if !rp.pruneTimer.Stop() {
rp.pruneTimer = time.AfterFunc(rp.pruneInterval, rp.Prune)
return
}
rp.pruneTimer.Reset(rp.pruneInterval)
}

View File

@@ -0,0 +1,643 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"errors"
"fmt"
"math"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
connectionlegacy "go.mongodb.org/mongo-driver/x/network/connection"
"go.mongodb.org/mongo-driver/x/network/result"
"golang.org/x/sync/semaphore"
)
const minHeartbeatInterval = 500 * time.Millisecond
const connectionSemaphoreSize = math.MaxInt64
var isMasterOrRecoveringCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91}
// ErrServerClosed occurs when an attempt to get a connection is made after
// the server has been closed.
var ErrServerClosed = errors.New("server is closed")
// ErrServerConnected occurs when at attempt to connect is made after a server
// has already been connected.
var ErrServerConnected = errors.New("server is connected")
// SelectedServer represents a specific server that was selected during server selection.
// It contains the kind of the topology it was selected from.
type SelectedServer struct {
*Server
Kind description.TopologyKind
}
// Description returns a description of the server as of the last heartbeat.
func (ss *SelectedServer) Description() description.SelectedServer {
sdesc := ss.Server.Description()
return description.SelectedServer{
Server: sdesc,
Kind: ss.Kind,
}
}
// These constants represent the connection states of a server.
const (
disconnected int32 = iota
disconnecting
connected
connecting
)
func connectionStateString(state int32) string {
switch state {
case 0:
return "Disconnected"
case 1:
return "Disconnecting"
case 2:
return "Connected"
case 3:
return "Connecting"
}
return ""
}
// Server is a single server within a topology.
type Server struct {
cfg *serverConfig
address address.Address
connectionstate int32
// connection related fields
pool *pool
sem *semaphore.Weighted
// goroutine management fields
done chan struct{}
checkNow chan struct{}
closewg sync.WaitGroup
// description related fields
desc atomic.Value // holds a description.Server
updateTopologyCallback atomic.Value
averageRTTSet bool
averageRTT time.Duration
// subscriber related fields
subLock sync.Mutex
subscribers map[uint64]chan description.Server
currentSubscriberID uint64
subscriptionsClosed bool
}
// ConnectServer creates a new Server and then initializes it using the
// Connect method.
func ConnectServer(addr address.Address, updateCallback func(description.Server), opts ...ServerOption) (*Server, error) {
srvr, err := NewServer(addr, opts...)
if err != nil {
return nil, err
}
err = srvr.Connect(updateCallback)
if err != nil {
return nil, err
}
return srvr, nil
}
// NewServer creates a new server. The mongodb server at the address will be monitored
// on an internal monitoring goroutine.
func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
cfg, err := newServerConfig(opts...)
if err != nil {
return nil, err
}
var maxConns = uint64(cfg.maxConns)
if maxConns == 0 {
maxConns = math.MaxInt64
}
s := &Server{
cfg: cfg,
address: addr,
sem: semaphore.NewWeighted(int64(maxConns)),
done: make(chan struct{}),
checkNow: make(chan struct{}, 1),
subscribers: make(map[uint64]chan description.Server),
}
s.desc.Store(description.Server{Addr: addr})
callback := func(desc description.Server) { s.updateDescription(desc, false) }
s.pool = newPool(addr, uint64(cfg.maxIdleConns), withServerDescriptionCallback(callback, cfg.connectionOpts...)...)
return s, nil
}
// Connect initializes the Server by starting background monitoring goroutines.
// This method must be called before a Server can be used.
func (s *Server) Connect(updateCallback func(description.Server)) error {
if !atomic.CompareAndSwapInt32(&s.connectionstate, disconnected, connected) {
return ErrServerConnected
}
s.desc.Store(description.Server{Addr: s.address})
s.updateTopologyCallback.Store(updateCallback)
go s.update()
s.closewg.Add(1)
return s.pool.connect()
}
// Disconnect closes sockets to the server referenced by this Server.
// Subscriptions to this Server will be closed. Disconnect will shutdown
// any monitoring goroutines, close the idle connection pool, and will
// wait until all the in use connections have been returned to the connection
// pool and are closed before returning. If the context expires via
// cancellation, deadline, or timeout before the in use connections have been
// returned, the in use connections will be closed, resulting in the failure of
// any in flight read or write operations. If this method returns with no
// errors, all connections associated with this Server have been closed.
func (s *Server) Disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&s.connectionstate, connected, disconnecting) {
return ErrServerClosed
}
s.updateTopologyCallback.Store((func(description.Server))(nil))
// For every call to Connect there must be at least 1 goroutine that is
// waiting on the done channel.
s.done <- struct{}{}
err := s.pool.disconnect(ctx)
if err != nil {
return err
}
s.closewg.Wait()
atomic.StoreInt32(&s.connectionstate, disconnected)
return nil
}
// Connection gets a connection to the server.
func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
if atomic.LoadInt32(&s.connectionstate) != connected {
return nil, ErrServerClosed
}
err := s.sem.Acquire(ctx, 1)
if err != nil {
return nil, err
}
conn, err := s.pool.get(ctx)
if err != nil {
s.sem.Release(1)
connerr, ok := err.(ConnectionError)
if !ok {
return nil, err
}
// Since the only kind of ConnectionError we receive from pool.get will be an initialization
// error, we should set the description.Server appropriately.
desc := description.Server{
Kind: description.Unknown,
LastError: connerr.Wrapped,
}
s.updateDescription(desc, false)
return nil, err
}
return &Connection{connection: conn, s: s}, nil
}
// ConnectionLegacy gets a connection to the server.
func (s *Server) ConnectionLegacy(ctx context.Context) (connectionlegacy.Connection, error) {
if atomic.LoadInt32(&s.connectionstate) != connected {
return nil, ErrServerClosed
}
err := s.sem.Acquire(ctx, 1)
if err != nil {
return nil, err
}
conn, err := s.pool.get(ctx)
if err != nil {
s.sem.Release(1)
connerr, ok := err.(ConnectionError)
if !ok {
return nil, err
}
// Since the only kind of ConnectionError we receive from pool.get will be an initialization
// error, we should set the description.Server appropriately.
desc := description.Server{
Kind: description.Unknown,
LastError: connerr.Wrapped,
}
s.updateDescription(desc, false)
return nil, err
}
return newConnectionLegacy(conn, s, s.cfg.connectionOpts...)
}
// Description returns a description of the server as of the last heartbeat.
func (s *Server) Description() description.Server {
return s.desc.Load().(description.Server)
}
// SelectedDescription returns a description.SelectedServer with a Kind of
// Single. This can be used when performing tasks like monitoring a batch
// of servers and you want to run one off commands against those servers.
func (s *Server) SelectedDescription() description.SelectedServer {
sdesc := s.Description()
return description.SelectedServer{
Server: sdesc,
Kind: description.Single,
}
}
// Subscribe returns a ServerSubscription which has a channel on which all
// updated server descriptions will be sent. The channel will have a buffer
// size of one, and will be pre-populated with the current description.
func (s *Server) Subscribe() (*ServerSubscription, error) {
if atomic.LoadInt32(&s.connectionstate) != connected {
return nil, ErrSubscribeAfterClosed
}
ch := make(chan description.Server, 1)
ch <- s.desc.Load().(description.Server)
s.subLock.Lock()
defer s.subLock.Unlock()
if s.subscriptionsClosed {
return nil, ErrSubscribeAfterClosed
}
id := s.currentSubscriberID
s.subscribers[id] = ch
s.currentSubscriberID++
ss := &ServerSubscription{
C: ch,
s: s,
id: id,
}
return ss, nil
}
// RequestImmediateCheck will cause the server to send a heartbeat immediately
// instead of waiting for the heartbeat timeout.
func (s *Server) RequestImmediateCheck() {
select {
case s.checkNow <- struct{}{}:
default:
}
}
// ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
func (s *Server) ProcessError(err error) {
// Invalidate server description if not master or node recovering error occurs
if cerr, ok := err.(driver.Error); ok && (cerr.NetworkError() || cerr.NodeIsRecovering() || cerr.NotMaster()) {
desc := s.Description()
desc.Kind = description.Unknown
desc.LastError = err
// updates description to unknown
s.updateDescription(desc, false)
s.RequestImmediateCheck()
s.pool.drain()
return
}
ne, ok := err.(ConnectionError)
if !ok {
return
}
if netErr, ok := ne.Wrapped.(net.Error); ok && netErr.Timeout() {
return
}
if ne.Wrapped == context.Canceled || ne.Wrapped == context.DeadlineExceeded {
return
}
desc := s.Description()
desc.Kind = description.Unknown
desc.LastError = err
// updates description to unknown
s.updateDescription(desc, false)
}
// ProcessWriteConcernError checks if a WriteConcernError is an isNotMaster or
// isRecovering error, and if so updates the server accordingly.
func (s *Server) ProcessWriteConcernError(err *result.WriteConcernError) {
if err == nil || !wceIsNotMasterOrRecovering(err) {
return
}
desc := s.Description()
desc.Kind = description.Unknown
desc.LastError = err
// updates description to unknown
s.updateDescription(desc, false)
s.RequestImmediateCheck()
}
func wceIsNotMasterOrRecovering(wce *result.WriteConcernError) bool {
for _, code := range isMasterOrRecoveringCodes {
if int32(wce.Code) == code {
return true
}
}
if strings.Contains(wce.ErrMsg, "not master") || strings.Contains(wce.ErrMsg, "node is recovering") {
return true
}
return false
}
// update handles performing heartbeats and updating any subscribers of the
// newest description.Server retrieved.
func (s *Server) update() {
defer s.closewg.Done()
heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
rateLimiter := time.NewTicker(minHeartbeatInterval)
defer heartbeatTicker.Stop()
defer rateLimiter.Stop()
checkNow := s.checkNow
done := s.done
var doneOnce bool
defer func() {
if r := recover(); r != nil {
if doneOnce {
return
}
// We keep this goroutine alive attempting to read from the done channel.
<-done
}
}()
var conn *connection
var desc description.Server
desc, conn = s.heartbeat(nil)
s.updateDescription(desc, true)
closeServer := func() {
doneOnce = true
s.subLock.Lock()
for id, c := range s.subscribers {
close(c)
delete(s.subscribers, id)
}
s.subscriptionsClosed = true
s.subLock.Unlock()
if conn == nil || conn.nc == nil {
return
}
conn.nc.Close()
}
for {
select {
case <-heartbeatTicker.C:
case <-checkNow:
case <-done:
closeServer()
return
}
select {
case <-rateLimiter.C:
case <-done:
closeServer()
return
}
desc, conn = s.heartbeat(conn)
s.updateDescription(desc, false)
}
}
// updateDescription handles updating the description on the Server, notifying
// subscribers, and potentially draining the connection pool. The initial
// parameter is used to determine if this is the first description from the
// server.
func (s *Server) updateDescription(desc description.Server, initial bool) {
defer func() {
// ¯\_(ツ)_/¯
_ = recover()
}()
s.desc.Store(desc)
callback, ok := s.updateTopologyCallback.Load().(func(description.Server))
if ok && callback != nil {
callback(desc)
}
s.subLock.Lock()
for _, c := range s.subscribers {
select {
// drain the channel if it isn't empty
case <-c:
default:
}
c <- desc
}
s.subLock.Unlock()
if initial {
// We don't clear the pool on the first update on the description.
return
}
switch desc.Kind {
case description.Unknown:
s.pool.drain()
}
}
// heartbeat sends a heartbeat to the server using the given connection. The connection can be nil.
func (s *Server) heartbeat(conn *connection) (description.Server, *connection) {
const maxRetry = 2
var saved error
var desc description.Server
var set bool
var err error
ctx := context.Background()
for i := 1; i <= maxRetry; i++ {
var now time.Time
var descPtr *description.Server
if conn != nil && conn.expired() {
if conn.nc != nil {
conn.nc.Close()
}
conn = nil
}
if conn == nil {
opts := []ConnectionOption{
WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
}
opts = append(opts, s.cfg.connectionOpts...)
// We override whatever handshaker is currently attached to the options with a basic
// one because need to make sure we don't do auth.
opts = append(opts, WithHandshaker(func(h Handshaker) Handshaker {
now = time.Now()
return operation.NewIsMaster().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts)
}))
// Override any command monitors specified in options with nil to avoid monitoring heartbeats.
opts = append(opts, WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor {
return nil
}))
conn, err = newConnection(ctx, s.address, opts...)
if err == nil {
descPtr = &conn.desc
}
}
// do a heartbeat because a new connection wasn't created so a handshake was not performed
if descPtr == nil && err == nil {
now = time.Now()
op := operation.
NewIsMaster().
ClusterClock(s.cfg.clock).
Deployment(driver.SingleConnectionDeployment{initConnection{conn}})
err = op.Execute(ctx)
if err == nil {
tmpDesc := op.Result(s.address)
descPtr = &tmpDesc
}
}
// we do a retry if the server is connected, if succeed return new server desc (see below)
if err != nil {
saved = err
if conn != nil && conn.nc != nil {
conn.nc.Close()
}
conn = nil
if _, ok := err.(ConnectionError); ok {
s.pool.drain()
// If the server is not connected, give up and exit loop
if s.Description().Kind == description.Unknown {
break
}
}
continue
}
desc = *descPtr
delay := time.Since(now)
desc = desc.SetAverageRTT(s.updateAverageRTT(delay))
desc.HeartbeatInterval = s.cfg.heartbeatInterval
set = true
break
}
if !set {
desc = description.Server{
Addr: s.address,
LastError: saved,
Kind: description.Unknown,
}
}
return desc, conn
}
func (s *Server) updateAverageRTT(delay time.Duration) time.Duration {
if !s.averageRTTSet {
s.averageRTT = delay
} else {
alpha := 0.2
s.averageRTT = time.Duration(alpha*float64(delay) + (1-alpha)*float64(s.averageRTT))
}
return s.averageRTT
}
// Drain will drain the connection pool of this server. This is mainly here so the
// pool for the server doesn't need to be directly exposed and so that when an error
// is returned from reading or writing, a client can drain the pool for this server.
// This is exposed here so we don't have to wrap the Connection type and sniff responses
// for errors that would cause the pool to be drained, which can in turn centralize the
// logic for handling errors in the Client type.
//
// TODO(GODRIVER-617): I don't think we actually need this method. It's likely replaced by
// ProcessError.
func (s *Server) Drain() error {
s.pool.drain()
return nil
}
// String implements the Stringer interface.
func (s *Server) String() string {
desc := s.Description()
connState := atomic.LoadInt32(&s.connectionstate)
str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
s.address, desc.Kind, connectionStateString(connState))
if len(desc.Tags) != 0 {
str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
}
if connState == connected {
str += fmt.Sprintf(", Average RTT: %d", desc.AverageRTT)
}
if desc.LastError != nil {
str += fmt.Sprintf(", Last error: %s", desc.LastError)
}
return str
}
// ServerSubscription represents a subscription to the description.Server updates for
// a specific server.
type ServerSubscription struct {
C <-chan description.Server
s *Server
id uint64
}
// Unsubscribe unsubscribes this ServerSubscription from updates and closes the
// subscription channel.
func (ss *ServerSubscription) Unsubscribe() error {
ss.s.subLock.Lock()
defer ss.s.subLock.Unlock()
if ss.s.subscriptionsClosed {
return nil
}
ch, ok := ss.s.subscribers[ss.id]
if !ok {
return nil
}
close(ch)
delete(ss.s.subscribers, ss.id)
return nil
}

View File

@@ -0,0 +1,120 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
var defaultRegistry = bson.NewRegistryBuilder().Build()
type serverConfig struct {
clock *session.ClusterClock
compressionOpts []string
connectionOpts []ConnectionOption
appname string
heartbeatInterval time.Duration
heartbeatTimeout time.Duration
maxConns uint16
maxIdleConns uint16
registry *bsoncodec.Registry
}
func newServerConfig(opts ...ServerOption) (*serverConfig, error) {
cfg := &serverConfig{
heartbeatInterval: 10 * time.Second,
heartbeatTimeout: 10 * time.Second,
maxConns: 100,
maxIdleConns: 100,
registry: defaultRegistry,
}
for _, opt := range opts {
err := opt(cfg)
if err != nil {
return nil, err
}
}
return cfg, nil
}
// ServerOption configures a server.
type ServerOption func(*serverConfig) error
// WithConnectionOptions configures the server's connections.
func WithConnectionOptions(fn func(...ConnectionOption) []ConnectionOption) ServerOption {
return func(cfg *serverConfig) error {
cfg.connectionOpts = fn(cfg.connectionOpts...)
return nil
}
}
// WithCompressionOptions configures the server's compressors.
func WithCompressionOptions(fn func(...string) []string) ServerOption {
return func(cfg *serverConfig) error {
cfg.compressionOpts = fn(cfg.compressionOpts...)
return nil
}
}
// WithHeartbeatInterval configures a server's heartbeat interval.
func WithHeartbeatInterval(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) error {
cfg.heartbeatInterval = fn(cfg.heartbeatInterval)
return nil
}
}
// WithHeartbeatTimeout configures how long to wait for a heartbeat socket to
// connection.
func WithHeartbeatTimeout(fn func(time.Duration) time.Duration) ServerOption {
return func(cfg *serverConfig) error {
cfg.heartbeatTimeout = fn(cfg.heartbeatTimeout)
return nil
}
}
// WithMaxConnections configures the maximum number of connections to allow for
// a given server. If max is 0, then there is no upper limit to the number of
// connections.
func WithMaxConnections(fn func(uint16) uint16) ServerOption {
return func(cfg *serverConfig) error {
cfg.maxConns = fn(cfg.maxConns)
return nil
}
}
// WithMaxIdleConnections configures the maximum number of idle connections
// allowed for the server.
func WithMaxIdleConnections(fn func(uint16) uint16) ServerOption {
return func(cfg *serverConfig) error {
cfg.maxIdleConns = fn(cfg.maxIdleConns)
return nil
}
}
// WithClock configures the ClusterClock for the server to use.
func WithClock(fn func(clock *session.ClusterClock) *session.ClusterClock) ServerOption {
return func(cfg *serverConfig) error {
cfg.clock = fn(cfg.clock)
return nil
}
}
// WithRegistry configures the registry for the server to use when creating
// cursors.
func WithRegistry(fn func(*bsoncodec.Registry) *bsoncodec.Registry) ServerOption {
return func(cfg *serverConfig) error {
cfg.registry = fn(cfg.registry)
return nil
}
}

View File

@@ -0,0 +1,624 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package topology contains types that handles the discovery, monitoring, and selection
// of servers. This package is designed to expose enough inner workings of service discovery
// and monitoring to allow low level applications to have fine grained control, while hiding
// most of the detailed implementation of the algorithms.
package topology // import "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
import (
"context"
"errors"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"fmt"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
// closed Server or Topology.
var ErrSubscribeAfterClosed = errors.New("cannot subscribe after close")
// ErrTopologyClosed is returned when a user attempts to call a method on a
// closed Topology.
var ErrTopologyClosed = errors.New("topology is closed")
// ErrTopologyConnected is returned whena user attempts to connect to an
// already connected Topology.
var ErrTopologyConnected = errors.New("topology is connected or connecting")
// ErrServerSelectionTimeout is returned from server selection when the server
// selection process took longer than allowed by the timeout.
var ErrServerSelectionTimeout = errors.New("server selection timeout")
// MonitorMode represents the way in which a server is monitored.
type MonitorMode uint8
// These constants are the available monitoring modes.
const (
AutomaticMode MonitorMode = iota
SingleMode
)
// Topology represents a MongoDB deployment.
type Topology struct {
connectionstate int32
cfg *config
desc atomic.Value // holds a description.Topology
dnsResolver *dns.Resolver
done chan struct{}
pollingDone chan struct{}
pollingwg sync.WaitGroup
rescanSRVInterval time.Duration
pollHeartbeatTime atomic.Value // holds a bool
fsm *fsm
SessionPool *session.Pool
// This should really be encapsulated into it's own type. This will likely
// require a redesign so we can share a minimum of data between the
// subscribers and the topology.
subscribers map[uint64]chan description.Topology
currentSubscriberID uint64
subscriptionsClosed bool
subLock sync.Mutex
// We should redesign how we connect and handle individal servers. This is
// too difficult to maintain and it's rather easy to accidentally access
// the servers without acquiring the lock or checking if the servers are
// closed. This lock should also be an RWMutex.
serversLock sync.Mutex
serversClosed bool
servers map[address.Address]*Server
}
// New creates a new topology.
func New(opts ...Option) (*Topology, error) {
cfg, err := newConfig(opts...)
if err != nil {
return nil, err
}
t := &Topology{
cfg: cfg,
done: make(chan struct{}),
pollingDone: make(chan struct{}),
rescanSRVInterval: 60 * time.Second,
fsm: newFSM(),
subscribers: make(map[uint64]chan description.Topology),
servers: make(map[address.Address]*Server),
dnsResolver: dns.DefaultResolver,
}
t.desc.Store(description.Topology{})
if cfg.replicaSetName != "" {
t.fsm.SetName = cfg.replicaSetName
t.fsm.Kind = description.ReplicaSetNoPrimary
}
if cfg.mode == SingleMode {
t.fsm.Kind = description.Single
}
return t, nil
}
// Connect initializes a Topology and starts the monitoring process. This function
// must be called to properly monitor the topology.
func (t *Topology) Connect() error {
if !atomic.CompareAndSwapInt32(&t.connectionstate, disconnected, connecting) {
return ErrTopologyConnected
}
t.desc.Store(description.Topology{})
var err error
t.serversLock.Lock()
for _, a := range t.cfg.seedList {
addr := address.Address(a).Canonicalize()
t.fsm.Servers = append(t.fsm.Servers, description.Server{Addr: addr})
err = t.addServer(addr)
}
t.serversLock.Unlock()
if srvPollingRequired(t.cfg.cs.Original) {
go t.pollSRVRecords()
t.pollingwg.Add(1)
}
t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
atomic.StoreInt32(&t.connectionstate, connected)
// After connection, make a subscription to keep the pool updated
sub, err := t.Subscribe()
t.SessionPool = session.NewPool(sub.C)
return err
}
// Disconnect closes the topology. It stops the monitoring thread and
// closes all open subscriptions.
func (t *Topology) Disconnect(ctx context.Context) error {
if !atomic.CompareAndSwapInt32(&t.connectionstate, connected, disconnecting) {
return ErrTopologyClosed
}
servers := make(map[address.Address]*Server)
t.serversLock.Lock()
t.serversClosed = true
for addr, server := range t.servers {
servers[addr] = server
}
t.serversLock.Unlock()
for _, server := range servers {
_ = server.Disconnect(ctx)
}
t.subLock.Lock()
for id, ch := range t.subscribers {
close(ch)
delete(t.subscribers, id)
}
t.subscriptionsClosed = true
t.subLock.Unlock()
if srvPollingRequired(t.cfg.cs.Original) {
t.pollingDone <- struct{}{}
t.pollingwg.Wait()
}
t.desc.Store(description.Topology{})
atomic.StoreInt32(&t.connectionstate, disconnected)
return nil
}
func srvPollingRequired(connstr string) bool {
return strings.HasPrefix(connstr, "mongodb+srv://")
}
// Description returns a description of the topology.
func (t *Topology) Description() description.Topology {
td, ok := t.desc.Load().(description.Topology)
if !ok {
td = description.Topology{}
}
return td
}
// Kind returns the topology kind of this Topology.
func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind }
// Subscribe returns a Subscription on which all updated description.Topologys
// will be sent. The channel of the subscription will have a buffer size of one,
// and will be pre-populated with the current description.Topology.
func (t *Topology) Subscribe() (*Subscription, error) {
if atomic.LoadInt32(&t.connectionstate) != connected {
return nil, errors.New("cannot subscribe to Topology that is not connected")
}
ch := make(chan description.Topology, 1)
td, ok := t.desc.Load().(description.Topology)
if !ok {
td = description.Topology{}
}
ch <- td
t.subLock.Lock()
defer t.subLock.Unlock()
if t.subscriptionsClosed {
return nil, ErrSubscribeAfterClosed
}
id := t.currentSubscriberID
t.subscribers[id] = ch
t.currentSubscriberID++
return &Subscription{
C: ch,
t: t,
id: id,
}, nil
}
// RequestImmediateCheck will send heartbeats to all the servers in the
// topology right away, instead of waiting for the heartbeat timeout.
func (t *Topology) RequestImmediateCheck() {
if atomic.LoadInt32(&t.connectionstate) != connected {
return
}
t.serversLock.Lock()
for _, server := range t.servers {
server.RequestImmediateCheck()
}
t.serversLock.Unlock()
}
// SupportsSessions returns true if the topology supports sessions.
func (t *Topology) SupportsSessions() bool {
return t.Description().SessionTimeoutMinutes != 0 && t.Description().Kind != description.Single
}
// SupportsRetry returns true if the topology supports retryability, which it does if it supports sessions.
func (t *Topology) SupportsRetry() bool { return t.SupportsSessions() }
// SelectServer selects a server with given a selector. SelectServer complies with the
// server selection spec, and will time out after severSelectionTimeout or when the
// parent context is done.
func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) {
if atomic.LoadInt32(&t.connectionstate) != connected {
return nil, ErrTopologyClosed
}
var ssTimeoutCh <-chan time.Time
if t.cfg.serverSelectionTimeout > 0 {
ssTimeout := time.NewTimer(t.cfg.serverSelectionTimeout)
ssTimeoutCh = ssTimeout.C
defer ssTimeout.Stop()
}
sub, err := t.Subscribe()
if err != nil {
return nil, err
}
defer sub.Unsubscribe()
for {
suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh)
if err != nil {
return nil, err
}
selected := suitable[rand.Intn(len(suitable))]
selectedS, err := t.FindServer(selected)
switch {
case err != nil:
return nil, err
case selectedS != nil:
return selectedS, nil
default:
// We don't have an actual server for the provided description.
// This could happen for a number of reasons, including that the
// server has since stopped being a part of this topology, or that
// the server selector returned no suitable servers.
}
}
}
// SelectServerLegacy selects a server with given a selector. SelectServerLegacy complies with the
// server selection spec, and will time out after severSelectionTimeout or when the
// parent context is done.
func (t *Topology) SelectServerLegacy(ctx context.Context, ss description.ServerSelector) (*SelectedServer, error) {
if atomic.LoadInt32(&t.connectionstate) != connected {
return nil, ErrTopologyClosed
}
var ssTimeoutCh <-chan time.Time
if t.cfg.serverSelectionTimeout > 0 {
ssTimeout := time.NewTimer(t.cfg.serverSelectionTimeout)
ssTimeoutCh = ssTimeout.C
defer ssTimeout.Stop()
}
sub, err := t.Subscribe()
if err != nil {
return nil, err
}
defer sub.Unsubscribe()
for {
suitable, err := t.selectServer(ctx, sub.C, ss, ssTimeoutCh)
if err != nil {
return nil, err
}
selected := suitable[rand.Intn(len(suitable))]
selectedS, err := t.FindServer(selected)
switch {
case err != nil:
return nil, err
case selectedS != nil:
return selectedS, nil
default:
// We don't have an actual server for the provided description.
// This could happen for a number of reasons, including that the
// server has since stopped being a part of this topology, or that
// the server selector returned no suitable servers.
}
}
}
// FindServer will attempt to find a server that fits the given server description.
// This method will return nil, nil if a matching server could not be found.
func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
if atomic.LoadInt32(&t.connectionstate) != connected {
return nil, ErrTopologyClosed
}
t.serversLock.Lock()
defer t.serversLock.Unlock()
server, ok := t.servers[selected.Addr]
if !ok {
return nil, nil
}
desc := t.Description()
return &SelectedServer{
Server: server,
Kind: desc.Kind,
}, nil
}
func wrapServerSelectionError(err error, t *Topology) error {
return fmt.Errorf("server selection error: %v\ncurrent topology: %s", err, t.String())
}
// selectServer is the core piece of server selection. It handles getting
// topology descriptions and running sever selection on those descriptions.
func (t *Topology) selectServer(ctx context.Context, subscriptionCh <-chan description.Topology, ss description.ServerSelector, timeoutCh <-chan time.Time) ([]description.Server, error) {
var current description.Topology
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timeoutCh:
return nil, wrapServerSelectionError(ErrServerSelectionTimeout, t)
case current = <-subscriptionCh:
}
var allowed []description.Server
for _, s := range current.Servers {
if s.Kind != description.Unknown {
allowed = append(allowed, s)
}
}
suitable, err := ss.SelectServer(current, allowed)
if err != nil {
return nil, wrapServerSelectionError(err, t)
}
if len(suitable) > 0 {
return suitable, nil
}
t.RequestImmediateCheck()
}
}
func (t *Topology) pollSRVRecords() {
defer t.pollingwg.Done()
serverConfig, _ := newServerConfig(t.cfg.serverOpts...)
heartbeatInterval := serverConfig.heartbeatInterval
pollTicker := time.NewTicker(t.rescanSRVInterval)
defer pollTicker.Stop()
t.pollHeartbeatTime.Store(false)
var doneOnce bool
defer func() {
// ¯\_(ツ)_/¯
if r := recover(); r != nil && !doneOnce {
<-t.pollingDone
}
}()
// remove the scheme
uri := t.cfg.cs.Original[14:]
hosts := uri
if idx := strings.IndexAny(uri, "/?@"); idx != -1 {
hosts = uri[:idx]
}
for {
select {
case <-pollTicker.C:
case <-t.pollingDone:
doneOnce = true
return
}
topoKind := t.Description().Kind
if !(topoKind == description.Unknown || topoKind == description.Sharded) {
break
}
parsedHosts, err := t.dnsResolver.ParseHosts(hosts, false)
// DNS problem or no verified hosts returned
if err != nil || len(parsedHosts) == 0 {
if !t.pollHeartbeatTime.Load().(bool) {
pollTicker.Stop()
pollTicker = time.NewTicker(heartbeatInterval)
t.pollHeartbeatTime.Store(true)
}
continue
}
if t.pollHeartbeatTime.Load().(bool) {
pollTicker.Stop()
pollTicker = time.NewTicker(t.rescanSRVInterval)
t.pollHeartbeatTime.Store(false)
}
cont := t.processSRVResults(parsedHosts)
if !cont {
break
}
}
<-t.pollingDone
doneOnce = true
}
func (t *Topology) processSRVResults(parsedHosts []string) bool {
t.serversLock.Lock()
defer t.serversLock.Unlock()
if t.serversClosed {
return false
}
diff := t.fsm.Topology.DiffHostlist(parsedHosts)
if len(diff.Added) == 0 && len(diff.Removed) == 0 {
return true
}
for _, r := range diff.Removed {
addr := address.Address(r).Canonicalize()
s, ok := t.servers[addr]
if !ok {
continue
}
go func() {
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
_ = s.Disconnect(cancelCtx)
}()
delete(t.servers, addr)
t.fsm.removeServerByAddr(addr)
}
for _, a := range diff.Added {
addr := address.Address(a).Canonicalize()
_ = t.addServer(addr)
t.fsm.addServer(addr)
}
//store new description
newDesc := description.Topology{
Kind: t.fsm.Kind,
Servers: t.fsm.Servers,
SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
}
t.desc.Store(newDesc)
t.subLock.Lock()
for _, ch := range t.subscribers {
// We drain the description if there's one in the channel
select {
case <-ch:
default:
}
ch <- newDesc
}
t.subLock.Unlock()
return true
}
func (t *Topology) apply(ctx context.Context, desc description.Server) {
var err error
t.serversLock.Lock()
defer t.serversLock.Unlock()
if _, ok := t.servers[desc.Addr]; t.serversClosed || !ok {
return
}
prev := t.fsm.Topology
current, err := t.fsm.apply(desc)
if err != nil {
return
}
diff := description.DiffTopology(prev, current)
for _, removed := range diff.Removed {
if s, ok := t.servers[removed.Addr]; ok {
go func() {
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
_ = s.Disconnect(cancelCtx)
}()
delete(t.servers, removed.Addr)
}
}
for _, added := range diff.Added {
_ = t.addServer(added.Addr)
}
t.desc.Store(current)
t.subLock.Lock()
for _, ch := range t.subscribers {
// We drain the description if there's one in the channel
select {
case <-ch:
default:
}
ch <- current
}
t.subLock.Unlock()
}
func (t *Topology) addServer(addr address.Address) error {
if _, ok := t.servers[addr]; ok {
return nil
}
topoFunc := func(desc description.Server) {
t.apply(context.TODO(), desc)
}
svr, err := ConnectServer(addr, topoFunc, t.cfg.serverOpts...)
if err != nil {
return err
}
t.servers[addr] = svr
return nil
}
// String implements the Stringer interface
func (t *Topology) String() string {
desc := t.Description()
str := fmt.Sprintf("Type: %s\nServers:\n", desc.Kind)
for _, s := range t.servers {
str += s.String() + "\n"
}
return str
}
// Subscription is a subscription to updates to the description of the Topology that created this
// Subscription.
type Subscription struct {
C <-chan description.Topology
t *Topology
id uint64
}
// Unsubscribe unsubscribes this Subscription from updates and closes the
// subscription channel.
func (s *Subscription) Unsubscribe() error {
s.t.subLock.Lock()
defer s.t.subLock.Unlock()
if s.t.subscriptionsClosed {
return nil
}
ch, ok := s.t.subscribers[s.id]
if !ok {
return nil
}
close(ch)
delete(s.t.subscribers, s.id)
return nil
}

View File

@@ -0,0 +1,386 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"strings"
"time"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// Option is a configuration option for a topology.
type Option func(*config) error
type config struct {
mode MonitorMode
replicaSetName string
seedList []string
serverOpts []ServerOption
cs connstring.ConnString
serverSelectionTimeout time.Duration
}
func newConfig(opts ...Option) (*config, error) {
cfg := &config{
seedList: []string{"localhost:27017"},
serverSelectionTimeout: 30 * time.Second,
}
for _, opt := range opts {
err := opt(cfg)
if err != nil {
return nil, err
}
}
return cfg, nil
}
// WithConnString configures the topology using the connection string.
func WithConnString(fn func(connstring.ConnString) connstring.ConnString) Option {
return func(c *config) error {
cs := fn(c.cs)
c.cs = cs
if cs.ServerSelectionTimeoutSet {
c.serverSelectionTimeout = cs.ServerSelectionTimeout
}
var connOpts []ConnectionOption
if cs.AppName != "" {
connOpts = append(connOpts, WithAppName(func(string) string { return cs.AppName }))
}
switch cs.Connect {
case connstring.SingleConnect:
c.mode = SingleMode
}
c.seedList = cs.Hosts
if cs.ConnectTimeout > 0 {
c.serverOpts = append(c.serverOpts, WithHeartbeatTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
connOpts = append(connOpts, WithConnectTimeout(func(time.Duration) time.Duration { return cs.ConnectTimeout }))
}
if cs.SocketTimeoutSet {
connOpts = append(
connOpts,
WithReadTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
WithWriteTimeout(func(time.Duration) time.Duration { return cs.SocketTimeout }),
)
}
if cs.HeartbeatInterval > 0 {
c.serverOpts = append(c.serverOpts, WithHeartbeatInterval(func(time.Duration) time.Duration { return cs.HeartbeatInterval }))
}
if cs.MaxConnIdleTime > 0 {
connOpts = append(connOpts, WithIdleTimeout(func(time.Duration) time.Duration { return cs.MaxConnIdleTime }))
}
if cs.MaxPoolSizeSet {
c.serverOpts = append(c.serverOpts, WithMaxConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
c.serverOpts = append(c.serverOpts, WithMaxIdleConnections(func(uint16) uint16 { return cs.MaxPoolSize }))
}
if cs.ReplicaSet != "" {
c.replicaSetName = cs.ReplicaSet
}
var x509Username string
if cs.SSL {
tlsConfig := new(tls.Config)
if cs.SSLCaFileSet {
err := addCACertFromFile(tlsConfig, cs.SSLCaFile)
if err != nil {
return err
}
}
if cs.SSLInsecure {
tlsConfig.InsecureSkipVerify = true
}
if cs.SSLClientCertificateKeyFileSet {
var keyPasswd string
if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
keyPasswd = cs.SSLClientCertificateKeyPassword()
}
s, err := addClientCertFromFile(tlsConfig, cs.SSLClientCertificateKeyFile, keyPasswd)
if err != nil {
return err
}
// The Go x509 package gives the subject with the pairs in reverse order that we want.
pairs := strings.Split(s, ",")
b := bytes.NewBufferString("")
for i := len(pairs) - 1; i >= 0; i-- {
b.WriteString(pairs[i])
if i > 0 {
b.WriteString(",")
}
}
x509Username = b.String()
}
connOpts = append(connOpts, WithTLSConfig(func(*tls.Config) *tls.Config { return tlsConfig }))
}
if cs.Username != "" || cs.AuthMechanism == auth.MongoDBX509 || cs.AuthMechanism == auth.GSSAPI {
cred := &auth.Cred{
Source: "admin",
Username: cs.Username,
Password: cs.Password,
PasswordSet: cs.PasswordSet,
Props: cs.AuthMechanismProperties,
}
if cs.AuthSource != "" {
cred.Source = cs.AuthSource
} else {
switch cs.AuthMechanism {
case auth.MongoDBX509:
if cred.Username == "" {
cred.Username = x509Username
}
fallthrough
case auth.GSSAPI, auth.PLAIN:
cred.Source = "$external"
default:
cred.Source = cs.Database
}
}
authenticator, err := auth.CreateAuthenticator(cs.AuthMechanism, cred)
if err != nil {
return err
}
connOpts = append(connOpts, WithHandshaker(func(h Handshaker) Handshaker {
options := &auth.HandshakeOptions{
AppName: cs.AppName,
Authenticator: authenticator,
Compressors: cs.Compressors,
}
if cs.AuthMechanism == "" {
// Required for SASL mechanism negotiation during handshake
options.DBUser = cred.Source + "." + cred.Username
}
return auth.Handshaker(h, options)
}))
} else {
// We need to add a non-auth Handshaker to the connection options
connOpts = append(connOpts, WithHandshaker(func(h driver.Handshaker) driver.Handshaker {
return operation.NewIsMaster().AppName(cs.AppName).Compressors(cs.Compressors)
}))
}
if len(cs.Compressors) > 0 {
connOpts = append(connOpts, WithCompressors(func(compressors []string) []string {
return append(compressors, cs.Compressors...)
}))
for _, comp := range cs.Compressors {
if comp == "zlib" {
connOpts = append(connOpts, WithZlibLevel(func(level *int) *int {
return &cs.ZlibLevel
}))
}
}
c.serverOpts = append(c.serverOpts, WithCompressionOptions(func(opts ...string) []string {
return append(opts, cs.Compressors...)
}))
}
if len(connOpts) > 0 {
c.serverOpts = append(c.serverOpts, WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
return append(opts, connOpts...)
}))
}
return nil
}
}
// WithMode configures the topology's monitor mode.
func WithMode(fn func(MonitorMode) MonitorMode) Option {
return func(cfg *config) error {
cfg.mode = fn(cfg.mode)
return nil
}
}
// WithReplicaSetName configures the topology's default replica set name.
func WithReplicaSetName(fn func(string) string) Option {
return func(cfg *config) error {
cfg.replicaSetName = fn(cfg.replicaSetName)
return nil
}
}
// WithSeedList configures a topology's seed list.
func WithSeedList(fn func(...string) []string) Option {
return func(cfg *config) error {
cfg.seedList = fn(cfg.seedList...)
return nil
}
}
// WithServerOptions configures a topology's server options for when a new server
// needs to be created.
func WithServerOptions(fn func(...ServerOption) []ServerOption) Option {
return func(cfg *config) error {
cfg.serverOpts = fn(cfg.serverOpts...)
return nil
}
}
// WithServerSelectionTimeout configures a topology's server selection timeout.
// A server selection timeout of 0 means there is no timeout for server selection.
func WithServerSelectionTimeout(fn func(time.Duration) time.Duration) Option {
return func(cfg *config) error {
cfg.serverSelectionTimeout = fn(cfg.serverSelectionTimeout)
return nil
}
}
// addCACertFromFile adds a root CA certificate to the configuration given a path
// to the containing file.
func addCACertFromFile(cfg *tls.Config, file string) error {
data, err := ioutil.ReadFile(file)
if err != nil {
return err
}
certBytes, err := loadCert(data)
if err != nil {
return err
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return err
}
if cfg.RootCAs == nil {
cfg.RootCAs = x509.NewCertPool()
}
cfg.RootCAs.AddCert(cert)
return nil
}
func loadCert(data []byte) ([]byte, error) {
var certBlock *pem.Block
for certBlock == nil {
if data == nil || len(data) == 0 {
return nil, errors.New(".pem file must have both a CERTIFICATE and an RSA PRIVATE KEY section")
}
block, rest := pem.Decode(data)
if block == nil {
return nil, errors.New("invalid .pem file")
}
switch block.Type {
case "CERTIFICATE":
if certBlock != nil {
return nil, errors.New("multiple CERTIFICATE sections in .pem file")
}
certBlock = block
}
data = rest
}
return certBlock.Bytes, nil
}
// addClientCertFromFile adds a client certificate to the configuration given a path to the
// containing file and returns the certificate's subject name.
func addClientCertFromFile(cfg *tls.Config, clientFile, keyPasswd string) (string, error) {
data, err := ioutil.ReadFile(clientFile)
if err != nil {
return "", err
}
var currentBlock *pem.Block
var certBlock, certDecodedBlock, keyBlock []byte
remaining := data
start := 0
for {
currentBlock, remaining = pem.Decode(remaining)
if currentBlock == nil {
break
}
if currentBlock.Type == "CERTIFICATE" {
certBlock = data[start : len(data)-len(remaining)]
certDecodedBlock = currentBlock.Bytes
start += len(certBlock)
} else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") {
if keyPasswd != "" && x509.IsEncryptedPEMBlock(currentBlock) {
var encoded bytes.Buffer
buf, err := x509.DecryptPEMBlock(currentBlock, []byte(keyPasswd))
if err != nil {
return "", err
}
pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf})
keyBlock = encoded.Bytes()
start = len(data) - len(remaining)
} else {
keyBlock = data[start : len(data)-len(remaining)]
start += len(keyBlock)
}
}
}
if len(certBlock) == 0 {
return "", fmt.Errorf("failed to find CERTIFICATE")
}
if len(keyBlock) == 0 {
return "", fmt.Errorf("failed to find PRIVATE KEY")
}
cert, err := tls.X509KeyPair(certBlock, keyBlock)
if err != nil {
return "", err
}
cfg.Certificates = append(cfg.Certificates, cert)
// The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not
// retained.
crt, err := x509.ParseCertificate(certDecodedBlock)
if err != nil {
return "", err
}
return x509CertSubject(crt), nil
}

View File

@@ -0,0 +1,9 @@
// +build go1.10
package topology
import "crypto/x509"
func x509CertSubject(cert *x509.Certificate) string {
return cert.Subject.String()
}

View File

@@ -0,0 +1,13 @@
// +build !go1.10
package topology
import (
"crypto/x509"
)
// We don't support version less then 1.10, but Evergreen needs to be able to compile the driver
// using version 1.8.
func x509CertSubject(cert *x509.Certificate) string {
return ""
}