feat: gfdb增加noTenantId
This commit is contained in:
146
db/gfdb/gfdb.go
146
db/gfdb/gfdb.go
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/bwmarrin/snowflake"
|
||||
"github.com/gogf/gf/v2/crypto/gmd5"
|
||||
@@ -280,11 +281,48 @@ func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result
|
||||
// ==================== Select钩子(缓存读取) ====================
|
||||
|
||||
func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
|
||||
traceID := getTraceID(ctx)
|
||||
var tenantId uint64
|
||||
// ===================== 最终版:安全追加租户ID =====================
|
||||
tenantEnabled, err := gcache.Get(ctx, getTraceID(ctx, noTenantIdKeyPrefix))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !gconv.Bool(tenantEnabled) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tenantId = user.TenantId
|
||||
// 【关键修复】找到 SQL 中第一个出现的 ORDER BY / GROUP BY / LIMIT 等关键字位置
|
||||
sql := in.Sql
|
||||
insertPos := len(sql)
|
||||
keywords := []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"}
|
||||
for _, kw := range keywords {
|
||||
if idx := gstr.PosI(sql, kw); idx != -1 {
|
||||
insertPos = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
enabled, err := gcache.Get(ctx, traceID)
|
||||
// 【正确拼接】把条件插入到关键字之前,而不是直接拼在最后
|
||||
condition := " " + beans.DefSQLBaseCol.TenantId + " = ?"
|
||||
if gstr.Contains(gstr.ToUpper(sql), " WHERE ") {
|
||||
// 有 WHERE → 加 AND
|
||||
in.Sql = sql[:insertPos] + " AND" + condition + sql[insertPos:]
|
||||
} else {
|
||||
// 无 WHERE → 加 WHERE
|
||||
in.Sql = sql[:insertPos] + " WHERE" + condition + sql[insertPos:]
|
||||
}
|
||||
in.Args = append(in.Args, tenantId)
|
||||
}
|
||||
// ==================================================================
|
||||
|
||||
cacheEnabled, err := gcache.Get(ctx, getTraceID(ctx, cacheKeyPrefix))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 未启用缓存,直接执行查询
|
||||
if !gconv.Bool(enabled) {
|
||||
if !gconv.Bool(cacheEnabled) {
|
||||
return in.Next(ctx)
|
||||
}
|
||||
|
||||
@@ -303,18 +341,13 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result
|
||||
}
|
||||
}
|
||||
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建缓存key:sql:tenantId:table:where条件:args
|
||||
cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(user.TenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt)
|
||||
cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(tenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt)
|
||||
|
||||
// 1. 先查缓存
|
||||
if data, ok := getFromCache(ctx, cacheKey); ok {
|
||||
@@ -343,7 +376,12 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result
|
||||
}
|
||||
|
||||
func getCacheKey(tenantId uint64, table string, isBlur bool) string {
|
||||
cacheKey := fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table)
|
||||
var cacheKey string
|
||||
if g.IsEmpty(tenantId) {
|
||||
cacheKey = fmt.Sprintf("sql:%s", table)
|
||||
} else {
|
||||
cacheKey = fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table)
|
||||
}
|
||||
if isBlur {
|
||||
cacheKey = fmt.Sprintf("%s:*", cacheKey)
|
||||
}
|
||||
@@ -369,7 +407,9 @@ func getSelectTypeString(selectType gdb.SelectType) string {
|
||||
// ==================== 调用方法 ====================
|
||||
|
||||
var (
|
||||
schemaPrefix = "tenant-"
|
||||
schemaPrefix = "tenant-"
|
||||
cacheKeyPrefix = "cache-"
|
||||
noTenantIdKeyPrefix = "tenantId-"
|
||||
)
|
||||
|
||||
type Gfdb interface {
|
||||
@@ -382,49 +422,39 @@ type cache interface {
|
||||
Cache(ctx context.Context) *gdb.Model
|
||||
}
|
||||
|
||||
type model struct {
|
||||
*gdb.Model
|
||||
type noTenantId interface {
|
||||
NoTenantId(ctx context.Context) *modelCache
|
||||
}
|
||||
|
||||
type dataBase struct {
|
||||
gdb.DB
|
||||
}
|
||||
|
||||
func GetTablePrefix(ctx context.Context) (prefix string, err error) {
|
||||
tenantId, config, err := checkSchemaConfig(ctx)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err)
|
||||
return
|
||||
}
|
||||
if config {
|
||||
sprintf := fmt.Sprintf("database.%s%v.0.prefix", schemaPrefix, tenantId)
|
||||
prefix = g.Cfg().MustGet(ctx, sprintf).String()
|
||||
return
|
||||
}
|
||||
prefix = g.Cfg().MustGet(ctx, "database.default.0.prefix").String()
|
||||
return
|
||||
type model struct {
|
||||
*gdb.Model
|
||||
}
|
||||
|
||||
func checkSchemaConfig(ctx context.Context) (uint64, bool, error) {
|
||||
type modelCache struct {
|
||||
*model
|
||||
}
|
||||
|
||||
func checkSchemaConfig(ctx context.Context) (uint64, bool) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err)
|
||||
return 0, false, err
|
||||
return 0, false
|
||||
}
|
||||
var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId)
|
||||
sprintf := fmt.Sprintf("database.%s", schema)
|
||||
if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() {
|
||||
return user.TenantId, true, nil
|
||||
return user.TenantId, true
|
||||
}
|
||||
return user.TenantId, false, nil
|
||||
return user.TenantId, false
|
||||
}
|
||||
|
||||
func DB(ctx context.Context) Gfdb {
|
||||
tenantId, config, err := checkSchemaConfig(ctx)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err)
|
||||
return nil
|
||||
}
|
||||
tenantId, config := checkSchemaConfig(ctx)
|
||||
|
||||
var schema = fmt.Sprintf("%s%v", schemaPrefix, tenantId)
|
||||
|
||||
var dbName []string
|
||||
@@ -450,11 +480,8 @@ func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model {
|
||||
|
||||
m := d.DB.Model(tableNameOrStruct...).Ctx(ctx)
|
||||
|
||||
tenantId, config, err := checkSchemaConfig(ctx)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "[DB] checkSchemaConfig error: %v", err)
|
||||
return nil
|
||||
}
|
||||
tenantId, config := checkSchemaConfig(ctx)
|
||||
|
||||
if config {
|
||||
// 创建按地区分库的配置
|
||||
shardingConfig := gdb.ShardingConfig{
|
||||
@@ -478,7 +505,7 @@ func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context,
|
||||
}
|
||||
|
||||
func (d *model) Cache(ctx context.Context) *gdb.Model {
|
||||
traceID := getTraceID(ctx)
|
||||
traceID := getTraceID(ctx, cacheKeyPrefix)
|
||||
if traceID == "" {
|
||||
glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty")
|
||||
return nil
|
||||
@@ -490,11 +517,28 @@ func (d *model) Cache(ctx context.Context) *gdb.Model {
|
||||
return d.Model
|
||||
}
|
||||
|
||||
func (d *model) NoTenantId(ctx context.Context) *modelCache {
|
||||
traceID := getTraceID(ctx, noTenantIdKeyPrefix)
|
||||
if traceID == "" {
|
||||
glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty")
|
||||
return nil
|
||||
}
|
||||
if err := gcache.Set(ctx, traceID, true, time.Second); err != nil {
|
||||
glog.Errorf(ctx, "[DB] Cache error: %v", err)
|
||||
return nil
|
||||
}
|
||||
return &modelCache{
|
||||
&model{
|
||||
Model: d.Model,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// getTraceID 从 context 中获取链路追踪 ID
|
||||
func getTraceID(ctx context.Context) string {
|
||||
func getTraceID(ctx context.Context, prefix string) string {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span != nil && span.SpanContext().HasTraceID() {
|
||||
return span.SpanContext().TraceID().String()
|
||||
return fmt.Sprintf("%s%v", prefix, span.SpanContext().TraceID().String())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -521,3 +565,19 @@ func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingT
|
||||
// 这里不实现分表,返回空字符串
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func GetTablePrefix(ctx context.Context) (prefix string, err error) {
|
||||
tenantId, config := checkSchemaConfig(ctx)
|
||||
if config {
|
||||
sprintf := fmt.Sprintf("database.%s%v.0.prefix", schemaPrefix, tenantId)
|
||||
prefix = g.Cfg().MustGet(ctx, sprintf).String()
|
||||
return
|
||||
}
|
||||
defaultConfig := g.Cfg().MustGet(ctx, "database.default")
|
||||
if defaultConfig.IsSlice() {
|
||||
prefix = g.Cfg().MustGet(ctx, "database.default.0.prefix").String()
|
||||
} else {
|
||||
prefix = g.Cfg().MustGet(ctx, "database.default.prefix").String()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user