From 0c2e36f607474423a344dd0d11e182327bef366e Mon Sep 17 00:00:00 2001 From: qhd <1766646056@qq.com> Date: Thu, 2 Apr 2026 10:37:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20gfdb=E5=A2=9E=E5=8A=A0noTenantId?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- db/gfdb/gfdb.go | 146 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 103 insertions(+), 43 deletions(-) diff --git a/db/gfdb/gfdb.go b/db/gfdb/gfdb.go index ea5d844..3c1c754 100644 --- a/db/gfdb/gfdb.go +++ b/db/gfdb/gfdb.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" "github.com/bwmarrin/snowflake" "github.com/gogf/gf/v2/crypto/gmd5" @@ -280,11 +281,48 @@ func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result // ==================== Select钩子(缓存读取) ==================== func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { - traceID := getTraceID(ctx) + var tenantId uint64 + // ===================== 最终版:安全追加租户ID ===================== + tenantEnabled, err := gcache.Get(ctx, getTraceID(ctx, noTenantIdKeyPrefix)) + if err != nil { + return + } + if !gconv.Bool(tenantEnabled) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + tenantId = user.TenantId + // 【关键修复】找到 SQL 中第一个出现的 ORDER BY / GROUP BY / LIMIT 等关键字位置 + sql := in.Sql + insertPos := len(sql) + keywords := []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} + for _, kw := range keywords { + if idx := gstr.PosI(sql, kw); idx != -1 { + insertPos = idx + break + } + } - enabled, err := gcache.Get(ctx, traceID) + // 【正确拼接】把条件插入到关键字之前,而不是直接拼在最后 + condition := " " + beans.DefSQLBaseCol.TenantId + " = ?" + if gstr.Contains(gstr.ToUpper(sql), " WHERE ") { + // 有 WHERE → 加 AND + in.Sql = sql[:insertPos] + " AND" + condition + sql[insertPos:] + } else { + // 无 WHERE → 加 WHERE + in.Sql = sql[:insertPos] + " WHERE" + condition + sql[insertPos:] + } + in.Args = append(in.Args, tenantId) + } + // ================================================================== + + cacheEnabled, err := gcache.Get(ctx, getTraceID(ctx, cacheKeyPrefix)) + if err != nil { + return + } // 未启用缓存,直接执行查询 - if !gconv.Bool(enabled) { + if !gconv.Bool(cacheEnabled) { return in.Next(ctx) } @@ -303,18 +341,13 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result } } - user, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args)) if err != nil { return nil, err } // 构建缓存key:sql:tenantId:table:where条件:args - cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(user.TenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt) + cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(tenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt) // 1. 先查缓存 if data, ok := getFromCache(ctx, cacheKey); ok { @@ -343,7 +376,12 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result } func getCacheKey(tenantId uint64, table string, isBlur bool) string { - cacheKey := fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table) + var cacheKey string + if g.IsEmpty(tenantId) { + cacheKey = fmt.Sprintf("sql:%s", table) + } else { + cacheKey = fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table) + } if isBlur { cacheKey = fmt.Sprintf("%s:*", cacheKey) } @@ -369,7 +407,9 @@ func getSelectTypeString(selectType gdb.SelectType) string { // ==================== 调用方法 ==================== var ( - schemaPrefix = "tenant-" + schemaPrefix = "tenant-" + cacheKeyPrefix = "cache-" + noTenantIdKeyPrefix = "tenantId-" ) type Gfdb interface { @@ -382,49 +422,39 @@ type cache interface { Cache(ctx context.Context) *gdb.Model } -type model struct { - *gdb.Model +type noTenantId interface { + NoTenantId(ctx context.Context) *modelCache } type dataBase struct { gdb.DB } -func GetTablePrefix(ctx context.Context) (prefix string, err error) { - tenantId, config, err := checkSchemaConfig(ctx) - if err != nil { - glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err) - return - } - if config { - sprintf := fmt.Sprintf("database.%s%v.0.prefix", schemaPrefix, tenantId) - prefix = g.Cfg().MustGet(ctx, sprintf).String() - return - } - prefix = g.Cfg().MustGet(ctx, "database.default.0.prefix").String() - return +type model struct { + *gdb.Model } -func checkSchemaConfig(ctx context.Context) (uint64, bool, error) { +type modelCache struct { + *model +} + +func checkSchemaConfig(ctx context.Context) (uint64, bool) { user, err := utils.GetUserInfo(ctx) if err != nil { glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err) - return 0, false, err + return 0, false } var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId) sprintf := fmt.Sprintf("database.%s", schema) if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() { - return user.TenantId, true, nil + return user.TenantId, true } - return user.TenantId, false, nil + return user.TenantId, false } func DB(ctx context.Context) Gfdb { - tenantId, config, err := checkSchemaConfig(ctx) - if err != nil { - glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err) - return nil - } + tenantId, config := checkSchemaConfig(ctx) + var schema = fmt.Sprintf("%s%v", schemaPrefix, tenantId) var dbName []string @@ -450,11 +480,8 @@ func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model { m := d.DB.Model(tableNameOrStruct...).Ctx(ctx) - tenantId, config, err := checkSchemaConfig(ctx) - if err != nil { - glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err) - return nil - } + tenantId, config := checkSchemaConfig(ctx) + if config { // 创建按地区分库的配置 shardingConfig := gdb.ShardingConfig{ @@ -478,7 +505,7 @@ func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context, } func (d *model) Cache(ctx context.Context) *gdb.Model { - traceID := getTraceID(ctx) + traceID := getTraceID(ctx, cacheKeyPrefix) if traceID == "" { glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty") return nil @@ -490,11 +517,28 @@ func (d *model) Cache(ctx context.Context) *gdb.Model { return d.Model } +func (d *model) NoTenantId(ctx context.Context) *modelCache { + traceID := getTraceID(ctx, noTenantIdKeyPrefix) + if traceID == "" { + glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty") + return nil + } + if err := gcache.Set(ctx, traceID, true, time.Second); err != nil { + glog.Errorf(ctx, "[DB] Cache error: %v", err) + return nil + } + return &modelCache{ + &model{ + Model: d.Model, + }, + } +} + // getTraceID 从 context 中获取链路追踪 ID -func getTraceID(ctx context.Context) string { +func getTraceID(ctx context.Context, prefix string) string { span := trace.SpanFromContext(ctx) if span != nil && span.SpanContext().HasTraceID() { - return span.SpanContext().TraceID().String() + return fmt.Sprintf("%s%v", prefix, span.SpanContext().TraceID().String()) } return "" } @@ -521,3 +565,19 @@ func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingT // 这里不实现分表,返回空字符串 return "", nil } + +func GetTablePrefix(ctx context.Context) (prefix string, err error) { + tenantId, config := checkSchemaConfig(ctx) + if config { + sprintf := fmt.Sprintf("database.%s%v.0.prefix", schemaPrefix, tenantId) + prefix = g.Cfg().MustGet(ctx, sprintf).String() + return + } + defaultConfig := g.Cfg().MustGet(ctx, "database.default") + if defaultConfig.IsSlice() { + prefix = g.Cfg().MustGet(ctx, "database.default.0.prefix").String() + } else { + prefix = g.Cfg().MustGet(ctx, "database.default.prefix").String() + } + return +}