Files
common/mongo/mongo.go

1000 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// =============================================================================
// MongoDB 多数据源支持
// 支持多数据源配置、自动重连、优雅关闭
// 向后兼容原有的单数据源API
// =============================================================================
package mongo
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"reflect"
"strings"
"sync"
"syscall"
"time"
"gitee.com/red-future---jilin-g/common/beans"
"gitee.com/red-future---jilin-g/common/log/model/entity"
"gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
// =============================================================================
// 数据源配置结构
// =============================================================================
type DataSourceConfig struct {
Name string `json:"name"`
Address string `json:"address"`
Database string `json:"database"`
MaxPoolSize int32 `json:"maxPoolSize"`
MinPoolSize int32 `json:"minPoolSize"`
ConnectTimeout time.Duration `json:"connectTimeout"`
}
// =============================================================================
// 单个数据源接口
// =============================================================================
type DataSource interface {
Name() string
Database() *mongo.Database
Client() *mongo.Client
IsConnected() bool
Connect(ctx context.Context) error
Reconnect(ctx context.Context) error
Close(ctx context.Context) error
}
// =============================================================================
// 数据源实现
// =============================================================================
type BaseDataSource struct {
config *DataSourceConfig
client *mongo.Client
database *mongo.Database
isConnected bool
mu sync.RWMutex
lastError error
lastErrorTime time.Time
}
func NewBaseDataSource(config *DataSourceConfig) *BaseDataSource {
return &BaseDataSource{
config: config,
isConnected: false,
}
}
func (d *BaseDataSource) Name() string {
return d.config.Name
}
func (d *BaseDataSource) Database() *mongo.Database {
d.mu.RLock()
defer d.mu.RUnlock()
return d.database
}
func (d *BaseDataSource) Client() *mongo.Client {
d.mu.RLock()
defer d.mu.RUnlock()
return d.client
}
func (d *BaseDataSource) IsConnected() bool {
d.mu.RLock()
defer d.mu.RUnlock()
return d.isConnected && d.client != nil
}
func (d *BaseDataSource) Connect(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.client != nil {
d.client.Disconnect(ctx)
}
// 解析数据库名
dbName := d.config.Database
if strings.Contains(dbName, "?") {
dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?"))
}
// 构建连接选项
opt := options.Client().
ApplyURI(d.config.Address).
SetMaxPoolSize(uint64(d.config.MaxPoolSize)).
SetMinPoolSize(uint64(d.config.MinPoolSize)).
SetConnectTimeout(d.config.ConnectTimeout).
SetMaxConnecting(10).
SetServerSelectionTimeout(10 * time.Second).
SetHeartbeatInterval(10 * time.Second).
SetMaxConnIdleTime(60 * time.Second).
SetRetryWrites(true).
SetRetryReads(true)
var err error
d.client, err = mongo.Connect(opt)
if err != nil {
d.isConnected = false
d.lastError = err
d.lastErrorTime = time.Now()
return fmt.Errorf("datasource [%s] connection failed: %w", d.config.Name, err)
}
// 测试连接
pingCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err = d.client.Ping(pingCtx, nil); err != nil {
d.isConnected = false
d.lastError = err
d.lastErrorTime = time.Now()
return fmt.Errorf("datasource [%s] ping failed: %w", d.config.Name, err)
}
d.database = d.client.Database(dbName)
d.isConnected = true
d.lastError = nil
glog.Infof(ctx, "✅ datasource [%s] connected successfully", d.config.Name)
return nil
}
func (d *BaseDataSource) Reconnect(ctx context.Context) error {
glog.Infof(ctx, "🔄 reconnecting datasource [%s]", d.config.Name)
return d.Connect(ctx)
}
func (d *BaseDataSource) Close(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.client != nil {
disconnectCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := d.client.Disconnect(disconnectCtx); err != nil {
return fmt.Errorf("datasource [%s] close failed: %w", d.config.Name, err)
}
}
d.isConnected = false
glog.Infof(ctx, "datasource [%s] closed", d.config.Name)
return nil
}
// =============================================================================
// 多数据源管理器
// =============================================================================
type DataSourceManager struct {
sources map[string]DataSource
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
started bool
maxRetries int
}
var (
globalManager *DataSourceManager
managerOnce sync.Once
)
// GetManager 获取全局管理器
func GetManager() *DataSourceManager {
managerOnce.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
globalManager = &DataSourceManager{
sources: make(map[string]DataSource),
ctx: ctx,
cancel: cancel,
started: false,
maxRetries: 3,
}
})
return globalManager
}
// RegisterDataSource 注册数据源
func (m *DataSourceManager) RegisterDataSource(config *DataSourceConfig) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.sources[config.Name]; exists {
return fmt.Errorf("datasource [%s] already exists", config.Name)
}
source := NewBaseDataSource(config)
m.sources[config.Name] = source
return nil
}
// GetDataSource 获取数据源
func (m *DataSourceManager) GetDataSource(name string) (DataSource, error) {
m.mu.RLock()
defer m.mu.RUnlock()
source, exists := m.sources[name]
if !exists {
return nil, fmt.Errorf("datasource [%s] not found", name)
}
return source, nil
}
// GetAllDataSourceNames 获取所有数据源名称
func (m *DataSourceManager) GetAllDataSourceNames() []string {
m.mu.RLock()
defer m.mu.RUnlock()
names := make([]string, 0, len(m.sources))
for name := range m.sources {
names = append(names, name)
}
return names
}
// InitializeFromConfig 从配置初始化数据源
// 动态读取 config.yml 中 mongo 下的所有配置项
func (m *DataSourceManager) InitializeFromConfig(ctx context.Context) error {
var firstErr error
// 获取 mongo 配置下的所有子键
mongoConfig := g.Cfg().MustGet(ctx, "mongo")
if mongoConfig.IsNil() {
glog.Warningf(ctx, "no mongo configuration found in config.yml")
return nil
}
// 将配置转换为 map
configMap := mongoConfig.Map()
if configMap == nil {
glog.Warningf(ctx, "mongo configuration is not a map")
return nil
}
// 遍历所有 mongo 子配置
for name, subConfig := range configMap {
// 跳过非对象类型的配置
subMap, ok := subConfig.(map[string]interface{})
if !ok {
continue
}
// 检查是否有 address 配置
address, hasAddress := subMap["address"]
if !hasAddress || gconv.String(address) == "" {
continue
}
// 构建数据源配置
config := &DataSourceConfig{
Name: name,
Address: gconv.String(address),
Database: gconv.String(subMap["database"]),
MaxPoolSize: int32(gconv.Int(subMap["maxPoolSize"])),
MinPoolSize: int32(gconv.Int(subMap["minPoolSize"])),
ConnectTimeout: gconv.Duration(subMap["connectTimeout"]),
}
// 设置默认值
if config.MaxPoolSize == 0 {
config.MaxPoolSize = 100
}
if config.MinPoolSize == 0 {
config.MinPoolSize = 10
}
if config.ConnectTimeout == 0 {
config.ConnectTimeout = 10 * time.Second
}
// 注册数据源
if err := m.RegisterDataSource(config); err != nil {
glog.Errorf(ctx, "failed to register datasource [%s]: %v", name, err)
if firstErr == nil {
firstErr = err
}
continue
}
// 连接数据源
source, _ := m.GetDataSource(name)
if err := source.Connect(ctx); err != nil {
glog.Errorf(ctx, "failed to initialize datasource [%s]: %v", name, err)
if firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
// StartHealthCheck 启动健康检查
func (m *DataSourceManager) StartHealthCheck() {
if m.started {
return
}
m.started = true
go m.healthCheckLoop()
}
// healthCheckLoop 健康检查循环
func (m *DataSourceManager) healthCheckLoop() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.checkAndReconnect()
}
}
}
// checkAndReconnect 检查并重新连接
func (m *DataSourceManager) checkAndReconnect() {
m.mu.RLock()
defer m.mu.RUnlock()
for name, source := range m.sources {
if !source.IsConnected() {
glog.Warningf(context.Background(), "datasource [%s] disconnected, attempting reconnect", name)
reconnectCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := source.Reconnect(reconnectCtx); err != nil {
glog.Errorf(reconnectCtx, "datasource [%s] reconnect failed: %v", name, err)
} else {
glog.Infof(reconnectCtx, "✅ datasource [%s] reconnected successfully", name)
}
}
}
}
// CloseAll 关闭所有数据源
func (m *DataSourceManager) CloseAll(ctx context.Context) error {
m.cancel()
m.mu.RLock()
defer m.mu.RUnlock()
var lastErr error
for name, source := range m.sources {
if err := source.Close(ctx); err != nil {
glog.Errorf(ctx, "failed to close datasource [%s]: %v", name, err)
lastErr = err
}
}
return lastErr
}
// =============================================================================
// 向后兼容的MongoDB结构体
// =============================================================================
type MongoDB struct {
Cache bool
dataSource string // 数据源名称,默认为 "default"
}
func DB(cache ...bool) *MongoDB {
b := true
if len(cache) > 0 {
b = cache[0]
}
return &MongoDB{
Cache: b,
dataSource: "default",
}
}
// WithDataSource 指定使用的数据源
func (m *MongoDB) WithDataSource(name string) *MongoDB {
m.dataSource = name
return m
}
// =============================================================================
// 向后兼容的全局变量和方法
// =============================================================================
var (
manager = GetManager()
logPool *grpool.Pool
serverName string
logRedisKey string
)
const PageSize = 20
// GetDB 获取默认数据源的数据库实例(向后兼容)
func GetDB() *mongo.Database {
source, err := manager.GetDataSource("default")
if err != nil {
return nil
}
return source.Database()
}
// init 初始化多数据源
func init() {
logPool = grpool.New(1)
serverName = g.Cfg().MustGet(context.TODO(), "server.name").String()
logRedisKey = fmt.Sprintf("log:%s", serverName)
ctx := context.Background()
// 从配置初始化多数据源
if err := manager.InitializeFromConfig(ctx); err != nil {
glog.Errorf(ctx, "❌ Failed to initialize MongoDB datasources: %v", err)
} else {
glog.Infof(ctx, "✅ MongoDB datasources initialized: %v", manager.GetAllDataSourceNames())
}
// 启动健康检查
manager.StartHealthCheck()
// 设置优雅关闭
setupGracefulShutdown()
}
// setupGracefulShutdown 设置优雅关闭
func setupGracefulShutdown() {
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
glog.Info(ctx, "🔄 Shutting down MongoDB connections...")
if err := manager.CloseAll(ctx); err != nil {
glog.Errorf(ctx, "❌ Failed to close MongoDB connections: %v", err)
} else {
glog.Info(ctx, "✅ MongoDB connections closed successfully")
}
}()
}
// =============================================================================
// MongoDB 操作方法(支持多数据源)
// =============================================================================
// getDataSource 获取当前使用的数据源
func (m *MongoDB) getDataSource() (DataSource, error) {
if m.dataSource == "" {
m.dataSource = "default"
}
return manager.GetDataSource(m.dataSource)
}
// Find 查询多条记录
func (m *MongoDB) Find(ctx context.Context, filter bson.M, result interface{}, collection string, page *beans.Page, orderBy []beans.OrderBy) (total int64, err error) {
source, err := m.getDataSource()
if err != nil {
return 0, err
}
db := source.Database()
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
optionsKey := fmt.Sprintf("%+v%+v", page, orderBy)
redisKey := fmt.Sprintf(redis.List, user.TenantId, collection, filterKey, optionsKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !resultStr.IsEmpty() {
if err = resultStr.Structs(result); err != nil {
return
}
total = int64(len(resultStr.Array()))
return
}
}
filter["tenantId"] = user.TenantId
limit := int64(PageSize)
skip := int64(0)
if page != nil {
limit = page.PageSize
if limit == -1 {
skip = 0
} else {
skip = (page.PageNum - 1) * limit
}
}
opt := options.Find().SetSkip(skip)
if limit != -1 {
opt.SetLimit(limit)
} else {
total, err = m.Count(ctx, filter, collection)
if err != nil || total == 0 {
return
}
}
if orderBy == nil {
opt.SetSort(bson.M{"createdAt": -1})
} else {
orderBson := bson.D{}
for _, v := range orderBy {
if v.Order == beans.Asc {
orderBson = append(orderBson, bson.E{Key: v.Field, Value: 1})
} else {
orderBson = append(orderBson, bson.E{Key: v.Field, Value: -1})
}
}
opt.SetSort(orderBson)
}
cur, err := db.Collection(collection).Find(ctx, filter, opt)
if err != nil {
return
}
if limit == -1 {
total = int64(cur.RemainingBatchLength())
}
defer cur.Close(ctx)
if err = cur.All(ctx, result); err != nil {
return
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return
}
}
return
}
// FindOne 查询1条记录
func (m *MongoDB) FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) {
source, err := m.getDataSource()
if err != nil {
return err
}
db := source.Database()
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
if err = utils.ValidStructPtr(result); err != nil {
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.One, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
err = gconv.Scan(resultStr, result)
if err != nil {
return err
}
return
}
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
cur := db.Collection(collection).FindOne(ctx, filter, opts...)
err = cur.Decode(result)
if errors.Is(err, mongo.ErrNoDocuments) {
err = nil
}
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour))
if err != nil {
return err
}
}
return
}
func (m *MongoDB) CleanRedis(ctx context.Context, filter bson.M, tenantId interface{}, collection string) (err error) {
listKeys := fmt.Sprintf(redis.CleanList, tenantId, collection)
keys, err := redis.RedisClient.Keys(ctx, listKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
countKeys := fmt.Sprintf(redis.CleanCount, tenantId, collection)
keys, err = redis.RedisClient.Keys(ctx, countKeys)
if err != nil {
return
}
for _, key := range keys {
_, err = redis.RedisClient.Del(ctx, key)
if err != nil {
return
}
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
oneKey := fmt.Sprintf(redis.One, tenantId, collection, filterKey)
_, err = redis.RedisClient.Del(ctx, oneKey)
if err != nil {
return
}
return
}
func (m *MongoDB) log(ctx context.Context, filter bson.M, collection string, data interface{}, userName, tenantId interface{}, operationType string) {
_ = logPool.AddWithRecover(ctx, func(ctx context.Context) {
log := &entity.OperationLog{
ServiceName: serverName,
Collection: collection,
CollectionID: filter["_id"].(string),
Operation: operationType,
IPAddress: g.RequestFromCtx(ctx).GetClientIp(),
Data: data,
}
log.Creator = userName
now := &gtime.Now().Time
log.CreatedAt = now
log.UpdatedAt = now
log.TenantId = tenantId
if _, err := redis.AddToStream(ctx, logRedisKey, log); err != nil {
glog.Error(ctx, "mongoLog-AddToStream err: %v", err)
}
}, func(ctx context.Context, exception error) {
glog.Error(ctx, "mongoLog-AddWithRecover err: %v", exception)
})
return
}
// Delete 删除记录
func (m *MongoDB) Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) {
source, err := m.getDataSource()
if err != nil {
return 0, err
}
db := source.Database()
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["tenantId"] = user.TenantId
r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...)
if err != nil {
return
}
count = r.DeletedCount
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
return
}
// Update 修改记录
func (m *MongoDB) Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
source, err := m.getDataSource()
if err != nil {
return nil, err
}
db := source.Database()
if len(filter) == 0 {
err = gerror.New("缺少查询条件")
return
}
filter["isDeleted"] = false
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
if !g.IsEmpty(user.TenantId) {
filter["tenantId"] = user.TenantId
}
setDoc := update["$set"].(bson.M)
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update = bson.M{"$set": setDoc}
result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...)
if err != nil {
return
}
err = m.CleanRedis(ctx, filter, user.TenantId, collection)
return
}
// RandomSoftDelete 随机软删除个文档的 _id
func (m *MongoDB) RandomSoftDelete(ctx context.Context, limit int, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) {
source, err := m.getDataSource()
if err != nil {
return nil, err
}
db := source.Database()
_ = opts
pipeline := mongo.Pipeline{
bson.D{{Key: "$addFields", Value: bson.D{{Key: "random", Value: bson.M{"$rand": bson.M{}}}}}},
bson.D{{Key: "$match", Value: bson.D{{Key: "isDeleted", Value: false}}}},
bson.D{{Key: "$sort", Value: bson.D{{Key: "random", Value: -1}}}},
bson.D{{Key: "$limit", Value: limit}},
bson.D{{Key: "$project", Value: bson.D{{Key: "_id", Value: 1}}}},
}
cursor, err := db.Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return
}
defer cursor.Close(ctx)
var idsToUpdate []bson.ObjectID
for cursor.Next(ctx) {
var result bson.M
if err := cursor.Decode(&result); err != nil {
return nil, err
}
id := result["_id"].(bson.ObjectID)
idsToUpdate = append(idsToUpdate, id)
}
if err := cursor.Err(); err != nil {
return nil, err
}
fmt.Printf("准备更新的随机文档ID: %v\n", idsToUpdate)
if len(idsToUpdate) > 0 {
filter := bson.D{{Key: "_id", Value: bson.D{{Key: "$in", Value: idsToUpdate}}}}
update := bson.D{{Key: "$set", Value: bson.D{{Key: "isDeleted", Value: true}}}}
_, err = db.Collection(collection).UpdateMany(ctx, filter, update)
if err != nil {
return
}
}
return
}
// SaveOrUpdate 批量增加或修改
func (m *MongoDB) SaveOrUpdate(ctx context.Context, filter []bson.M, update []bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.BulkWriteResult, err error) {
source, err := m.getDataSource()
if err != nil {
return nil, err
}
db := source.Database()
if len(filter) == 0 || len(update) == 0 {
err = gerror.New("缺少查询条件或更新数据")
return
}
if len(filter) != len(update) {
err = gerror.New("查询条件和更新数据的数量必须一致")
return
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
var models []mongo.WriteModel
for i := 0; i < len(filter); i++ {
filter[i]["isDeleted"] = false
if !g.IsEmpty(user.TenantId) {
filter[i]["tenantId"] = user.TenantId
}
if setDoc, exists := update[i]["$set"].(bson.M); exists {
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
} else {
setDoc := bson.M{}
if !g.IsEmpty(user.UserName) {
setDoc["updater"] = user.UserName
}
setDoc["updatedAt"] = gtime.Now().Time
update[i]["$set"] = setDoc
}
updateModel := mongo.NewUpdateOneModel()
updateModel.SetFilter(filter[i])
updateModel.SetUpdate(update[i])
updateModel.SetUpsert(true)
if len(opts) > 0 {
for _, opt := range opts {
var updateOpts options.UpdateManyOptions
optFuncs := opt.List()
for _, fn := range optFuncs {
fn(&updateOpts)
}
if updateOpts.Upsert != nil {
updateModel.SetUpsert(*updateOpts.Upsert)
}
}
}
models = append(models, updateModel)
}
bulkOpts := options.BulkWrite().SetOrdered(false)
bulkResult, err := db.Collection(collection).BulkWrite(ctx, models, bulkOpts)
if err != nil {
return nil, err
}
for _, filterItem := range filter {
err = m.CleanRedis(ctx, filterItem, user.TenantId, collection)
if err != nil {
glog.Warning(ctx, "清理Redis缓存失败:", err)
}
}
return bulkResult, nil
}
// Insert 插入多条记录
func (m *MongoDB) Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) {
source, err := m.getDataSource()
if err != nil {
return nil, err
}
db := source.Database()
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
docs := make([]interface{}, 0, len(documents))
for _, document := range documents {
doc := gconv.Map(document)
delete(doc, "id")
if !g.IsEmpty(user.UserName) {
doc["creator"] = user.UserName
}
if !g.IsEmpty(user.UserName) {
doc["updater"] = user.UserName
}
if !g.IsEmpty(user.TenantId) {
doc["tenantId"] = user.TenantId
}
doc["createdAt"] = gtime.Now().Time
doc["updatedAt"] = gtime.Now().Time
doc["isDeleted"] = false
docs = append(docs, doc)
}
r, err := db.Collection(collection).InsertMany(ctx, docs, opts...)
if err != nil {
return
}
ids = r.InsertedIDs
err = m.CleanRedis(ctx, bson.M{}, user.TenantId, collection)
return
}
// Count 查询总数
func (m *MongoDB) Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) {
source, err := m.getDataSource()
if err != nil {
return 0, err
}
db := source.Database()
user, err := utils.GetUserInfo(ctx)
if err != nil {
return
}
filter["isDeleted"] = false
delete(filter, "tenantId")
filterKey := fmt.Sprintf("%+v", filter)
redisKey := fmt.Sprintf(redis.Count, user.TenantId, collection, filterKey)
if m.Cache {
var resultStr *gvar.Var
resultStr, err = redis.RedisClient.Get(ctx, redisKey)
if err != nil {
return
}
if !g.IsEmpty(resultStr) {
count = gconv.Int64(resultStr)
return
}
}
count, err = db.Collection(collection).CountDocuments(ctx, filter)
if m.Cache {
err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour))
if err != nil {
return
}
}
return
}
// EntityToBson 将 *entity/entity 转换为 bson.M
func EntityToBson(entity interface{}) (bson.M, error) {
return EntityToBsonWithFilter(entity, false)
}
// EntityToBsonWithFilter 将 *entity/entity 转换为 bson.M并可选择是否过滤空值
func EntityToBsonWithFilter(entity interface{}, filterEmpty bool) (bson.M, error) {
if entity == nil {
return nil, fmt.Errorf("传入的 entity 实例为 nil")
}
bsonBytes, err := bson.Marshal(entity)
if err != nil {
return nil, fmt.Errorf("entity 序列化为 BSON 字节流失败:%w", err)
}
var bsonMap bson.M
err = bson.Unmarshal(bsonBytes, &bsonMap)
if err != nil {
return nil, fmt.Errorf("BSON 字节流反序列化为 bson.M 失败:%w", err)
}
if filterEmpty {
for key, value := range bsonMap {
if isEmptyWithZero(value) {
delete(bsonMap, key)
}
}
}
return bsonMap, nil
}
// isEmptyWithZero 判断是否为空值,但保留 int 类型的 0 值
func isEmptyWithZero(value interface{}) bool {
if value == nil {
return true
}
rv := reflect.ValueOf(value)
kind := rv.Kind()
if kind == reflect.Ptr {
if rv.IsNil() {
return true
}
kind = rv.Elem().Kind()
}
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64:
return false
default:
return g.IsEmpty(value)
}
}