Files
assets/dao/base/hook.go
2026-03-18 10:18:03 +08:00

457 lines
11 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package base
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/text/gstr"
"time"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
// ==================== 上下文键定义 ====================
type ctxKey string
const (
// ctxKeySkipTenant 跳过租户ID自动赋值的上下文键
ctxKeySkipTenant ctxKey = "hook_skip_tenant"
// ctxKeyCacheEnabled 缓存启用标记的上下文键
ctxKeyCacheEnabled ctxKey = "hook_cache_enabled"
// ctxKeyCachePrefix 缓存key前缀的上下文键
ctxKeyCachePrefix ctxKey = "hook_cache_prefix"
)
// ==================== 租户相关 ====================
// SkipTenantId 在上下文中标记跳过租户ID自动赋值
func SkipTenantId(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxKeySkipTenant, true)
}
// isSkipTenant 检查是否跳过租户ID
func isSkipTenant(ctx context.Context) bool {
if ctx == nil {
return false
}
v, ok := ctx.Value(ctxKeySkipTenant).(bool)
return ok && v
}
// ==================== 缓存配置 ====================
// CacheConfig 缓存配置
type CacheConfig struct {
// 本地缓存过期时间默认60秒
LocalTTL int
// Redis缓存过期时间默认300秒
RedisTTL int
}
// DefaultCacheConfig 默认缓存配置
var DefaultCacheConfig = CacheConfig{
LocalTTL: 60,
RedisTTL: 300,
}
// isCacheEnabled 检查是否启用缓存
func isCacheEnabled(ctx context.Context) bool {
if ctx == nil {
return false
}
v, ok := ctx.Value(ctxKeyCacheEnabled).(bool)
return ok && v
}
// getCachePrefix 获取缓存key前缀
func getCachePrefix(ctx context.Context) string {
if ctx == nil {
return ""
}
v, ok := ctx.Value(ctxKeyCachePrefix).(string)
if !ok {
return ""
}
return v
}
// ==================== 缓存管理器(单例) ====================
var (
localCache *gcache.Cache
)
// getLocalCache 获取本地缓存实例
func getLocalCache() *gcache.Cache {
if localCache == nil {
localCache = gcache.New()
}
return localCache
}
// buildCacheKey 构建缓存key
// 根据表名和查询条件自动生成key
func buildCacheKey(prefix string, table string, where ...interface{}) string {
// 基础key: prefix:table
key := fmt.Sprintf("%s:%s", prefix, table)
// 如果有where条件追加到key中
if len(where) > 0 {
for _, w := range where {
key = fmt.Sprintf("%s:%v", key, w)
}
}
return key
}
// getFromCache 从缓存获取数据(本地缓存 -> Redis
func getFromCache(ctx context.Context, key string) ([]byte, bool) {
config := DefaultCacheConfig
// 1. 先查本地缓存
if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil {
if data := val.Bytes(); len(data) > 0 {
glog.Debugf(ctx, "[Cache] Hit local cache: %s", key)
return data, true
}
}
// 2. 再查Redis缓存
if g.Redis() != nil {
result, err := g.Redis().Get(ctx, key)
if err == nil && !result.IsEmpty() {
data := result.Bytes()
// 写入本地缓存
getLocalCache().Set(ctx, key, data, time.Duration(config.LocalTTL)*time.Second)
glog.Debugf(ctx, "[Cache] Hit redis cache: %s", key)
return data, true
}
}
return nil, false
}
// setToCache 写入缓存(本地缓存 + Redis
func setToCache(ctx context.Context, key string, data []byte) {
if len(data) == 0 {
return
}
config := DefaultCacheConfig
// 1. 写入本地缓存
getLocalCache().Set(ctx, key, data, time.Duration(config.LocalTTL)*time.Second)
// 2. 写入Redis缓存
if g.Redis() != nil {
expire := int64(config.RedisTTL)
_, err := g.Redis().Set(ctx, key, data, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: &expire,
},
})
if err != nil {
glog.Warningf(ctx, "[Cache] Failed to set redis cache: %s, err: %v", key, err)
}
}
}
// deleteCache 删除缓存
func deleteCache(ctx context.Context, key string) {
// 1. 删除本地缓存
getLocalCache().Remove(ctx, key)
// 2. 删除Redis缓存
if g.Redis() != nil {
_, err := g.Redis().Del(ctx, key)
if err != nil {
glog.Warningf(ctx, "[Cache] Failed to delete redis cache: %s, err: %v", key, err)
}
}
}
// deleteCacheByPattern 根据模式删除缓存
func deleteCacheByPattern(ctx context.Context, pattern string) {
// 1. 清空本地缓存(简单实现:清空所有)
getLocalCache().Clear(ctx)
// 2. 删除Redis缓存使用SCAN+DEL
if g.Redis() != nil {
var cursor uint64 = 0
for {
result, err := g.Redis().Do(ctx, "SCAN", cursor, "MATCH", pattern, "COUNT", 100)
if err != nil {
glog.Warningf(ctx, "[Cache] Failed to scan redis keys: %s, err: %v", pattern, err)
break
}
resultMap := result.Map()
cursor = gconv.Uint64(resultMap["cursor"])
keys := gconv.Strings(resultMap["keys"])
if len(keys) > 0 {
args := make([]interface{}, len(keys))
for i, k := range keys {
args[i] = k
}
_, err = g.Redis().Do(ctx, "DEL", args...)
if err != nil {
glog.Warningf(ctx, "[Cache] Failed to delete redis keys: %v, err: %v", keys, err)
}
}
if cursor == 0 {
break
}
}
}
}
// ==================== 统一Hook入口 ====================
// CatchSQLHook 返回统一的 HookHandler包含租户自动赋值和缓存
// 使用示例:
//
// // 基础使用(自动租户赋值,无缓存)
// g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data)
//
// // 启用缓存用户无感知自动处理缓存key
// ctx = base.WithCacheEnabled(ctx, "asset")
// Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result)
func CatchSQLHook() gdb.HookHandler {
return gdb.HookHandler{
Insert: insertHook,
Update: updateHook,
Delete: deleteHook,
Select: selectHook,
}
}
// ==================== Insert钩子 ====================
func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) {
// 1. 自动赋值租户字段
userInfo, _ := utils.GetUserInfo(ctx)
if !g.IsEmpty(userInfo.TenantId) {
in.Model.Data("tenant_id", userInfo.TenantId)
}
if !g.IsEmpty(userInfo.UserName) {
in.Model.Data("creator", userInfo.UserName)
in.Model.Data("updater", userInfo.UserName)
}
//for i := range in.Data {
// if !g.IsEmpty(userInfo.TenantId) {
// if _, ok := in.Data[i]["tenant_id"]; !ok {
// in.Data[i]["tenant_id"] = userInfo.TenantId
// }
// }
// if !g.IsEmpty(userInfo.UserId) {
// if _, ok := in.Data[i]["creator"]; !ok {
// in.Data[i]["creator"] = userInfo.UserId
// }
// if _, ok := in.Data[i]["updater"]; !ok {
// in.Data[i]["updater"] = userInfo.UserId
// }
// }
//}
// 2. 执行插入
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 清除相关缓存
prefix := getCachePrefix(ctx)
if prefix != "" {
deleteCacheByPattern(ctx, prefix+":*")
glog.Debugf(ctx, "[Hook] Cache cleared after insert, prefix: %s", prefix)
}
return result, nil
}
// ==================== Update钩子 ====================
func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) {
// 1. 自动赋值修改人
userInfo, _ := utils.GetUserInfo(ctx)
if !g.IsEmpty(userInfo.TenantId) {
in.Model.Where("tenant_id", userInfo.TenantId)
}
if !g.IsEmpty(userInfo.UserName) {
in.Model.Where("creator", userInfo.UserName)
in.Model.Where("updater", userInfo.UserName)
}
//switch data := in.Data.(type) {
//case gdb.Map:
// if !g.IsEmpty(userInfo.UserId) {
// if _, ok := data["updater"]; !ok {
// data["updater"] = userInfo.UserId
// }
// }
//case gdb.List:
// for i := range data {
// if !g.IsEmpty(userInfo.UserId) {
// if _, ok := data[i]["updater"]; !ok {
// data[i]["updater"] = userInfo.UserId
// }
// }
// }
//}
// 2. 执行更新
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 清除相关缓存
prefix := getCachePrefix(ctx)
if prefix != "" {
deleteCacheByPattern(ctx, prefix+":*")
glog.Debugf(ctx, "[Hook] Cache cleared after update, prefix: %s", prefix)
}
return result, nil
}
// ==================== Delete钩子 ====================
func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) {
// 1. 执行删除
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 2. 清除相关缓存
prefix := getCachePrefix(ctx)
if prefix != "" {
deleteCacheByPattern(ctx, prefix+":*")
glog.Debugf(ctx, "[Hook] Cache cleared after delete, prefix: %s", prefix)
}
return result, nil
}
// ==================== Select钩子缓存读取 ====================
func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
userInfo, _ := utils.GetUserInfo(ctx)
if !isSkipTenant(ctx) && !g.IsEmpty(userInfo.TenantId) {
in.Model.Where("tenant_id", userInfo.TenantId)
}
// 未启用缓存,直接执行查询
if !isCacheEnabled(ctx) {
return in.Next(ctx)
}
prefix := getCachePrefix(ctx)
if prefix == "" {
return in.Next(ctx)
}
// 从 SQL 字符串中提取 WHERE 条件部分
whereCondition := extractWhereCondition(in.Sql)
// 构建缓存keyprefix:table:where条件:args
cacheKey := buildCacheKey(prefix, in.Table, whereCondition, in.Args)
glog.Debugf(ctx, "[Hook] Cache key: %s", cacheKey)
// 1. 先查缓存
if data, ok := getFromCache(ctx, cacheKey); ok {
var records gdb.Result
if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 {
glog.Debugf(ctx, "[Hook] Cache hit for key: %s", cacheKey)
return records, nil
}
}
// 2. 执行数据库查询
result, err = in.Next(ctx)
if err != nil {
return nil, err
}
// 3. 写入缓存
if len(result) > 0 {
if data, err := json.Marshal(result); err == nil {
setToCache(ctx, cacheKey, data)
glog.Debugf(ctx, "[Hook] Cache set for key: %s", cacheKey)
}
}
return result, nil
}
// extractWhereCondition 从 SQL 语句中提取 WHERE 条件部分
func extractWhereCondition(sql string) string {
// 查找 WHERE 关键字(不区分大小写)
whereIndex := gstr.PosI(sql, " WHERE ")
if whereIndex == -1 {
return ""
}
// 提取 WHERE 之后的内容
whereClause := sql[whereIndex+7:]
// 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句
for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} {
if idx := gstr.PosI(whereClause, keyword); idx != -1 {
whereClause = whereClause[:idx]
}
}
return whereClause
}
// ==================== 快捷方法 ====================
type gfdb interface {
Model(tableNameOrStruct ...any) *Model
}
type cache interface {
Cache(ctx context.Context) *gdb.Model
}
type Model struct {
*gdb.Model
}
type DataBase struct {
gdb.DB
DbName string
}
func DB(dbName string) gfdb {
return &DataBase{
DB: g.DB(dbName),
DbName: dbName,
}
}
func (d *DataBase) Model(tableNameOrStruct ...any) *Model {
return &Model{
Model: d.DB.Model(tableNameOrStruct...),
}
}
func (d *Model) Cache(ctx context.Context) *gdb.Model {
ctx = context.WithValue(ctx, ctxKeyCachePrefix, true)
return d.Model
}