mongo.go重构

This commit is contained in:
2025-12-30 10:52:12 +08:00
parent 7d05387104
commit f06e050d78
2 changed files with 48 additions and 102 deletions

View File

@@ -11,7 +11,6 @@ import (
"github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/container/gvar"
"gitee.com/red-future---jilin-g/common/consts" "gitee.com/red-future---jilin-g/common/consts"
"gitee.com/red-future---jilin-g/common/do"
"gitee.com/red-future---jilin-g/common/redis" "gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils" "gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/errors/gerror"
@@ -99,6 +98,13 @@ func connect() error {
return nil return nil
} }
// GetDB 获取 MongoDB 数据库实例
func GetDB() *mongo.Database {
mu.RLock()
defer mu.RUnlock()
return db
}
// healthCheck 健康检查协程 // healthCheck 健康检查协程
func healthCheck() { func healthCheck() {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(30 * time.Second)
@@ -261,76 +267,12 @@ func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOne
return return
} }
// getTenantInfo 获取租户信息
// 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表
func getTenantInfo(ctx context.Context) (user do.User, err error) {
// 1. 优先从 token 获取
user, err = utils.GetUserInfo(ctx)
if err == nil {
return
}
// 2. token 获取失败尝试从请求参数或context获取 accountName
var accountName string
// 2.1 尝试从request获取HTTP请求场景
req := g.RequestFromCtx(ctx)
if req != nil {
accountName = req.Get("accountName").String()
if accountName == "" {
accountName = req.Get("account_name").String()
}
// 兼容旧参数名
if accountName == "" {
accountName = req.Get("customerServiceId").String()
}
if accountName == "" {
accountName = req.Get("customer_service_id").String()
}
}
// 2.2 request不存在或未获取到尝试从context.Value获取WebSocket场景
if accountName == "" {
if val := ctx.Value("accountName"); val != nil {
if str, ok := val.(string); ok {
accountName = str
}
}
// 兼容旧参数名
if accountName == "" {
if val := ctx.Value("customerServiceId"); val != nil {
if str, ok := val.(string); ok {
accountName = str
}
}
}
}
if accountName == "" {
return user, gerror.New("无法获取租户信息:无 token 且无 accountName 参数")
}
// 3. 直接查询 customer_service_account 表获取 tenantId
filter := bson.M{"accountName": accountName, "isDeleted": false}
var account struct {
TenantId interface{} `bson:"tenantId"`
}
if findErr := db.Collection("customer_service_account").FindOne(ctx, filter).Decode(&account); findErr != nil {
return user, gerror.Newf("通过 accountName 查询租户失败: %v", findErr)
}
user.TenantId = account.TenantId
user.UserName = accountName
err = nil // 清空之前从token获取时的错误
return
}
// Find 查询多条记录 // Find 查询多条记录
func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) {
if err = utils.ValidStructPtr(result); err != nil { if err = utils.ValidStructPtr(result); err != nil {
return return
} }
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -379,7 +321,7 @@ func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}
if err = utils.ValidStructPtr(result); err != nil { if err = utils.ValidStructPtr(result); err != nil {
return return
} }
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -456,7 +398,7 @@ func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string,
err = gerror.New("缺少查询条件") err = gerror.New("缺少查询条件")
return return
} }
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -477,7 +419,7 @@ func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, coll
return return
} }
filter["isDeleted"] = false filter["isDeleted"] = false
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -557,7 +499,7 @@ func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bs
err = gerror.New("查询条件和更新数据的数量必须一致") err = gerror.New("查询条件和更新数据的数量必须一致")
return return
} }
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -622,7 +564,7 @@ func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bs
// Insert 插入多条记录 // Insert 插入多条记录
func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -655,7 +597,7 @@ func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collectio
// Count 查询总数 // Count 查询总数
func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) { func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) {
user, err := getTenantInfo(ctx) user, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return return
} }

View File

@@ -60,6 +60,8 @@ func GetMonthToday(t time.Time, month int) time.Time {
return target.AddDate(0, 0, t.Day()-1) return target.AddDate(0, 0, t.Day()-1)
} }
func GetUserInfo(ctx context.Context) (user do.User, err error) { func GetUserInfo(ctx context.Context) (user do.User, err error) {
r := g.RequestFromCtx(ctx)
if r != nil {
redisAddr := g.Cfg().MustGet(ctx, "redis.default.address").String() redisAddr := g.Cfg().MustGet(ctx, "redis.default.address").String()
gft := gftoken.NewGfToken( gft := gftoken.NewGfToken(
gftoken.WithCacheKey("gfToken:"), gftoken.WithCacheKey("gfToken:"),
@@ -71,10 +73,6 @@ func GetUserInfo(ctx context.Context) (user do.User, err error) {
Address: redisAddr, Address: redisAddr,
Db: 1, Db: 1,
})) }))
r := g.RequestFromCtx(ctx)
if r == nil {
return
}
// 解析 token // 解析 token
data, err := gft.ParseToken(g.RequestFromCtx(ctx)) data, err := gft.ParseToken(g.RequestFromCtx(ctx))
if err != nil { if err != nil {
@@ -94,6 +92,12 @@ func GetUserInfo(ctx context.Context) (user do.User, err error) {
dataMap := gconv.Map(data.Data) dataMap := gconv.Map(data.Data)
user.UserName = dataMap["userName"] user.UserName = dataMap["userName"]
user.TenantId = dataMap["tenantId"] user.TenantId = dataMap["tenantId"]
} else {
user.TenantId = ctx.Value("tenantId")
}
if user.TenantId == nil {
return user, gerror.New("租户信息为空")
}
return return
} }
func OrderMap(m map[string]interface{}) map[string]interface{} { func OrderMap(m map[string]interface{}) map[string]interface{} {