Files
with/vendor/github.com/lib/pq/conn.go
Bel LaPointe 886c4aabff vendor
2026-03-09 09:42:09 -06:00

1776 lines
45 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package pq
import (
"bufio"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"os"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/lib/pq/internal/pgpass"
"github.com/lib/pq/internal/pqsql"
"github.com/lib/pq/internal/pqutil"
"github.com/lib/pq/internal/proto"
"github.com/lib/pq/oid"
"github.com/lib/pq/scram"
)
// Common error types
var (
ErrNotSupported = errors.New("pq: unsupported command")
ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction")
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly")
ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership
ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions
errUnexpectedReady = errors.New("unexpected ReadyForQuery")
errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
)
// Compile time validation that our types implement the expected interfaces
var (
_ driver.Driver = Driver{}
_ driver.ConnBeginTx = (*conn)(nil)
_ driver.ConnPrepareContext = (*conn)(nil)
_ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x
_ driver.ExecerContext = (*conn)(nil)
_ driver.NamedValueChecker = (*conn)(nil)
_ driver.Pinger = (*conn)(nil)
_ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x
_ driver.QueryerContext = (*conn)(nil)
_ driver.SessionResetter = (*conn)(nil)
_ driver.Validator = (*conn)(nil)
_ driver.StmtExecContext = (*stmt)(nil)
_ driver.StmtQueryContext = (*stmt)(nil)
)
func init() {
sql.Register("postgres", &Driver{})
}
var debugProto = func() bool {
// Check for exactly "1" (rather than mere existence) so we can add
// options/flags in the future. I don't know if we ever want that, but it's
// nice to leave the option open.
return os.Getenv("PQGO_DEBUG") == "1"
}()
// Driver is the Postgres database driver.
type Driver struct{}
// Open opens a new connection to the database. name is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func (d Driver) Open(name string) (driver.Conn, error) {
return Open(name)
}
// Parameters sent by PostgreSQL on startup.
type parameterStatus struct {
serverVersion int
currentLocation *time.Location
inHotStandby, defaultTransactionReadOnly sql.NullBool
}
type format int
const (
formatText format = 0
formatBinary format = 1
)
var (
// One result-column format code with the value 1 (i.e. all binary).
colFmtDataAllBinary = []byte{0, 1, 0, 1}
// No result-column format codes (i.e. all text).
colFmtDataAllText = []byte{0, 0}
)
type transactionStatus byte
const (
txnStatusIdle transactionStatus = 'I'
txnStatusIdleInTransaction transactionStatus = 'T'
txnStatusInFailedTransaction transactionStatus = 'E'
)
func (s transactionStatus) String() string {
switch s {
case txnStatusIdle:
return "idle"
case txnStatusIdleInTransaction:
return "idle in transaction"
case txnStatusInFailedTransaction:
return "in a failed transaction"
default:
panic(fmt.Sprintf("pq: unknown transactionStatus %d", s))
}
}
// Dialer is the dialer interface. It can be used to obtain more control over
// how pq creates network connections.
type Dialer interface {
Dial(network, address string) (net.Conn, error)
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
// DialerContext is the context-aware dialer interface.
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type defaultDialer struct {
d net.Dialer
}
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
scratch [512]byte
txnStatus transactionStatus
txnFinish func()
// Save connection arguments to use during CancelRequest.
dialer Dialer
cfg Config
parameterStatus parameterStatus
saveMessageType proto.ResponseCode
saveMessageBuffer []byte
// If an error is set, this connection is bad and all public-facing
// functions should return the appropriate error by calling get()
// (ErrBadConn) or getForNext().
err syncErr
processID, secretKey int // Cancellation key data for use with CancelRequest messages.
inCopy bool // If true this connection is in the middle of a COPY
noticeHandler func(*Error) // If not nil, notices will be synchronously sent here
notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here
gss GSS // GSSAPI context
}
type syncErr struct {
err error
sync.Mutex
}
// Return ErrBadConn if connection is bad.
func (e *syncErr) get() error {
e.Lock()
defer e.Unlock()
if e.err != nil {
return driver.ErrBadConn
}
return nil
}
// Return the error set on the connection. Currently only used by rows.Next.
func (e *syncErr) getForNext() error {
e.Lock()
defer e.Unlock()
return e.err
}
// Set error, only if it isn't set yet.
func (e *syncErr) set(err error) {
if err == nil {
panic("attempt to set nil err")
}
e.Lock()
defer e.Unlock()
if e.err == nil {
e.err = err
}
}
func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf {
cn.scratch[0] = byte(b)
return &writeBuf{
buf: cn.scratch[:5],
pos: 1,
}
}
// Open opens a new connection to the database. dsn is a connection string. Most
// users should only use it through database/sql package from the standard
// library.
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
}
// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.Dialer(d)
return c.open(context.Background())
}
func (c *Connector) open(ctx context.Context) (*conn, error) {
tsa := c.cfg.TargetSessionAttrs
restart:
var (
errs []error
app = func(err error, cfg Config) bool {
if err != nil {
if debugProto {
fmt.Println("CONNECT (error)", err)
}
errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err))
}
return err != nil
}
)
for _, cfg := range c.cfg.hosts() {
if debugProto {
fmt.Println("CONNECT ", cfg.string())
}
cn := &conn{cfg: cfg, dialer: c.dialer}
cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password,
cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password"))
var err error
cn.c, err = dial(ctx, c.dialer, cn.cfg)
if app(err, cfg) {
continue
}
err = cn.ssl(cn.cfg)
if app(err, cfg) {
if cn.c != nil {
_ = cn.c.Close()
}
continue
}
cn.buf = bufio.NewReader(cn.c)
err = cn.startup(cn.cfg)
if app(err, cfg) {
_ = cn.c.Close()
continue
}
// Reset the deadline, in case one was set (see dial)
if cn.cfg.ConnectTimeout > 0 {
err := cn.c.SetDeadline(time.Time{})
if app(err, cfg) {
_ = cn.c.Close()
continue
}
}
err = cn.checkTSA(tsa)
if app(err, cfg) {
_ = cn.c.Close()
continue
}
return cn, nil
}
// target_session_attrs=prefer-standby is treated as standby in checkTSA; we
// ran out of hosts so none are on standby. Clear the setting and try again.
if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby {
tsa = TargetSessionAttrsAny
goto restart
}
if len(c.cfg.Multi) == 0 {
// Remove the "connecting to [..]" when we have just one host, so the
// error is identical to what we had before.
return nil, errors.Unwrap(errs[0])
}
return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...))
}
func (cn *conn) getBool(query string) (bool, error) {
res, err := cn.simpleQuery(query)
if err != nil {
return false, err
}
defer res.Close()
v := make([]driver.Value, 1)
err = res.Next(v)
if err != nil {
return false, err
}
switch vv := v[0].(type) {
default:
return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0])
case bool:
return vv, nil
case string:
vv, ok := v[0].(string)
if !ok {
return false, err
}
return vv == "on", nil
}
}
func (cn *conn) checkTSA(tsa TargetSessionAttrs) error {
var (
geths = func() (hs bool, err error) {
hs = cn.parameterStatus.inHotStandby.Bool
if !cn.parameterStatus.inHotStandby.Valid {
hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()")
}
return hs, err
}
getro = func() (ro bool, err error) {
ro = cn.parameterStatus.defaultTransactionReadOnly.Bool
if !cn.parameterStatus.defaultTransactionReadOnly.Valid {
ro, err = cn.getBool("show transaction_read_only")
}
return ro, err
}
)
switch tsa {
default:
panic("unreachable")
case "", TargetSessionAttrsAny:
return nil
case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly:
readonly, err := getro()
if err != nil {
return err
}
if !cn.parameterStatus.defaultTransactionReadOnly.Valid {
var err error
readonly, err = cn.getBool("show transaction_read_only")
if err != nil {
return err
}
}
switch {
case tsa == TargetSessionAttrsReadOnly && !readonly:
return errors.New("session is not read-only")
case tsa == TargetSessionAttrsReadWrite:
if readonly {
return errors.New("session is read-only")
}
hs, err := geths()
if err != nil {
return err
}
if hs {
return errors.New("server is in hot standby mode")
}
return nil
default:
return nil
}
case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby:
hs, err := geths()
if err != nil {
return err
}
switch {
case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs:
return errors.New("server is not in hot standby mode")
case tsa == TargetSessionAttrsPrimary && hs:
return errors.New("server is in hot standby mode")
default:
return nil
}
}
}
func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) {
network, address := cfg.network()
// Zero or not specified means wait indefinitely.
if cfg.ConnectTimeout > 0 {
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake. The
// deadline is then reset after startup() is done.
var (
deadline = time.Now().Add(cfg.ConnectTimeout)
conn net.Conn
err error
)
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout)
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
}
func (cn *conn) isInTransaction() bool {
return cn.txnStatus == txnStatusIdleInTransaction ||
cn.txnStatus == txnStatusInFailedTransaction
}
func (cn *conn) checkIsInTransaction(intxn bool) error {
if cn.isInTransaction() != intxn {
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus)
}
return nil
}
func (cn *conn) Begin() (_ driver.Tx, err error) {
return cn.begin("")
}
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
if err := cn.checkIsInTransaction(false); err != nil {
return nil, err
}
_, commandTag, err := cn.simpleExec("BEGIN" + mode)
if err != nil {
return nil, cn.handleError(err)
}
if commandTag != "BEGIN" {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("unexpected command tag %s", commandTag)
}
if cn.txnStatus != txnStatusIdleInTransaction {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
}
return cn, nil
}
func (cn *conn) closeTxn() {
if finish := cn.txnFinish; finish != nil {
finish()
}
}
func (cn *conn) Commit() error {
defer cn.closeTxn()
if err := cn.err.get(); err != nil {
return err
}
if err := cn.checkIsInTransaction(true); err != nil {
return err
}
// We don't want the client to think that everything is okay if it tries
// to commit a failed transaction. However, no matter what we return,
// database/sql will release this connection back into the free connection
// pool so we have to abort the current transaction here. Note that you
// would get the same behaviour if you issued a COMMIT in a failed
// transaction, so it's also the least surprising thing to do here.
if cn.txnStatus == txnStatusInFailedTransaction {
if err := cn.rollback(); err != nil {
return err
}
return ErrInFailedTransaction
}
_, commandTag, err := cn.simpleExec("COMMIT")
if err != nil {
if cn.isInTransaction() {
cn.err.set(driver.ErrBadConn)
}
return cn.handleError(err)
}
if commandTag != "COMMIT" {
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("unexpected command tag %s", commandTag)
}
return cn.checkIsInTransaction(false)
}
func (cn *conn) Rollback() error {
defer cn.closeTxn()
if err := cn.err.get(); err != nil {
return err
}
err := cn.rollback()
if err != nil {
return cn.handleError(err)
}
return nil
}
func (cn *conn) rollback() (err error) {
if err := cn.checkIsInTransaction(true); err != nil {
return err
}
_, commandTag, err := cn.simpleExec("ROLLBACK")
if err != nil {
if cn.isInTransaction() {
cn.err.set(driver.ErrBadConn)
}
return err
}
if commandTag != "ROLLBACK" {
return fmt.Errorf("unexpected command tag %s", commandTag)
}
return cn.checkIsInTransaction(false)
}
func (cn *conn) gname() string {
cn.namei++
return strconv.FormatInt(int64(cn.namei), 10)
}
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.simpleExec\n")
defer fmt.Fprintf(os.Stderr, " END conn.simpleExec\n")
}
b := cn.writeBuf(proto.Query)
b.string(q)
err := cn.send(b)
if err != nil {
return nil, "", err
}
for {
t, r, err := cn.recv1()
if err != nil {
return nil, "", err
}
switch t {
case proto.CommandComplete:
res, commandTag, err = cn.parseComplete(r.string())
if err != nil {
return nil, "", err
}
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
if res == nil && resErr == nil {
resErr = errUnexpectedReady
}
return res, commandTag, resErr
case proto.ErrorResponse:
resErr = parseError(r, q)
case proto.EmptyQueryResponse:
res = emptyRows
case proto.RowDescription, proto.DataRow:
// ignore any results
default:
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t)
}
}
}
func (cn *conn) simpleQuery(q string) (*rows, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.simpleQuery\n")
defer fmt.Fprintf(os.Stderr, " END conn.simpleQuery\n")
}
b := cn.writeBuf(proto.Query)
b.string(q)
err := cn.send(b)
if err != nil {
return nil, cn.handleError(err, q)
}
var (
res *rows
resErr error
)
for {
t, r, err := cn.recv1()
if err != nil {
return nil, cn.handleError(err, q)
}
switch t {
case proto.CommandComplete, proto.EmptyQueryResponse:
// We allow queries which don't return any results through Query as
// well as Exec. We still have to give database/sql a rows object
// the user can close, though, to avoid connections from being
// leaked. A "rows" with done=true works fine for that purpose.
if resErr != nil {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t)
}
if res == nil {
res = &rows{cn: cn}
}
// Set the result and tag to the last command complete if there wasn't a
// query already run. Although queries usually return from here and cede
// control to Next, a query with zero results does not.
if t == proto.CommandComplete {
res.result, res.tag, err = cn.parseComplete(r.string())
if err != nil {
return nil, cn.handleError(err, q)
}
if res.colNames != nil {
return res, cn.handleError(resErr, q)
}
}
res.done = true
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
if err == nil && res == nil {
res = &rows{done: true}
}
return res, cn.handleError(resErr, q) // done
case proto.ErrorResponse:
res = nil
resErr = parseError(r, q)
case proto.DataRow:
if res == nil {
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution")
}
return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next
case proto.RowDescription:
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine and just overwrite it.
res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)}
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
default:
cn.err.set(driver.ErrBadConn)
return nil, fmt.Errorf("pq: unknown response for simple query: %q", t)
}
}
}
// Decides which column formats to use for a prepared statement. The input is
// an array of type oids, one element per result column.
func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) {
if len(colTyps) == 0 {
return nil, colFmtDataAllText, nil
}
colFmts = make([]format, len(colTyps))
if forceText {
return colFmts, colFmtDataAllText, nil
}
allBinary := true
allText := true
for i, t := range colTyps {
switch t.OID {
// This is the list of types to use binary mode for when receiving them
// through a prepared statement. If a type appears in this list, it
// must also be implemented in binaryDecode in encode.go.
case oid.T_bytea:
fallthrough
case oid.T_int8:
fallthrough
case oid.T_int4:
fallthrough
case oid.T_int2:
fallthrough
case oid.T_uuid:
colFmts[i] = formatBinary
allText = false
default:
allBinary = false
}
}
if allBinary {
return colFmts, colFmtDataAllBinary, nil
} else if allText {
return colFmts, colFmtDataAllText, nil
} else {
colFmtData = make([]byte, 2+len(colFmts)*2)
if len(colFmts) > math.MaxUint16 {
return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts))
}
binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
for i, v := range colFmts {
binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
}
return colFmts, colFmtData, nil
}
}
func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.prepareTo\n")
defer fmt.Fprintf(os.Stderr, " END conn.prepareTo\n")
}
st := &stmt{cn: cn, name: stmtName}
b := cn.writeBuf(proto.Parse)
b.string(st.name)
b.string(q)
b.int16(0)
b.next(proto.Describe)
b.byte(proto.Sync)
b.string(st.name)
b.next(proto.Sync)
err := cn.send(b)
if err != nil {
return nil, err
}
err = cn.readParseResponse()
if err != nil {
return nil, err
}
st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse()
if err != nil {
return nil, err
}
st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult)
if err != nil {
return nil, err
}
err = cn.readReadyForQuery()
if err != nil {
return nil, err
}
return st, nil
}
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
if err := cn.err.get(); err != nil {
return nil, err
}
if pqsql.StartsWithCopy(q) {
s, err := cn.prepareCopyIn(q)
if err == nil {
cn.inCopy = true
}
return s, cn.handleError(err, q)
}
s, err := cn.prepareTo(q, cn.gname())
if err != nil {
return nil, cn.handleError(err, q)
}
return s, nil
}
func (cn *conn) Close() error {
// Don't go through send(); ListenerConn relies on us not scribbling on the
// scratch buffer of this connection.
err := cn.sendSimpleMessage(proto.Terminate)
if err != nil {
_ = cn.c.Close() // Ensure that cn.c.Close is always run.
return cn.handleError(err)
}
return cn.c.Close()
}
func toNamedValue(v []driver.Value) []driver.NamedValue {
v2 := make([]driver.NamedValue, len(v))
for i := range v {
v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]}
}
return v2
}
// CheckNamedValue implements [driver.NamedValueChecker].
func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error {
// Ignore Valuer, for backward compatibility with pq.Array().
if _, ok := nv.Value.(driver.Valuer); ok {
return driver.ErrSkip
}
v := reflect.ValueOf(nv.Value)
if !v.IsValid() {
return driver.ErrSkip
}
t := v.Type()
for t.Kind() == reflect.Ptr {
t, v = t.Elem(), v.Elem()
}
// Ignore []byte and related types: *[]byte, json.RawMessage, etc.
if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 {
return driver.ErrSkip
}
switch v.Kind() {
default:
return driver.ErrSkip
case reflect.Slice:
var err error
nv.Value, err = Array(v.Interface()).Value()
return err
case reflect.Uint64:
value := v.Uint()
if value >= math.MaxInt64 {
nv.Value = strconv.FormatUint(value, 10)
} else {
nv.Value = int64(value)
}
return nil
}
}
// Implement the "Queryer" interface
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return cn.query(query, toNamedValue(args))
}
func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) {
if debugProto {
fmt.Fprintf(os.Stderr, " START conn.query\n")
defer fmt.Fprintf(os.Stderr, " END conn.query\n")
}
if err := cn.err.get(); err != nil {
return nil, err
}
if cn.inCopy {
return nil, errCopyInProgress
}
// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
return cn.simpleQuery(query)
}
if cn.cfg.BinaryParameters {
err := cn.sendBinaryModeQuery(query, args)
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.readParseResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.readBindResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
rows := &rows{cn: cn}
rows.rowsHeader, err = cn.readPortalDescribeResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.postExecuteWorkaround()
if err != nil {
return nil, cn.handleError(err, query)
}
return rows, nil
}
st, err := cn.prepareTo(query, "")
if err != nil {
return nil, cn.handleError(err, query)
}
err = st.exec(args)
if err != nil {
return nil, cn.handleError(err, query)
}
return &rows{
cn: cn,
rowsHeader: st.rowsHeader,
}, nil
}
// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
if err := cn.err.get(); err != nil {
return nil, err
}
// Check to see if we can use the "simpleExec" interface, which is *much*
// faster than going through prepare/exec
if len(args) == 0 {
// ignore commandTag, our caller doesn't care
r, _, err := cn.simpleExec(query)
return r, cn.handleError(err, query)
}
if cn.cfg.BinaryParameters {
err := cn.sendBinaryModeQuery(query, toNamedValue(args))
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.readParseResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.readBindResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
_, err = cn.readPortalDescribeResponse()
if err != nil {
return nil, cn.handleError(err, query)
}
err = cn.postExecuteWorkaround()
if err != nil {
return nil, cn.handleError(err, query)
}
res, _, err := cn.readExecuteResponse("Execute")
return res, cn.handleError(err, query)
}
// Use the unnamed statement to defer planning until bind time, or else
// value-based selectivity estimates cannot be used.
st, err := cn.prepareTo(query, "")
if err != nil {
return nil, cn.handleError(err, query)
}
r, err := st.Exec(args)
if err != nil {
return nil, cn.handleError(err, query)
}
return r, nil
}
type safeRetryError struct{ Err error }
func (se *safeRetryError) Error() string { return se.Err.Error() }
func (cn *conn) send(m *writeBuf) error {
if debugProto {
w := m.wrap()
for len(w) > 0 { // Can contain multiple messages.
c := proto.RequestCode(w[0])
l := int(binary.BigEndian.Uint32(w[1:5])) - 4
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5])
w = w[l+5:]
}
}
n, err := cn.c.Write(m.wrap())
if err != nil && n == 0 {
err = &safeRetryError{Err: err}
}
return err
}
func (cn *conn) sendStartupPacket(m *writeBuf) error {
if debugProto {
w := m.wrap()
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n",
"Startup",
int(binary.BigEndian.Uint32(w[1:5]))-4,
w[5:])
}
_, err := cn.c.Write((m.wrap())[1:])
return err
}
// Send a message of type typ to the server on the other end of cn. The message
// should have no payload. This method does not use the scratch buffer.
func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error {
if debugProto {
fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n",
proto.RequestCode(typ), 0, []byte{})
}
_, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'})
return err
}
// saveMessage memorizes a message and its buffer in the conn struct.
// recvMessage will then return these values on the next call to it. This
// method is useful in cases where you have to see what the next message is
// going to be (e.g. to see whether it's an error or not) but you can't handle
// the message yourself.
func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error {
if cn.saveMessageType != 0 {
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType)
}
cn.saveMessageType = typ
cn.saveMessageBuffer = *buf
return nil
}
// recvMessage receives any message from the backend, or returns an error if
// a problem occurred while reading the message.
func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) {
// workaround for a QueryRow bug, see exec
if cn.saveMessageType != 0 {
t := cn.saveMessageType
*r = cn.saveMessageBuffer
cn.saveMessageType = 0
cn.saveMessageBuffer = nil
return t, nil
}
x := cn.scratch[:5]
_, err := io.ReadFull(cn.buf, x)
if err != nil {
return 0, err
}
// Read the type and length of the message that follows.
t := proto.ResponseCode(x[0])
n := int(binary.BigEndian.Uint32(x[1:])) - 4
// When PostgreSQL cannot start a backend (e.g., an external process limit),
// it sends plain text like "Ecould not fork new process [..]", which
// doesn't use the standard encoding for the Error message.
//
// libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)",
// but check < 4 since n represents bytes remaining to be read after length.
if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) {
msg, _ := cn.buf.ReadString('\x00')
return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00"))
}
var y []byte
if n <= len(cn.scratch) {
y = cn.scratch[:n]
} else {
y = make([]byte, n)
}
_, err = io.ReadFull(cn.buf, y)
if err != nil {
return 0, err
}
*r = y
if debugProto {
fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y)
}
return t, nil
}
// recv receives a message from the backend, returning an error if an error
// happened while reading the message or the received message an ErrorResponse.
// NoticeResponses are ignored. This function should generally be used only
// during the startup sequence.
func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) {
for {
r := new(readBuf)
t, err := cn.recvMessage(r)
if err != nil {
return 0, nil, err
}
switch t {
case proto.ErrorResponse:
return 0, nil, parseError(r, "")
case proto.NoticeResponse:
if n := cn.noticeHandler; n != nil {
n(parseError(r, ""))
}
case proto.NotificationResponse:
if n := cn.notificationHandler; n != nil {
n(recvNotification(r))
}
default:
return t, r, nil
}
}
}
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
// the caller to avoid an allocation.
func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) {
for {
t, err := cn.recvMessage(r)
if err != nil {
return 0, err
}
switch t {
case proto.NotificationResponse:
if n := cn.notificationHandler; n != nil {
n(recvNotification(r))
}
case proto.NoticeResponse:
if n := cn.noticeHandler; n != nil {
n(parseError(r, ""))
}
case proto.ParameterStatus:
cn.processParameterStatus(r)
default:
return t, nil
}
}
}
// recv1 receives a message from the backend, returning an error if an error
// happened while reading the message or the received message an ErrorResponse.
// All asynchronous messages are ignored, with the exception of ErrorResponse.
func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) {
r := new(readBuf)
t, err := cn.recv1Buf(r)
if err != nil {
return 0, nil, err
}
return t, r, nil
}
func (cn *conn) ssl(cfg Config) error {
upgrade, err := ssl(cfg)
if err != nil {
return err
}
if upgrade == nil {
// Nothing to do
return nil
}
// Only negotiate the ssl handshake if requested (which is the default).
// sllnegotiation=direct is supported by pg17 and above.
if cfg.SSLNegotiation != SSLNegotiationDirect {
w := cn.writeBuf(0)
w.int32(proto.NegotiateSSLCode)
if err = cn.sendStartupPacket(w); err != nil {
return err
}
b := cn.scratch[:1]
_, err = io.ReadFull(cn.c, b)
if err != nil {
return err
}
if b[0] != 'S' {
return ErrSSLNotSupported
}
}
cn.c, err = upgrade(cn.c)
return err
}
func (cn *conn) startup(cfg Config) error {
w := cn.writeBuf(0)
w.int32(proto.ProtocolVersion30)
if cfg.User != "" {
w.string("user")
w.string(cfg.User)
}
if cfg.Database != "" {
w.string("database")
w.string(cfg.Database)
}
// w.string("replication") // Sent by libpq, but we don't support that.
if cfg.Options != "" {
w.string("options")
w.string(cfg.Options)
}
if cfg.ApplicationName != "" {
w.string("application_name")
w.string(cfg.ApplicationName)
}
if cfg.ClientEncoding != "" {
w.string("client_encoding")
w.string(cfg.ClientEncoding)
}
for k, v := range cfg.Runtime {
w.string(k)
w.string(v)
}
w.string("")
if err := cn.sendStartupPacket(w); err != nil {
return err
}
for {
t, r, err := cn.recv()
if err != nil {
return err
}
switch t {
case proto.BackendKeyData:
cn.processBackendKeyData(r)
case proto.ParameterStatus:
cn.processParameterStatus(r)
case proto.AuthenticationRequest:
err := cn.auth(r, cfg)
if err != nil {
return err
}
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
return nil
default:
return fmt.Errorf("pq: unknown response for startup: %q", t)
}
}
}
func (cn *conn) auth(r *readBuf, cfg Config) error {
switch code := proto.AuthCode(r.int32()); code {
default:
return fmt.Errorf("pq: unknown authentication response: %s", code)
case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI:
return fmt.Errorf("pq: unsupported authentication method: %s", code)
case proto.AuthReqOk:
return nil
case proto.AuthReqPassword:
w := cn.writeBuf(proto.PasswordMessage)
w.string(cfg.Password)
// Don't need to check AuthOk response here; auth() is called in a loop,
// which catches the errors and AuthReqOk responses.
return cn.send(w)
case proto.AuthReqMD5:
s := string(r.next(4))
w := cn.writeBuf(proto.PasswordMessage)
w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s))
// Same here.
return cn.send(w)
case proto.AuthReqGSS: // GSSAPI, startup
if newGss == nil {
return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)")
}
cli, err := newGss()
if err != nil {
return fmt.Errorf("pq: kerberos error: %w", err)
}
var token []byte
if cfg.isset("krbspn") {
// Use the supplied SPN if provided..
token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if cfg.isset("krbsrvname") {
service = cfg.KrbSrvname
}
token, err = cli.GetInitToken(cfg.Host, service)
}
if err != nil {
return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err)
}
w := cn.writeBuf(proto.GSSResponse)
w.bytes(token)
err = cn.send(w)
if err != nil {
return err
}
// Store for GSSAPI continue message
cn.gss = cli
return nil
case proto.AuthReqGSSCont: // GSSAPI continue
if cn.gss == nil {
return errors.New("pq: GSSAPI protocol error")
}
done, tokOut, err := cn.gss.Continue([]byte(*r))
if err == nil && !done {
w := cn.writeBuf(proto.SASLInitialResponse)
w.bytes(tokOut)
err = cn.send(w)
if err != nil {
return err
}
}
// Errors fall through and read the more detailed message from the
// server.
return nil
case proto.AuthReqSASL:
sc := scram.NewClient(sha256.New, cfg.User, cfg.Password)
sc.Step(nil)
if sc.Err() != nil {
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}
scOut := sc.Out()
w := cn.writeBuf(proto.SASLResponse)
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
err := cn.send(w)
if err != nil {
return err
}
t, r, err := cn.recv()
if err != nil {
return err
}
if t != proto.AuthenticationRequest {
return fmt.Errorf("pq: unexpected password response: %q", t)
}
if r.int32() != int(proto.AuthReqSASLCont) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}
nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}
scOut = sc.Out()
w = cn.writeBuf(proto.SASLResponse)
w.bytes(scOut)
err = cn.send(w)
if err != nil {
return err
}
t, r, err = cn.recv()
if err != nil {
return err
}
if t != proto.AuthenticationRequest {
return fmt.Errorf("pq: unexpected password response: %q", t)
}
if r.int32() != int(proto.AuthReqSASLFin) {
return fmt.Errorf("pq: unexpected authentication response: %q", t)
}
nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err())
}
return nil
}
}
// parseComplete parses the "command tag" from a CommandComplete message, and
// returns the number of rows affected (if applicable) and a string identifying
// only the command that was executed, e.g. "ALTER TABLE". Returns an error if
// the command can cannot be parsed.
func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) {
commandsWithAffectedRows := []string{
"SELECT ",
// INSERT is handled below
"UPDATE ",
"DELETE ",
"FETCH ",
"MOVE ",
"COPY ",
}
var affectedRows *string
for _, tag := range commandsWithAffectedRows {
if strings.HasPrefix(commandTag, tag) {
t := commandTag[len(tag):]
affectedRows = &t
commandTag = tag[:len(tag)-1]
break
}
}
// INSERT also includes the oid of the inserted row in its command tag. Oids
// in user tables are deprecated, and the oid is only returned when exactly
// one row is inserted, so it's unlikely to be of value to any real-world
// application and we can ignore it.
if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
parts := strings.Split(commandTag, " ")
if len(parts) != 3 {
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag)
}
affectedRows = &parts[len(parts)-1]
commandTag = "INSERT"
}
// There should be no affected rows attached to the tag, just return it
if affectedRows == nil {
return driver.RowsAffected(0), commandTag, nil
}
n, err := strconv.ParseInt(*affectedRows, 10, 64)
if err != nil {
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err)
}
return driver.RowsAffected(n), commandTag, nil
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error {
// Do one pass over the parameters to see if we're going to send any of them
// over in binary. If we are, create a paramFormats array at the same time.
var paramFormats []int
for i, x := range args {
_, ok := x.Value.([]byte)
if ok {
if paramFormats == nil {
paramFormats = make([]int, len(args))
}
paramFormats[i] = 1
}
}
if paramFormats == nil {
b.int16(0)
} else {
b.int16(len(paramFormats))
for _, x := range paramFormats {
b.int16(x)
}
}
b.int16(len(args))
for _, x := range args {
if x.Value == nil {
b.int32(-1)
} else if xx, ok := x.Value.([]byte); ok && xx == nil {
b.int32(-1)
} else {
datum, err := binaryEncode(x.Value)
if err != nil {
return err
}
b.int32(len(datum))
b.bytes(datum)
}
}
return nil
}
func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error {
if len(args) >= 65536 {
return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}
b := cn.writeBuf(proto.Parse)
b.byte(0) // unnamed statement
b.string(query)
b.int16(0)
b.next(proto.Bind)
b.int16(0) // unnamed portal and statement
err := cn.sendBinaryParameters(b, args)
if err != nil {
return err
}
b.bytes(colFmtDataAllText)
b.next(proto.Describe)
b.byte(proto.Parse)
b.byte(0) // unnamed portal
b.next(proto.Execute)
b.byte(0)
b.int32(0)
b.next(proto.Sync)
return cn.send(b)
}
func (cn *conn) processParameterStatus(r *readBuf) {
switch r.string() {
default:
// ignore
case "server_version":
var major1, major2 int
_, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
if err == nil {
cn.parameterStatus.serverVersion = major1*10000 + major2*100
}
case "TimeZone":
var err error
cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
if err != nil {
cn.parameterStatus.currentLocation = nil
}
// Use sql.NullBool so we can distinguish between false and not sent. If
// it's not sent we use a query to get the value I don't know when these
// parameters are not sent, but this is what libpq does.
case "in_hot_standby":
b, err := pqutil.ParseBool(r.string())
if err == nil {
cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b}
}
case "default_transaction_read_only":
b, err := pqutil.ParseBool(r.string())
if err == nil {
cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b}
}
}
}
func (cn *conn) processReadyForQuery(r *readBuf) {
cn.txnStatus = transactionStatus(r.byte())
}
func (cn *conn) readReadyForQuery() error {
t, r, err := cn.recv1()
if err != nil {
return err
}
switch t {
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
return nil
case proto.ErrorResponse:
err := parseError(r, "")
cn.err.set(driver.ErrBadConn)
return err
default:
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t)
}
}
func (cn *conn) processBackendKeyData(r *readBuf) {
cn.processID = r.int32()
cn.secretKey = r.int32()
}
func (cn *conn) readParseResponse() error {
t, r, err := cn.recv1()
if err != nil {
return err
}
switch t {
case proto.ParseComplete:
return nil
case proto.ErrorResponse:
err := parseError(r, "")
_ = cn.readReadyForQuery()
return err
default:
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("pq: unexpected Parse response %q", t)
}
}
func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) {
for {
t, r, err := cn.recv1()
if err != nil {
return nil, nil, nil, err
}
switch t {
case proto.ParameterDescription:
nparams := r.int16()
paramTyps = make([]oid.Oid, nparams)
for i := range paramTyps {
paramTyps[i] = r.oid()
}
case proto.NoData:
return paramTyps, nil, nil, nil
case proto.RowDescription:
colNames, colTyps = parseStatementRowDescribe(r)
return paramTyps, colNames, colTyps, nil
case proto.ErrorResponse:
err := parseError(r, "")
_ = cn.readReadyForQuery()
return nil, nil, nil, err
default:
cn.err.set(driver.ErrBadConn)
return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t)
}
}
}
func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) {
t, r, err := cn.recv1()
if err != nil {
return rowsHeader{}, err
}
switch t {
case proto.RowDescription:
return parsePortalRowDescribe(r), nil
case proto.NoData:
return rowsHeader{}, nil
case proto.ErrorResponse:
err := parseError(r, "")
_ = cn.readReadyForQuery()
return rowsHeader{}, err
default:
cn.err.set(driver.ErrBadConn)
return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t)
}
}
func (cn *conn) readBindResponse() error {
t, r, err := cn.recv1()
if err != nil {
return err
}
switch t {
case proto.BindComplete:
return nil
case proto.ErrorResponse:
err := parseError(r, "")
_ = cn.readReadyForQuery()
return err
default:
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("pq: unexpected Bind response %q", t)
}
}
func (cn *conn) postExecuteWorkaround() error {
// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
// any errors from rows.Next, which masks errors that happened during the
// execution of the query. To avoid the problem in common cases, we wait
// here for one more message from the database. If it's not an error the
// query will likely succeed (or perhaps has already, if it's a
// CommandComplete), so we push the message into the conn struct; recv1
// will return it as the next message for rows.Next or rows.Close.
// However, if it's an error, we wait until ReadyForQuery and then return
// the error to our caller.
for {
t, r, err := cn.recv1()
if err != nil {
return err
}
switch t {
case proto.ErrorResponse:
err := parseError(r, "")
_ = cn.readReadyForQuery()
return err
case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse:
// the query didn't fail, but we can't process this message
return cn.saveMessage(t, r)
default:
cn.err.set(driver.ErrBadConn)
return fmt.Errorf("pq: unexpected message during extended query execution: %q", t)
}
}
}
// Only for Exec(), since we ignore the returned data
func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) {
for {
t, r, err := cn.recv1()
if err != nil {
return nil, "", err
}
switch t {
case proto.CommandComplete:
if resErr != nil {
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr)
}
res, commandTag, err = cn.parseComplete(r.string())
if err != nil {
return nil, "", err
}
case proto.ReadyForQuery:
cn.processReadyForQuery(r)
if res == nil && resErr == nil {
resErr = errUnexpectedReady
}
return res, commandTag, resErr
case proto.ErrorResponse:
resErr = parseError(r, "")
case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse:
if resErr != nil {
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr)
}
if t == proto.EmptyQueryResponse {
res = emptyRows
}
// ignore any results
default:
cn.err.set(driver.ErrBadConn)
return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t)
}
}
}
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
n := r.int16()
colNames = make([]string, n)
colTyps = make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i].OID = r.oid()
colTyps[i].Len = r.int16()
colTyps[i].Mod = r.int32()
// format code not known when describing a statement; always 0
r.next(2)
}
return
}
func parsePortalRowDescribe(r *readBuf) rowsHeader {
n := r.int16()
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
colTyps[i].OID = r.oid()
colTyps[i].Len = r.int16()
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
}
func (cn *conn) ResetSession(ctx context.Context) error {
// Ensure bad connections are reported: From database/sql/driver:
// If a connection is never returned to the connection pool but immediately reused, then
// ResetSession is called prior to reuse but IsValid is not called.
return cn.err.get()
}
func (cn *conn) IsValid() bool {
return cn.err.get() == nil
}