package base import ( "context" "database/sql" "encoding/json" "fmt" "github.com/gogf/gf/v2/text/gstr" "time" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gcache" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/util/gconv" ) // ==================== 上下文键定义 ==================== type ctxKey string const ( // ctxKeySkipTenant 跳过租户ID自动赋值的上下文键 ctxKeySkipTenant ctxKey = "hook_skip_tenant" // ctxKeyCacheEnabled 缓存启用标记的上下文键 ctxKeyCacheEnabled ctxKey = "hook_cache_enabled" // ctxKeyCachePrefix 缓存key前缀的上下文键 ctxKeyCachePrefix ctxKey = "hook_cache_prefix" ) // ==================== 租户相关 ==================== // SkipTenantId 在上下文中标记跳过租户ID自动赋值 func SkipTenantId(ctx context.Context) context.Context { return context.WithValue(ctx, ctxKeySkipTenant, true) } // isSkipTenant 检查是否跳过租户ID func isSkipTenant(ctx context.Context) bool { if ctx == nil { return false } v, ok := ctx.Value(ctxKeySkipTenant).(bool) return ok && v } // ==================== 缓存配置 ==================== // CacheConfig 缓存配置 type CacheConfig struct { // 本地缓存过期时间(秒),默认60秒 LocalTTL int // Redis缓存过期时间(秒),默认300秒 RedisTTL int } // DefaultCacheConfig 默认缓存配置 var DefaultCacheConfig = CacheConfig{ LocalTTL: 60, RedisTTL: 300, } // isCacheEnabled 检查是否启用缓存 func isCacheEnabled(ctx context.Context) bool { if ctx == nil { return false } v, ok := ctx.Value(ctxKeyCacheEnabled).(bool) return ok && v } // getCachePrefix 获取缓存key前缀 func getCachePrefix(ctx context.Context) string { if ctx == nil { return "" } v, ok := ctx.Value(ctxKeyCachePrefix).(string) if !ok { return "" } return v } // ==================== 缓存管理器(单例) ==================== var ( localCache *gcache.Cache ) // getLocalCache 获取本地缓存实例 func getLocalCache() *gcache.Cache { if localCache == nil { localCache = gcache.New() } return localCache } // buildCacheKey 构建缓存key // 根据表名和查询条件自动生成key func buildCacheKey(prefix string, table string, where ...interface{}) string { // 基础key: prefix:table key := fmt.Sprintf("%s:%s", prefix, table) // 如果有where条件,追加到key中 if len(where) > 0 { for _, w := range where { key = fmt.Sprintf("%s:%v", key, w) } } return key } // getFromCache 从缓存获取数据(本地缓存 -> Redis) func getFromCache(ctx context.Context, key string) ([]byte, bool) { config := DefaultCacheConfig // 1. 先查本地缓存 if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil { if data := val.Bytes(); len(data) > 0 { glog.Debugf(ctx, "[Cache] Hit local cache: %s", key) return data, true } } // 2. 再查Redis缓存 if g.Redis() != nil { result, err := g.Redis().Get(ctx, key) if err == nil && !result.IsEmpty() { data := result.Bytes() // 写入本地缓存 getLocalCache().Set(ctx, key, data, time.Duration(config.LocalTTL)*time.Second) glog.Debugf(ctx, "[Cache] Hit redis cache: %s", key) return data, true } } return nil, false } // setToCache 写入缓存(本地缓存 + Redis) func setToCache(ctx context.Context, key string, data []byte) { if len(data) == 0 { return } config := DefaultCacheConfig // 1. 写入本地缓存 getLocalCache().Set(ctx, key, data, time.Duration(config.LocalTTL)*time.Second) // 2. 写入Redis缓存 if g.Redis() != nil { expire := int64(config.RedisTTL) _, err := g.Redis().Set(ctx, key, data, gredis.SetOption{ TTLOption: gredis.TTLOption{ EX: &expire, }, }) if err != nil { glog.Warningf(ctx, "[Cache] Failed to set redis cache: %s, err: %v", key, err) } } } // deleteCache 删除缓存 func deleteCache(ctx context.Context, key string) { // 1. 删除本地缓存 getLocalCache().Remove(ctx, key) // 2. 删除Redis缓存 if g.Redis() != nil { _, err := g.Redis().Del(ctx, key) if err != nil { glog.Warningf(ctx, "[Cache] Failed to delete redis cache: %s, err: %v", key, err) } } } // deleteCacheByPattern 根据模式删除缓存 func deleteCacheByPattern(ctx context.Context, pattern string) { // 1. 清空本地缓存(简单实现:清空所有) getLocalCache().Clear(ctx) // 2. 删除Redis缓存(使用SCAN+DEL) if g.Redis() != nil { var cursor uint64 = 0 for { result, err := g.Redis().Do(ctx, "SCAN", cursor, "MATCH", pattern, "COUNT", 100) if err != nil { glog.Warningf(ctx, "[Cache] Failed to scan redis keys: %s, err: %v", pattern, err) break } resultMap := result.Map() cursor = gconv.Uint64(resultMap["cursor"]) keys := gconv.Strings(resultMap["keys"]) if len(keys) > 0 { args := make([]interface{}, len(keys)) for i, k := range keys { args[i] = k } _, err = g.Redis().Do(ctx, "DEL", args...) if err != nil { glog.Warningf(ctx, "[Cache] Failed to delete redis keys: %v, err: %v", keys, err) } } if cursor == 0 { break } } } } // ==================== 统一Hook入口 ==================== // CatchSQLHook 返回统一的 HookHandler(包含租户自动赋值和缓存) // 使用示例: // // // 基础使用(自动租户赋值,无缓存) // g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data) // // // 启用缓存(用户无感知,自动处理缓存key) // ctx = base.WithCacheEnabled(ctx, "asset") // Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result) func CatchSQLHook() gdb.HookHandler { return gdb.HookHandler{ Insert: insertHook, Update: updateHook, Delete: deleteHook, Select: selectHook, } } // ==================== Insert钩子 ==================== func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { // 1. 自动赋值租户字段 userInfo, _ := utils.GetUserInfo(ctx) if !g.IsEmpty(userInfo.TenantId) { in.Model.Data("tenant_id", userInfo.TenantId) } if !g.IsEmpty(userInfo.UserName) { in.Model.Data("creator", userInfo.UserName) in.Model.Data("updater", userInfo.UserName) } //for i := range in.Data { // if !g.IsEmpty(userInfo.TenantId) { // if _, ok := in.Data[i]["tenant_id"]; !ok { // in.Data[i]["tenant_id"] = userInfo.TenantId // } // } // if !g.IsEmpty(userInfo.UserId) { // if _, ok := in.Data[i]["creator"]; !ok { // in.Data[i]["creator"] = userInfo.UserId // } // if _, ok := in.Data[i]["updater"]; !ok { // in.Data[i]["updater"] = userInfo.UserId // } // } //} // 2. 执行插入 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 清除相关缓存 prefix := getCachePrefix(ctx) if prefix != "" { deleteCacheByPattern(ctx, prefix+":*") glog.Debugf(ctx, "[Hook] Cache cleared after insert, prefix: %s", prefix) } return result, nil } // ==================== Update钩子 ==================== func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { // 1. 自动赋值修改人 userInfo, _ := utils.GetUserInfo(ctx) if !g.IsEmpty(userInfo.TenantId) { in.Model.Where("tenant_id", userInfo.TenantId) } if !g.IsEmpty(userInfo.UserName) { in.Model.Where("creator", userInfo.UserName) in.Model.Where("updater", userInfo.UserName) } //switch data := in.Data.(type) { //case gdb.Map: // if !g.IsEmpty(userInfo.UserId) { // if _, ok := data["updater"]; !ok { // data["updater"] = userInfo.UserId // } // } //case gdb.List: // for i := range data { // if !g.IsEmpty(userInfo.UserId) { // if _, ok := data[i]["updater"]; !ok { // data[i]["updater"] = userInfo.UserId // } // } // } //} // 2. 执行更新 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 清除相关缓存 prefix := getCachePrefix(ctx) if prefix != "" { deleteCacheByPattern(ctx, prefix+":*") glog.Debugf(ctx, "[Hook] Cache cleared after update, prefix: %s", prefix) } return result, nil } // ==================== Delete钩子 ==================== func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { // 1. 执行删除 result, err = in.Next(ctx) if err != nil { return nil, err } // 2. 清除相关缓存 prefix := getCachePrefix(ctx) if prefix != "" { deleteCacheByPattern(ctx, prefix+":*") glog.Debugf(ctx, "[Hook] Cache cleared after delete, prefix: %s", prefix) } return result, nil } // ==================== Select钩子(缓存读取) ==================== func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { userInfo, _ := utils.GetUserInfo(ctx) if !isSkipTenant(ctx) && !g.IsEmpty(userInfo.TenantId) { in.Model.Where("tenant_id", userInfo.TenantId) } // 未启用缓存,直接执行查询 if !isCacheEnabled(ctx) { return in.Next(ctx) } prefix := getCachePrefix(ctx) if prefix == "" { return in.Next(ctx) } // 从 SQL 字符串中提取 WHERE 条件部分 whereCondition := extractWhereCondition(in.Sql) // 构建缓存key:prefix:table:where条件:args cacheKey := buildCacheKey(prefix, in.Table, whereCondition, in.Args) glog.Debugf(ctx, "[Hook] Cache key: %s", cacheKey) // 1. 先查缓存 if data, ok := getFromCache(ctx, cacheKey); ok { var records gdb.Result if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 { glog.Debugf(ctx, "[Hook] Cache hit for key: %s", cacheKey) return records, nil } } // 2. 执行数据库查询 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 写入缓存 if len(result) > 0 { if data, err := json.Marshal(result); err == nil { setToCache(ctx, cacheKey, data) glog.Debugf(ctx, "[Hook] Cache set for key: %s", cacheKey) } } return result, nil } // extractWhereCondition 从 SQL 语句中提取 WHERE 条件部分 func extractWhereCondition(sql string) string { // 查找 WHERE 关键字(不区分大小写) whereIndex := gstr.PosI(sql, " WHERE ") if whereIndex == -1 { return "" } // 提取 WHERE 之后的内容 whereClause := sql[whereIndex+7:] // 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句 for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} { if idx := gstr.PosI(whereClause, keyword); idx != -1 { whereClause = whereClause[:idx] } } return whereClause } // ==================== 快捷方法 ==================== type gfdb interface { Model(tableNameOrStruct ...any) *Model } type cache interface { Cache(ctx context.Context) *gdb.Model } type Model struct { *gdb.Model } type DataBase struct { gdb.DB DbName string } func DB(dbName string) gfdb { return &DataBase{ DB: g.DB(dbName), DbName: dbName, } } func (d *DataBase) Model(tableNameOrStruct ...any) *Model { return &Model{ Model: d.DB.Model(tableNameOrStruct...), } } func (d *Model) Cache(ctx context.Context) *gdb.Model { ctx = context.WithValue(ctx, ctxKeyCachePrefix, true) return d.Model }