Files
common/db/gfdb/gfdb.go
qhd d1f80c3109 refactor: 重构SQL基础实体并集成雪花ID生成器
将主键ID类型从uint64改为int64,移除Bid和Deleter字段;在insertHook中集成Snowflake算法自动生成ID;更新ModuleAssetId为int64类型。
2026-03-19 17:07:01 +08:00

478 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package gfdb
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"regexp"
"strings"
"time"
"gitea.com/red-future/common/utils"
"github.com/bwmarrin/snowflake"
"github.com/gogf/gf/v2/crypto/gmd5"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"go.opentelemetry.io/otel/trace"
)
// ==================== 缓存管理器(单例) ====================
var (
localCache *gcache.Cache
)
// getLocalCache 获取本地缓存实例
func getLocalCache() *gcache.Cache {
if localCache == nil {
localCache = gcache.New()
}
return localCache
}
// getFromCache 从缓存获取数据(本地缓存 -> Redis
func getFromCache(ctx context.Context, key string) ([]byte, bool) {
// 1. 先查本地缓存
if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil {
if data := val.Bytes(); len(data) > 0 {
return data, true
}
}
// 2. 再查Redis缓存
if g.Redis() != nil {
result, err := g.Redis().Get(ctx, key)
if err == nil && !result.IsEmpty() {
data := result.Bytes()
// 写入本地缓存
err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second)
if err != nil {
return nil, false
}
return data, true
}
}
return nil, false
}
// setToCache 写入缓存(本地缓存 + Redis
func setToCache(ctx context.Context, key string, data []byte) (err error) {
if len(data) == 0 {
return
}
// 1. 写入本地缓存
if err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second); err != nil {
return
}
// 2. 写入Redis缓存
if g.Redis() != nil {
_, err = g.Redis().Set(ctx, key, data, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(g.Cfg().MustGet(ctx, "cache.redisTTL")),
},
})
if err != nil {
return
}
}
return
}
// deleteCacheByPattern 根据模式删除缓存
func deleteCacheByPattern(ctx context.Context, pattern string) (err error) {
// 1. 删除匹配模式的本地缓存
localCache := getLocalCache()
keys := localCache.MustKeyStrings(ctx)
if len(keys) > 0 {
for _, key := range keys {
if matchPattern(key, pattern) {
_, err = localCache.Remove(ctx, key)
if err != nil {
return err
}
}
}
}
// 2. 删除Redis缓存使用SCAN+DEL
if g.Redis() != nil {
keys, err := g.Redis().Keys(ctx, pattern)
if err != nil {
return err
}
for _, key := range keys {
_, err = g.Redis().Del(ctx, key)
if err != nil {
return err
}
}
}
return nil
}
// matchPattern 检查 key 是否匹配 Redis SCAN 的 MATCH 模式(支持 * 通配符)
func matchPattern(key string, pattern string) bool {
// 将 Redis 的 MATCH 模式转换为正则表达式
// 转义正则特殊字符(除了 *
regexPattern := regexp.QuoteMeta(pattern)
// 将转义后的 \* 替换回 .*
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
// 添加开始和结束锚点
regexPattern = "^" + regexPattern + "$"
matched, _ := regexp.MatchString(regexPattern, key)
return matched
}
// ==================== 统一Hook入口 ====================
// CatchSQLHook 返回统一的 HookHandler包含租户自动赋值和缓存
// 使用示例:
//
// // 基础使用(自动租户赋值,无缓存)
// g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data)
//
// // 启用缓存用户无感知自动处理缓存key
// ctx = base.WithCacheEnabled(ctx, "asset")
// Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result)
func catchSQLHook() gdb.HookHandler {
return gdb.HookHandler{
Insert: insertHook,
Update: updateHook,
Delete: deleteHook,
Select: selectHook,
}
}
// ==================== Insert钩子 ====================
func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
node, err := snowflake.NewNode(g.Cfg().MustGet(ctx, "server.workerId").Int64())
if err != nil {
return nil, err
}
for i := range in.Data {
if _, ok := in.Data[i]["id"]; ok {
in.Data[i]["id"] = node.Generate().Int64()
}
if !g.IsEmpty(userInfo.UserName) {
if _, ok := in.Data[i]["creator"]; ok {
in.Data[i]["creator"] = userInfo.UserName
}
if _, ok := in.Data[i]["updater"]; ok {
in.Data[i]["updater"] = userInfo.UserName
}
}
}
// 2. 执行插入
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 清除相关缓存
if userInfo != nil && userInfo.TenantId != 0 {
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
return nil, err
}
}
return result, nil
}
// ==================== Update钩子 ====================
func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) {
// 1. 自动赋值修改人
userInfo, _ := utils.GetUserInfo(ctx)
switch data := in.Data.(type) {
case gdb.Map:
if !g.IsEmpty(userInfo.UserName) {
if _, ok := data["updater"]; ok {
data["updater"] = userInfo.UserName
}
}
case gdb.List:
for i := range data {
if !g.IsEmpty(userInfo.UserName) {
if _, ok := data[i]["updater"]; ok {
data[i]["updater"] = userInfo.UserName
}
}
}
}
// 2. 执行更新
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 清除相关缓存
if userInfo != nil && userInfo.TenantId != 0 {
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
return nil, err
}
}
return result, nil
}
// ==================== Delete钩子 ====================
func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) {
// 1. 执行删除
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 2. 清除相关缓存
userInfo, _ := utils.GetUserInfo(ctx)
if userInfo != nil && userInfo.TenantId != 0 {
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
return nil, err
}
}
return result, nil
}
// ==================== Select钩子缓存读取 ====================
func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
traceID := getTraceID(ctx)
enabled, err := gcache.Get(ctx, traceID)
// 未启用缓存,直接执行查询
if !gconv.Bool(enabled) {
return in.Next(ctx)
}
// 从 SQL 字符串中提取 WHERE 条件部分
whereCondition := ""
// 查找 WHERE 关键字(不区分大小写)
whereIndex := gstr.PosI(in.Sql, " WHERE ")
if whereIndex != -1 {
// 提取 WHERE 之后的内容
whereCondition = in.Sql[whereIndex+7:]
// 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句
for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} {
if idx := gstr.PosI(whereCondition, keyword); idx != -1 {
whereCondition = whereCondition[:idx]
}
}
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args))
if err != nil {
return nil, err
}
// 构建缓存keysql:tenantId:table:where条件:args
cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(user.TenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt)
// 1. 先查缓存
if data, ok := getFromCache(ctx, cacheKey); ok {
var records gdb.Result
if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 {
return records, nil
}
}
// 2. 执行数据库查询
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 写入缓存
if len(result) > 0 {
if data, err := json.Marshal(result); err == nil {
if err = setToCache(ctx, cacheKey, data); err != nil {
return nil, err
}
}
}
return result, nil
}
func getCacheKey(tenantId uint64, table string, isBlur bool) string {
cacheKey := fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table)
if isBlur {
cacheKey = fmt.Sprintf("%s:*", cacheKey)
}
return cacheKey
}
// getSelectTypeString 将 SelectType 枚举转换为可读字符串
func getSelectTypeString(selectType gdb.SelectType) string {
switch selectType {
case gdb.SelectTypeDefault:
return "default"
case gdb.SelectTypeCount:
return "count"
case gdb.SelectTypeValue:
return "value"
case gdb.SelectTypeArray:
return "array"
default:
return "unknown"
}
}
// ==================== 调用方法 ====================
var (
schemaPrefix = "tenant-"
)
type Gfdb interface {
Model(ctx context.Context, tableNameOrStruct ...any) *model
Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error
}
type cache interface {
Cache(ctx context.Context) *gdb.Model
}
type model struct {
*gdb.Model
}
type dataBase struct {
gdb.DB
}
func DB(ctx context.Context) Gfdb {
var dbName []string
user, err := utils.GetUserInfo(ctx)
if err != nil {
glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err)
return nil
}
var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId)
sprintf := fmt.Sprintf("database.%s", schema)
if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() {
dbName = append(dbName, schema)
} else {
dbName = append(dbName, "default")
// 配置文件中 default 是数组格式,需要通过索引 0 访问
defaultConfig := g.Cfg().MustGet(ctx, "database.default")
if defaultConfig.IsSlice() {
schema = g.Cfg().MustGet(ctx, "database.default.0.name").String()
} else {
schema = g.Cfg().MustGet(ctx, "database.default.name").String()
}
}
return &dataBase{
DB: g.DB(dbName...).Schema(schema),
}
}
func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model {
user, err := utils.GetUserInfo(ctx)
if err != nil {
glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err)
return nil
}
m := d.DB.Model(tableNameOrStruct...).Ctx(ctx)
var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId)
sprintf := fmt.Sprintf("database.%s", schema)
if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() {
// 创建按地区分库的配置
shardingConfig := gdb.ShardingConfig{
Schema: gdb.ShardingSchemaConfig{
Enable: true, // 启用分库
Prefix: schemaPrefix, // 分库前缀
Rule: &RegionShardingRule{RegionMapping: user.TenantId}, // 自定义分库规则
},
}
m.Sharding(shardingConfig).ShardingValue(user.TenantId)
}
m.OmitNilData().OmitNilWhere().Hook(catchSQLHook())
return &model{
Model: m,
}
}
func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error {
return d.DB.Transaction(ctx, f)
}
func (d *model) Cache(ctx context.Context) *gdb.Model {
traceID := getTraceID(ctx)
if traceID == "" {
glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty")
return nil
}
if err := gcache.Set(ctx, traceID, true, time.Second); err != nil {
glog.Errorf(ctx, "[DB] Cache error: %v", err)
return nil
}
return d.Model
}
// getTraceID 从 context 中获取链路追踪 ID
func getTraceID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
if span != nil && span.SpanContext().HasTraceID() {
return span.SpanContext().TraceID().String()
}
return ""
}
type RegionShardingRule struct {
RegionMapping uint64
}
func (r *RegionShardingRule) SchemaName(ctx context.Context, config gdb.ShardingSchemaConfig, value any) (string, error) {
region, ok := value.(uint64)
if !ok {
return "", fmt.Errorf("sharding value must be string for RegionShardingRule")
}
if r.RegionMapping == region {
return config.Prefix + gconv.String(region), nil
}
return "default", nil
}
// TableName 实现分表规则接口
func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingTableConfig, value any) (string, error) {
// 这里不实现分表,返回空字符串
return "", nil
}