GORM Gen 类型安全 ORM
第十五章:GORM Gen 类型安全 ORM
15.1 GORM Gen 简介
GORM Gen 是基于 GORM 的代码生成工具,提供类型安全的查询 API,解决了原生 GORM 的以下问题:
| 原生 GORM 问题 | GORM Gen 解决方案 |
|---|---|
| 字符串字段名易出错 | 编译时检查的类型安全 API |
| 缺少 IDE 智能提示 | 完全的类型推导和补全 |
| 重构时容易遗漏 | 编译错误提示需要修改的地方 |
| 复杂查询可读性差 | 链式调用,流畅表达 |
15.2 安装与配置
安装
go get -u gorm.io/gen
代码生成配置
package main
import (
"gorm.io/driver/mysql"
"gorm.io/gen"
"gorm.io/gorm"
)
func main() {
g := gen.NewGenerator(gen.Config{
OutPath: "./dao/query", // 生成代码的输出路径
OutFile: "./dao/query/gen.go", // 输出文件名
// 模型包名
ModelPkgPath: "./dao/model",
// 生成模式
Mode: gen.WithoutContext | // 不使用 context
gen.WithDefaultQuery | // 生成默认查询对象
gen.WithQueryInterface, // 生成接口
// 字段可为空时的类型(指针)
FieldNullable: true,
// 生成字段覆盖率标签
FieldCoverable: true,
// 生成字段签名标签
FieldSignable: true,
// 生成 gorm 标签
FieldWithIndexTag: true,
FieldWithTypeTag: true,
})
// 连接数据库
dsn := "user:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
panic(err)
}
g.UseDB(db)
// 生成所有表的代码
g.ApplyBasic(g.GenerateAllTable()...)
// 执行生成
g.Execute()
}
运行生成
go run generate/gen.go
15.3 生成的代码结构
dao/
├── query/ # 查询代码
│ ├── gen.go # 入口文件
│ ├── user.gen.go # User 查询
│ ├── order.gen.go # Order 查询
│ └── ...
└── model/ # 模型代码
├── user.gen.go
├── order.gen.go
└── ...
15.4 基础使用
初始化
package main
import (
"myapp/dao/query"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
func main() {
dsn := "user:password@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
panic(err)
}
// 设置全局 DB
query.SetDefault(db)
}
CRUD 操作
import (
"myapp/dao/query"
"myapp/dao/model"
)
func UserCRUD() {
u := query.User
// ===== 创建 =====
user := &model.User{
Name: "张三",
Email: "zhangsan@example.com",
Age: 25,
}
err := u.Create(user)
// 批量创建
users := []*model.User{
{Name: "李四", Email: "lisi@example.com"},
{Name: "王五", Email: "wangwu@example.com"},
}
err = u.CreateInBatches(users, 100)
// ===== 查询 =====
// 根据主键查询
user, err = u.FirstByID(1)
// 条件查询
user, err = u.Where(u.Name.Eq("张三")).First()
// 多条件
user, err = u.Where(
u.Name.Eq("张三"),
u.Age.Gte(18),
).First()
// 获取所有
allUsers, err := u.Find()
// 分页
list, total, err := u.FindByPage(0, 10) // offset, limit
// ===== 更新 =====
// 更新单列
_, err = u.Where(u.ID.Eq(1)).Update(u.Name, "张三_new")
// 更新多列
_, err = u.Where(u.ID.Eq(1)).Updates(&model.User{
Name: "张三_new",
Age: 26,
})
// 使用 Map 更新(零值有效)
_, err = u.Where(u.ID.Eq(1)).UpdateSimple(
u.Name.Value("张三_new"),
u.Age.Value(0), // 可以更新为 0
)
// ===== 删除 =====
_, err = u.Where(u.ID.Eq(1)).Delete()
// 软删除(如果模型有 DeletedAt)
_, err = u.Where(u.ID.Eq(1)).Delete()
}
15.5 高级查询
条件组合
func AdvancedQuery() {
u := query.User
// AND
users, err := u.Where(
u.Age.Gte(18),
u.Age.Lte(60),
).Find()
// OR
users, err = u.Where(
u.Where(u.Name.Eq("张三")).Or(u.Name.Eq("李四")),
).Find()
// IN
users, err = u.Where(u.ID.In(1, 2, 3)).Find()
// LIKE
users, err = u.Where(u.Name.Like("%张%")).Find()
// BETWEEN
users, err = u.Where(u.Age.Between(18, 30)).Find()
// IS NULL
users, err = u.Where(u.Email.IsNull()).Find()
// NOT
users, err = u.Where(u.Status.NotIn(0, 2)).Find()
}
排序和分页
func OrderAndPage() {
u := query.User
// 排序
users, err := u.Order(u.Age.Desc()).Find()
users, err = u.Order(u.Age.Desc(), u.Name.Asc()).Find()
// 分页
users, err = u.Limit(10).Offset(20).Find()
// 去重
ages, err := u.Distinct().Pluck(u.Age)
// 选择字段
users, err = u.Select(u.Name, u.Email).Find()
}
聚合查询
func Aggregation() {
u := query.User
// Count
count, err := u.Count()
count, err = u.Where(u.Status.Eq(1)).Count()
// Sum
totalAge, err := u.Select(u.Age.Sum()).FindAgg()
// AVG / MAX / MIN
result, err := u.Select(
u.Age.Avg().As("avg_age"),
u.Age.Max().As("max_age"),
u.Age.Min().As("min_age"),
).FindAgg()
// 分组
type AgeGroup struct {
Status int
Count int64
}
var groups []AgeGroup
err = u.Select(u.Status, u.ID.Count().As("count")).
Group(u.Status).
Scan(&groups)
}
Join 查询
func JoinQuery() {
u := query.User
o := query.Order
// 内连接
results, err := u.Join(o, u.ID.EqCol(o.UserID)).
Where(o.Amount.Gt(100)).
Find()
// 左连接
results, err = u.LeftJoin(o, u.ID.EqCol(o.UserID)).Find()
// 预加载(类似 GORM Preload)
users, err := u.Preload(u.Orders).Find()
}
子查询
func SubQuery() {
u := query.User
o := query.Order
// 有订单的用户
users, err := u.Where(
u.Columns(u.ID).In(
o.Select(o.UserID),
),
).Find()
// 订单数大于平均值的用户
subQuery := o.Select(o.UserID).
Group(o.UserID).
Having(o.ID.Count().Gt(
o.Select(o.ID.Count().Avg()),
))
users, err = u.Where(u.Columns(u.ID).In(subQuery)).Find()
}
15.6 事务处理
func Transaction() error {
u := query.User
o := query.Order
return query.Q.Transaction(func(tx *query.Query) error {
// 使用 tx 进行事务内操作
user := &model.User{Name: "张三"}
if err := tx.User.Create(user); err != nil {
return err
}
order := &model.Order{UserID: user.ID, Amount: 100}
if err := tx.Order.Create(order); err != nil {
return err
}
return nil
})
}
15.7 自定义查询方法
在生成时添加自定义方法
func main() {
g := gen.NewGenerator(gen.Config{
OutPath: "./dao/query",
})
// 自定义字段类型
g.WithOpts(gen.FieldOpts{
gen.FieldType("phone", "string"),
gen.FieldGORMTag("phone", "uniqueIndex"),
})
// 自定义查询方法
user := g.GenerateModelAs("user", "User",
gen.FieldIgnore("password"), // 忽略字段
gen.FieldType("status", "int16"),
gen.FieldNewTag("created_at", "json:", "created_at,string"),
)
// 应用自定义方法
g.ApplyInterface(func(method userMethod) {}, user)
g.Execute()
}
// 自定义方法接口
type userMethod interface {
// 根据手机号查询
FindByPhone(phone string) (*gen.T, error)
// 根据状态统计
CountByStatus(status int) (int64, error)
}
生成后的自定义
// 在生成的文件基础上扩展(使用新文件,不要修改生成文件)
package query
// FindActiveUsers 查询活跃用户
func (u user) FindActiveUsers() ([]*model.User, error) {
return u.Where(u.Status.Eq(1), u.LastLoginAt.Gte(time.Now().AddDate(0, -1, 0))).Find()
}
15.8 与原生 GORM 混用
func MixWithGORM(db *gorm.DB) {
u := query.User
// Gen 查询
users, _ := u.Where(u.Status.Eq(1)).Find()
// 原生 GORM 复杂操作
var results []struct{
UserName string
OrderCount int64
TotalAmount float64
}
db.Raw(`
SELECT u.name as user_name,
COUNT(o.id) as order_count,
SUM(o.amount) as total_amount
FROM users u
LEFT JOIN orders o ON u.id = o.user_id
WHERE u.status = ?
GROUP BY u.id
`, 1).Scan(&results)
// 从 Gen 获取 *gorm.DB
genDB := u.UnderlyingDB()
genDB.Exec("...")
}
15.9 完整实战示例
package main
import (
"fmt"
"myapp/dao/model"
"myapp/dao/query"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
func main() {
// 初始化
dsn := "user:password@tcp(127.0.0.1:3306)/myapp?charset=utf8mb4&parseTime=True&loc=Local"
db, _ := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
})
query.SetDefault(db)
// ===== 业务操作示例 =====
// 1. 创建用户
user := &model.User{
Name: "张三",
Email: "zhangsan@example.com",
Phone: "13800138000",
Age: 25,
Status: 1,
}
if err := query.User.Create(user); err != nil {
panic(err)
}
fmt.Printf("创建用户: ID=%d\n", user.ID)
// 2. 查询用户
found, err := query.User.Where(query.User.Email.Eq("zhangsan@example.com")).First()
if err != nil {
panic(err)
}
fmt.Printf("查询用户: %+v\n", found)
// 3. 更新用户
_, err = query.User.Where(query.User.ID.Eq(user.ID)).Update(query.User.Age, 26)
if err != nil {
panic(err)
}
// 4. 分页查询
users, total, err := query.User.Where(query.User.Status.Eq(1)).FindByPage(0, 10)
if err != nil {
panic(err)
}
fmt.Printf("分页查询: 总数=%d, 本页=%d\n", total, len(users))
// 5. 事务处理
err = query.Q.Transaction(func(tx *query.Query) error {
// 创建订单
order := &model.Order{
UserID: user.ID,
Amount: 99.99,
Status: 1,
}
if err := tx.Order.Create(order); err != nil {
return err
}
// 更新用户订单数
_, err := tx.User.Where(tx.User.ID.Eq(user.ID)).
UpdateColumn(tx.User.OrderCount, tx.User.OrderCount.Add(1))
return err
})
if err != nil {
panic(err)
}
fmt.Println("操作完成!")
}
15.10 Gen vs 原生 GORM 对比
| 场景 | 推荐 | 原因 |
|---|---|---|
| 简单 CRUD | Gen | 类型安全,开发效率高 |
| 复杂报表查询 | 原生 GORM | 复杂 SQL 直接用原生更灵活 |
| 动态条件构建 | Gen | 类型安全的条件构建 |
| 需要高度定制的查询 | 原生 GORM | 完全控制 SQL |
| 新项目 | Gen | 减少运行时错误 |
| 已有项目改造 | 混用 | 渐进式迁移 |
15.11 练习题
- 使用 Gen 重构一个已有的原生 GORM 项目
- 实现一个带复杂条件(多表关联、分组统计)的查询
- 编写一个自定义方法生成插件
15.12 小结
GORM Gen 通过代码生成提供了类型安全的 ORM 体验,大幅减少运行时错误,提高开发效率。推荐在新项目中优先使用,已有项目可以渐进式迁移。
本文代码地址:https://github.com/LittleMoreInteresting/gorm_study
欢迎关注公众号,一起学习进步!