diff --git a/mongo/mongo.go b/mongo/mongo.go index 711b0c0..4c3855f 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/gogf/gf/v2/container/gvar" "strings" "sync" "time" + "github.com/gogf/gf/v2/container/gvar" + "gitee.com/red-future---jilin-g/common/consts" "gitee.com/red-future---jilin-g/common/do" "gitee.com/red-future---jilin-g/common/redis" @@ -24,6 +25,16 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" ) +type MongoDB struct { + Cache bool +} + +func DB(cache bool) *MongoDB { + return &MongoDB{ + Cache: cache, + } +} + var ( db *mongo.Database client *mongo.Client @@ -35,15 +46,8 @@ var ( healthCancel context.CancelFunc ) -// GetDB 获取 MongoDB 数据库实例 -func GetDB() *mongo.Database { - mu.RLock() - defer mu.RUnlock() - return db -} - -// IsConnected 检查连接状态 -func IsConnected() bool { +// checkConnected 检查连接状态 +func checkConnected() bool { mu.RLock() defer mu.RUnlock() return isConnected @@ -187,8 +191,8 @@ func init() { go healthCheck() } -// Close 关闭MongoDB连接 -func Close() { +// close 关闭MongoDB连接 +func close() { if healthCancel != nil { healthCancel() } @@ -253,9 +257,9 @@ func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOne return } -// GetTenantInfo 获取租户信息 +// getTenantInfo 获取租户信息 // 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表 -func GetTenantInfo(ctx context.Context) (user do.User, err error) { +func getTenantInfo(ctx context.Context) (user do.User, err error) { // 1. 优先从 token 获取 user, err = utils.GetUserInfo(ctx) if err == nil { @@ -318,20 +322,19 @@ func GetTenantInfo(ctx context.Context) (user do.User, err error) { } // Find 查询多条记录 -func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { +func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { if err = utils.ValidStructPtr(result); err != nil { return } - user, err := GetTenantInfo(ctx) + user, err := getTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) optsMap := listOptionsToMap(ctx, opts...) - redisKey := "" - if !NoCache { - redisKey = fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap)) + redisKey := fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap)) + if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { @@ -350,8 +353,11 @@ func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{}, if err != nil { return } - err = cur.All(ctx, result) - if !NoCache { + defer cur.Close(ctx) + if err = cur.All(ctx, result); err != nil { + return + } + if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err @@ -361,7 +367,7 @@ func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{}, } // FindOne 查询1条记录 -func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) { +func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return @@ -369,15 +375,14 @@ func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{ if err = utils.ValidStructPtr(result); err != nil { return } - user, err := GetTenantInfo(ctx) + user, err := getTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) - redisKey := "" - if !NoCache { - redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap)) + redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap)) + if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { @@ -399,7 +404,7 @@ func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{ if errors.Is(err, mongo.ErrNoDocuments) { err = nil } - if !NoCache { + if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err @@ -442,12 +447,12 @@ func cleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collec } // Delete 删除记录 -func Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) { +func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } - user, err := GetTenantInfo(ctx) + user, err := getTenantInfo(ctx) if err != nil { return } @@ -462,13 +467,13 @@ func Delete(ctx context.Context, filter bson.M, collection string, opts ...optio } // Update 修改记录 -func Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { +func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } filter["isDeleted"] = false - user, err := GetTenantInfo(ctx) + user, err := getTenantInfo(ctx) if err != nil { return } @@ -490,7 +495,7 @@ func Update(ctx context.Context, filter bson.M, update bson.M, collection string } // RandomSoftDelete 随机软删除个文档的 _id -func RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { +func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { // 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id pipeline := mongo.Pipeline{ // 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random' @@ -539,7 +544,7 @@ func RandomSoftDelete(ctx context.Context, limit int, collection string, opts .. } // SaveOrUpdate 批量增加或修改 -func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) { +func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) { if len(filter) == 0 || len(update) == 0 { err = gerror.New("缺少查询条件或更新数据") return @@ -548,7 +553,7 @@ func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collect err = gerror.New("查询条件和更新数据的数量必须一致") return } - user, err := GetTenantInfo(ctx) + user, err := getTenantInfo(ctx) if err != nil { return } @@ -612,8 +617,8 @@ func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collect } // Insert 插入多条记录 -func Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { - user, err := GetTenantInfo(ctx) +func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { + user, err := getTenantInfo(ctx) if err != nil { return } @@ -645,16 +650,15 @@ func Insert(ctx context.Context, documents []interface{}, collection string, opt } // Count 查询总数 -func Count(ctx context.Context, NoCache bool, filter bson.M, collection string) (count int64, err error) { - user, err := GetTenantInfo(ctx) +func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) { + user, err := getTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) - redisKey := "" - if !NoCache { - redisKey = fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap)) + redisKey := fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap)) + if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { @@ -667,7 +671,7 @@ func Count(ctx context.Context, NoCache bool, filter bson.M, collection string) } // 调用驱动的 CountDocuments,在数据库端执行的 count, err = db.Collection(collection).CountDocuments(ctx, filter) - if !NoCache { + if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour)) if err != nil { return