Files
common/mongo/mongo.go

701 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mongo
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"gitee.com/red-future---jilin-g/common/beans"
"gitee.com/red-future---jilin-g/common/log/model/dto"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/os/grpool"
"gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
type MongoDB struct {
Cache bool
}
func DB(cache ...bool) *MongoDB {
b := true
if len(cache) > 0 {
b = cache[0]
}
return &MongoDB{
Cache: b,
}
}
var (
db *mongo.Database
client *mongo.Client
isConnected bool
mu sync.RWMutex
mongoAddr string
dbName string
healthCtx context.Context
healthCancel context.CancelFunc
)
// checkConnected 检查连接状态
func checkConnected() bool {
mu.RLock()
defer mu.RUnlock()
return isConnected
}
// connect 建立MongoDB连接
func connect() error {
mu.Lock()
defer mu.Unlock()
if client != nil {
client.Disconnect(context.Background())
}
// 创建连接选项
opt := options.Client().
ApplyURI(mongoAddr).
SetMaxPoolSize(100).
SetMinPoolSize(10).
SetMaxConnecting(10).
SetConnectTimeout(10 * time.Second)
var err error
client, err = mongo.Connect(opt)
if err != nil {
isConnected = false
glog.Error(context.Background(), "MongoDB连接失败", err)
return err
}
// 测试连接
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer testCancel()
err = client.Ping(testCtx, nil)
if err != nil {
isConnected = false
glog.Error(testCtx, "MongoDB连接测试失败", err)
return err
}
db = client.Database(dbName)
isConnected = true
glog.Info(context.Background(), "✅ MongoDB连接成功")
return nil
}
// GetDB 获取 MongoDB 数据库实例
func GetDB() *mongo.Database {
mu.RLock()
defer mu.RUnlock()
return db
}
// healthCheck 健康检查协程
func healthCheck() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-healthCtx.Done():
return
case <-ticker.C:
mu.RLock()
currentConnected := isConnected
currentClient := client
mu.RUnlock()
if !currentConnected || currentClient == nil {
glog.Warning(context.Background(), "MongoDB连接断开尝试重连")
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
continue
}
// 测试连接状态
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
err := currentClient.Ping(testCtx, nil)
testCancel()
if err != nil {
mu.Lock()
isConnected = false
mu.Unlock()
glog.Warning(context.Background(), "MongoDB连接健康检查失败", err)
// 尝试重连
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
} else {
glog.Debug(context.Background(), "MongoDB连接健康检查通过")
}
}
}
}
// reconnect 重连函数
func reconnect() error {
maxRetries := 3
retryDelay := 2 * time.Second
for i := 0; i < maxRetries; i++ {
glog.Info(context.Background(), fmt.Sprintf("尝试第%d次重连MongoDB", i+1))
if err := connect(); err == nil {
glog.Info(context.Background(), "MongoDB重连成功")
return nil
}
if i < maxRetries-1 {
time.Sleep(retryDelay)
retryDelay *= 2 // 指数退避
}
}
return gerror.New("MongoDB重连失败已达到最大重试次数")
}
var logPool *grpool.Pool
// init 初始化MongoDB连接
func init() {
logPool = grpool.New(10)
// 按需初始化:没有配置 mongo.address 则跳过
mongoAddr = g.Cfg().MustGet(context.Background(), "mongo.address").String()
if mongoAddr == "" {
return
}
// 创建健康检查上下文
healthCtx, healthCancel = context.WithCancel(context.Background())
// 从连接串中解析数据库名
dbName = gstr.SubStr(mongoAddr, strings.LastIndex(mongoAddr, "/")+1, len(mongoAddr))
// 如果连接串带有参数(如 ?retryWrites=true需要去掉参数部分
if strings.Contains(dbName, "?") {
dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?"))
}
go func() {
// 初始连接
if err := connect(); err != nil {
glog.Error(context.Background(), "MongoDB初始连接失败", err)
return
}
}()
// 启动健康检查协程
go healthCheck()
}
// close 关闭MongoDB连接
func close() {
if healthCancel != nil {
healthCancel()
}
mu.Lock()
defer mu.Unlock()
if client != nil {
disconnectCtx, disconnectCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer disconnectCancel()
client.Disconnect(disconnectCtx)
}
isConnected = false
glog.Info(context.Background(), "MongoDB连接已关闭")
}
const PageSize = 20
// Find 查询多条记录
func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, page *beans.Page, orderBy []beans.OrderBy) (total int64, err error) {
//if err = utils.ValidStructPtr(result); err != nil {
// return
//}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
optionsKey := fmt.Sprintf("%+v%+v", page, orderBy)
redisKey := fmt.Sprintf(redis.List, user.TenantId, collection, filterKey, optionsKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !resultStr.IsEmpty() {
if err = resultStr.Structs(result); err != nil {
return
}
total = int64(len(resultStr.Array()))
return
}
}
filter["tenantId"] = user.TenantId
// 分页参数处理
limit := int64(PageSize)
skip := int64(0)
if page != nil {
limit = page.PageSize
if limit == -1 {
skip = 0
} else {
skip = (page.PageNum - 1) * limit
}
}
opt := options.Find().SetSkip(skip)
if limit != -1 {
opt.SetLimit(limit)
} else {
total, err = m.Count(ctx, filter, collection)
if err != nil || total == 0 {
return
}
}
if orderBy == nil {
opt.SetSort(bson.M{"createdAt": -1})
} else {
orderBson := bson.M{}
for _, v := range orderBy {
if v.Order == beans.Asc {
orderBson[v.Field] = 1
} else {
orderBson[v.Field] = -1
}
}
opt.SetSort(orderBson)
}
cur, err := db.Collection(collection).Find(ctx, filter, opt)
if err != nil {
return
}
if limit == -1 {
total = int64(cur.RemainingBatchLength())
}
defer cur.Close(ctx)
if err = cur.All(ctx, result); err != nil {
return
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return
}
}
return
}
// FindOne 查询1条记录
func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.One, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
err = gconv.Scan(resultStr, result)
if err != nil {
return err
}
return
}
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
cur := db.Collection(collection).FindOne(ctx, filter, opts...)
err = cur.Decode(result)
if errors.Is(err, mongo.ErrNoDocuments) {
err = nil
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
}
return
}
func (m *MongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) {
listKeys := fmt.Sprintf(redis.CleanList, tenantId, collection)
keys, err := redis.RedisClient.Keys(ctx, listKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
countKeys := fmt.Sprintf(redis.CleanCount, tenantId, collection)
keys, err = redis.RedisClient.Keys(ctx, countKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
oneKey := fmt.Sprintf(redis.One, tenantId, collection, filterKey)
_, err = redis.RedisClient.Del(ctx, oneKey)
if err != nil {
return
}
return
}
// Delete 删除记录
func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["tenantId"] = user.TenantId
r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...)
if err != nil {
return
}
count = r.DeletedCount
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
//写日志
var rows []interface{}
if _, err = m.Find(ctx, filter, &rows, collection, nil, nil); err != nil {
return
}
serverName := g.Cfg().MustGet(ctx, "server.name").String()
logRedisKey := fmt.Sprintf("log:%s", serverName)
if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{
ServiceName: serverName,
Collection: collection,
Data: rows,
}); err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
return
}
// Update 修改记录
func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
filter["isDeleted"] = false
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
setDoc := update["$set"].(bson.M)
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update = bson.M{"$set": setDoc}
result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...)
if err != nil {
return
}
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
//写日志
serverName := g.Cfg().MustGet(ctx, "server.name").String()
logRedisKey := fmt.Sprintf("log:%s", serverName)
var rows []interface{}
if _, err = m.Find(ctx, filter, &rows, collection, nil, nil); err != nil {
return
}
if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{
ServiceName: serverName,
Collection: collection,
Data: rows,
}); err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
return
}
// RandomSoftDelete 随机软删除个文档的 _id
func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
// 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id
pipeline := mongo.Pipeline{
// 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random'
bson.D{{Key: "$addFields", Value: bson.D{{Key: "random", Value: bson.M{"$rand": bson.M{}}}}}},
// 阶段1: 匹配所有未删除的文档
bson.D{{Key: "$match", Value: bson.D{{Key: "isDeleted", Value: false}}}},
// 阶段2: 按随机数降序排序
bson.D{{Key: "$sort", Value: bson.D{{Key: "random", Value: -1}}}},
// 阶段3: 只取前5个
bson.D{{Key: "$limit", Value: limit}},
// 阶段4: 只投影 _id
bson.D{{Key: "$project", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := db.Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return
}
defer cursor.Close(ctx)
// 步骤 2: 从聚合结果中提取 _id 到一个切片中
var idsToUpdate []bson.ObjectID
for cursor.Next(ctx) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
return nil, err
}
// 将 bson.M 中的 _id 断言为 primitive.ObjectID
id := result["_id"].(bson.ObjectID)
idsToUpdate = append(idsToUpdate, id)
}
if err := cursor.Err(); err != nil {
return nil, err
}
fmt.Printf("准备更新的随机文档ID: %v\n", idsToUpdate)
// 步骤 3: 使用 $in 操作符和 UpdateMany 批量更新选定的文档
if len(idsToUpdate) > 0 {
// 过滤条件:匹配 idsToUpdate 切片中的任意一个 _id
filter := bson.D{{Key: "_id", Value: bson.D{{Key: "$in", Value: idsToUpdate}}}}
// 更新操作:使用 $set 修改字段
update := bson.D{{Key: "$set", Value: bson.D{{Key: "isDeleted", Value: true}}}}
_, err = db.Collection(collection).UpdateMany(ctx, filter, update)
if err != nil {
return
}
}
return
}
// SaveOrUpdate 批量增加或修改
func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) {
if len(filter) == 0 || len(update) == 0 {
err = gerror.New("缺少查询条件或更新数据")
return
}
if len(filter) != len(update) {
err = gerror.New("查询条件和更新数据的数量必须一致")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
// 构建批量操作模型
var models []mongo.WriteModel
for i := 0; i < len(filter); i++ {
// 处理过滤器
filter[i]["isDeleted"] = false
if !g.IsEmpty(user.TenantId) {
filter[i]["tenantId"] = user.TenantId
}
// 处理更新数据
if setDoc, exists := update[i]["$set"].(bson.M); exists {
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
} else {
// 如果没有$set字段则创建一个
setDoc := bson.M{}
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update[i]["$set"] = setDoc
}
// 创建更新操作模型
updateModel := mongo.NewUpdateOneModel()
updateModel.SetFilter(filter[i])
updateModel.SetUpdate(update[i])
updateModel.SetUpsert(true) // 默认不插入新文档
// 处理选项参数
if len(opts) > 0 {
for _, opt := range opts {
var updateOpts options.UpdateManyOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&updateOpts)
}
if updateOpts.Upsert != nil {
updateModel.SetUpsert(*updateOpts.Upsert)
}
}
}
models = append(models, updateModel)
}
// 执行批量操作,无序执行提高性能
bulkOpts := options.BulkWrite().SetOrdered(false)
bulkResult, err := db.Collection(collection).BulkWrite(ctx, models, bulkOpts)
if err != nil {
return nil, err
}
// 清理相关缓存
for _, filterItem := range filter {
err = m.CleanRedis(ctx, filterItem, user.TenantId, collection)
if err != nil {
glog.Warning(ctx, "清理Redis缓存失败:", err)
}
}
return bulkResult, nil
}
// Insert 插入多条记录
func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
docs := make([]interface{}, 0, len(documents))
for _, document := range documents {
doc := gconv.Map(document)
delete(doc, "id")
if !g.IsEmpty(user.UserName) {
doc["creator"] = user.UserName
}
if !g.IsEmpty(user.UserName) {
doc["updater"] = user.UserName
}
if !g.IsEmpty(user.TenantId) {
doc["tenantId"] = user.TenantId
}
doc["createdAt"] = gtime.Now().Time
doc["updatedAt"] = gtime.Now().Time
doc["isDeleted"] = false
docs = append(docs, doc)
}
r, err := db.Collection(collection).InsertMany(ctx, docs, opts...)
if err != nil {
return
}
ids = r.InsertedIDs
err = m.CleanRedis(ctx, bson.M{}, user.TenantId, collection)
//写日志
serverName := g.Cfg().MustGet(ctx, "server.name").String()
logRedisKey := fmt.Sprintf("log:%s", serverName)
if len(ids) == 0 {
return
}
rows := make([]interface{}, 0, len(ids))
if len(ids) == 1 {
doc := gconv.Map(documents[0])
doc["id"] = ids[0]
rows = append(rows, doc)
} else {
filter := bson.M{"_id": bson.M{"$in": ids}}
if _, err = m.Find(ctx, filter, &rows, collection, nil, nil); err != nil {
return
}
}
if _, err = redis.AddToStream(ctx, logRedisKey, &dto.RecordCreateLogReq{
ServiceName: serverName,
Collection: collection,
Data: rows,
}); err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
return
}
// Count 查询总数
func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.Count, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
count = gconv.Int64(resultStr)
return
}
}
// 调用驱动的 CountDocuments在数据库端执行的
count, err = db.Collection(collection).CountDocuments(ctx, filter)
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil {
return
}
}
return
}
// EntityToBSONM 将 *entity/entity 转换为 bson.M
// 支持传入值类型或指针类型,返回 bson.M 和错误信息
func EntityToBSONM(entity interface{}) (bson.M, error) {
// 第一步:判断入参是否为 nil 或无效类型
if entity == nil {
return nil, fmt.Errorf("传入的 entity 实例为 nil")
}
// 第二步:将 entity 序列化为 BSON 字节流
// bson.Marshal 支持值类型和指针类型,会自动解析结构体的 bson 标签
bsonBytes, err := bson.Marshal(entity)
if err != nil {
return nil, fmt.Errorf("entity 序列化为 BSON 字节流失败:%w", err)
}
// 第三步:将 BSON 字节流反序列化为 bson.M
var bsonMap bson.M
err = bson.Unmarshal(bsonBytes, &bsonMap)
if err != nil {
return nil, fmt.Errorf("BSON 字节流反序列化为 bson.M 失败:%w", err)
}
return bsonMap, nil
}