gotosocial/vendor/github.com/uptrace/bun/schema/scan.go

517 lines
9.9 KiB
Go

package schema
import (
"bytes"
"database/sql"
"fmt"
"net"
"reflect"
"strconv"
"strings"
"time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/extra/bunjson"
"github.com/uptrace/bun/internal"
)
var scannerType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
type ScannerFunc func(dest reflect.Value, src interface{}) error
var scanners []ScannerFunc
func init() {
scanners = []ScannerFunc{
reflect.Bool: scanBool,
reflect.Int: scanInt64,
reflect.Int8: scanInt64,
reflect.Int16: scanInt64,
reflect.Int32: scanInt64,
reflect.Int64: scanInt64,
reflect.Uint: scanUint64,
reflect.Uint8: scanUint64,
reflect.Uint16: scanUint64,
reflect.Uint32: scanUint64,
reflect.Uint64: scanUint64,
reflect.Uintptr: scanUint64,
reflect.Float32: scanFloat64,
reflect.Float64: scanFloat64,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Interface: scanInterface,
reflect.Map: scanJSON,
reflect.Ptr: nil,
reflect.Slice: scanJSON,
reflect.String: scanString,
reflect.Struct: scanJSON,
reflect.UnsafePointer: nil,
}
}
var scannerCache = xsync.NewMapOf[reflect.Type, ScannerFunc]()
func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("msgpack") {
return scanMsgpack
}
if field.Tag.HasOption("json_use_number") {
return scanJSONUseNumber
}
if field.StructField.Type.Kind() == reflect.Interface {
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON, sqltype.JSONB:
return scanJSONIntoInterface
}
}
return Scanner(field.StructField.Type)
}
func Scanner(typ reflect.Type) ScannerFunc {
if v, ok := scannerCache.Load(typ); ok {
return v
}
fn := scanner(typ)
if v, ok := scannerCache.LoadOrStore(typ, fn); ok {
return v
}
return fn
}
func scanner(typ reflect.Type) ScannerFunc {
kind := typ.Kind()
if kind == reflect.Ptr {
if fn := Scanner(typ.Elem()); fn != nil {
return PtrScanner(fn)
}
}
switch typ {
case bytesType:
return scanBytes
case timeType:
return scanTime
case ipType:
return scanIP
case ipNetType:
return scanIPNet
case jsonRawMessageType:
return scanBytes
}
if typ.Implements(scannerType) {
return scanScanner
}
if kind != reflect.Ptr {
ptr := reflect.PointerTo(typ)
if ptr.Implements(scannerType) {
return addrScanner(scanScanner)
}
}
if typ.Kind() == reflect.Slice && typ.Elem().Kind() == reflect.Uint8 {
return scanBytes
}
return scanners[kind]
}
func scanBool(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetBool(false)
return nil
case bool:
dest.SetBool(src)
return nil
case int64:
dest.SetBool(src != 0)
return nil
case []byte:
f, err := strconv.ParseBool(internal.String(src))
if err != nil {
return err
}
dest.SetBool(f)
return nil
case string:
f, err := strconv.ParseBool(src)
if err != nil {
return err
}
dest.SetBool(f)
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanInt64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetInt(0)
return nil
case int64:
dest.SetInt(src)
return nil
case uint64:
dest.SetInt(int64(src))
return nil
case []byte:
n, err := strconv.ParseInt(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
case string:
n, err := strconv.ParseInt(src, 10, 64)
if err != nil {
return err
}
dest.SetInt(n)
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanUint64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetUint(0)
return nil
case uint64:
dest.SetUint(src)
return nil
case int64:
dest.SetUint(uint64(src))
return nil
case []byte:
n, err := strconv.ParseUint(internal.String(src), 10, 64)
if err != nil {
return err
}
dest.SetUint(n)
return nil
case string:
n, err := strconv.ParseUint(src, 10, 64)
if err != nil {
return err
}
dest.SetUint(n)
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanFloat64(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetFloat(0)
return nil
case float64:
dest.SetFloat(src)
return nil
case []byte:
f, err := strconv.ParseFloat(internal.String(src), 64)
if err != nil {
return err
}
dest.SetFloat(f)
return nil
case string:
f, err := strconv.ParseFloat(src, 64)
if err != nil {
return err
}
dest.SetFloat(f)
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanString(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetString("")
return nil
case string:
dest.SetString(src)
return nil
case []byte:
dest.SetString(string(src))
return nil
case time.Time:
dest.SetString(src.Format(time.RFC3339Nano))
return nil
case int64:
dest.SetString(strconv.FormatInt(src, 10))
return nil
case uint64:
dest.SetString(strconv.FormatUint(src, 10))
return nil
case float64:
dest.SetString(strconv.FormatFloat(src, 'G', -1, 64))
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanBytes(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
dest.SetBytes(nil)
return nil
case string:
dest.SetBytes([]byte(src))
return nil
case []byte:
clone := make([]byte, len(src))
copy(clone, src)
dest.SetBytes(clone)
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanTime(dest reflect.Value, src interface{}) error {
switch src := src.(type) {
case nil:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = time.Time{}
return nil
case time.Time:
destTime := dest.Addr().Interface().(*time.Time)
*destTime = src
return nil
case string:
srcTime, err := internal.ParseTime(src)
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
case []byte:
srcTime, err := internal.ParseTime(internal.String(src))
if err != nil {
return err
}
destTime := dest.Addr().Interface().(*time.Time)
*destTime = srcTime
return nil
default:
return scanError(dest.Type(), src)
}
}
func scanScanner(dest reflect.Value, src interface{}) error {
return dest.Interface().(sql.Scanner).Scan(src)
}
func scanMsgpack(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := msgpack.GetDecoder()
defer msgpack.PutDecoder(dec)
dec.Reset(bytes.NewReader(b))
return dec.DecodeValue(dest)
}
func scanJSON(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
return bunjson.Unmarshal(b, dest.Addr().Interface())
}
func scanJSONUseNumber(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
dec := bunjson.NewDecoder(bytes.NewReader(b))
dec.UseNumber()
return dec.Decode(dest.Addr().Interface())
}
func scanIP(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
ip := net.ParseIP(internal.String(b))
if ip == nil {
return fmt.Errorf("bun: invalid ip: %q", b)
}
ptr := dest.Addr().Interface().(*net.IP)
*ptr = ip
return nil
}
func scanIPNet(dest reflect.Value, src interface{}) error {
if src == nil {
return scanNull(dest)
}
b, err := toBytes(src)
if err != nil {
return err
}
_, ipnet, err := net.ParseCIDR(internal.String(b))
if err != nil {
return err
}
ptr := dest.Addr().Interface().(*net.IPNet)
*ptr = *ipnet
return nil
}
func addrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if !dest.CanAddr() {
return fmt.Errorf("bun: Scan(nonaddressable %T)", dest.Interface())
}
return fn(dest.Addr(), src)
}
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return internal.Bytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}
func PtrScanner(fn ScannerFunc) ScannerFunc {
return func(dest reflect.Value, src interface{}) error {
if src == nil {
if !dest.CanAddr() {
if dest.IsNil() {
return nil
}
return fn(dest.Elem(), src)
}
if !dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return nil
}
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
if dest.Kind() == reflect.Map {
return fn(dest, src)
}
return fn(dest.Elem(), src)
}
}
func scanNull(dest reflect.Value) error {
if nilable(dest.Kind()) && dest.IsNil() {
return nil
}
dest.Set(reflect.New(dest.Type()).Elem())
return nil
}
func scanJSONIntoInterface(dest reflect.Value, src interface{}) error {
if dest.IsNil() {
if src == nil {
return nil
}
b, err := toBytes(src)
if err != nil {
return err
}
return bunjson.Unmarshal(b, dest.Addr().Interface())
}
dest = dest.Elem()
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
return scanError(dest.Type(), src)
}
func scanInterface(dest reflect.Value, src interface{}) error {
if dest.IsNil() {
if src == nil {
return nil
}
dest.Set(reflect.ValueOf(src))
return nil
}
dest = dest.Elem()
if fn := Scanner(dest.Type()); fn != nil {
return fn(dest, src)
}
return scanError(dest.Type(), src)
}
func nilable(kind reflect.Kind) bool {
switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
return false
}
func scanError(dest reflect.Type, src interface{}) error {
return fmt.Errorf("bun: can't scan %#v (%T) into %s", src, src, dest.String())
}