GORM Gen 类型安全 ORM


第十五章:GORM Gen 类型安全 ORM

15.1 GORM Gen 简介

GORM Gen 是基于 GORM 的代码生成工具,提供类型安全的查询 API,解决了原生 GORM 的以下问题:

原生 GORM 问题 GORM Gen 解决方案
字符串字段名易出错 编译时检查的类型安全 API
缺少 IDE 智能提示 完全的类型推导和补全
重构时容易遗漏 编译错误提示需要修改的地方
复杂查询可读性差 链式调用,流畅表达

15.2 安装与配置

安装

go get -u gorm.io/gen

代码生成配置

package main

import (
    "gorm.io/driver/mysql"
    "gorm.io/gen"
    "gorm.io/gorm"
)

func main() {
    g := gen.NewGenerator(gen.Config{
        OutPath: "./dao/query",        // 生成代码的输出路径
        OutFile: "./dao/query/gen.go", // 输出文件名
        
        // 模型包名
        ModelPkgPath: "./dao/model",
        
        // 生成模式
        Mode: gen.WithoutContext |     // 不使用 context
            gen.WithDefaultQuery |      // 生成默认查询对象
            gen.WithQueryInterface,     // 生成接口
        
        // 字段可为空时的类型(指针)
        FieldNullable: true,
        
        // 生成字段覆盖率标签
        FieldCoverable: true,
        
        // 生成字段签名标签
        FieldSignable: true,
        
        // 生成 gorm 标签
        FieldWithIndexTag: true,
        FieldWithTypeTag:  true,
    })

    // 连接数据库
    dsn := "user:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
    db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
    if err != nil {
        panic(err)
    }
    g.UseDB(db)

    // 生成所有表的代码
    g.ApplyBasic(g.GenerateAllTable()...)

    // 执行生成
    g.Execute()
}

运行生成

go run generate/gen.go

15.3 生成的代码结构

dao/
├── query/              # 查询代码
│   ├── gen.go          # 入口文件
│   ├── user.gen.go     # User 查询
│   ├── order.gen.go    # Order 查询
│   └── ...
└── model/              # 模型代码
    ├── user.gen.go
    ├── order.gen.go
    └── ...

15.4 基础使用

初始化

package main

import (
    "myapp/dao/query"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

func main() {
    dsn := "user:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
    db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
    if err != nil {
        panic(err)
    }
    
    // 设置全局 DB
    query.SetDefault(db)
}

CRUD 操作

import (
    "myapp/dao/query"
    "myapp/dao/model"
)

func UserCRUD() {
    u := query.User
    
    // ===== 创建 =====
    user := &model.User{
        Name:  "张三",
        Email: "zhangsan@example.com",
        Age:   25,
    }
    err := u.Create(user)
    
    // 批量创建
    users := []*model.User{
        {Name: "李四", Email: "lisi@example.com"},
        {Name: "王五", Email: "wangwu@example.com"},
    }
    err = u.CreateInBatches(users, 100)
    
    // ===== 查询 =====
    // 根据主键查询
    user, err = u.FirstByID(1)
    
    // 条件查询
    user, err = u.Where(u.Name.Eq("张三")).First()
    
    // 多条件
    user, err = u.Where(
        u.Name.Eq("张三"),
        u.Age.Gte(18),
    ).First()
    
    // 获取所有
    allUsers, err := u.Find()
    
    // 分页
    list, total, err := u.FindByPage(0, 10)  // offset, limit
    
    // ===== 更新 =====
    // 更新单列
    _, err = u.Where(u.ID.Eq(1)).Update(u.Name, "张三_new")
    
    // 更新多列
    _, err = u.Where(u.ID.Eq(1)).Updates(&model.User{
        Name: "张三_new",
        Age:  26,
    })
    
    // 使用 Map 更新(零值有效)
    _, err = u.Where(u.ID.Eq(1)).UpdateSimple(
        u.Name.Value("张三_new"),
        u.Age.Value(0),  // 可以更新为 0
    )
    
    // ===== 删除 =====
    _, err = u.Where(u.ID.Eq(1)).Delete()
    
    // 软删除(如果模型有 DeletedAt)
    _, err = u.Where(u.ID.Eq(1)).Delete()
}

15.5 高级查询

条件组合

func AdvancedQuery() {
    u := query.User
    
    // AND
    users, err := u.Where(
        u.Age.Gte(18),
        u.Age.Lte(60),
    ).Find()
    
    // OR
    users, err = u.Where(
        u.Where(u.Name.Eq("张三")).Or(u.Name.Eq("李四")),
    ).Find()
    
    // IN
    users, err = u.Where(u.ID.In(1, 2, 3)).Find()
    
    // LIKE
    users, err = u.Where(u.Name.Like("%张%")).Find()
    
    // BETWEEN
    users, err = u.Where(u.Age.Between(18, 30)).Find()
    
    // IS NULL
    users, err = u.Where(u.Email.IsNull()).Find()
    
    // NOT
    users, err = u.Where(u.Status.NotIn(0, 2)).Find()
}

排序和分页

func OrderAndPage() {
    u := query.User
    
    // 排序
    users, err := u.Order(u.Age.Desc()).Find()
    users, err = u.Order(u.Age.Desc(), u.Name.Asc()).Find()
    
    // 分页
    users, err = u.Limit(10).Offset(20).Find()
    
    // 去重
    ages, err := u.Distinct().Pluck(u.Age)
    
    // 选择字段
    users, err = u.Select(u.Name, u.Email).Find()
}

聚合查询

func Aggregation() {
    u := query.User
    
    // Count
    count, err := u.Count()
    count, err = u.Where(u.Status.Eq(1)).Count()
    
    // Sum
    totalAge, err := u.Select(u.Age.Sum()).FindAgg()
    
    // AVG / MAX / MIN
    result, err := u.Select(
        u.Age.Avg().As("avg_age"),
        u.Age.Max().As("max_age"),
        u.Age.Min().As("min_age"),
    ).FindAgg()
    
    // 分组
    type AgeGroup struct {
        Status int
        Count  int64
    }
    var groups []AgeGroup
    err = u.Select(u.Status, u.ID.Count().As("count")).
        Group(u.Status).
        Scan(&groups)
}

Join 查询

func JoinQuery() {
    u := query.User
    o := query.Order
    
    // 内连接
    results, err := u.Join(o, u.ID.EqCol(o.UserID)).
        Where(o.Amount.Gt(100)).
        Find()
    
    // 左连接
    results, err = u.LeftJoin(o, u.ID.EqCol(o.UserID)).Find()
    
    // 预加载(类似 GORM Preload)
    users, err := u.Preload(u.Orders).Find()
}

子查询

func SubQuery() {
    u := query.User
    o := query.Order
    
    // 有订单的用户
    users, err := u.Where(
        u.Columns(u.ID).In(
            o.Select(o.UserID),
        ),
    ).Find()
    
    // 订单数大于平均值的用户
    subQuery := o.Select(o.UserID).
        Group(o.UserID).
        Having(o.ID.Count().Gt(
            o.Select(o.ID.Count().Avg()),
        ))
    
    users, err = u.Where(u.Columns(u.ID).In(subQuery)).Find()
}

15.6 事务处理

func Transaction() error {
    u := query.User
    o := query.Order
    
    return query.Q.Transaction(func(tx *query.Query) error {
        // 使用 tx 进行事务内操作
        user := &model.User{Name: "张三"}
        if err := tx.User.Create(user); err != nil {
            return err
        }
        
        order := &model.Order{UserID: user.ID, Amount: 100}
        if err := tx.Order.Create(order); err != nil {
            return err
        }
        
        return nil
    })
}

15.7 自定义查询方法

在生成时添加自定义方法

func main() {
    g := gen.NewGenerator(gen.Config{
        OutPath: "./dao/query",
    })
    
    // 自定义字段类型
    g.WithOpts(gen.FieldOpts{
        gen.FieldType("phone", "string"),
        gen.FieldGORMTag("phone", "uniqueIndex"),
    })
    
    // 自定义查询方法
    user := g.GenerateModelAs("user", "User",
        gen.FieldIgnore("password"),  // 忽略字段
        gen.FieldType("status", "int16"),
        gen.FieldNewTag("created_at", "json:", "created_at,string"),
    )
    
    // 应用自定义方法
    g.ApplyInterface(func(method userMethod) {}, user)
    
    g.Execute()
}

// 自定义方法接口
type userMethod interface {
    // 根据手机号查询
    FindByPhone(phone string) (*gen.T, error)
    
    // 根据状态统计
    CountByStatus(status int) (int64, error)
}

生成后的自定义

// 在生成的文件基础上扩展(使用新文件,不要修改生成文件)

package query

// FindActiveUsers 查询活跃用户
func (u user) FindActiveUsers() ([]*model.User, error) {
    return u.Where(u.Status.Eq(1), u.LastLoginAt.Gte(time.Now().AddDate(0, -1, 0))).Find()
}

15.8 与原生 GORM 混用

func MixWithGORM(db *gorm.DB) {
    u := query.User
    
    // Gen 查询
    users, _ := u.Where(u.Status.Eq(1)).Find()
    
    // 原生 GORM 复杂操作
    var results []struct{
        UserName string
        OrderCount int64
        TotalAmount float64
    }
    db.Raw(`
        SELECT u.name as user_name, 
               COUNT(o.id) as order_count,
               SUM(o.amount) as total_amount
        FROM users u
        LEFT JOIN orders o ON u.id = o.user_id
        WHERE u.status = ?
        GROUP BY u.id
    `, 1).Scan(&results)
    
    // 从 Gen 获取 *gorm.DB
    genDB := u.UnderlyingDB()
    genDB.Exec("...")
}

15.9 完整实战示例

package main

import (
    "fmt"
    "myapp/dao/model"
    "myapp/dao/query"
    "gorm.io/driver/mysql"
    "gorm.io/gorm"
    "gorm.io/gorm/logger"
)

func main() {
    // 初始化
    dsn := "user:password@tcp(127.0.0.1:3306)/myapp?charset=utf8mb4&parseTime=True&loc=Local"
    db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{
        Logger: logger.Default.LogMode(logger.Info),
    })
    query.SetDefault(db)
    
    // ===== 业务操作示例 =====
    
    // 1. 创建用户
    user := &model.User{
        Name:     "张三",
        Email:    "zhangsan@example.com",
        Phone:    "13800138000",
        Age:      25,
        Status:   1,
    }
    if err := query.User.Create(user); err != nil {
        panic(err)
    }
    fmt.Printf("创建用户: ID=%d\n", user.ID)
    
    // 2. 查询用户
    found, err := query.User.Where(query.User.Email.Eq("zhangsan@example.com")).First()
    if err != nil {
        panic(err)
    }
    fmt.Printf("查询用户: %+v\n", found)
    
    // 3. 更新用户
    _, err = query.User.Where(query.User.ID.Eq(user.ID)).Update(query.User.Age, 26)
    if err != nil {
        panic(err)
    }
    
    // 4. 分页查询
    users, total, err := query.User.Where(query.User.Status.Eq(1)).FindByPage(0, 10)
    if err != nil {
        panic(err)
    }
    fmt.Printf("分页查询: 总数=%d, 本页=%d\n", total, len(users))
    
    // 5. 事务处理
    err = query.Q.Transaction(func(tx *query.Query) error {
        // 创建订单
        order := &model.Order{
            UserID: user.ID,
            Amount: 99.99,
            Status: 1,
        }
        if err := tx.Order.Create(order); err != nil {
            return err
        }
        
        // 更新用户订单数
        _, err := tx.User.Where(tx.User.ID.Eq(user.ID)).
            UpdateColumn(tx.User.OrderCount, tx.User.OrderCount.Add(1))
        return err
    })
    if err != nil {
        panic(err)
    }
    
    fmt.Println("操作完成!")
}

15.10 Gen vs 原生 GORM 对比

场景 推荐 原因
简单 CRUD Gen 类型安全,开发效率高
复杂报表查询 原生 GORM 复杂 SQL 直接用原生更灵活
动态条件构建 Gen 类型安全的条件构建
需要高度定制的查询 原生 GORM 完全控制 SQL
新项目 Gen 减少运行时错误
已有项目改造 混用 渐进式迁移

15.11 练习题

  1. 使用 Gen 重构一个已有的原生 GORM 项目
  2. 实现一个带复杂条件(多表关联、分组统计)的查询
  3. 编写一个自定义方法生成插件

15.12 小结

GORM Gen 通过代码生成提供了类型安全的 ORM 体验,大幅减少运行时错误,提高开发效率。推荐在新项目中优先使用,已有项目可以渐进式迁移。


本文代码地址:https://github.com/LittleMoreInteresting/gorm_study

欢迎关注公众号,一起学习进步!

如有疑问关注公众号给我留言
wx

关注公众号

©2017-2023 鲁ICP备17023316号-1 Powered by Hugo