Gorm源码学习-创建行记录

1. 前言

Gorm源码学习系列

此文是Gorm源码学习系列的第二篇,主要梳理下通过Gorm创建表的流程。

 

2. 创建行记录代码示例

gorm提供了以下几个接口来创建行记录

  • 一次创建一行 func (db *DB) Create(value interface{}) (tx *DB)
  • 批量创建 func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
  • 数据库不存在主键时创建,存在时更新 func (db *DB) Save(value interface{}) (tx *DB)

详细请看教程及源码finisher_api.go,这里使用func (db *DB) Create(value interface{}) (tx *DB)来说明创建行记录等大致流程。

 

2.1 声明模型

type Stu struct { 	ID     int64 `gorm:"column:id; primary_key" json:"id"` 	Age    int64 `gorm:"column:age;"` 	Height int64 `gorm:"column:height;"` 	Weight int64 `gorm:"column:weight;"` }  // 设置表名 func (Stu) TableName() string { 	return "t_student" }

模型代码的主要用途如下,

  • 申明的表中有哪些列及每列的名称、特性等,如gorm标签指定每个字断对于的表的列名
  • 通过实现Tabler接口指定了固定的表名,接口定义如下
type Tabler interface { 	TableName() string }

关于模型定义中更多的约定和约束等,请看教程

出于分表等业务场景,我们并不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)等方法

来动态指定表名,详情请看教程

 

2.2 创建行

func main() { 	// 数据库连接, 具体查看https://www.cnblogs.com/amos01/p/16890747.html 连接数据库代码示例 	db, _ := dbOpen() 	// 打开调试模式、会打印DML 	db = db.Debug() 	stu := &Stu{ 		Age:    18, 		Height: 185, 		Weight: 70, 	} 	db = db.Create(stu) 	fmt.Printf("Error:%v ID:%v RowsAffected:%vn", db.Error, stu.ID, db.RowsAffected) }

代码输出如下

$ go run main.go 2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33 [1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70) Error:<nil> ID:1027 RowsAffected:1

从代码输出可以看,行记录的ID为1027,连接数据库查询,结果如下。

mysql> select * from t_student where id = 1027G *************************** 1. row ***************************     id: 1027    age: 18 height: 185 weight: 70 1 row in set (0.01 sec)

因此,我们带着以下问题来梳理下Gorm创建行记录的流程

  • 如何从model到DML语句的
  • 如何将ID写入到model的 

 

3. 从Model到DML

func (db *DB) Create(value interface{}) (tx *DB)的实现如下

// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { 	if db.CreateBatchSize > 0 { 		return db.CreateInBatches(value, db.CreateBatchSize) 	}  	tx = db.getInstance() 	tx.Statement.Dest = value 	return tx.callbacks.Create().Execute(tx) }

func (p *processor) Execute(db *DB) *DB 的实现比较长,具体代码见github

总结下来,做了两件主要的事情,

  • 解析model获取表名、每列的定义等
  • 执行钩子函数以及创建行函数

X

3.1 数据结构理解

  • gorm.Statement
查看gorm.Statement代码
// Statement statement type Statement struct { 	*DB 	TableExpr            *clause.Expr 	Table                string      // 表名 	Model                interface{} // model定义 	Unscoped             bool 	Dest                 interface{} // model的另外一种表达,如map 	ReflectValue         reflect.Value 	Clauses              map[string]clause.Clause 	BuildClauses         []string 	Distinct             bool 	Selects              []string // selected columns 	Omits                []string // omit columns 	Joins                []join 	Preloads             map[string][]interface{} 	Settings             sync.Map 	ConnPool             ConnPool       // 数据库连接 	Schema               *schema.Schema // 表结构化信息 	Context              context.Context 	RaiseErrorOnNotFound bool 	SkipHooks            bool 	SQL                  strings.Builder // 最终的DML语句 	Vars                 []interface{}   // DML语句的参数值 	CurDestIndex         int             // 批量创建/更新时,gorm当前操作的数组/slice的下标 	attrs                []interface{} 	assigns              []interface{} 	scopes               []func(*DB) *DB }
  • schema.Schem
查看schema.Schema代码
type Schema struct { 	Name                      string 	ModelType                 reflect.Type 	Table                     string // 表名 	PrioritizedPrimaryField   *Field 	DBNames                   []string // 表每列的名字 	PrimaryFields             []*Field 	PrimaryFieldDBNames       []string // 表的主键列明 	Fields                    []*Field // gorm自定义的model每个字短 	FieldsByName              map[string]*Field 	FieldsByDBName            map[string]*Field 	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database 	Relationships             Relationships 	CreateClauses             []clause.Interface // 创建行的子句 	QueryClauses              []clause.Interface 	UpdateClauses             []clause.Interface 	DeleteClauses             []clause.Interface 	BeforeCreate, AfterCreate bool 	BeforeUpdate, AfterUpdate bool 	BeforeDelete, AfterDelete bool 	BeforeSave, AfterSave     bool 	AfterFind                 bool 	err                       error 	initialized               chan struct{} 	namer                     Namer 	cacheStore                *sync.Map }
  • schema.Field
查看schema.Field代码
// Field is the representation of model schema's field type Field struct { 	Name                   string // model的字段名 	DBName                 string // 对应表的列名 	BindNames              []string 	DataType               DataType 	GORMDataType           DataType 	PrimaryKey             bool 	AutoIncrement          bool 	AutoIncrementIncrement int64 	Creatable              bool 	Updatable              bool 	Readable               bool 	AutoCreateTime         TimeType 	AutoUpdateTime         TimeType 	HasDefaultValue        bool 	DefaultValue           string 	DefaultValueInterface  interface{} 	NotNull                bool 	Unique                 bool 	Comment                string 	Size                   int 	Precision              int 	Scale                  int 	IgnoreMigration        bool 	FieldType              reflect.Type // 反射类型 	IndirectFieldType      reflect.Type // 反射类型 	StructField            reflect.StructField // model字段信息 	Tag                    reflect.StructTag // tag 	TagSettings            map[string]string 	Schema                 *Schema 	EmbeddedSchema         *Schema 	OwnerSchema            *Schema 	ReflectValueOf         func(context.Context, reflect.Value) reflect.Value                  // 通过反射获取该字段的反射对象 	ValueOf                func(context.Context, reflect.Value) (value interface{}, zero bool) // 通过反射获取该字段的值 get方法 	Set                    func(context.Context, reflect.Value, interface{}) error             // 通过反射设置该字段的值 set方法 	Serializer             SerializerInterface 	NewValuePool           FieldNewValuePool }
  • clause.Interfaceclause.Clause

gorm定义了多种clause,包括

查看clause.Interface代码
// Interface clause interface type Interface interface { 	Name() string 	Build(Builder) 	MergeClause(*Clause) }
查看clause.Clause代码
// Clause type Clause struct { 	Name                string // WHERE 	BeforeExpression    Expression 	AfterNameExpression Expression 	AfterExpression     Expression 	Expression          Expression 	Builder             ClauseBuilder }

 

3.2 解析Model

通过调用stmt.Parse(stmt.Model)进行model解析

stmt.Parse(stmt.Model)会调用到函数func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) 进行解析。

详细代码见schema.go,下面列举重要的几个点。

  • 通过反射判断dest interface{}是否为reflect.Struct
  • 通过接口获取表名,其中stu实现了Tabler接口
	// 获取表名 	modelValue := reflect.New(modelType) 	tableName := namer.TableName(modelType.Name()) 	if tabler, ok := modelValue.Interface().(Tabler); ok { 		tableName = tabler.TableName() 	} 	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { 		tableName = tabler.TableName(namer) 	} 	if en, ok := namer.(embeddedNamer); ok { 		tableName = en.Table 	} 	if specialTableName != "" && specialTableName != tableName { 		tableName = specialTableName 	}
  • 解析model每个字段
// 通过反射获取每个字段 for i := 0; i < modelType.NumField(); i++ {     if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {         // 解析每个字段         if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {             schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)         } else {             schema.Fields = append(schema.Fields, field)         }     } }
  • 放到map方便查找,并且通过func (field *Field) setupValuerAndSetter()初始化每个Field的ReflectValueOfValueOfSet方法。
    for _, field := range schema.Fields {     if field.DBName == "" && field.DataType != "" {         field.DBName = namer.ColumnName(schema.Table, field.Name)     }     if field.DBName != "" {         // nonexistence or shortest path or first appear prioritized if has permission         if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {             if _, ok := schema.FieldsByDBName[field.DBName]; !ok {                 schema.DBNames = append(schema.DBNames, field.DBName)             }             // gorm tag字段到field的映射             schema.FieldsByDBName[field.DBName] = field             // model 字段到field的映射             schema.FieldsByName[field.Name] = field             if v != nil && v.PrimaryKey {                 for idx, f := range schema.PrimaryFields {                     if f == v {                         schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)                     }                 }             }             // 主键             if field.PrimaryKey {                 schema.PrimaryFields = append(schema.PrimaryFields, field)             }         }     }     if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {         schema.FieldsByName[field.Name] = field     }     // 挂载字段的set方法和get方法     field.setupValuerAndSetter() }

 

值得一提的是,每个model解析后的结果是一致,可以将结果解析的结构缓存下来,并且通过chan来解决并发的问题。

解析model之后,通过process获取到钩子函数及创建行的函数,具体代码见Github

	for _, f := range p.fns { 		f(db) 	}

 

3.3 执行钩子函数及创建行的函数

创建行的函数及对应的钩子函数位于create.go

  • 创建行记录
if db.Statement.SQL.Len() == 0 {     db.Statement.SQL.Grow(180)     db.Statement.AddClauseIfNotExists(clause.Insert{})     db.Statement.AddClause(ConvertToCreateValues(db.Statement))      db.Statement.Build(db.Statement.BuildClauses...) }

 这里插入两个clause.Clause,分别为clause.Insert以及clause.Values,然后调用这两种clause.Clausebuild方法生成SQL语句。

首先,看下ConvertToCreateValues的实现,这里只截取部分代码

values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} // 获取每一列的名字 for _, db := range stmt.Schema.DBNames {     if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { 	    if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { 		    values.Columns = append(values.Columns, clause.Column{Name: db}) 		} 	} }  // 获取每一列对应的值 switch stmt.ReflectValue.Kind() {     case reflect.Slice, reflect.Array:     case reflect.Struct:         values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}         for idx, column := range values.Columns {             field := stmt.Schema.FieldsByDBName[column.Name]             // func (field *Field) setupValuerAndSetter() 挂载的方法             if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {                 if field.DefaultValueInterface != nil {                     values.Values[0][idx] = field.DefaultValueInterface                     stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))                 } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {                     stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))                     values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)                 }             } else if field.AutoUpdateTime > 0 && updateTrackTime {                 stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))                 values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)             }         }

通过ConvertToCreateValues获取了每一列的名称及对应的值。

接下来,看clause.ClauseSQL语句的过程。

遍历加入clause,此时分别为clause.Insert以及clause.Values

// Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) {     var firstClauseWritten bool     for _, name := range clauses {         if c, ok := stmt.Clauses[name]; ok {             // 代码有删减             c.Build(stmt)         }     } }

接着调用func (c Clause) Build(builder Builder)

// Build build clause  func (c Clause) Build(builder Builder) {     // 有删减     // c为clause.Insert以及clause.Values     if c.Name != "" {         // builder写入 INSERT 或者 VALUES         builder.WriteString(c.Name)         builder.WriteByte(' ')     }     // 通过clause.Insert以及clause.Values的MergeClause函数,c.Expression为clause.Insert以及clause.Values     // 因此,这里调用clause.Insert或者clause.Values的Build的方法     c.Expression.Build(builder) }

接下来分别看clause.Insert以及clause.Values

// Build build insert clause func (insert Insert) Build(builder Builder) {     // builder写入INTO,此时builder为INSERT INTO     builder.WriteString("INTO ")     // builder写入表名     builder.WriteQuoted(currentTable) }

从调用的链路可以得出,这里builderstmt *Statement,并且currentTable类型为clause.Table,因此

// WriteQuoted write quoted value func (stmt *Statement) WriteQuoted(value interface{}) {     stmt.QuoteTo(&stmt.SQL, value) }  // QuoteTo write quoted value to writer 有删减 func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {     write := func(raw bool, str string) {         // mysql驱动Dialector         stmt.DB.Dialector.QuoteTo(writer, str)     }     switch v := field.(type) {     case clause.Table:         write(v.Raw, stmt.Table)     } }

至此,builder已经拼装出INSERT INTO `t_student` ,解析来再看clause.Valuesbuild方法

// Build build from clause func (values Values) Build(builder Builder) {     if len(values.Columns) > 0 {         builder.WriteByte('(')         for idx, column := range values.Columns {             if idx > 0 {                 builder.WriteByte(',')             }             builder.WriteQuoted(column)         }         builder.WriteByte(')')         builder.WriteString(" VALUES ")         for idx, value := range values.Values {             if idx > 0 {                 builder.WriteByte(',')             }             builder.WriteByte('(')             builder.AddVar(builder, value...)             builder.WriteByte(')')         }     } else {         builder.WriteString("DEFAULT VALUES")     } }

func (values Values) Build(builder Builder)取出所有列名和列对应的值

最终builder拼装成例子的完整SQL语句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)

有了SQL语句,就可以执行了

result, err := db.Statement.ConnPool.ExecContext(     db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., )

通过前一面学习,db.Statement.ConnPool的值为sql.DB,实际执行的函数为func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)

至此,从Model到DML到流程已经完成。

 

4. 将ID写入到model的 

看返回参数sql.Result,因此通过LastInsertId() (int64, error)可以获取到插入行的ID值。

// A Result summarizes an executed SQL command. type Result interface { 	// LastInsertId returns the integer generated by the database 	// in response to a command. Typically this will be from an 	// "auto increment" column when inserting a new row. Not all 	// databases support this feature, and the syntax of such 	// statements varies. 	LastInsertId() (int64, error)  	// RowsAffected returns the number of rows affected by an 	// update, insert, or delete. Not every database or database 	// driver may support this. 	RowsAffected() (int64, error) }

获取到刚插入的行ID值,再通过反射写入model的ID字段即可。

db.RowsAffected, _ = result.RowsAffected() if db.RowsAffected != 0 && db.Statement.Schema != nil &&     db.Statement.Schema.PrioritizedPrimaryField != nil &&     db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {     insertID, err := result.LastInsertId()     switch db.Statement.ReflectValue.Kind() {     case reflect.Struct:         _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)         if isZero {             // 通过反射更新ID             db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))         }     } }

 

5. 总结

使用反射解析Model,获得每个成员对应的表的列名、值等信息。

定义SQL各个关键词如INSERTVALUESFROMDELETE的结构体,并实现clause.Interface接口

进而对SQL语句的构造进行抽象封装。

 

发表评论

相关文章