mongo.go重构

This commit is contained in:
2025-12-30 09:29:36 +08:00
parent 836460306e
commit 55e8a829de

View File

@@ -4,11 +4,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/gogf/gf/v2/container/gvar"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gogf/gf/v2/container/gvar"
"gitee.com/red-future---jilin-g/common/consts" "gitee.com/red-future---jilin-g/common/consts"
"gitee.com/red-future---jilin-g/common/do" "gitee.com/red-future---jilin-g/common/do"
"gitee.com/red-future---jilin-g/common/redis" "gitee.com/red-future---jilin-g/common/redis"
@@ -24,6 +25,16 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/options"
) )
type MongoDB struct {
Cache bool
}
func DB(cache bool) *MongoDB {
return &MongoDB{
Cache: cache,
}
}
var ( var (
db *mongo.Database db *mongo.Database
client *mongo.Client client *mongo.Client
@@ -35,15 +46,8 @@ var (
healthCancel context.CancelFunc healthCancel context.CancelFunc
) )
// GetDB 获取 MongoDB 数据库实例 // checkConnected 检查连接状态
func GetDB() *mongo.Database { func checkConnected() bool {
mu.RLock()
defer mu.RUnlock()
return db
}
// IsConnected 检查连接状态
func IsConnected() bool {
mu.RLock() mu.RLock()
defer mu.RUnlock() defer mu.RUnlock()
return isConnected return isConnected
@@ -187,8 +191,8 @@ func init() {
go healthCheck() go healthCheck()
} }
// Close 关闭MongoDB连接 // close 关闭MongoDB连接
func Close() { func close() {
if healthCancel != nil { if healthCancel != nil {
healthCancel() healthCancel()
} }
@@ -253,9 +257,9 @@ func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOne
return return
} }
// GetTenantInfo 获取租户信息 // getTenantInfo 获取租户信息
// 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表 // 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表
func GetTenantInfo(ctx context.Context) (user do.User, err error) { func getTenantInfo(ctx context.Context) (user do.User, err error) {
// 1. 优先从 token 获取 // 1. 优先从 token 获取
user, err = utils.GetUserInfo(ctx) user, err = utils.GetUserInfo(ctx)
if err == nil { if err == nil {
@@ -318,20 +322,19 @@ func GetTenantInfo(ctx context.Context) (user do.User, err error) {
} }
// Find 查询多条记录 // Find 查询多条记录
func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) {
if err = utils.ValidStructPtr(result); err != nil { if err = utils.ValidStructPtr(result); err != nil {
return return
} }
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
filter["isDeleted"] = false filter["isDeleted"] = false
filterMap := utils.OrderMap(filter) filterMap := utils.OrderMap(filter)
optsMap := listOptionsToMap(ctx, opts...) optsMap := listOptionsToMap(ctx, opts...)
redisKey := "" redisKey := fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap))
if !NoCache { if m.Cache {
redisKey = fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap))
var resultStr *gvar.Var var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey) resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil { if err != nil {
@@ -350,8 +353,11 @@ func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{},
if err != nil { if err != nil {
return return
} }
err = cur.All(ctx, result) defer cur.Close(ctx)
if !NoCache { if err = cur.All(ctx, result); err != nil {
return
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil { if err != nil {
return err return err
@@ -361,7 +367,7 @@ func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{},
} }
// FindOne 查询1条记录 // FindOne 查询1条记录
func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) { func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) {
if len(filter) == 0 { if len(filter) == 0 {
err = gerror.New("缺少查询条件") err = gerror.New("缺少查询条件")
return return
@@ -369,15 +375,14 @@ func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{
if err = utils.ValidStructPtr(result); err != nil { if err = utils.ValidStructPtr(result); err != nil {
return return
} }
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
filter["isDeleted"] = false filter["isDeleted"] = false
filterMap := utils.OrderMap(filter) filterMap := utils.OrderMap(filter)
redisKey := "" redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap))
if !NoCache { if m.Cache {
redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap))
var resultStr *gvar.Var var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey) resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil { if err != nil {
@@ -399,7 +404,7 @@ func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{
if errors.Is(err, mongo.ErrNoDocuments) { if errors.Is(err, mongo.ErrNoDocuments) {
err = nil err = nil
} }
if !NoCache { if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil { if err != nil {
return err return err
@@ -442,12 +447,12 @@ func cleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collec
} }
// Delete 删除记录 // Delete 删除记录
func Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) { func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) {
if len(filter) == 0 { if len(filter) == 0 {
err = gerror.New("缺少查询条件") err = gerror.New("缺少查询条件")
return return
} }
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -462,13 +467,13 @@ func Delete(ctx context.Context, filter bson.M, collection string, opts ...optio
} }
// Update 修改记录 // Update 修改记录
func Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
if len(filter) == 0 { if len(filter) == 0 {
err = gerror.New("缺少查询条件") err = gerror.New("缺少查询条件")
return return
} }
filter["isDeleted"] = false filter["isDeleted"] = false
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -490,7 +495,7 @@ func Update(ctx context.Context, filter bson.M, update bson.M, collection string
} }
// RandomSoftDelete 随机软删除个文档的 _id // RandomSoftDelete 随机软删除个文档的 _id
func RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
// 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id // 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id
pipeline := mongo.Pipeline{ pipeline := mongo.Pipeline{
// 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random' // 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random'
@@ -539,7 +544,7 @@ func RandomSoftDelete(ctx context.Context, limit int, collection string, opts ..
} }
// SaveOrUpdate 批量增加或修改 // SaveOrUpdate 批量增加或修改
func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) { func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) {
if len(filter) == 0 || len(update) == 0 { if len(filter) == 0 || len(update) == 0 {
err = gerror.New("缺少查询条件或更新数据") err = gerror.New("缺少查询条件或更新数据")
return return
@@ -548,7 +553,7 @@ func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collect
err = gerror.New("查询条件和更新数据的数量必须一致") err = gerror.New("查询条件和更新数据的数量必须一致")
return return
} }
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -612,8 +617,8 @@ func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collect
} }
// Insert 插入多条记录 // Insert 插入多条记录
func Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
@@ -645,16 +650,15 @@ func Insert(ctx context.Context, documents []interface{}, collection string, opt
} }
// Count 查询总数 // Count 查询总数
func Count(ctx context.Context, NoCache bool, filter bson.M, collection string) (count int64, err error) { func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) {
user, err := GetTenantInfo(ctx) user, err := getTenantInfo(ctx)
if err != nil { if err != nil {
return return
} }
filter["isDeleted"] = false filter["isDeleted"] = false
filterMap := utils.OrderMap(filter) filterMap := utils.OrderMap(filter)
redisKey := "" redisKey := fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap))
if !NoCache { if m.Cache {
redisKey = fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap))
var resultStr *gvar.Var var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey) resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil { if err != nil {
@@ -667,7 +671,7 @@ func Count(ctx context.Context, NoCache bool, filter bson.M, collection string)
} }
// 调用驱动的 CountDocuments在数据库端执行的 // 调用驱动的 CountDocuments在数据库端执行的
count, err = db.Collection(collection).CountDocuments(ctx, filter) count, err = db.Collection(collection).CountDocuments(ctx, filter)
if !NoCache { if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour)) err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil { if err != nil {
return return