package mongo import ( "context" "errors" "fmt" "strings" "sync" "time" "gitee.com/red-future---jilin-g/common/log/model/dto" "github.com/gogf/gf/v2/container/gvar" "gitee.com/red-future---jilin-g/common/redis" "gitee.com/red-future---jilin-g/common/utils" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) type MongoDB struct { Cache bool } func DB(cache ...bool) *MongoDB { b := true if len(cache) > 0 { b = cache[0] } return &MongoDB{ Cache: b, } } var ( db *mongo.Database client *mongo.Client isConnected bool mu sync.RWMutex mongoAddr string dbName string healthCtx context.Context healthCancel context.CancelFunc ) // checkConnected 检查连接状态 func checkConnected() bool { mu.RLock() defer mu.RUnlock() return isConnected } // connect 建立MongoDB连接 func connect() error { mu.Lock() defer mu.Unlock() if client != nil { client.Disconnect(context.Background()) } // 创建连接选项 opt := options.Client(). ApplyURI(mongoAddr). SetMaxPoolSize(100). SetMinPoolSize(10). SetMaxConnecting(10). SetConnectTimeout(10 * time.Second) var err error client, err = mongo.Connect(opt) if err != nil { isConnected = false glog.Error(context.Background(), "MongoDB连接失败", err) return err } // 测试连接 testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second) defer testCancel() err = client.Ping(testCtx, nil) if err != nil { isConnected = false glog.Error(testCtx, "MongoDB连接测试失败", err) return err } db = client.Database(dbName) isConnected = true glog.Info(context.Background(), "✅ MongoDB连接成功") return nil } // GetDB 获取 MongoDB 数据库实例 func GetDB() *mongo.Database { mu.RLock() defer mu.RUnlock() return db } // healthCheck 健康检查协程 func healthCheck() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-healthCtx.Done(): return case <-ticker.C: mu.RLock() currentConnected := isConnected currentClient := client mu.RUnlock() if !currentConnected || currentClient == nil { glog.Warning(context.Background(), "MongoDB连接断开,尝试重连") if err := reconnect(); err != nil { glog.Error(context.Background(), "MongoDB重连失败", err) } continue } // 测试连接状态 testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second) err := currentClient.Ping(testCtx, nil) testCancel() if err != nil { mu.Lock() isConnected = false mu.Unlock() glog.Warning(context.Background(), "MongoDB连接健康检查失败", err) // 尝试重连 if err := reconnect(); err != nil { glog.Error(context.Background(), "MongoDB重连失败", err) } } else { glog.Debug(context.Background(), "MongoDB连接健康检查通过") } } } } // reconnect 重连函数 func reconnect() error { maxRetries := 3 retryDelay := 2 * time.Second for i := 0; i < maxRetries; i++ { glog.Info(context.Background(), fmt.Sprintf("尝试第%d次重连MongoDB", i+1)) if err := connect(); err == nil { glog.Info(context.Background(), "MongoDB重连成功") return nil } if i < maxRetries-1 { time.Sleep(retryDelay) retryDelay *= 2 // 指数退避 } } return gerror.New("MongoDB重连失败,已达到最大重试次数") } // init 初始化MongoDB连接 func init() { // 按需初始化:没有配置 mongo.address 则跳过 mongoAddr = g.Cfg().MustGet(context.Background(), "mongo.address").String() if mongoAddr == "" { return } // 创建健康检查上下文 healthCtx, healthCancel = context.WithCancel(context.Background()) // 从连接串中解析数据库名 dbName = gstr.SubStr(mongoAddr, strings.LastIndex(mongoAddr, "/")+1, len(mongoAddr)) // 如果连接串带有参数(如 ?retryWrites=true),需要去掉参数部分 if strings.Contains(dbName, "?") { dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?")) } go func() { // 初始连接 if err := connect(); err != nil { glog.Error(context.Background(), "MongoDB初始连接失败", err) return } }() // 启动健康检查协程 go healthCheck() } // close 关闭MongoDB连接 func close() { if healthCancel != nil { healthCancel() } mu.Lock() defer mu.Unlock() if client != nil { disconnectCtx, disconnectCancel := context.WithTimeout(context.Background(), 5*time.Second) defer disconnectCancel() client.Disconnect(disconnectCtx) } isConnected = false glog.Info(context.Background(), "MongoDB连接已关闭") } func listOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOptions]) (m map[string]interface{}) { // 输出opts参数中的值 m = make(map[string]interface{}) for _, opt := range opts { var findOpts options.FindOptions optFuncs := opt.List() for _, fn := range optFuncs { fn(&findOpts) } if findOpts.Limit != nil { m["limit"] = *findOpts.Limit } if findOpts.Skip != nil { m["skip"] = *findOpts.Skip } if findOpts.Sort != nil { m["sort"] = findOpts.Sort } if findOpts.Projection != nil { m["projection"] = findOpts.Projection } } m = utils.OrderMap(m) return } func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOneOptions]) (m map[string]interface{}) { // 输出opts参数中的值 m = make(map[string]interface{}) for _, opt := range opts { var findOpts options.FindOneOptions optFuncs := opt.List() for _, fn := range optFuncs { fn(&findOpts) } if findOpts.Skip != nil { m["skip"] = *findOpts.Skip } if findOpts.Sort != nil { m["sort"] = findOpts.Sort } if findOpts.Projection != nil { m["projection"] = findOpts.Projection } } m = utils.OrderMap(m) return } // Find 查询多条记录 func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { if err = utils.ValidStructPtr(result); err != nil { return } user, err := utils.GetUserInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) optsMap := listOptionsToMap(ctx, opts...) redisKey := fmt.Sprintf(redis.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap)) if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { err = gconv.Scan(resultStr, result) if err != nil { return err } return } } filter["tenantId"] = user.TenantId cur, err := db.Collection(collection).Find(ctx, filter, opts...) if err != nil { return } defer cur.Close(ctx) if err = cur.All(ctx, result); err != nil { return } if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err } } return } // FindOne 查询1条记录 func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } if err = utils.ValidStructPtr(result); err != nil { return } user, err := utils.GetUserInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) redisKey := fmt.Sprintf(redis.One, user.TenantId, collection, gconv.String(filterMap)) if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { err = gconv.Scan(resultStr, result) if err != nil { return err } return } } if !g.IsEmpty(user.TenantId) { filter["tenantId"] = user.TenantId } cur := db.Collection(collection).FindOne(ctx, filter, opts...) err = cur.Decode(result) if errors.Is(err, mongo.ErrNoDocuments) { err = nil } if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err } } return } func (m *MongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) { listKeys := fmt.Sprintf(redis.CleanList, tenantId, collection) keys, err := redis.RedisClient.Keys(ctx, listKeys) if err != nil { return } for _, key := range keys { _, err = redis.RedisClient.Del(ctx, key) if err != nil { return } } countKeys := fmt.Sprintf(redis.CleanCount, tenantId, collection) keys, err = redis.RedisClient.Keys(ctx, countKeys) if err != nil { return } for _, key := range keys { _, err = redis.RedisClient.Del(ctx, key) if err != nil { return } } filter["isDeleted"] = false delete(filter, "tenantId") filterMap := utils.OrderMap(filter) oneKey := fmt.Sprintf(redis.One, tenantId, collection, gconv.String(filterMap)) _, err = redis.RedisClient.Del(ctx, oneKey) if err != nil { return } return } // Delete 删除记录 func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } user, err := utils.GetUserInfo(ctx) if err != nil { return } filter["tenantId"] = user.TenantId var rows []interface{} if err = m.Find(ctx, filter, &rows, collection); err != nil { return } r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...) if err != nil { return } count = r.DeletedCount err = m.CleanRedis(ctx, filter, user.TenantId, collection) serverName := g.Cfg().MustGet(ctx, "server.name").String() logRedisKey := fmt.Sprintf("log:%s", serverName) if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{ ServiceName: serverName, Collection: collection, Data: rows, }); err != nil { glog.Error(ctx, "mongoLog-AddToStream err: %v", err) } return } // Update 修改记录 func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } filter["isDeleted"] = false user, err := utils.GetUserInfo(ctx) if err != nil { return } if !g.IsEmpty(user.TenantId) { filter["tenantId"] = user.TenantId } setDoc := update["$set"].(bson.M) if !g.IsEmpty(user.UserName) { setDoc["updater"] = user.UserName } setDoc["updatedAt"] = gtime.Now().Time update = bson.M{"$set": setDoc} var rows []interface{} if err = m.Find(ctx, filter, &rows, collection); err != nil { return } result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...) if err != nil { return } err = m.CleanRedis(ctx, filter, user.TenantId, collection) serverName := g.Cfg().MustGet(ctx, "server.name").String() logRedisKey := fmt.Sprintf("log:%s", serverName) if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{ ServiceName: serverName, Collection: collection, Data: rows, }); err != nil { glog.Error(ctx, "mongoLog-AddToStream err: %v", err) } return } // RandomSoftDelete 随机软删除个文档的 _id func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { // 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id pipeline := mongo.Pipeline{ // 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random' bson.D{{Key: "$addFields", Value: bson.D{{Key: "random", Value: bson.M{"$rand": bson.M{}}}}}}, // 阶段1: 匹配所有未删除的文档 bson.D{{Key: "$match", Value: bson.D{{Key: "isDeleted", Value: false}}}}, // 阶段2: 按随机数降序排序 bson.D{{Key: "$sort", Value: bson.D{{Key: "random", Value: -1}}}}, // 阶段3: 只取前5个 bson.D{{Key: "$limit", Value: limit}}, // 阶段4: 只投影 _id bson.D{{Key: "$project", Value: bson.D{{Key: "_id", Value: 1}}}}, } cursor, err := db.Collection(collection).Aggregate(ctx, pipeline) if err != nil { return } defer cursor.Close(ctx) // 步骤 2: 从聚合结果中提取 _id 到一个切片中 var idsToUpdate []bson.ObjectID for cursor.Next(ctx) { var result bson.M if err := cursor.Decode(&result); err != nil { return nil, err } // 将 bson.M 中的 _id 断言为 primitive.ObjectID id := result["_id"].(bson.ObjectID) idsToUpdate = append(idsToUpdate, id) } if err := cursor.Err(); err != nil { return nil, err } fmt.Printf("准备更新的随机文档ID: %v\n", idsToUpdate) // 步骤 3: 使用 $in 操作符和 UpdateMany 批量更新选定的文档 if len(idsToUpdate) > 0 { // 过滤条件:匹配 idsToUpdate 切片中的任意一个 _id filter := bson.D{{Key: "_id", Value: bson.D{{Key: "$in", Value: idsToUpdate}}}} // 更新操作:使用 $set 修改字段 update := bson.D{{Key: "$set", Value: bson.D{{Key: "isDeleted", Value: true}}}} _, err = db.Collection(collection).UpdateMany(ctx, filter, update) if err != nil { return } } return } // SaveOrUpdate 批量增加或修改 func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) { if len(filter) == 0 || len(update) == 0 { err = gerror.New("缺少查询条件或更新数据") return } if len(filter) != len(update) { err = gerror.New("查询条件和更新数据的数量必须一致") return } user, err := utils.GetUserInfo(ctx) if err != nil { return } // 构建批量操作模型 var models []mongo.WriteModel for i := 0; i < len(filter); i++ { // 处理过滤器 filter[i]["isDeleted"] = false if !g.IsEmpty(user.TenantId) { filter[i]["tenantId"] = user.TenantId } // 处理更新数据 if setDoc, exists := update[i]["$set"].(bson.M); exists { if !g.IsEmpty(user.UserName) { setDoc["updater"] = user.UserName } setDoc["updatedAt"] = gtime.Now().Time } else { // 如果没有$set字段,则创建一个 setDoc := bson.M{} if !g.IsEmpty(user.UserName) { setDoc["updater"] = user.UserName } setDoc["updatedAt"] = gtime.Now().Time update[i]["$set"] = setDoc } // 创建更新操作模型 updateModel := mongo.NewUpdateOneModel() updateModel.SetFilter(filter[i]) updateModel.SetUpdate(update[i]) updateModel.SetUpsert(true) // 默认不插入新文档 // 处理选项参数 if len(opts) > 0 { for _, opt := range opts { var updateOpts options.UpdateManyOptions optFuncs := opt.List() for _, fn := range optFuncs { fn(&updateOpts) } if updateOpts.Upsert != nil { updateModel.SetUpsert(*updateOpts.Upsert) } } } models = append(models, updateModel) } // 执行批量操作,无序执行提高性能 bulkOpts := options.BulkWrite().SetOrdered(false) bulkResult, err := db.Collection(collection).BulkWrite(ctx, models, bulkOpts) if err != nil { return nil, err } // 清理相关缓存 for _, filterItem := range filter { err = m.CleanRedis(ctx, filterItem, user.TenantId, collection) if err != nil { glog.Warning(ctx, "清理Redis缓存失败:", err) } } return bulkResult, nil } // Insert 插入多条记录 func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { user, err := utils.GetUserInfo(ctx) if err != nil { return } docs := make([]interface{}, 0, len(documents)) for _, document := range documents { doc := gconv.Map(document) delete(doc, "id") if !g.IsEmpty(user.UserName) { doc["creator"] = user.UserName } if !g.IsEmpty(user.UserName) { doc["updater"] = user.UserName } if !g.IsEmpty(user.TenantId) { doc["tenantId"] = user.TenantId } doc["createdAt"] = gtime.Now().Time doc["updatedAt"] = gtime.Now().Time doc["isDeleted"] = false docs = append(docs, doc) } r, err := db.Collection(collection).InsertMany(ctx, docs, opts...) if err != nil { return } ids = r.InsertedIDs err = m.CleanRedis(ctx, bson.M{}, user.TenantId, collection) //写日志 serverName := g.Cfg().MustGet(ctx, "server.name").String() logRedisKey := fmt.Sprintf("log:%s", serverName) if len(ids) == 0 { return } rows := make([]interface{}, 0, len(ids)) if len(ids) == 1 { doc := gconv.Map(documents[0]) doc["id"] = ids[0] rows = append(rows, doc) } else { filter := bson.M{"_id": bson.M{"$in": ids}} if err = m.Find(ctx, filter, &rows, collection); err != nil { return } } if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{ ServiceName: serverName, Collection: collection, Data: rows, }); err != nil { glog.Error(ctx, "mongoLog-AddToStream err: %v", err) } return } // Count 查询总数 func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) { user, err := utils.GetUserInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) redisKey := fmt.Sprintf(redis.Count, user.TenantId, collection, gconv.String(filterMap)) if m.Cache { var resultStr *gvar.Var resultStr, err = redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { count = gconv.Int64(resultStr) return } } // 调用驱动的 CountDocuments,在数据库端执行的 count, err = db.Collection(collection).CountDocuments(ctx, filter) if m.Cache { err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour)) if err != nil { return } } return } // EntityToBSONM 将 *entity/entity 转换为 bson.M // 支持传入值类型或指针类型,返回 bson.M 和错误信息 func EntityToBSONM(entity interface{}) (bson.M, error) { // 第一步:判断入参是否为 nil 或无效类型 if entity == nil { return nil, fmt.Errorf("传入的 entity 实例为 nil") } // 第二步:将 entity 序列化为 BSON 字节流 // bson.Marshal 支持值类型和指针类型,会自动解析结构体的 bson 标签 bsonBytes, err := bson.Marshal(entity) if err != nil { return nil, fmt.Errorf("entity 序列化为 BSON 字节流失败:%w", err) } // 第三步:将 BSON 字节流反序列化为 bson.M var bsonMap bson.M err = bson.Unmarshal(bsonBytes, &bsonMap) if err != nil { return nil, fmt.Errorf("BSON 字节流反序列化为 bson.M 失败:%w", err) } return bsonMap, nil }