Files
common/mongo/mongo.go

700 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mongo
import (
"context"
"errors"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"strings"
"sync"
"time"
"gitee.com/red-future---jilin-g/common/consts"
"gitee.com/red-future---jilin-g/common/do"
"gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
var (
db *mongo.Database
client *mongo.Client
isConnected bool
mu sync.RWMutex
mongoAddr string
dbName string
healthCtx context.Context
healthCancel context.CancelFunc
)
// GetDB 获取 MongoDB 数据库实例
func GetDB() *mongo.Database {
mu.RLock()
defer mu.RUnlock()
return db
}
// IsConnected 检查连接状态
func IsConnected() bool {
mu.RLock()
defer mu.RUnlock()
return isConnected
}
// connect 建立MongoDB连接
func connect() error {
mu.Lock()
defer mu.Unlock()
if client != nil {
client.Disconnect(context.Background())
}
// 创建连接选项
opt := options.Client().
ApplyURI(mongoAddr).
SetMaxPoolSize(100).
SetMinPoolSize(10).
SetMaxConnecting(10).
SetConnectTimeout(10 * time.Second)
var err error
client, err = mongo.Connect(opt)
if err != nil {
isConnected = false
glog.Error(context.Background(), "MongoDB连接失败", err)
return err
}
// 测试连接
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer testCancel()
err = client.Ping(testCtx, nil)
if err != nil {
isConnected = false
glog.Error(testCtx, "MongoDB连接测试失败", err)
return err
}
db = client.Database(dbName)
isConnected = true
glog.Info(context.Background(), "✅ MongoDB连接成功")
return nil
}
// healthCheck 健康检查协程
func healthCheck() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-healthCtx.Done():
return
case <-ticker.C:
mu.RLock()
currentConnected := isConnected
currentClient := client
mu.RUnlock()
if !currentConnected || currentClient == nil {
glog.Warning(context.Background(), "MongoDB连接断开尝试重连")
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
continue
}
// 测试连接状态
testCtx, testCancel := context.WithTimeout(context.Background(), 5*time.Second)
err := currentClient.Ping(testCtx, nil)
testCancel()
if err != nil {
mu.Lock()
isConnected = false
mu.Unlock()
glog.Warning(context.Background(), "MongoDB连接健康检查失败", err)
// 尝试重连
if err := reconnect(); err != nil {
glog.Error(context.Background(), "MongoDB重连失败", err)
}
} else {
glog.Debug(context.Background(), "MongoDB连接健康检查通过")
}
}
}
}
// reconnect 重连函数
func reconnect() error {
maxRetries := 3
retryDelay := 2 * time.Second
for i := 0; i < maxRetries; i++ {
glog.Info(context.Background(), fmt.Sprintf("尝试第%d次重连MongoDB", i+1))
if err := connect(); err == nil {
glog.Info(context.Background(), "MongoDB重连成功")
return nil
}
if i < maxRetries-1 {
time.Sleep(retryDelay)
retryDelay *= 2 // 指数退避
}
}
return gerror.New("MongoDB重连失败已达到最大重试次数")
}
// init 初始化MongoDB连接
func init() {
// 按需初始化:没有配置 mongo.address 则跳过
mongoAddr = g.Cfg().MustGet(context.Background(), "mongo.address").String()
if mongoAddr == "" {
return
}
// 创建健康检查上下文
healthCtx, healthCancel = context.WithCancel(context.Background())
// 从连接串中解析数据库名
dbName = gstr.SubStr(mongoAddr, strings.LastIndex(mongoAddr, "/")+1, len(mongoAddr))
// 如果连接串带有参数(如 ?retryWrites=true需要去掉参数部分
if strings.Contains(dbName, "?") {
dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?"))
}
go func() {
// 初始连接
if err := connect(); err != nil {
glog.Error(context.Background(), "MongoDB初始连接失败", err)
return
}
}()
// 启动健康检查协程
go healthCheck()
}
// Close 关闭MongoDB连接
func Close() {
if healthCancel != nil {
healthCancel()
}
mu.Lock()
defer mu.Unlock()
if client != nil {
disconnectCtx, disconnectCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer disconnectCancel()
client.Disconnect(disconnectCtx)
}
isConnected = false
glog.Info(context.Background(), "MongoDB连接已关闭")
}
func listOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOptions]) (m map[string]interface{}) {
// 输出opts参数中的值
m = make(map[string]interface{})
for _, opt := range opts {
var findOpts options.FindOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&findOpts)
}
if findOpts.Limit != nil {
m["limit"] = *findOpts.Limit
}
if findOpts.Skip != nil {
m["skip"] = *findOpts.Skip
}
if findOpts.Sort != nil {
m["sort"] = findOpts.Sort
}
if findOpts.Projection != nil {
m["projection"] = findOpts.Projection
}
}
m = utils.OrderMap(m)
return
}
func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOneOptions]) (m map[string]interface{}) {
// 输出opts参数中的值
m = make(map[string]interface{})
for _, opt := range opts {
var findOpts options.FindOneOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&findOpts)
}
if findOpts.Skip != nil {
m["skip"] = *findOpts.Skip
}
if findOpts.Sort != nil {
m["sort"] = findOpts.Sort
}
if findOpts.Projection != nil {
m["projection"] = findOpts.Projection
}
}
m = utils.OrderMap(m)
return
}
// GetTenantInfo 获取租户信息
// 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表
func GetTenantInfo(ctx context.Context) (user do.User, err error) {
// 1. 优先从 token 获取
user, err = utils.GetUserInfo(ctx)
if err == nil {
return
}
// 2. token 获取失败尝试从请求参数或context获取 accountName
var accountName string
// 2.1 尝试从request获取HTTP请求场景
req := g.RequestFromCtx(ctx)
if req != nil {
accountName = req.Get("accountName").String()
if accountName == "" {
accountName = req.Get("account_name").String()
}
// 兼容旧参数名
if accountName == "" {
accountName = req.Get("customerServiceId").String()
}
if accountName == "" {
accountName = req.Get("customer_service_id").String()
}
}
// 2.2 request不存在或未获取到尝试从context.Value获取WebSocket场景
if accountName == "" {
if val := ctx.Value("accountName"); val != nil {
if str, ok := val.(string); ok {
accountName = str
}
}
// 兼容旧参数名
if accountName == "" {
if val := ctx.Value("customerServiceId"); val != nil {
if str, ok := val.(string); ok {
accountName = str
}
}
}
}
if accountName == "" {
return user, gerror.New("无法获取租户信息:无 token 且无 accountName 参数")
}
// 3. 直接查询 customer_service_account 表获取 tenantId
filter := bson.M{"accountName": accountName, "isDeleted": false}
var account struct {
TenantId interface{} `bson:"tenantId"`
}
if findErr := db.Collection("customer_service_account").FindOne(ctx, filter).Decode(&account); findErr != nil {
return user, gerror.Newf("通过 accountName 查询租户失败: %v", findErr)
}
user.TenantId = account.TenantId
user.UserName = accountName
err = nil // 清空之前从token获取时的错误
return
}
// Find 查询多条记录
func Find(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) {
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterMap := utils.OrderMap(filter)
optsMap := listOptionsToMap(ctx, opts...)
redisKey := ""
if !NoCache {
redisKey = fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap))
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
err = gconv.Scan(resultStr, result)
if err != nil {
return err
}
return
}
}
filter["tenantId"] = user.TenantId
cur, err := db.Collection(collection).Find(ctx, filter, opts...)
if err != nil {
return
}
err = cur.All(ctx, result)
if !NoCache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
}
return
}
// FindOne 查询1条记录
func FindOne(ctx context.Context, NoCache bool, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterMap := utils.OrderMap(filter)
redisKey := ""
if !NoCache {
redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap))
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
err = gconv.Scan(resultStr, result)
if err != nil {
return err
}
return
}
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
cur := db.Collection(collection).FindOne(ctx, filter, opts...)
err = cur.Decode(result)
if errors.Is(err, mongo.ErrNoDocuments) {
err = nil
}
if !NoCache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
}
return
}
func cleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) {
listKeys := fmt.Sprintf(consts.CleanList, tenantId, collection)
keys, err := redis.RedisClient.Keys(ctx, listKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
countKeys := fmt.Sprintf(consts.CleanCount, tenantId, collection)
keys, err = redis.RedisClient.Keys(ctx, countKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterMap := utils.OrderMap(filter)
oneKey := fmt.Sprintf(consts.One, tenantId, collection, gconv.String(filterMap))
_, err = redis.RedisClient.Del(ctx, oneKey)
if err != nil {
return
}
return
}
// Delete 删除记录
func Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
filter["tenantId"] = user.TenantId
r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...)
if err != nil {
return
}
count = r.DeletedCount
err = cleanRedis(ctx, filter, user.TenantId, collection)
return
}
// Update 修改记录
func Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
filter["isDeleted"] = false
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
setDoc := update["$set"].(bson.M)
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update = bson.M{"$set": setDoc}
result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...)
if err != nil {
return
}
err = cleanRedis(ctx, filter, user.TenantId, collection)
return
}
// RandomSoftDelete 随机软删除个文档的 _id
func RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
// 步骤 1: 使用聚合管道的 $sample 操作符随机抽取5个文档的 _id
pipeline := mongo.Pipeline{
// 阶段1: 为每个文档添加一个 0-1 之间的随机数字段 'random'
bson.D{{Key: "$addFields", Value: bson.D{{Key: "random", Value: bson.M{"$rand": bson.M{}}}}}},
// 阶段1: 匹配所有未删除的文档
bson.D{{Key: "$match", Value: bson.D{{Key: "isDeleted", Value: false}}}},
// 阶段2: 按随机数降序排序
bson.D{{Key: "$sort", Value: bson.D{{Key: "random", Value: -1}}}},
// 阶段3: 只取前5个
bson.D{{Key: "$limit", Value: limit}},
// 阶段4: 只投影 _id
bson.D{{Key: "$project", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := db.Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return
}
defer cursor.Close(ctx)
// 步骤 2: 从聚合结果中提取 _id 到一个切片中
var idsToUpdate []bson.ObjectID
for cursor.Next(ctx) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
return nil, err
}
// 将 bson.M 中的 _id 断言为 primitive.ObjectID
id := result["_id"].(bson.ObjectID)
idsToUpdate = append(idsToUpdate, id)
}
if err := cursor.Err(); err != nil {
return nil, err
}
fmt.Printf("准备更新的随机文档ID: %v\n", idsToUpdate)
// 步骤 3: 使用 $in 操作符和 UpdateMany 批量更新选定的文档
if len(idsToUpdate) > 0 {
// 过滤条件:匹配 idsToUpdate 切片中的任意一个 _id
filter := bson.D{{Key: "_id", Value: bson.D{{Key: "$in", Value: idsToUpdate}}}}
// 更新操作:使用 $set 修改字段
update := bson.D{{Key: "$set", Value: bson.D{{Key: "isDeleted", Value: true}}}}
_, err = db.Collection(collection).UpdateMany(ctx, filter, update)
if err != nil {
return
}
}
return
}
// SaveOrUpdate 批量增加或修改
func SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) {
if len(filter) == 0 || len(update) == 0 {
err = gerror.New("缺少查询条件或更新数据")
return
}
if len(filter) != len(update) {
err = gerror.New("查询条件和更新数据的数量必须一致")
return
}
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
// 构建批量操作模型
var models []mongo.WriteModel
for i := 0; i < len(filter); i++ {
// 处理过滤器
filter[i]["isDeleted"] = false
if !g.IsEmpty(user.TenantId) {
filter[i]["tenantId"] = user.TenantId
}
// 处理更新数据
if setDoc, exists := update[i]["$set"].(bson.M); exists {
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
} else {
// 如果没有$set字段则创建一个
setDoc := bson.M{}
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update[i]["$set"] = setDoc
}
// 创建更新操作模型
updateModel := mongo.NewUpdateOneModel()
updateModel.SetFilter(filter[i])
updateModel.SetUpdate(update[i])
updateModel.SetUpsert(true) // 默认不插入新文档
// 处理选项参数
if len(opts) > 0 {
for _, opt := range opts {
var updateOpts options.UpdateManyOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&updateOpts)
}
if updateOpts.Upsert != nil {
updateModel.SetUpsert(*updateOpts.Upsert)
}
}
}
models = append(models, updateModel)
}
// 执行批量操作,无序执行提高性能
bulkOpts := options.BulkWrite().SetOrdered(false)
bulkResult, err := db.Collection(collection).BulkWrite(ctx, models, bulkOpts)
if err != nil {
return nil, err
}
// 清理相关缓存
for _, filterItem := range filter {
err = cleanRedis(ctx, filterItem, user.TenantId, collection)
if err != nil {
glog.Warning(ctx, "清理Redis缓存失败:", err)
}
}
return bulkResult, nil
}
// Insert 插入多条记录
func Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
docs := make([]interface{}, 0, len(documents))
for _, document := range documents {
doc := gconv.Map(document)
delete(doc, "id")
if !g.IsEmpty(user.UserName) {
doc["creator"] = user.UserName
}
if !g.IsEmpty(user.UserName) {
doc["updater"] = user.UserName
}
if !g.IsEmpty(user.TenantId) {
doc["tenantId"] = user.TenantId
}
doc["createdAt"] = gtime.Now().Time
doc["updatedAt"] = gtime.Now().Time
doc["isDeleted"] = false
docs = append(docs, doc)
}
r, err := db.Collection(collection).InsertMany(ctx, docs, opts...)
if err != nil {
return
}
ids = r.InsertedIDs
err = cleanRedis(ctx, bson.M{}, user.TenantId, collection)
return
}
// Count 查询总数
func Count(ctx context.Context, NoCache bool, filter bson.M, collection string) (count int64, err error) {
user, err := GetTenantInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterMap := utils.OrderMap(filter)
redisKey := ""
if !NoCache {
redisKey = fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap))
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
count = gconv.Int64(resultStr)
return
}
}
// 调用驱动的 CountDocuments在数据库端执行的
count, err = db.Collection(collection).CountDocuments(ctx, filter)
if !NoCache {
err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil {
return
}
}
return
}
// EntityToBSONM 将 *entity/entity 转换为 bson.M
// 支持传入值类型或指针类型,返回 bson.M 和错误信息
func EntityToBSONM(entity interface{}) (bson.M, error) {
// 第一步:判断入参是否为 nil 或无效类型
if entity == nil {
return nil, fmt.Errorf("传入的 entity 实例为 nil")
}
// 第二步:将 entity 序列化为 BSON 字节流
// bson.Marshal 支持值类型和指针类型,会自动解析结构体的 bson 标签
bsonBytes, err := bson.Marshal(entity)
if err != nil {
return nil, fmt.Errorf("entity 序列化为 BSON 字节流失败:%w", err)
}
// 第三步:将 BSON 字节流反序列化为 bson.M
var bsonMap bson.M
err = bson.Unmarshal(bsonBytes, &bsonMap)
if err != nil {
return nil, fmt.Errorf("BSON 字节流反序列化为 bson.M 失败:%w", err)
}
return bsonMap, nil
}