package pq import ( "bufio" "context" "crypto/md5" "crypto/sha256" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" "math" "net" "os" "reflect" "strconv" "strings" "sync" "time" "github.com/lib/pq/internal/pgpass" "github.com/lib/pq/internal/pqsql" "github.com/lib/pq/internal/pqutil" "github.com/lib/pq/internal/proto" "github.com/lib/pq/oid" "github.com/lib/pq/scram" ) // Common error types var ( ErrNotSupported = errors.New("pq: unsupported command") ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly") ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) // Compile time validation that our types implement the expected interfaces var ( _ driver.Driver = Driver{} _ driver.ConnBeginTx = (*conn)(nil) _ driver.ConnPrepareContext = (*conn)(nil) _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x _ driver.ExecerContext = (*conn)(nil) _ driver.NamedValueChecker = (*conn)(nil) _ driver.Pinger = (*conn)(nil) _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x _ driver.QueryerContext = (*conn)(nil) _ driver.SessionResetter = (*conn)(nil) _ driver.Validator = (*conn)(nil) _ driver.StmtExecContext = (*stmt)(nil) _ driver.StmtQueryContext = (*stmt)(nil) ) func init() { sql.Register("postgres", &Driver{}) } var debugProto = func() bool { // Check for exactly "1" (rather than mere existence) so we can add // options/flags in the future. I don't know if we ever want that, but it's // nice to leave the option open. return os.Getenv("PQGO_DEBUG") == "1" }() // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. func (d Driver) Open(name string) (driver.Conn, error) { return Open(name) } // Parameters sent by PostgreSQL on startup. type parameterStatus struct { serverVersion int currentLocation *time.Location inHotStandby, defaultTransactionReadOnly sql.NullBool } type format int const ( formatText format = 0 formatBinary format = 1 ) var ( // One result-column format code with the value 1 (i.e. all binary). colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). colFmtDataAllText = []byte{0, 0} ) type transactionStatus byte const ( txnStatusIdle transactionStatus = 'I' txnStatusIdleInTransaction transactionStatus = 'T' txnStatusInFailedTransaction transactionStatus = 'E' ) func (s transactionStatus) String() string { switch s { case txnStatusIdle: return "idle" case txnStatusIdleInTransaction: return "idle in transaction" case txnStatusInFailedTransaction: return "in a failed transaction" default: panic(fmt.Sprintf("pq: unknown transactionStatus %d", s)) } } // Dialer is the dialer interface. It can be used to obtain more control over // how pq creates network connections. type Dialer interface { Dial(network, address string) (net.Conn, error) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) } // DialerContext is the context-aware dialer interface. type DialerContext interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } type defaultDialer struct { d net.Dialer } func (d defaultDialer) Dial(network, address string) (net.Conn, error) { return d.d.Dial(network, address) } func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return d.DialContext(ctx, network, address) } func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.d.DialContext(ctx, network, address) } type conn struct { c net.Conn buf *bufio.Reader namei int scratch [512]byte txnStatus transactionStatus txnFinish func() // Save connection arguments to use during CancelRequest. dialer Dialer cfg Config parameterStatus parameterStatus saveMessageType proto.ResponseCode saveMessageBuffer []byte // If an error is set, this connection is bad and all public-facing // functions should return the appropriate error by calling get() // (ErrBadConn) or getForNext(). err syncErr processID, secretKey int // Cancellation key data for use with CancelRequest messages. inCopy bool // If true this connection is in the middle of a COPY noticeHandler func(*Error) // If not nil, notices will be synchronously sent here notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here gss GSS // GSSAPI context } type syncErr struct { err error sync.Mutex } // Return ErrBadConn if connection is bad. func (e *syncErr) get() error { e.Lock() defer e.Unlock() if e.err != nil { return driver.ErrBadConn } return nil } // Return the error set on the connection. Currently only used by rows.Next. func (e *syncErr) getForNext() error { e.Lock() defer e.Unlock() return e.err } // Set error, only if it isn't set yet. func (e *syncErr) set(err error) { if err == nil { panic("attempt to set nil err") } e.Lock() defer e.Unlock() if e.err == nil { e.err = err } } func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf { cn.scratch[0] = byte(b) return &writeBuf{ buf: cn.scratch[:5], pos: 1, } } // Open opens a new connection to the database. dsn is a connection string. Most // users should only use it through database/sql package from the standard // library. func Open(dsn string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, dsn) } // DialOpen opens a new connection to the database using a dialer. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { c, err := NewConnector(dsn) if err != nil { return nil, err } c.Dialer(d) return c.open(context.Background()) } func (c *Connector) open(ctx context.Context) (*conn, error) { tsa := c.cfg.TargetSessionAttrs restart: var ( errs []error app = func(err error, cfg Config) bool { if err != nil { if debugProto { fmt.Println("CONNECT (error)", err) } errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) } return err != nil } ) for _, cfg := range c.cfg.hosts() { if debugProto { fmt.Println("CONNECT ", cfg.string()) } cn := &conn{cfg: cfg, dialer: c.dialer} cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database, cn.cfg.isset("password")) var err error cn.c, err = dial(ctx, c.dialer, cn.cfg) if app(err, cfg) { continue } err = cn.ssl(cn.cfg) if app(err, cfg) { if cn.c != nil { _ = cn.c.Close() } continue } cn.buf = bufio.NewReader(cn.c) err = cn.startup(cn.cfg) if app(err, cfg) { _ = cn.c.Close() continue } // Reset the deadline, in case one was set (see dial) if cn.cfg.ConnectTimeout > 0 { err := cn.c.SetDeadline(time.Time{}) if app(err, cfg) { _ = cn.c.Close() continue } } err = cn.checkTSA(tsa) if app(err, cfg) { _ = cn.c.Close() continue } return cn, nil } // target_session_attrs=prefer-standby is treated as standby in checkTSA; we // ran out of hosts so none are on standby. Clear the setting and try again. if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { tsa = TargetSessionAttrsAny goto restart } if len(c.cfg.Multi) == 0 { // Remove the "connecting to [..]" when we have just one host, so the // error is identical to what we had before. return nil, errors.Unwrap(errs[0]) } return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...)) } func (cn *conn) getBool(query string) (bool, error) { res, err := cn.simpleQuery(query) if err != nil { return false, err } defer res.Close() v := make([]driver.Value, 1) err = res.Next(v) if err != nil { return false, err } switch vv := v[0].(type) { default: return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0]) case bool: return vv, nil case string: vv, ok := v[0].(string) if !ok { return false, err } return vv == "on", nil } } func (cn *conn) checkTSA(tsa TargetSessionAttrs) error { var ( geths = func() (hs bool, err error) { hs = cn.parameterStatus.inHotStandby.Bool if !cn.parameterStatus.inHotStandby.Valid { hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()") } return hs, err } getro = func() (ro bool, err error) { ro = cn.parameterStatus.defaultTransactionReadOnly.Bool if !cn.parameterStatus.defaultTransactionReadOnly.Valid { ro, err = cn.getBool("show transaction_read_only") } return ro, err } ) switch tsa { default: panic("unreachable") case "", TargetSessionAttrsAny: return nil case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly: readonly, err := getro() if err != nil { return err } if !cn.parameterStatus.defaultTransactionReadOnly.Valid { var err error readonly, err = cn.getBool("show transaction_read_only") if err != nil { return err } } switch { case tsa == TargetSessionAttrsReadOnly && !readonly: return errors.New("session is not read-only") case tsa == TargetSessionAttrsReadWrite: if readonly { return errors.New("session is read-only") } hs, err := geths() if err != nil { return err } if hs { return errors.New("server is in hot standby mode") } return nil default: return nil } case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby: hs, err := geths() if err != nil { return err } switch { case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs: return errors.New("server is not in hot standby mode") case tsa == TargetSessionAttrsPrimary && hs: return errors.New("server is in hot standby mode") default: return nil } } } func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) { network, address := cfg.network() // Zero or not specified means wait indefinitely. if cfg.ConnectTimeout > 0 { // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection // establishment and set a deadline for doing the initial handshake. The // deadline is then reset after startup() is done. var ( deadline = time.Now().Add(cfg.ConnectTimeout) conn net.Conn err error ) if dctx, ok := d.(DialerContext); ok { ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) defer cancel() conn, err = dctx.DialContext(ctx, network, address) } else { conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout) } if err != nil { return nil, err } err = conn.SetDeadline(deadline) return conn, err } if dctx, ok := d.(DialerContext); ok { return dctx.DialContext(ctx, network, address) } return d.Dial(network, address) } func (cn *conn) isInTransaction() bool { return cn.txnStatus == txnStatusIdleInTransaction || cn.txnStatus == txnStatusInFailedTransaction } func (cn *conn) checkIsInTransaction(intxn bool) error { if cn.isInTransaction() != intxn { cn.err.set(driver.ErrBadConn) return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus) } return nil } func (cn *conn) Begin() (_ driver.Tx, err error) { return cn.begin("") } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if err := cn.err.get(); err != nil { return nil, err } if err := cn.checkIsInTransaction(false); err != nil { return nil, err } _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, cn.handleError(err) } if commandTag != "BEGIN" { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil } func (cn *conn) closeTxn() { if finish := cn.txnFinish; finish != nil { finish() } } func (cn *conn) Commit() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } if err := cn.checkIsInTransaction(true); err != nil { return err } // We don't want the client to think that everything is okay if it tries // to commit a failed transaction. However, no matter what we return, // database/sql will release this connection back into the free connection // pool so we have to abort the current transaction here. Note that you // would get the same behaviour if you issued a COMMIT in a failed // transaction, so it's also the least surprising thing to do here. if cn.txnStatus == txnStatusInFailedTransaction { if err := cn.rollback(); err != nil { return err } return ErrInFailedTransaction } _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { cn.err.set(driver.ErrBadConn) } return cn.handleError(err) } if commandTag != "COMMIT" { cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected command tag %s", commandTag) } return cn.checkIsInTransaction(false) } func (cn *conn) Rollback() error { defer cn.closeTxn() if err := cn.err.get(); err != nil { return err } err := cn.rollback() if err != nil { return cn.handleError(err) } return nil } func (cn *conn) rollback() (err error) { if err := cn.checkIsInTransaction(true); err != nil { return err } _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { cn.err.set(driver.ErrBadConn) } return err } if commandTag != "ROLLBACK" { return fmt.Errorf("unexpected command tag %s", commandTag) } return cn.checkIsInTransaction(false) } func (cn *conn) gname() string { cn.namei++ return strconv.FormatInt(int64(cn.namei), 10) } func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { if debugProto { fmt.Fprintf(os.Stderr, " START conn.simpleExec\n") defer fmt.Fprintf(os.Stderr, " END conn.simpleExec\n") } b := cn.writeBuf(proto.Query) b.string(q) err := cn.send(b) if err != nil { return nil, "", err } for { t, r, err := cn.recv1() if err != nil { return nil, "", err } switch t { case proto.CommandComplete: res, commandTag, err = cn.parseComplete(r.string()) if err != nil { return nil, "", err } case proto.ReadyForQuery: cn.processReadyForQuery(r) if res == nil && resErr == nil { resErr = errUnexpectedReady } return res, commandTag, resErr case proto.ErrorResponse: resErr = parseError(r, q) case proto.EmptyQueryResponse: res = emptyRows case proto.RowDescription, proto.DataRow: // ignore any results default: cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t) } } } func (cn *conn) simpleQuery(q string) (*rows, error) { if debugProto { fmt.Fprintf(os.Stderr, " START conn.simpleQuery\n") defer fmt.Fprintf(os.Stderr, " END conn.simpleQuery\n") } b := cn.writeBuf(proto.Query) b.string(q) err := cn.send(b) if err != nil { return nil, cn.handleError(err, q) } var ( res *rows resErr error ) for { t, r, err := cn.recv1() if err != nil { return nil, cn.handleError(err, q) } switch t { case proto.CommandComplete, proto.EmptyQueryResponse: // We allow queries which don't return any results through Query as // well as Exec. We still have to give database/sql a rows object // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if resErr != nil { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t) } if res == nil { res = &rows{cn: cn} } // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. if t == proto.CommandComplete { res.result, res.tag, err = cn.parseComplete(r.string()) if err != nil { return nil, cn.handleError(err, q) } if res.colNames != nil { return res, cn.handleError(resErr, q) } } res.done = true case proto.ReadyForQuery: cn.processReadyForQuery(r) if err == nil && res == nil { res = &rows{done: true} } return res, cn.handleError(resErr, q) // done case proto.ErrorResponse: res = nil resErr = parseError(r, q) case proto.DataRow: if res == nil { cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution") } return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next case proto.RowDescription: // res might be non-nil here if we received a previous // CommandComplete, but that's fine and just overwrite it. res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)} // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: cn.err.set(driver.ErrBadConn) return nil, fmt.Errorf("pq: unknown response for simple query: %q", t) } } } // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) { if len(colTyps) == 0 { return nil, colFmtDataAllText, nil } colFmts = make([]format, len(colTyps)) if forceText { return colFmts, colFmtDataAllText, nil } allBinary := true allText := true for i, t := range colTyps { switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. case oid.T_bytea: fallthrough case oid.T_int8: fallthrough case oid.T_int4: fallthrough case oid.T_int2: fallthrough case oid.T_uuid: colFmts[i] = formatBinary allText = false default: allBinary = false } } if allBinary { return colFmts, colFmtDataAllBinary, nil } else if allText { return colFmts, colFmtDataAllText, nil } else { colFmtData = make([]byte, 2+len(colFmts)*2) if len(colFmts) > math.MaxUint16 { return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts)) } binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) for i, v := range colFmts { binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) } return colFmts, colFmtData, nil } } func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { if debugProto { fmt.Fprintf(os.Stderr, " START conn.prepareTo\n") defer fmt.Fprintf(os.Stderr, " END conn.prepareTo\n") } st := &stmt{cn: cn, name: stmtName} b := cn.writeBuf(proto.Parse) b.string(st.name) b.string(q) b.int16(0) b.next(proto.Describe) b.byte(proto.Sync) b.string(st.name) b.next(proto.Sync) err := cn.send(b) if err != nil { return nil, err } err = cn.readParseResponse() if err != nil { return nil, err } st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() if err != nil { return nil, err } st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult) if err != nil { return nil, err } err = cn.readReadyForQuery() if err != nil { return nil, err } return st, nil } func (cn *conn) Prepare(q string) (driver.Stmt, error) { if err := cn.err.get(); err != nil { return nil, err } if pqsql.StartsWithCopy(q) { s, err := cn.prepareCopyIn(q) if err == nil { cn.inCopy = true } return s, cn.handleError(err, q) } s, err := cn.prepareTo(q, cn.gname()) if err != nil { return nil, cn.handleError(err, q) } return s, nil } func (cn *conn) Close() error { // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. err := cn.sendSimpleMessage(proto.Terminate) if err != nil { _ = cn.c.Close() // Ensure that cn.c.Close is always run. return cn.handleError(err) } return cn.c.Close() } func toNamedValue(v []driver.Value) []driver.NamedValue { v2 := make([]driver.NamedValue, len(v)) for i := range v { v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} } return v2 } // CheckNamedValue implements [driver.NamedValueChecker]. func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { // Ignore Valuer, for backward compatibility with pq.Array(). if _, ok := nv.Value.(driver.Valuer); ok { return driver.ErrSkip } v := reflect.ValueOf(nv.Value) if !v.IsValid() { return driver.ErrSkip } t := v.Type() for t.Kind() == reflect.Ptr { t, v = t.Elem(), v.Elem() } // Ignore []byte and related types: *[]byte, json.RawMessage, etc. if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { return driver.ErrSkip } switch v.Kind() { default: return driver.ErrSkip case reflect.Slice: var err error nv.Value, err = Array(v.Interface()).Value() return err case reflect.Uint64: value := v.Uint() if value >= math.MaxInt64 { nv.Value = strconv.FormatUint(value, 10) } else { nv.Value = int64(value) } return nil } } // Implement the "Queryer" interface func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { return cn.query(query, toNamedValue(args)) } func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { if debugProto { fmt.Fprintf(os.Stderr, " START conn.query\n") defer fmt.Fprintf(os.Stderr, " END conn.query\n") } if err := cn.err.get(); err != nil { return nil, err } if cn.inCopy { return nil, errCopyInProgress } // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec if len(args) == 0 { return cn.simpleQuery(query) } if cn.cfg.BinaryParameters { err := cn.sendBinaryModeQuery(query, args) if err != nil { return nil, cn.handleError(err, query) } err = cn.readParseResponse() if err != nil { return nil, cn.handleError(err, query) } err = cn.readBindResponse() if err != nil { return nil, cn.handleError(err, query) } rows := &rows{cn: cn} rows.rowsHeader, err = cn.readPortalDescribeResponse() if err != nil { return nil, cn.handleError(err, query) } err = cn.postExecuteWorkaround() if err != nil { return nil, cn.handleError(err, query) } return rows, nil } st, err := cn.prepareTo(query, "") if err != nil { return nil, cn.handleError(err, query) } err = st.exec(args) if err != nil { return nil, cn.handleError(err, query) } return &rows{ cn: cn, rowsHeader: st.rowsHeader, }, nil } // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { if err := cn.err.get(); err != nil { return nil, err } // Check to see if we can use the "simpleExec" interface, which is *much* // faster than going through prepare/exec if len(args) == 0 { // ignore commandTag, our caller doesn't care r, _, err := cn.simpleExec(query) return r, cn.handleError(err, query) } if cn.cfg.BinaryParameters { err := cn.sendBinaryModeQuery(query, toNamedValue(args)) if err != nil { return nil, cn.handleError(err, query) } err = cn.readParseResponse() if err != nil { return nil, cn.handleError(err, query) } err = cn.readBindResponse() if err != nil { return nil, cn.handleError(err, query) } _, err = cn.readPortalDescribeResponse() if err != nil { return nil, cn.handleError(err, query) } err = cn.postExecuteWorkaround() if err != nil { return nil, cn.handleError(err, query) } res, _, err := cn.readExecuteResponse("Execute") return res, cn.handleError(err, query) } // Use the unnamed statement to defer planning until bind time, or else // value-based selectivity estimates cannot be used. st, err := cn.prepareTo(query, "") if err != nil { return nil, cn.handleError(err, query) } r, err := st.Exec(args) if err != nil { return nil, cn.handleError(err, query) } return r, nil } type safeRetryError struct{ Err error } func (se *safeRetryError) Error() string { return se.Err.Error() } func (cn *conn) send(m *writeBuf) error { if debugProto { w := m.wrap() for len(w) > 0 { // Can contain multiple messages. c := proto.RequestCode(w[0]) l := int(binary.BigEndian.Uint32(w[1:5])) - 4 fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5]) w = w[l+5:] } } n, err := cn.c.Write(m.wrap()) if err != nil && n == 0 { err = &safeRetryError{Err: err} } return err } func (cn *conn) sendStartupPacket(m *writeBuf) error { if debugProto { w := m.wrap() fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", "Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:]) } _, err := cn.c.Write((m.wrap())[1:]) return err } // Send a message of type typ to the server on the other end of cn. The message // should have no payload. This method does not use the scratch buffer. func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { if debugProto { fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", proto.RequestCode(typ), 0, []byte{}) } _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) return err } // saveMessage memorizes a message and its buffer in the conn struct. // recvMessage will then return these values on the next call to it. This // method is useful in cases where you have to see what the next message is // going to be (e.g. to see whether it's an error or not) but you can't handle // the message yourself. func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error { if cn.saveMessageType != 0 { cn.err.set(driver.ErrBadConn) return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ cn.saveMessageBuffer = *buf return nil } // recvMessage receives any message from the backend, or returns an error if // a problem occurred while reading the message. func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { // workaround for a QueryRow bug, see exec if cn.saveMessageType != 0 { t := cn.saveMessageType *r = cn.saveMessageBuffer cn.saveMessageType = 0 cn.saveMessageBuffer = nil return t, nil } x := cn.scratch[:5] _, err := io.ReadFull(cn.buf, x) if err != nil { return 0, err } // Read the type and length of the message that follows. t := proto.ResponseCode(x[0]) n := int(binary.BigEndian.Uint32(x[1:])) - 4 // When PostgreSQL cannot start a backend (e.g., an external process limit), // it sends plain text like "Ecould not fork new process [..]", which // doesn't use the standard encoding for the Error message. // // libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)", // but check < 4 since n represents bytes remaining to be read after length. if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) { msg, _ := cn.buf.ReadString('\x00') return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00")) } var y []byte if n <= len(cn.scratch) { y = cn.scratch[:n] } else { y = make([]byte, n) } _, err = io.ReadFull(cn.buf, y) if err != nil { return 0, err } *r = y if debugProto { fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) } return t, nil } // recv receives a message from the backend, returning an error if an error // happened while reading the message or the received message an ErrorResponse. // NoticeResponses are ignored. This function should generally be used only // during the startup sequence. func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) { for { r := new(readBuf) t, err := cn.recvMessage(r) if err != nil { return 0, nil, err } switch t { case proto.ErrorResponse: return 0, nil, parseError(r, "") case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { n(parseError(r, "")) } case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } default: return t, r, nil } } } // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by // the caller to avoid an allocation. func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) { for { t, err := cn.recvMessage(r) if err != nil { return 0, err } switch t { case proto.NotificationResponse: if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } case proto.NoticeResponse: if n := cn.noticeHandler; n != nil { n(parseError(r, "")) } case proto.ParameterStatus: cn.processParameterStatus(r) default: return t, nil } } } // recv1 receives a message from the backend, returning an error if an error // happened while reading the message or the received message an ErrorResponse. // All asynchronous messages are ignored, with the exception of ErrorResponse. func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { r := new(readBuf) t, err := cn.recv1Buf(r) if err != nil { return 0, nil, err } return t, r, nil } func (cn *conn) ssl(cfg Config) error { upgrade, err := ssl(cfg) if err != nil { return err } if upgrade == nil { // Nothing to do return nil } // Only negotiate the ssl handshake if requested (which is the default). // sllnegotiation=direct is supported by pg17 and above. if cfg.SSLNegotiation != SSLNegotiationDirect { w := cn.writeBuf(0) w.int32(proto.NegotiateSSLCode) if err = cn.sendStartupPacket(w); err != nil { return err } b := cn.scratch[:1] _, err = io.ReadFull(cn.c, b) if err != nil { return err } if b[0] != 'S' { return ErrSSLNotSupported } } cn.c, err = upgrade(cn.c) return err } func (cn *conn) startup(cfg Config) error { w := cn.writeBuf(0) w.int32(proto.ProtocolVersion30) if cfg.User != "" { w.string("user") w.string(cfg.User) } if cfg.Database != "" { w.string("database") w.string(cfg.Database) } // w.string("replication") // Sent by libpq, but we don't support that. if cfg.Options != "" { w.string("options") w.string(cfg.Options) } if cfg.ApplicationName != "" { w.string("application_name") w.string(cfg.ApplicationName) } if cfg.ClientEncoding != "" { w.string("client_encoding") w.string(cfg.ClientEncoding) } for k, v := range cfg.Runtime { w.string(k) w.string(v) } w.string("") if err := cn.sendStartupPacket(w); err != nil { return err } for { t, r, err := cn.recv() if err != nil { return err } switch t { case proto.BackendKeyData: cn.processBackendKeyData(r) case proto.ParameterStatus: cn.processParameterStatus(r) case proto.AuthenticationRequest: err := cn.auth(r, cfg) if err != nil { return err } case proto.ReadyForQuery: cn.processReadyForQuery(r) return nil default: return fmt.Errorf("pq: unknown response for startup: %q", t) } } } func (cn *conn) auth(r *readBuf, cfg Config) error { switch code := proto.AuthCode(r.int32()); code { default: return fmt.Errorf("pq: unknown authentication response: %s", code) case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: return fmt.Errorf("pq: unsupported authentication method: %s", code) case proto.AuthReqOk: return nil case proto.AuthReqPassword: w := cn.writeBuf(proto.PasswordMessage) w.string(cfg.Password) // Don't need to check AuthOk response here; auth() is called in a loop, // which catches the errors and AuthReqOk responses. return cn.send(w) case proto.AuthReqMD5: s := string(r.next(4)) w := cn.writeBuf(proto.PasswordMessage) w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) // Same here. return cn.send(w) case proto.AuthReqGSS: // GSSAPI, startup if newGss == nil { return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)") } cli, err := newGss() if err != nil { return fmt.Errorf("pq: kerberos error: %w", err) } var token []byte if cfg.isset("krbspn") { // Use the supplied SPN if provided.. token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn) } else { // Allow the kerberos service name to be overridden service := "postgres" if cfg.isset("krbsrvname") { service = cfg.KrbSrvname } token, err = cli.GetInitToken(cfg.Host, service) } if err != nil { return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err) } w := cn.writeBuf(proto.GSSResponse) w.bytes(token) err = cn.send(w) if err != nil { return err } // Store for GSSAPI continue message cn.gss = cli return nil case proto.AuthReqGSSCont: // GSSAPI continue if cn.gss == nil { return errors.New("pq: GSSAPI protocol error") } done, tokOut, err := cn.gss.Continue([]byte(*r)) if err == nil && !done { w := cn.writeBuf(proto.SASLInitialResponse) w.bytes(tokOut) err = cn.send(w) if err != nil { return err } } // Errors fall through and read the more detailed message from the // server. return nil case proto.AuthReqSASL: sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) sc.Step(nil) if sc.Err() != nil { return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut := sc.Out() w := cn.writeBuf(proto.SASLResponse) w.string("SCRAM-SHA-256") w.int32(len(scOut)) w.bytes(scOut) err := cn.send(w) if err != nil { return err } t, r, err := cn.recv() if err != nil { return err } if t != proto.AuthenticationRequest { return fmt.Errorf("pq: unexpected password response: %q", t) } if r.int32() != int(proto.AuthReqSASLCont) { return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep := r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } scOut = sc.Out() w = cn.writeBuf(proto.SASLResponse) w.bytes(scOut) err = cn.send(w) if err != nil { return err } t, r, err = cn.recv() if err != nil { return err } if t != proto.AuthenticationRequest { return fmt.Errorf("pq: unexpected password response: %q", t) } if r.int32() != int(proto.AuthReqSASLFin) { return fmt.Errorf("pq: unexpected authentication response: %q", t) } nextStep = r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) } return nil } } // parseComplete parses the "command tag" from a CommandComplete message, and // returns the number of rows affected (if applicable) and a string identifying // only the command that was executed, e.g. "ALTER TABLE". Returns an error if // the command can cannot be parsed. func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) { commandsWithAffectedRows := []string{ "SELECT ", // INSERT is handled below "UPDATE ", "DELETE ", "FETCH ", "MOVE ", "COPY ", } var affectedRows *string for _, tag := range commandsWithAffectedRows { if strings.HasPrefix(commandTag, tag) { t := commandTag[len(tag):] affectedRows = &t commandTag = tag[:len(tag)-1] break } } // INSERT also includes the oid of the inserted row in its command tag. Oids // in user tables are deprecated, and the oid is only returned when exactly // one row is inserted, so it's unlikely to be of value to any real-world // application and we can ignore it. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] commandTag = "INSERT" } // There should be no affected rows attached to the tag, just return it if affectedRows == nil { return driver.RowsAffected(0), commandTag, nil } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err) } return driver.RowsAffected(n), commandTag, nil } func md5s(s string) string { h := md5.New() h.Write([]byte(s)) return fmt.Sprintf("%x", h.Sum(nil)) } func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error { // Do one pass over the parameters to see if we're going to send any of them // over in binary. If we are, create a paramFormats array at the same time. var paramFormats []int for i, x := range args { _, ok := x.Value.([]byte) if ok { if paramFormats == nil { paramFormats = make([]int, len(args)) } paramFormats[i] = 1 } } if paramFormats == nil { b.int16(0) } else { b.int16(len(paramFormats)) for _, x := range paramFormats { b.int16(x) } } b.int16(len(args)) for _, x := range args { if x.Value == nil { b.int32(-1) } else if xx, ok := x.Value.([]byte); ok && xx == nil { b.int32(-1) } else { datum, err := binaryEncode(x.Value) if err != nil { return err } b.int32(len(datum)) b.bytes(datum) } } return nil } func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error { if len(args) >= 65536 { return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } b := cn.writeBuf(proto.Parse) b.byte(0) // unnamed statement b.string(query) b.int16(0) b.next(proto.Bind) b.int16(0) // unnamed portal and statement err := cn.sendBinaryParameters(b, args) if err != nil { return err } b.bytes(colFmtDataAllText) b.next(proto.Describe) b.byte(proto.Parse) b.byte(0) // unnamed portal b.next(proto.Execute) b.byte(0) b.int32(0) b.next(proto.Sync) return cn.send(b) } func (cn *conn) processParameterStatus(r *readBuf) { switch r.string() { default: // ignore case "server_version": var major1, major2 int _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) if err == nil { cn.parameterStatus.serverVersion = major1*10000 + major2*100 } case "TimeZone": var err error cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { cn.parameterStatus.currentLocation = nil } // Use sql.NullBool so we can distinguish between false and not sent. If // it's not sent we use a query to get the value – I don't know when these // parameters are not sent, but this is what libpq does. case "in_hot_standby": b, err := pqutil.ParseBool(r.string()) if err == nil { cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b} } case "default_transaction_read_only": b, err := pqutil.ParseBool(r.string()) if err == nil { cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b} } } } func (cn *conn) processReadyForQuery(r *readBuf) { cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() error { t, r, err := cn.recv1() if err != nil { return err } switch t { case proto.ReadyForQuery: cn.processReadyForQuery(r) return nil case proto.ErrorResponse: err := parseError(r, "") cn.err.set(driver.ErrBadConn) return err default: cn.err.set(driver.ErrBadConn) return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t) } } func (cn *conn) processBackendKeyData(r *readBuf) { cn.processID = r.int32() cn.secretKey = r.int32() } func (cn *conn) readParseResponse() error { t, r, err := cn.recv1() if err != nil { return err } switch t { case proto.ParseComplete: return nil case proto.ErrorResponse: err := parseError(r, "") _ = cn.readReadyForQuery() return err default: cn.err.set(driver.ErrBadConn) return fmt.Errorf("pq: unexpected Parse response %q", t) } } func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) { for { t, r, err := cn.recv1() if err != nil { return nil, nil, nil, err } switch t { case proto.ParameterDescription: nparams := r.int16() paramTyps = make([]oid.Oid, nparams) for i := range paramTyps { paramTyps[i] = r.oid() } case proto.NoData: return paramTyps, nil, nil, nil case proto.RowDescription: colNames, colTyps = parseStatementRowDescribe(r) return paramTyps, colNames, colTyps, nil case proto.ErrorResponse: err := parseError(r, "") _ = cn.readReadyForQuery() return nil, nil, nil, err default: cn.err.set(driver.ErrBadConn) return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t) } } } func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) { t, r, err := cn.recv1() if err != nil { return rowsHeader{}, err } switch t { case proto.RowDescription: return parsePortalRowDescribe(r), nil case proto.NoData: return rowsHeader{}, nil case proto.ErrorResponse: err := parseError(r, "") _ = cn.readReadyForQuery() return rowsHeader{}, err default: cn.err.set(driver.ErrBadConn) return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t) } } func (cn *conn) readBindResponse() error { t, r, err := cn.recv1() if err != nil { return err } switch t { case proto.BindComplete: return nil case proto.ErrorResponse: err := parseError(r, "") _ = cn.readReadyForQuery() return err default: cn.err.set(driver.ErrBadConn) return fmt.Errorf("pq: unexpected Bind response %q", t) } } func (cn *conn) postExecuteWorkaround() error { // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores // any errors from rows.Next, which masks errors that happened during the // execution of the query. To avoid the problem in common cases, we wait // here for one more message from the database. If it's not an error the // query will likely succeed (or perhaps has already, if it's a // CommandComplete), so we push the message into the conn struct; recv1 // will return it as the next message for rows.Next or rows.Close. // However, if it's an error, we wait until ReadyForQuery and then return // the error to our caller. for { t, r, err := cn.recv1() if err != nil { return err } switch t { case proto.ErrorResponse: err := parseError(r, "") _ = cn.readReadyForQuery() return err case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse: // the query didn't fail, but we can't process this message return cn.saveMessage(t, r) default: cn.err.set(driver.ErrBadConn) return fmt.Errorf("pq: unexpected message during extended query execution: %q", t) } } } // Only for Exec(), since we ignore the returned data func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) { for { t, r, err := cn.recv1() if err != nil { return nil, "", err } switch t { case proto.CommandComplete: if resErr != nil { cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr) } res, commandTag, err = cn.parseComplete(r.string()) if err != nil { return nil, "", err } case proto.ReadyForQuery: cn.processReadyForQuery(r) if res == nil && resErr == nil { resErr = errUnexpectedReady } return res, commandTag, resErr case proto.ErrorResponse: resErr = parseError(r, "") case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse: if resErr != nil { cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr) } if t == proto.EmptyQueryResponse { res = emptyRows } // ignore any results default: cn.err.set(driver.ErrBadConn) return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t) } } } func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) colTyps[i].OID = r.oid() colTyps[i].Len = r.int16() colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() colNames := make([]string, n) colFmts := make([]format, n) colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) colTyps[i].OID = r.oid() colTyps[i].Len = r.int16() colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return rowsHeader{ colNames: colNames, colFmts: colFmts, colTyps: colTyps, } } func (cn *conn) ResetSession(ctx context.Context) error { // Ensure bad connections are reported: From database/sql/driver: // If a connection is never returned to the connection pool but immediately reused, then // ResetSession is called prior to reuse but IsValid is not called. return cn.err.get() } func (cn *conn) IsValid() bool { return cn.err.get() == nil }