Files
common/mongo/mongo.go

735 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mongo
import (
"context"
"errors"
"fmt"
"reflect"
"strings"
"sync"
"time"
"gitee.com/red-future---jilin-g/common/beans"
"gitee.com/red-future---jilin-g/common/log/model/entity"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/os/grpool"
"gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
type MongoDB struct {
Cache bool
}
func DB(cache ...bool) *MongoDB {
b := true
if len(cache) > 0 {
b = cache[0]
}
return &MongoDB{
Cache: b,
}
}
var (
db *mongo.Database
client *mongo.Client
isConnected bool
mu sync.RWMutex
mongoAddr string
dbName string
healthCtx context.Context
healthCancel context.CancelFunc
)
// checkConnected 检查连接状态
func checkConnected() bool {
mu.RLock()
defer mu.RUnlock()
return isConnected
}
// connect 建立MongoDB连接
func connect() error {
mu.Lock()
defer mu.Unlock()
if client != nil {
client.Disconnect(context.Background())
}
// 创建连接选项
opt := options.Client().
ApplyURI(mongoAddr).
SetMaxPoolSize(100).
SetMinPoolSize(10).
SetMaxConnecting(10).
SetConnectTimeout(10 * time.Second)
var err error
client, err = mongo.Connect(opt)
if err != nil {
isConnected = false
glog.Error(context.Background(), "MongoDB连接失败", err)
return err
}
// 测试连接
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer testCancel()
err = client.Ping(testCtx, nil)
if err != nil {
isConnected = false
glog.Error(testCtx, "MongoDB连接测试失败", err)
return err
}
db = client.Database(dbName)
isConnected = true
glog.Info(context.Background(), "✅ MongoDB连接成功")
return nil
}
// GetDB 获取 MongoDB 数据库实例
func GetDB() *mongo.Database {
mu.RLock()
defer mu.RUnlock()
return db
}
// healthCheck 健康检查协程
func healthCheck() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-healthCtx.Done():
return
case <-ticker.C:
mu.RLock()
currentConnected := isConnected
currentClient := client
mu.RUnlock()
if !currentConnected || currentClient == nil {
glog.Warning(context.Background(), "MongoDB连接断开尝试重连")
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
continue
}
// 测试连接状态
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
err := currentClient.Ping(testCtx, nil)
testCancel()
if err != nil {
mu.Lock()
isConnected = false
mu.Unlock()
glog.Warning(context.Background(), "MongoDB连接健康检查失败", err)
// 尝试重连
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
} else {
glog.Debug(context.Background(), "MongoDB连接健康检查通过")
}
}
}
}
// reconnect 重连函数
func reconnect() error {
maxRetries := 3
retryDelay := 2 * time.Second
for i := 0; i < maxRetries; i++ {
glog.Info(context.Background(), fmt.Sprintf("尝试第%d次重连MongoDB", i+1))
if err := connect(); err == nil {
glog.Info(context.Background(), "MongoDB重连成功")
return nil
}
if i < maxRetries-1 {
time.Sleep(retryDelay)
retryDelay *= 2 // 指数退避
}
}
return gerror.New("MongoDB重连失败已达到最大重试次数")
}
var logPool *grpool.Pool
// init 初始化MongoDB连接
func init() {
logPool = grpool.New(1)
// 按需初始化:没有配置 mongo.address 则跳过
mongoAddr = g.Cfg().MustGet(context.Background(), "mongo.address").String()
if mongoAddr == "" {
return
}
// 创建健康检查上下文
healthCtx, healthCancel = context.WithCancel(context.Background())
// 从连接串中解析数据库名
dbName = gstr.SubStr(mongoAddr, strings.LastIndex(mongoAddr, "/")+1, len(mongoAddr))
// 如果连接串带有参数(如 ?retryWrites=true需要去掉参数部分
if strings.Contains(dbName, "?") {
dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?"))
}
go func() {
// 初始连接
if err := connect(); err != nil {
glog.Error(context.Background(), "MongoDB初始连接失败", err)
return
}
}()
// 启动健康检查协程
go healthCheck()
}
// close 关闭MongoDB连接
func close() {
if healthCancel != nil {
healthCancel()
}
mu.Lock()
defer mu.Unlock()
if client != nil {
disconnectCtx, disconnectCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer disconnectCancel()
client.Disconnect(disconnectCtx)
}
isConnected = false
glog.Info(context.Background(), "MongoDB连接已关闭")
}
const PageSize = 20
// Find 查询多条记录
func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, page *beans.Page, orderBy []beans.OrderBy) (total int64, err error) {
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
optionsKey := fmt.Sprintf("%+v%+v", page, orderBy)
redisKey := fmt.Sprintf(redis.List, user.TenantId, collection, filterKey, optionsKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !resultStr.IsEmpty() {
if err = resultStr.Structs(result); err != nil {
return
}
total = int64(len(resultStr.Array()))
return
}
}
filter["tenantId"] = user.TenantId
// 分页参数处理
limit := int64(PageSize)
skip := int64(0)
if page != nil {
limit = page.PageSize
if limit == -1 {
skip = 0
} else {
skip = (page.PageNum - 1) * limit
}
}
opt := options.Find().SetSkip(skip)
if limit != -1 {
opt.SetLimit(limit)
} else {
total, err = m.Count(ctx, filter, collection)
if err != nil || total == 0 {
return
}
}
if orderBy == nil {
opt.SetSort(bson.M{"createdAt": -1})
} else {
orderBson := bson.D{}
for _, v := range orderBy {
if v.Order == beans.Asc {
orderBson = append(orderBson, bson.E{Key: v.Field, Value: 1}) // 1 表示升序
} else {
orderBson = append(orderBson, bson.E{Key: v.Field, Value: -1}) // -1 表示降序
}
}
opt.SetSort(orderBson)
}
cur, err := db.Collection(collection).Find(ctx, filter, opt)
if err != nil {
return
}
if limit == -1 {
total = int64(cur.RemainingBatchLength())
}
defer cur.Close(ctx)
if err = cur.All(ctx, result); err != nil {
return
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return
}
}
return
}
// FindOne 查询1条记录
func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.One, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
err = gconv.Scan(resultStr, result)
if err != nil {
return err
}
return
}
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
cur := db.Collection(collection).FindOne(ctx, filter, opts...)
err = cur.Decode(result)
if errors.Is(err, mongo.ErrNoDocuments) {
err = nil
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
}
return
}
func (m *MongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) {
listKeys := fmt.Sprintf(redis.CleanList, tenantId, collection)
keys, err := redis.RedisClient.Keys(ctx, listKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
countKeys := fmt.Sprintf(redis.CleanCount, tenantId, collection)
keys, err = redis.RedisClient.Keys(ctx, countKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
oneKey := fmt.Sprintf(redis.One, tenantId, collection, filterKey)
_, err = redis.RedisClient.Del(ctx, oneKey)
if err != nil {
return
}
return
}
var serverName = g.Cfg().MustGet(context.TODO(), "server.name").String()
var logRedisKey = fmt.Sprintf("log:%s", serverName)
func (m *MongoDB) log(ctx context.Context, filter bson.M, collection string, data interface{}, userName, tenantId interface{}, operationType string) {
_ = logPool.AddWithRecover(ctx, func(ctx context.Context) {
log := &entity.OperationLog{
ServiceName: serverName,
Collection: collection,
CollectionID: filter["_id"].(string),
Operation: operationType,
IPAddress: g.RequestFromCtx(ctx).GetClientIp(),
Data: data,
}
log.Creator = userName
now := &gtime.Now().Time
log.CreatedAt = now
log.UpdatedAt = now
log.TenantId = tenantId
if _, err := redis.AddToStream(ctx, logRedisKey, log); err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
}, func(ctx context.Context, exception error) {
glog.Error(ctx, "mongoLog-AddWithRecover err: %v", exception)
})
return
}
// Delete 删除记录
func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["tenantId"] = user.TenantId
r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...)
if err != nil {
return
}
count = r.DeletedCount
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
//写日志
//m.log(ctx, filter, collection, nil, user.UserName, user.TenantId, "delete")
return
}
// Update 修改记录
func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
filter["isDeleted"] = false
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
setDoc := update["$set"].(bson.M)
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update = bson.M{"$set": setDoc}
result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...)
if err != nil {
return
}
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
//写日志
//m.log(ctx, filter, collection, update, user.UserName, user.TenantId, "update")
return
}
// RandomSoftDelete 随机软删除个文档的 _id
func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
_ = opts
// 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id
pipeline := mongo.Pipeline{
// 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random'
bson.D{{Key: "$addFields", Value: bson.D{{Key: "random", Value: bson.M{"$rand": bson.M{}}}}}},
// 阶段1: 匹配所有未删除的文档
bson.D{{Key: "$match", Value: bson.D{{Key: "isDeleted", Value: false}}}},
// 阶段2: 按随机数降序排序
bson.D{{Key: "$sort", Value: bson.D{{Key: "random", Value: -1}}}},
// 阶段3: 只取前5个
bson.D{{Key: "$limit", Value: limit}},
// 阶段4: 只投影 _id
bson.D{{Key: "$project", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := db.Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return
}
defer cursor.Close(ctx)
// 步骤 2: 从聚合结果中提取 _id 到一个切片中
var idsToUpdate []bson.ObjectID
for cursor.Next(ctx) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
return nil, err
}
// 将 bson.M 中的 _id 断言为 primitive.ObjectID
id := result["_id"].(bson.ObjectID)
idsToUpdate = append(idsToUpdate, id)
}
if err := cursor.Err(); err != nil {
return nil, err
}
fmt.Printf("准备更新的随机文档ID: %v\n", idsToUpdate)
// 步骤 3: 使用 $in 操作符和 UpdateMany 批量更新选定的文档
if len(idsToUpdate) > 0 {
// 过滤条件:匹配 idsToUpdate 切片中的任意一个 _id
filter := bson.D{{Key: "_id", Value: bson.D{{Key: "$in", Value: idsToUpdate}}}}
// 更新操作:使用 $set 修改字段
update := bson.D{{Key: "$set", Value: bson.D{{Key: "isDeleted", Value: true}}}}
_, err = db.Collection(collection).UpdateMany(ctx, filter, update)
if err != nil {
return
}
}
return
}
// SaveOrUpdate 批量增加或修改
func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) {
if len(filter) == 0 || len(update) == 0 {
err = gerror.New("缺少查询条件或更新数据")
return
}
if len(filter) != len(update) {
err = gerror.New("查询条件和更新数据的数量必须一致")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
// 构建批量操作模型
var models []mongo.WriteModel
for i := 0; i < len(filter); i++ {
// 处理过滤器
filter[i]["isDeleted"] = false
if !g.IsEmpty(user.TenantId) {
filter[i]["tenantId"] = user.TenantId
}
// 处理更新数据
if setDoc, exists := update[i]["$set"].(bson.M); exists {
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
} else {
// 如果没有$set字段则创建一个
setDoc := bson.M{}
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update[i]["$set"] = setDoc
}
// 创建更新操作模型
updateModel := mongo.NewUpdateOneModel()
updateModel.SetFilter(filter[i])
updateModel.SetUpdate(update[i])
updateModel.SetUpsert(true) // 默认不插入新文档
// 处理选项参数
if len(opts) > 0 {
for _, opt := range opts {
var updateOpts options.UpdateManyOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&updateOpts)
}
if updateOpts.Upsert != nil {
updateModel.SetUpsert(*updateOpts.Upsert)
}
}
}
models = append(models, updateModel)
}
// 执行批量操作,无序执行提高性能
bulkOpts := options.BulkWrite().SetOrdered(false)
bulkResult, err := db.Collection(collection).BulkWrite(ctx, models, bulkOpts)
if err != nil {
return nil, err
}
// 清理相关缓存
for _, filterItem := range filter {
err = m.CleanRedis(ctx, filterItem, user.TenantId, collection)
if err != nil {
glog.Warning(ctx, "清理Redis缓存失败:", err)
}
}
return bulkResult, nil
}
// Insert 插入多条记录
func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
docs := make([]interface{}, 0, len(documents))
for _, document := range documents {
doc := gconv.Map(document)
delete(doc, "id")
if !g.IsEmpty(user.UserName) {
doc["creator"] = user.UserName
}
if !g.IsEmpty(user.UserName) {
doc["updater"] = user.UserName
}
if !g.IsEmpty(user.TenantId) {
doc["tenantId"] = user.TenantId
}
doc["createdAt"] = gtime.Now().Time
doc["updatedAt"] = gtime.Now().Time
doc["isDeleted"] = false
docs = append(docs, doc)
}
r, err := db.Collection(collection).InsertMany(ctx, docs, opts...)
if err != nil {
return
}
ids = r.InsertedIDs
err = m.CleanRedis(ctx, bson.M{}, user.TenantId, collection)
//写日志
//m.log(ctx, nil, collection, ids, user.UserName, user.TenantId, "insert")
return
}
// Count 查询总数
func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.Count, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
count = gconv.Int64(resultStr)
return
}
}
// 调用驱动的 CountDocuments在数据库端执行的
count, err = db.Collection(collection).CountDocuments(ctx, filter)
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil {
return
}
}
return
}
// EntityToBson 将 *entity/entity 转换为 bson.M
// 支持传入值类型或指针类型,返回 bson.M 和错误信息
func EntityToBson(entity interface{}) (bson.M, error) {
return EntityToBsonWithFilter(entity, false)
}
// EntityToBsonWithFilter 将 *entity/entity 转换为 bson.M并可选择是否过滤空值
// filterEmpty: 为 true 时会过滤掉空值字段nil、空字符串、空切片、空map等但保留 int 类型的 0 值
// 支持:
// - 未传值的指针(如 *consts.AssetStatus(nil))会被过滤
// - 传值为0的指针如 *consts.AssetStatus(0))不会被过滤
// - 非0的整数值不会被过滤
func EntityToBsonWithFilter(entity interface{}, filterEmpty bool) (bson.M, error) {
// 第一步:判断入参是否为 nil 或无效类型
if entity == nil {
return nil, fmt.Errorf("传入的 entity 实例为 nil")
}
// 第二步:将 entity 序列化为 BSON 字节流
// bson.Marshal 支持值类型和指针类型,会自动解析结构体的 bson 标签
bsonBytes, err := bson.Marshal(entity)
if err != nil {
return nil, fmt.Errorf("entity 序列化为 BSON 字节流失败:%w", err)
}
// 第三步:将 BSON 字节流反序列化为 bson.M
var bsonMap bson.M
err = bson.Unmarshal(bsonBytes, &bsonMap)
if err != nil {
return nil, fmt.Errorf("BSON 字节流反序列化为 bson.M 失败:%w", err)
}
// 如果需要过滤空值
if filterEmpty {
for key, value := range bsonMap {
// 判断是否为空值,但保留 int 类型的 0 值
if isEmptyWithZero(value) {
delete(bsonMap, key)
}
}
}
return bsonMap, nil
}
// isEmptyWithZero 判断是否为空值,但保留 int 类型的 0 值
// 支持区分"未传值"和"传值为0"的情况:
// - *int(nil) 或 *consts.AssetStatus(nil) → 返回 true过滤掉
// - *int(0) 或 *consts.AssetStatus(0) → 返回 false保留
// - int(0) 或 consts.AssetStatus(0) → 返回 false保留
func isEmptyWithZero(value interface{}) bool {
// 先检查 value 是否为 nil
if value == nil {
return true
}
rv := reflect.ValueOf(value)
kind := rv.Kind()
// 处理 nil 指针
if kind == reflect.Ptr {
if rv.IsNil() {
return true
}
// 判断时如果是指针,需要获取指向的值的类型
kind = rv.Elem().Kind()
}
// 数字类型int/uint/float都保留包括 0 值
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return false
default:
// 其他类型使用 g.IsEmpty 判断
return g.IsEmpty(value)
}
}