package pq import ( "bytes" "database/sql" "encoding/binary" "encoding/hex" "errors" "fmt" "math" "regexp" "strconv" "strings" "sync" "time" "github.com/lib/pq/oid" ) var time2400Regex = regexp.MustCompile(`^(24:00(?::00(?:\.0+)?)?)(?:[Z+-].*)?$`) func binaryEncode(x any) ([]byte, error) { switch v := x.(type) { case []byte: return v, nil default: return encode(x, oid.T_unknown) } } func encode(x any, pgtypOid oid.Oid) ([]byte, error) { switch v := x.(type) { case int64: return strconv.AppendInt(nil, v, 10), nil case float64: return strconv.AppendFloat(nil, v, 'f', -1, 64), nil case []byte: if v == nil { return nil, nil } if pgtypOid == oid.T_bytea { return encodeBytea(v), nil } return v, nil case string: if pgtypOid == oid.T_bytea { return encodeBytea([]byte(v)), nil } return []byte(v), nil case bool: return strconv.AppendBool(nil, v), nil case time.Time: return formatTS(v), nil default: return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } } func decode(ps *parameterStatus, s []byte, typ oid.Oid, f format) (any, error) { switch f { case formatBinary: return binaryDecode(s, typ) case formatText: return textDecode(ps, s, typ) default: panic("unreachable") } } func binaryDecode(s []byte, typ oid.Oid) (any, error) { switch typ { case oid.T_bytea: return s, nil case oid.T_int8: return int64(binary.BigEndian.Uint64(s)), nil case oid.T_int4: return int64(int32(binary.BigEndian.Uint32(s))), nil case oid.T_int2: return int64(int16(binary.BigEndian.Uint16(s))), nil case oid.T_uuid: b, err := decodeUUIDBinary(s) if err != nil { err = errors.New("pq: " + err.Error()) } return b, err default: return nil, fmt.Errorf("pq: don't know how to decode binary parameter of type %d", uint32(typ)) } } // decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. func decodeUUIDBinary(src []byte) ([]byte, error) { if len(src) != 16 { return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) } dst := make([]byte, 36) dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' hex.Encode(dst[0:], src[0:4]) hex.Encode(dst[9:], src[4:6]) hex.Encode(dst[14:], src[6:8]) hex.Encode(dst[19:], src[8:10]) hex.Encode(dst[24:], src[10:16]) return dst, nil } func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { switch typ { case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text: return string(s), nil case oid.T_bytea: b, err := parseBytea(s) if err != nil { err = errors.New("pq: " + err.Error()) } return b, err case oid.T_timestamptz: return parseTS(ps.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: return parseTS(nil, string(s)) case oid.T_time: return parseTime("15:04:05", typ, s) case oid.T_timetz: return parseTime("15:04:05-07", typ, s) case oid.T_bool: return s[0] == 't', nil case oid.T_int8, oid.T_int4, oid.T_int2: i, err := strconv.ParseInt(string(s), 10, 64) if err != nil { err = errors.New("pq: " + err.Error()) } return i, err case oid.T_float4, oid.T_float8: // We always use 64 bit parsing, regardless of whether the input text is for // a float4 or float8, because clients expect float64s for all float datatypes // and returning a 32-bit parsed float64 produces lossy results. f, err := strconv.ParseFloat(string(s), 64) if err != nil { err = errors.New("pq: " + err.Error()) } return f, err } return s, nil } // appendEncodedText encodes item in text format as required by COPY // and appends to buf func appendEncodedText(buf []byte, x any) ([]byte, error) { switch v := x.(type) { case int64: return strconv.AppendInt(buf, v, 10), nil case float64: return strconv.AppendFloat(buf, v, 'f', -1, 64), nil case []byte: encodedBytea := encodeBytea(v) return appendEscapedText(buf, string(encodedBytea)), nil case string: return appendEscapedText(buf, v), nil case bool: return strconv.AppendBool(buf, v), nil case time.Time: return append(buf, formatTS(v)...), nil case nil: return append(buf, "\\N"...), nil default: return nil, fmt.Errorf("pq: encode: unknown type for %T", v) } } func appendEscapedText(buf []byte, text string) []byte { escapeNeeded := false startPos := 0 var c byte // check if we need to escape for i := 0; i < len(text); i++ { c = text[i] if c == '\\' || c == '\n' || c == '\r' || c == '\t' { escapeNeeded = true startPos = i break } } if !escapeNeeded { return append(buf, text...) } // copy till first char to escape, iterate the rest result := append(buf, text[:startPos]...) for i := startPos; i < len(text); i++ { c = text[i] switch c { case '\\': result = append(result, '\\', '\\') case '\n': result = append(result, '\\', 'n') case '\r': result = append(result, '\\', 'r') case '\t': result = append(result, '\\', 't') default: result = append(result, c) } } return result } func parseTime(f string, typ oid.Oid, s []byte) (time.Time, error) { str := string(s) // Check for a minute and second offset in the timezone. if typ == oid.T_timestamptz || typ == oid.T_timetz { for i := 3; i <= 6; i += 3 { if str[len(str)-i] == ':' { f += ":00" continue } break } } // Special case for 24:00 time. // Unfortunately, golang does not parse 24:00 as a proper time. // In this case, we want to try "round to the next day", to differentiate. // As such, we find if the 24:00 time matches at the beginning; if so, // we default it back to 00:00 but add a day later. var is2400Time bool switch typ { case oid.T_timetz, oid.T_time: if matches := time2400Regex.FindStringSubmatch(str); matches != nil { // Concatenate timezone information at the back. str = "00:00:00" + str[len(matches[1]):] is2400Time = true } } t, err := time.Parse(f, str) if err != nil { return time.Time{}, errors.New("pq: " + err.Error()) } if is2400Time { t = t.Add(24 * time.Hour) } return t, nil } var errInvalidTimestamp = errors.New("invalid timestamp") type timestampParser struct { err error } func (p *timestampParser) expect(str string, char byte, pos int) { if p.err != nil { return } if pos+1 > len(str) { p.err = errInvalidTimestamp return } if c := str[pos]; c != char && p.err == nil { p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) } } func (p *timestampParser) mustAtoi(str string, begin int, end int) int { if p.err != nil { return 0 } if begin < 0 || end < 0 || begin > end || end > len(str) { p.err = errInvalidTimestamp return 0 } result, err := strconv.Atoi(str[begin:end]) if err != nil { if p.err == nil { p.err = fmt.Errorf("expected number; got '%v'", str) } return 0 } return result } // The location cache caches the time zones typically used by the client. type locationCache struct { cache map[int]*time.Location lock sync.Mutex } // All connections share the same list of timezones. Benchmarking shows that // about 5% speed could be gained by putting the cache in the connection and // losing the mutex, at the cost of a small amount of memory and a somewhat // significant increase in code complexity. var globalLocationCache = newLocationCache() func newLocationCache() *locationCache { return &locationCache{cache: make(map[int]*time.Location)} } // Returns the cached timezone for the specified offset, creating and caching // it if necessary. func (c *locationCache) getLocation(offset int) *time.Location { c.lock.Lock() defer c.lock.Unlock() location, ok := c.cache[offset] if !ok { location = time.FixedZone("", offset) c.cache[offset] = location } return location } var ( infinityTSEnabled = false infinityTSNegative time.Time infinityTSPositive time.Time ) const ( infinityTSEnabledAlready = "pq: infinity timestamp enabled already" infinityTSNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" ) // EnableInfinityTs controls the handling of Postgres' "-infinity" and // "infinity" "timestamp"s. // // If EnableInfinityTs is not called, "-infinity" and "infinity" will return // []byte("-infinity") and []byte("infinity") respectively, and potentially // cause error "sql: Scan error on column index 0: unsupported driver -> Scan // pair: []uint8 -> *time.Time", when scanning into a time.Time value. // // Once EnableInfinityTs has been called, all connections created using this // driver will decode Postgres' "-infinity" and "infinity" for "timestamp", // "timestamp with time zone" and "date" types to the predefined minimum and // maximum times, respectively. When encoding time.Time values, any time which // equals or precedes the predefined minimum time will be encoded to // "-infinity". Any values at or past the maximum time will similarly be // encoded to "infinity". // // If EnableInfinityTs is called with negative >= positive, it will panic. // Calling EnableInfinityTs after a connection has been established results in // undefined behavior. If EnableInfinityTs is called more than once, it will // panic. func EnableInfinityTs(negative time.Time, positive time.Time) { if infinityTSEnabled { panic(infinityTSEnabledAlready) } if !negative.Before(positive) { panic(infinityTSNegativeMustBeSmaller) } infinityTSEnabled = true infinityTSNegative = negative infinityTSPositive = positive } // Testing might want to toggle infinityTSEnabled func disableInfinityTS() { infinityTSEnabled = false } // This is a time function specific to the Postgres default DateStyle // setting ("ISO, MDY"), the only one we currently support. This // accounts for the discrepancies between the parsing available with // time.Parse and the Postgres date formatting quirks. func parseTS(currentLocation *time.Location, str string) (any, error) { switch str { case "-infinity": if infinityTSEnabled { return infinityTSNegative, nil } return []byte(str), nil case "infinity": if infinityTSEnabled { return infinityTSPositive, nil } return []byte(str), nil } t, err := ParseTimestamp(currentLocation, str) if err != nil { err = errors.New("pq: " + err.Error()) } return t, err } // ParseTimestamp parses Postgres' text format. It returns a time.Time in // currentLocation iff that time's offset agrees with the offset sent from the // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the // fixed offset offset provided by the Postgres server. func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { p := timestampParser{} monSep := strings.IndexRune(str, '-') // this is Gregorian year, not ISO Year // In Gregorian system, the year 1 BC is followed by AD 1 year := p.mustAtoi(str, 0, monSep) daySep := monSep + 3 month := p.mustAtoi(str, monSep+1, daySep) p.expect(str, '-', daySep) timeSep := daySep + 3 day := p.mustAtoi(str, daySep+1, timeSep) minLen := monSep + len("01-01") + 1 isBC := strings.HasSuffix(str, " BC") if isBC { minLen += 3 } var hour, minute, second int if len(str) > minLen { p.expect(str, ' ', timeSep) minSep := timeSep + 3 p.expect(str, ':', minSep) hour = p.mustAtoi(str, timeSep+1, minSep) secSep := minSep + 3 p.expect(str, ':', secSep) minute = p.mustAtoi(str, minSep+1, secSep) secEnd := secSep + 3 second = p.mustAtoi(str, secSep+1, secEnd) } remainderIdx := monSep + len("01-01 00:00:00") + 1 // Three optional (but ordered) sections follow: the // fractional seconds, the time zone offset, and the BC // designation. We set them up here and adjust the other // offsets if the preceding sections exist. nanoSec := 0 tzOff := 0 if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 fracOff := strings.IndexAny(str[fracStart:], "-+Z ") if fracOff < 0 { fracOff = len(str) - fracStart } fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) remainderIdx += fracOff + 1 } if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { // time zone separator is always '-' or '+' or 'Z' (UTC is +00) var tzSign int switch c := str[tzStart]; c { case '-': tzSign = -1 case '+': tzSign = +1 default: return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) } tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) remainderIdx += 3 var tzMin, tzSec int if remainderIdx < len(str) && str[remainderIdx] == ':' { tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } if remainderIdx < len(str) && str[remainderIdx] == ':' { tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } else if tzStart < len(str) && str[tzStart] == 'Z' { // time zone Z separator indicates UTC is +00 remainderIdx += 1 } var isoYear int if isBC { isoYear = 1 - year remainderIdx += 3 } else { isoYear = year } if remainderIdx < len(str) { return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) } t := time.Date(isoYear, time.Month(month), day, hour, minute, second, nanoSec, globalLocationCache.getLocation(tzOff)) if currentLocation != nil { // Set the location of the returned Time based on the session's // TimeZone value, but only if the local time zone database agrees with // the remote database on the offset. lt := t.In(currentLocation) _, newOff := lt.Zone() if newOff == tzOff { t = lt } } return t, p.err } // formatTS formats t into a format postgres understands. func formatTS(t time.Time) []byte { if infinityTSEnabled { // t <= -infinity : ! (t > -infinity) if !t.After(infinityTSNegative) { return []byte("-infinity") } // t >= infinity : ! (!t < infinity) if !t.Before(infinityTSPositive) { return []byte("infinity") } } return FormatTimestamp(t) } // FormatTimestamp formats t into Postgres' text format for timestamps. func FormatTimestamp(t time.Time) []byte { // Need to send dates before 0001 A.D. with " BC" suffix, instead of the // minus sign preferred by Go. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on bc := false if t.Year() <= 0 { // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset %= 60 if offset != 0 { // RFC3339Nano already printed the minus sign if offset < 0 { offset = -offset } b = append(b, ':') if offset < 10 { b = append(b, '0') } b = strconv.AppendInt(b, int64(offset), 10) } if bc { b = append(b, " BC"...) } return b } // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. func parseBytea(s []byte) (result []byte, err error) { if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { return nil, err } } else { // bytea_output = escape for len(s) > 0 { if s[0] == '\\' { // escaped '\\' if len(s) >= 2 && s[1] == '\\' { result = append(result, '\\') s = s[2:] continue } // '\\' followed by an octal number if len(s) < 4 { return nil, fmt.Errorf("invalid bytea sequence %v", s) } r, err := strconv.ParseUint(string(s[1:4]), 8, 8) if err != nil { return nil, fmt.Errorf("could not parse bytea value: %w", err) } result = append(result, byte(r)) s = s[4:] } else { // We hit an unescaped, raw byte. Try to read in as many as // possible in one go. i := bytes.IndexByte(s, '\\') if i == -1 { result = append(result, s...) break } result = append(result, s[:i]...) s = s[i:] } } } return result, nil } func encodeBytea(v []byte) (result []byte) { result = make([]byte, 2+hex.EncodedLen(len(v))) result[0] = '\\' result[1] = 'x' hex.Encode(result[2:], v) return result } // NullTime represents a [time.Time] that may be null. // NullTime implements the [sql.Scanner] interface so // it can be used as a scan destination, similar to [sql.NullString]. // // Deprecated: this is an alias for [sql.NullTime]. type NullTime = sql.NullTime