@@ -9,6 +9,7 @@ import (
"strings"
"time"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/bwmarrin/snowflake"
"github.com/gogf/gf/v2/crypto/gmd5"
@@ -280,11 +281,48 @@ func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result
// ==================== Select钩子( 缓存读取) ====================
func selectHook ( ctx context . Context , in * gdb . HookSelectInput ) ( result gdb . Result , err error ) {
traceID := getTraceID ( ctx )
var tenantId uint64
// ===================== 最终版: 安全追加租户ID =====================
tenantEnabled , err := gcache . Get ( ctx , getTraceID ( ctx , noTenantIdKeyPrefix ) )
if err != nil {
return
}
if ! gconv . Bool ( tenantEnabled ) {
user , err := utils . GetUserInfo ( ctx )
if err != nil {
return nil , err
}
tenantId = user . TenantId
// 【关键修复】找到 SQL 中第一个出现的 ORDER BY / GROUP BY / LIMIT 等关键字位置
sql := in . Sql
insertPos := len ( sql )
keywords := [ ] string { " ORDER BY " , " GROUP BY " , " HAVING " , " LIMIT " , " FOR UPDATE" }
for _ , kw := range keywords {
if idx := gstr . PosI ( sql , kw ) ; idx != - 1 {
insertPos = idx
break
}
}
enabled , err := gcache . Get ( ctx , traceID )
// 【正确拼接】把条件插入到关键字之前,而不是直接拼在最后
condition := " " + beans . DefSQLBaseCol . TenantId + " = ?"
if gstr . Contains ( gstr . ToUpper ( sql ) , " WHERE " ) {
// 有 WHERE → 加 AND
in . Sql = sql [ : insertPos ] + " AND" + condition + sql [ insertPos : ]
} else {
// 无 WHERE → 加 WHERE
in . Sql = sql [ : insertPos ] + " WHERE" + condition + sql [ insertPos : ]
}
in . Args = append ( in . Args , tenantId )
}
// ==================================================================
cacheEnabled , err := gcache . Get ( ctx , getTraceID ( ctx , cacheKeyPrefix ) )
if err != nil {
return
}
// 未启用缓存,直接执行查询
if ! gconv . Bool ( e nabled) {
if ! gconv . Bool ( cacheE nabled) {
return in . Next ( ctx )
}
@@ -303,18 +341,13 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result
}
}
user , err := utils . GetUserInfo ( ctx )
if err != nil {
return nil , err
}
encrypt , err := gmd5 . Encrypt ( fmt . Sprintf ( "%s:%s" , whereCondition , in . Args ) )
if err != nil {
return nil , err
}
// 构建缓存key: sql:tenantId:table:where条件:args
cacheKey := fmt . Sprintf ( "%s:%s:%s" , getCacheKey ( user . T enantId, in . Table , false ) , getSelectTypeString ( in . SelectType ) , encrypt )
cacheKey := fmt . Sprintf ( "%s:%s:%s" , getCacheKey ( t enantId, in . Table , false ) , getSelectTypeString ( in . SelectType ) , encrypt )
// 1. 先查缓存
if data , ok := getFromCache ( ctx , cacheKey ) ; ok {
@@ -343,7 +376,12 @@ func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result
}
func getCacheKey ( tenantId uint64 , table string , isBlur bool ) string {
cacheKey := fmt . Sprintf ( "sql:tenantId-%v:%s" , tenantId , table )
var cacheKey string
if g . IsEmpty ( tenantId ) {
cacheKey = fmt . Sprintf ( "sql:%s" , table )
} else {
cacheKey = fmt . Sprintf ( "sql:tenantId-%v:%s" , tenantId , table )
}
if isBlur {
cacheKey = fmt . Sprintf ( "%s:*" , cacheKey )
}
@@ -369,7 +407,9 @@ func getSelectTypeString(selectType gdb.SelectType) string {
// ==================== 调用方法 ====================
var (
schemaPrefix = "tenant-"
schemaPrefix = "tenant-"
cacheKeyPrefix = "cache-"
noTenantIdKeyPrefix = "tenantId-"
)
type Gfdb interface {
@@ -382,49 +422,39 @@ type cache interface {
Cache ( ctx context . Context ) * gdb . Model
}
type model struct {
* gdb . Model
type noTenantId interface {
NoTenantId ( ctx context . Context ) * modelCache
}
type dataBase struct {
gdb . DB
}
func GetTablePrefix ( ctx context . Context ) ( prefix string , err error ) {
tenantId , config , err := checkSchemaConfig ( ctx )
if err != nil {
glog . Errorf ( ctx , "[DB] checkSchemaConfig error: %v" , err )
return
}
if config {
sprintf := fmt . Sprintf ( "database.%s%v.0.prefix" , schemaPrefix , tenantId )
prefix = g . Cfg ( ) . MustGet ( ctx , sprintf ) . String ( )
return
}
prefix = g . Cfg ( ) . MustGet ( ctx , "database.default.0.prefix" ) . String ( )
return
type model struct {
* gdb . Model
}
func checkSchemaConfig ( ctx context . Context ) ( uint64 , bool , error ) {
type modelCache struct {
* model
}
func checkSchemaConfig ( ctx context . Context ) ( uint64 , bool ) {
user , err := utils . GetUserInfo ( ctx )
if err != nil {
glog . Errorf ( ctx , "[DB] GetUserInfo error: %v" , err )
return 0 , false , err
return 0 , false
}
var schema = fmt . Sprintf ( "%s%v" , schemaPrefix , user . TenantId )
sprintf := fmt . Sprintf ( "database.%s" , schema )
if ! g . Cfg ( ) . MustGet ( ctx , sprintf ) . IsEmpty ( ) {
return user . TenantId , true , nil
return user . TenantId , true
}
return user . TenantId , false , nil
return user . TenantId , false
}
func DB ( ctx context . Context ) Gfdb {
tenantId , config , err := checkSchemaConfig ( ctx )
if err != nil {
glog . Errorf ( ctx , "[DB] checkSchemaConfig error: %v" , err )
return nil
}
tenantId , config := checkSchemaConfig ( ctx )
var schema = fmt . Sprintf ( "%s%v" , schemaPrefix , tenantId )
var dbName [ ] string
@@ -450,11 +480,8 @@ func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model {
m := d . DB . Model ( tableNameOrStruct ... ) . Ctx ( ctx )
tenantId , config , err := checkSchemaConfig ( ctx )
if err != nil {
glog . Errorf ( ctx , "[DB] checkSchemaConfig error: %v" , err )
return nil
}
tenantId , config := checkSchemaConfig ( ctx )
if config {
// 创建按地区分库的配置
shardingConfig := gdb . ShardingConfig {
@@ -478,7 +505,7 @@ func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context,
}
func ( d * model ) Cache ( ctx context . Context ) * gdb . Model {
traceID := getTraceID ( ctx )
traceID := getTraceID ( ctx , cacheKeyPrefix )
if traceID == "" {
glog . Errorf ( ctx , "[DB] GetTraceID error: traceID is empty" )
return nil
@@ -490,11 +517,28 @@ func (d *model) Cache(ctx context.Context) *gdb.Model {
return d . Model
}
func ( d * model ) NoTenantId ( ctx context . Context ) * modelCache {
traceID := getTraceID ( ctx , noTenantIdKeyPrefix )
if traceID == "" {
glog . Errorf ( ctx , "[DB] GetTraceID error: traceID is empty" )
return nil
}
if err := gcache . Set ( ctx , traceID , true , time . Second ) ; err != nil {
glog . Errorf ( ctx , "[DB] Cache error: %v" , err )
return nil
}
return & modelCache {
& model {
Model : d . Model ,
} ,
}
}
// getTraceID 从 context 中获取链路追踪 ID
func getTraceID ( ctx context . Context ) string {
func getTraceID ( ctx context . Context , prefix string ) string {
span := trace . SpanFromContext ( ctx )
if span != nil && span . SpanContext ( ) . HasTraceID ( ) {
return span . SpanContext ( ) . TraceID ( ) . String ( )
return fmt . Sprintf ( "%s%v" , prefix , span. SpanContext ( ) . TraceID ( ) . String ( ) )
}
return ""
}
@@ -521,3 +565,19 @@ func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingT
// 这里不实现分表,返回空字符串
return "" , nil
}
func GetTablePrefix ( ctx context . Context ) ( prefix string , err error ) {
tenantId , config := checkSchemaConfig ( ctx )
if config {
sprintf := fmt . Sprintf ( "database.%s%v.0.prefix" , schemaPrefix , tenantId )
prefix = g . Cfg ( ) . MustGet ( ctx , sprintf ) . String ( )
return
}
defaultConfig := g . Cfg ( ) . MustGet ( ctx , "database.default" )
if defaultConfig . IsSlice ( ) {
prefix = g . Cfg ( ) . MustGet ( ctx , "database.default.0.prefix" ) . String ( )
} else {
prefix = g . Cfg ( ) . MustGet ( ctx , "database.default.prefix" ) . String ( )
}
return
}