GORM 自定义类型 -- Pointer Receiver 踩坑记录
背景
使用过 GORM 的同学都知道,当业务场景中需要存储的自定义数据类型,与数据库中的数据类型不直接对应时,可以将对自定义类型实现 GORM 提供的 sql.Scanner 和 driver.Valuer 接口,定义如何将 Go 自定义类型转换为数据库内置的类型,以及如何从数据库的内置类型转换回 Go 自定义的类型。
问题复现
有一天,小明同学需要在某张表中存储一个 JSON 字符串数组,即 []string 类型。参考之前同学写过的代码:
type AppConfig struct {
AppID string `json:"app_id"`
AppSecret string `json:"app_secret"`
}
// Scan implements the Scanner interface for StringArray
func (c *AppConfig) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, c)
}
// Value implements the driver Valuer interface for StringArray
func (c *AppConfig) Value() (driver.Value, error) {
return json.Marshal(c)
}
于是小明照葫芦画瓢,按照同样的方式定义了自定义类型,并实现了 sql.Scanner 和 driver.Valuer接口:
// StringArray is a custom type to handle []string
type StringArray []string
// Scan implements the Scanner interface for StringArray
func (sa *StringArray) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New("type assertion to []byte failed")
}
return json.Unmarshal(bytes, sa)
}
// Value implements the driver Valuer interface for StringArray
func (sa *StringArray) Value() (driver.Value, error) {
return json.Marshal(sa)
}
然后,小明将这个字段放到了表结构定义中:
type DBMyTable struct {
ID int64 `gorm:"column:id"`
SA StringArray `gorm:"column:sa"`
// ... 省略了其他字段
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at`
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:datetime"`
}
func (DBMyTable) TableName() string {
return "my_table"
}
最后,小明实现了一个操作表更新的函数:
type MyTableUpdateParams struct {
SA StringArray
// ...
}
func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
db := mysql.GetConn(ctx)
updateParams := make(map[string]interface{})
// ... 省略了其他方法
if len(params.SA) > 0 {
updateParams["sa"] = params.SA
}
if len(updateParams) == 0 {
return nil
}
err := db.Model(&DBMyTable{}).
Where("id = ?", id).
Updates(updateParams).Error
if err != nil {
return err
}
return nil
}
然而,当小明测试这段代码的时候,发现事情并没有像自己想象的一样顺利。当传入的 SA 参数为 []string{"xiaoming"} 时,最终执行的 SQL 更新语句却是:
UPDATE `my_table` SET `sa`='xiaoming' WHERE id = 1
这与预期完全不符,SA 类型的字段本应该经过 JSON 序列化,变为 [xiaoming] 字符序列存入表中,然后再读取的时候再进行反序列化。但此时存储表中的却是一个字符串,在反序列化 JSON 的时候将会报错。
问题分析
经过单步调试,我们定位到 GORM 中将输入的变量转为 SQL 语句结构的函数:
// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
for idx, v := range vars {
if idx > 0 {
writer.WriteByte(',')
}
switch v := v.(type) {
case sql.NamedArg:
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
stmt.AddVar(writer, nil)
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []byte:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []interface{}:
if len(v) > 0 {
writer.WriteByte('(')
stmt.AddVar(writer, v...)
writer.WriteByte(')')
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default:
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
if i > 0 {
writer.WriteByte(',')
}
stmt.AddVar(writer, rv.Index(i).Interface())
}
writer.WriteByte(')')
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
}
}
}
}
在这里,我们预期执行的是标红的 case 分支,也就是 StringArray 预期是实现了 driver.Valuer 的接口,但实际上,代码却执行了绿色的分支,进入了 default 分支。且由于 StringArray 本质上是一个切片,GORM 直接遍历了切片的所有元素进行递归。
由于切片的元素类型是字符串类型,在递归后,执行了标黄部分的代码后,GORM 直接将一个字符串而非 StringArray类型的变量存入了语句中。接着,下面的 GORM 代码就将上面转换后的变量转为数据库内置的值:
func IsValue(v any) bool {
if v == nil {
return true
}
switch v.(type) {
case []byte, bool, float64, int64, string, time.Time:
return true
case decimalDecompose:
return true
}
return false
}
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
if vr, ok := v.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return nil, err
}
if driver.IsValue(sv) {
return sv, nil
}
// A value returned from the Valuer interface can be "a type handled by
// a database driver's NamedValueChecker interface" so we should accept
// uint64 here as well.
if u, ok := sv.(uint64); ok {
return u, nil
}
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
}
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Ptr:
// indirect pointers
if rv.IsNil() {
return nil, nil
} else {
return c.ConvertValue(rv.Elem().Interface())
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint(), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:
return rv.Bool(), nil
case reflect.Slice:
switch t := rv.Type(); {
case t == jsonType:
return v, nil
case t.Elem().Kind() == reflect.Uint8:
return rv.Bytes(), nil
default:
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
}
case reflect.String:
return rv.String(), nil
}
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}
可以看到,如果是字符串类型的变量,会直接经过标绿的判断,返回为一个合法的内置数据库的值。而如果实现了 driver.Valuer 的接口,则会经过标红的部分,调用用户自定义的 Value 方法,执行数据类型转换。
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
rv.IsNil() &&
rv.Type().Elem().Implements(valuerReflectType) {
return nil, nil
}
return vr.Value()
}
至此,经过单步调试,小明总算是清楚了非预期 SQL 语句的来源。但为什么仿照之前同学写的代码出现了问题呢?之前写的代码会不会也有问题?
经过查询 GORM 官方的资料,发现是这样描述的:
自定义的数据类型必须实现 Scanner 和 Valuer 接口,以便让 GORM 知道如何将该类型接收、保存到数据库
例如:
type JSON json.RawMessage // 实现 sql.Scanner 接口,Scan 将 value 扫描至 Jsonb func (j *JSON) Scan(value interface{}) error { bytes, ok := value.([]byte) if !ok { return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) } result := json.RawMessage{} err := json.Unmarshal(bytes, &result) *j = JSON(result) return err } // 实现 driver.Valuer 接口,Value 返回 json value func (j JSON) Value() (driver.Value, error) { if len(j) == 0 { return nil, nil } return json.RawMessage(j).MarshalJSON() }
我们注意到了一个细节,官方的例子在定义 Value 方法时,使用的 receiver 是值类型的,而非指针类型的,但小明在代码中定义的 Value 方法,使用了指针类型的 receiver:
// Value implements the driver Valuer interface for StringArray
func (sa *StringArray) Value() (driver.Value, error) {
return json.Marshal(sa)
}
这里一定有问题!
问题解决
小明尝试将指针类型的 receiver 换成了值类型的 receiver,问题解决。
// Value implements the driver Valuer interface for StringArray
func (sa StringArray) Value() (driver.Value, error) {
return json.Marshal(sa)
}
但同样是实现 driver.Valuer 的方法,使用值类型的 receiver 和指针类型的 receiver,为什么差距会这么大?之前写的代码有没有问题?
查阅 Go 语法的定义:
An interface* type* is defined as a set of method signatures.
A value of interface type can hold any value that implements those methods.
Note**:** There is an error in the example code on line 22. Vertex (the value type) doesn't implement Abser because the Abs method is defined only on *Vertex (the pointer type).
package main import ( "fmt" "math" ) type Abser interface { Abs() float64 } func main() { var a Abser f := MyFloat(-math.Sqrt2) v := Vertex{3, 4} a = f // a MyFloat implements Abser a = &v // a *Vertex implements Abser // In the following line, v is a Vertex (not *Vertex) // and does NOT implement Abser. a = v fmt.Println(a.Abs()) } type MyFloat float64 func (f MyFloat) Abs() float64 { if f < 0 { return float64(-f) } return float64(f) } type Vertex struct { X, Y float64 } func (v *Vertex) Abs() float64 { return math.Sqrt(v.X*v.X + v.Y*v.Y) }
回到小明的例子上,虽然小明定义了 Value 方法,但是使用的是指针类型的 receiver,所以只有 *StringArray 类型才实现了 driver.Valuer 接口,而 StringArray 类型并没有实现 driver.Valuer 接口。不幸的是,小明在更新的时候,赋值的是 StringArray 的对象,而非*StringArray类型的地址。
func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
db := mysql.GetConn(ctx)
updateParams := make(map[string]interface{})
// ... 省略了其他方法
if len(params.SA) > 0 {
updateParams["sa"] = params.SA
}
if len(updateParams) == 0 {
return nil
}
err := db.Model(&DBMyTable{}).
Where("id = ?", id).
Updates(updateParams).Error
if err != nil {
return err
}
return nil
}
之前同学实现的代码,使用的都是指针类型进行字段赋值的,所以之前按照指针类型的 receiver 去实现 Value 方法没有出现问题。当然,小明也可以用 *StringArray 的地址来进行更新:
type MyTableUpdateParams struct {
SA *StringArray
// ...
}
func UpdateSA(ctx context.Context, id int64, params MyTableUpdateParams) error {
db := mysql.GetConn(ctx)
updateParams := make(map[string]interface{})
// ... 省略了其他方法
if params.SA != nil && len(*params.SA) > 0 {
updateParams["sa"] = params.SA
}
if len(updateParams) == 0 {
return nil
}
err := db.Model(&DBMyTable{}).
Where("id = ?", id).
Updates(updateParams).Error
if err != nil {
return err
}
return nil
}
延伸
- 通过上面的分析,我们知道指针类型的 receiver 不能同等替代值类型的 receiver。但通过定义值类型的 receiver 实现了某个接口,意味着其指针类型也实现了这个接口?对比下面两段代码:
左侧的代码是我们刚讨论过的例子,显然,t 绑定的是 StringArray 类型,没有实现 Valuer接口,所以输出 uh?。但右侧的代码,却可以输出 ok。
也就是说,一旦某个结构体类型通过值 receiver 实现了某个接口,其对应的指针类型也会实现这个接口。这是因为值是可以取地址的,指针类型的 receiver,可以通过值类型的 receiver 来隐式调用。但反过来不行,因为地址可能无法找到对应的值。可以参考这里的解释:go.dev/tour/method…
- 下面的代码会不会出现 panic?
只有右侧的代码会出现 panic。当指针变量 t 调用 A.Hello 方法时,左侧的代码会直接通过指针类型的 receiver 调用,且方法内部没有对 *a 进行解引用。而右侧的代码需要通过值类型的 receiver 来调用 Hello,但由于 t 是 nil,一旦解引用为值作为 receiver,就会出现空指针 panic。
转载自:https://juejin.cn/post/7389200711403683890