likes
comments
collection
share

GORM 自定义类型 -- Pointer Receiver 踩坑记录

作者站长头像
站长
· 阅读数 172

背景

使用过 GORM 的同学都知道,当业务场景中需要存储的自定义数据类型,与数据库中的数据类型不直接对应时,可以将对自定义类型实现 GORM 提供的 sql.Scanner 和 driver.Valuer 接口,定义如何将 Go 自定义类型转换为数据库内置的类型,以及如何从数据库的内置类型转换回 Go 自定义的类型。

问题复现

有一天,小明同学需要在某张表中存储一个 JSON 字符串数组,即 []string 类型。参考之前同学写过的代码:

type AppConfig struct {
    AppID     string `json:"app_id"`
    AppSecret string `json:"app_secret"`
}

// Scan implements the Scanner interface for StringArray
func (c *AppConfig) Scan(value interface{}) error {
    bytes, ok := value.([]byte)
    if !ok {
       return errors.New("type assertion to []byte failed")
    }

    return json.Unmarshal(bytes, c)
}

// Value implements the driver Valuer interface for StringArray
func (c *AppConfig) Value() (driver.Value, error) {
    return json.Marshal(c)
}

于是小明照葫芦画瓢,按照同样的方式定义了自定义类型,并实现了 sql.Scanner 和 driver.Valuer接口:

// StringArray is a custom type to handle []string
type StringArray []string

// Scan implements the Scanner interface for StringArray
func (sa *StringArray) Scan(value interface{}) error {
    bytes, ok := value.([]byte)
    if !ok {
       return errors.New("type assertion to []byte failed")
    }

    return json.Unmarshal(bytes, sa)
}

// Value implements the driver Valuer interface for StringArray
func (sa *StringArray) Value() (driver.Value, error) {
    return json.Marshal(sa)
}

然后,小明将这个字段放到了表结构定义中:

type DBMyTable struct {
    ID        int64          `gorm:"column:id"`
    SA        StringArray    `gorm:"column:sa"`
    // ... 省略了其他字段
    CreatedAt time.Time      `gorm:"column:created_at"`
    UpdatedAt time.Time      `gorm:"column:updated_at`
    DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:datetime"`
}

func (DBMyTable) TableName() string {
    return "my_table"
}

最后,小明实现了一个操作表更新的函数:

type MyTableUpdateParams struct {
    SA StringArray
    // ...
}

func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
    db := mysql.GetConn(ctx)
    
    updateParams := make(map[string]interface{})
    
    // ... 省略了其他方法
    
    if len(params.SA) > 0 {
        updateParams["sa"] = params.SA
    }
    
    if len(updateParams) == 0 {
        return nil
    }
    err := db.Model(&DBMyTable{}).
        Where("id = ?", id).
        Updates(updateParams).Error
       
   if err != nil {
       return err
   }
   
   return nil
}

然而,当小明测试这段代码的时候,发现事情并没有像自己想象的一样顺利。当传入的 SA 参数为 []string{"xiaoming"} 时,最终执行的 SQL 更新语句却是:

UPDATE `my_table` SET `sa`='xiaoming' WHERE id = 1

这与预期完全不符,SA 类型的字段本应该经过 JSON 序列化,变为 [xiaoming] 字符序列存入表中,然后再读取的时候再进行反序列化。但此时存储表中的却是一个字符串,在反序列化 JSON 的时候将会报错。

问题分析

经过单步调试,我们定位到 GORM 中将输入的变量转为 SQL 语句结构的函数:

// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
    for idx, v := range vars {
       if idx > 0 {
          writer.WriteByte(',')
       }

       switch v := v.(type) {
       case sql.NamedArg:
          stmt.Vars = append(stmt.Vars, v.Value)
       case clause.Column, clause.Table:
          stmt.QuoteTo(writer, v)
       case Valuer:
          reflectValue := reflect.ValueOf(v)
          if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
             stmt.AddVar(writer, nil)
          } else {
             stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
          }
       case clause.Interface:
          c := clause.Clause{Name: v.Name()}
          v.MergeClause(&c)
          c.Build(stmt)
       case clause.Expression:
          v.Build(stmt)
       case driver.Valuer:
          stmt.Vars = append(stmt.Vars, v)
          stmt.DB.Dialector.BindVarTo(writer, stmt, v)
       case []byte:
          stmt.Vars = append(stmt.Vars, v)
          stmt.DB.Dialector.BindVarTo(writer, stmt, v)
       case []interface{}:
          if len(v) > 0 {
             writer.WriteByte('(')
             stmt.AddVar(writer, v...)
             writer.WriteByte(')')
          } else {
             writer.WriteString("(NULL)")
          }
       case *DB:
          subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
          if v.Statement.SQL.Len() > 0 {
             var (
                vars = subdb.Statement.Vars
                sql  = v.Statement.SQL.String()
             )

             subdb.Statement.Vars = make([]interface{}, 0, len(vars))
             for _, vv := range vars {
                subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
                bindvar := strings.Builder{}
                v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
                sql = strings.Replace(sql, bindvar.String(), "?", 1)
             }

             subdb.Statement.SQL.Reset()
             subdb.Statement.Vars = stmt.Vars
             if strings.Contains(sql, "@") {
                clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
             } else {
                clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
             }
          } else {
             subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
             subdb.callbacks.Query().Execute(subdb)
          }

          writer.WriteString(subdb.Statement.SQL.String())
          stmt.Vars = subdb.Statement.Vars
       default:
          switch rv := reflect.ValueOf(v); rv.Kind() {
          case reflect.Slice, reflect.Array:
             if rv.Len() == 0 {
                writer.WriteString("(NULL)")
             } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
                stmt.Vars = append(stmt.Vars, v)
                stmt.DB.Dialector.BindVarTo(writer, stmt, v)
             } else {
                writer.WriteByte('(')
                for i := 0; i < rv.Len(); i++ {
                   if i > 0 {
                      writer.WriteByte(',')
                   }
                   stmt.AddVar(writer, rv.Index(i).Interface())
                }
                writer.WriteByte(')')
             }
          default:
             stmt.Vars = append(stmt.Vars, v)
             stmt.DB.Dialector.BindVarTo(writer, stmt, v)
          }
       }
    }
}

在这里,我们预期执行的是标红的 case 分支,也就是 StringArray 预期是实现了 driver.Valuer 的接口,但实际上,代码却执行了绿色的分支,进入了 default 分支。且由于 StringArray 本质上是一个切片,GORM 直接遍历了切片的所有元素进行递归。

由于切片的元素类型是字符串类型,在递归后,执行了标黄部分的代码后,GORM 直接将一个字符串而非 StringArray类型的变量存入了语句中。接着,下面的 GORM 代码就将上面转换后的变量转为数据库内置的值:


func IsValue(v any) bool {
    if v == nil {
       return true
    }
    switch v.(type) {
    case []byte, bool, float64, int64, string, time.Time:
       return true
    case decimalDecompose:
       return true
    }
    return false
}

func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
    if driver.IsValue(v) {
       return v, nil
    }

    if vr, ok := v.(driver.Valuer); ok {
       sv, err := callValuerValue(vr)
       if err != nil {
          return nil, err
       }
       if driver.IsValue(sv) {
          return sv, nil
       }
       // A value returned from the Valuer interface can be "a type handled by
       // a database driver's NamedValueChecker interface" so we should accept
       // uint64 here as well.
       if u, ok := sv.(uint64); ok {
          return u, nil
       }
       return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
    }
    rv := reflect.ValueOf(v)
    switch rv.Kind() {
    case reflect.Ptr:
       // indirect pointers
       if rv.IsNil() {
          return nil, nil
       } else {
          return c.ConvertValue(rv.Elem().Interface())
       }
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
       return rv.Int(), nil
    case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
       return rv.Uint(), nil
    case reflect.Float32, reflect.Float64:
       return rv.Float(), nil
    case reflect.Bool:
       return rv.Bool(), nil
    case reflect.Slice:
       switch t := rv.Type(); {
       case t == jsonType:
          return v, nil
       case t.Elem().Kind() == reflect.Uint8:
          return rv.Bytes(), nil
       default:
          return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
       }
    case reflect.String:
       return rv.String(), nil
    }
    return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

可以看到,如果是字符串类型的变量,会直接经过标绿的判断,返回为一个合法的内置数据库的值。而如果实现了 driver.Valuer 的接口,则会经过标红的部分,调用用户自定义的 Value 方法,执行数据类型转换。

func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
    if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
       rv.IsNil() &&
       rv.Type().Elem().Implements(valuerReflectType) {
       return nil, nil
    }
    return vr.Value()
}

至此,经过单步调试,小明总算是清楚了非预期 SQL 语句的来源。但为什么仿照之前同学写的代码出现了问题呢?之前写的代码会不会也有问题?

经过查询 GORM 官方的资料,发现是这样描述的:

自定义的数据类型必须实现 ScannerValuer 接口,以便让 GORM 知道如何将该类型接收、保存到数据库

例如:

type JSON json.RawMessage

// 实现 sql.Scanner 接口,Scan 将 value 扫描至 Jsonb
func (j *JSON) Scan(value interface{}) error {
    bytes, ok := value.([]byte)
    if !ok {
        return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
    }

    result := json.RawMessage{}
    err := json.Unmarshal(bytes, &result)
    *j = JSON(result)
    return err
}

// 实现 driver.Valuer 接口,Value 返回 json value
func (j JSON) Value() (driver.Value, error) {
    if len(j) == 0 {
        return nil, nil
    }
    return json.RawMessage(j).MarshalJSON()
}

来源:gorm.io/zh\_CN/docs…

我们注意到了一个细节,官方的例子在定义 Value 方法时,使用的 receiver 是值类型的,而非指针类型的,但小明在代码中定义的 Value 方法,使用了指针类型的 receiver:

// Value implements the driver Valuer interface for StringArray
func (sa *StringArray) Value() (driver.Value, error) {
    return json.Marshal(sa)
}

这里一定有问题!

问题解决

小明尝试将指针类型的 receiver 换成了值类型的 receiver,问题解决。

// Value implements the driver Valuer interface for StringArray
func (sa StringArray) Value() (driver.Value, error) {
    return json.Marshal(sa)
}

但同样是实现 driver.Valuer 的方法,使用值类型的 receiver 和指针类型的 receiver,为什么差距会这么大?之前写的代码有没有问题?

查阅 Go 语法的定义:

An interface* type* is defined as a set of method signatures.

A value of interface type can hold any value that implements those methods.

Note**:** There is an error in the example code on line 22. Vertex (the value type) doesn't implement Abser because the Abs method is defined only on *Vertex (the pointer type).

package main

import (
        "fmt"
        "math"
)

type Abser interface {
        Abs() float64
}

func main() {
        var a Abser
        f := MyFloat(-math.Sqrt2)
        v := Vertex{3, 4}

        a = f  // a MyFloat implements Abser
        a = &v // a *Vertex implements Abser

        // In the following line, v is a Vertex (not *Vertex)
        // and does NOT implement Abser.
        a = v

        fmt.Println(a.Abs())
}

type MyFloat float64

func (f MyFloat) Abs() float64 {
        if f < 0 {
                return float64(-f)
        }
        return float64(f)
}

type Vertex struct {
        X, Y float64
}

func (v *Vertex) Abs() float64 {
        return math.Sqrt(v.X*v.X + v.Y*v.Y)
}

来源:go.dev/tour/method…

回到小明的例子上,虽然小明定义了 Value 方法,但是使用的是指针类型的 receiver,所以只有 *StringArray 类型才实现了 driver.Valuer 接口,而 StringArray 类型并没有实现 driver.Valuer 接口。不幸的是,小明在更新的时候,赋值的是 StringArray 的对象,而非*StringArray类型的地址。

func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
    db := mysql.GetConn(ctx)
    
    updateParams := make(map[string]interface{})
    
    // ... 省略了其他方法
    
    if len(params.SA) > 0 {
        updateParams["sa"] = params.SA
    }
    
    if len(updateParams) == 0 {
        return nil
    }
    err := db.Model(&DBMyTable{}).
        Where("id = ?", id).
        Updates(updateParams).Error
       
   if err != nil {
       return err
   }
   
   return nil
}

之前同学实现的代码,使用的都是指针类型进行字段赋值的,所以之前按照指针类型的 receiver 去实现 Value 方法没有出现问题。当然,小明也可以用 *StringArray 的地址来进行更新:

type MyTableUpdateParams struct {
    SA *StringArray
    // ...
}

func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
    db := mysql.GetConn(ctx)
    
    updateParams := make(map[string]interface{})
    
    // ... 省略了其他方法
    
    if params.SA != nil && len(*params.SA) > 0 {
        updateParams["sa"] = params.SA
    }
    
    if len(updateParams) == 0 {
        return nil
    }
    err := db.Model(&DBMyTable{}).
        Where("id = ?", id).
        Updates(updateParams).Error
       
   if err != nil {
       return err
   }
   
   return nil
}

延伸

  1. 通过上面的分析,我们知道指针类型的 receiver 不能同等替代值类型的 receiver。但通过定义值类型的 receiver 实现了某个接口,意味着其指针类型也实现了这个接口?对比下面两段代码:

左侧的代码是我们刚讨论过的例子,显然,t 绑定的是 StringArray 类型,没有实现 Valuer接口,所以输出 uh?。但右侧的代码,却可以输出 ok。

也就是说,一旦某个结构体类型通过值 receiver 实现了某个接口,其对应的指针类型也会实现这个接口。这是因为值是可以取地址的,指针类型的 receiver,可以通过值类型的 receiver 来隐式调用。但反过来不行,因为地址可能无法找到对应的值。可以参考这里的解释:go.dev/tour/method…

  1. 下面的代码会不会出现 panic?

只有右侧的代码会出现 panic。当指针变量 t 调用 A.Hello 方法时,左侧的代码会直接通过指针类型的 receiver 调用,且方法内部没有对 *a 进行解引用。而右侧的代码需要通过值类型的 receiver 来调用 Hello,但由于 t 是 nil,一旦解引用为值作为 receiver,就会出现空指针 panic。

转载自:https://juejin.cn/post/7389200711403683890
评论
请登录