package mongo import ( "context" "errors" "fmt" "strings" "time" "gitee.com/red-future---jilin-g/common/consts" "gitee.com/red-future---jilin-g/common/do" "gitee.com/red-future---jilin-g/common/redis" "gitee.com/red-future---jilin-g/common/utils" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) var db *mongo.Database // GetDB 获取 MongoDB 数据库实例 func GetDB() *mongo.Database { return db } func init() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() mongoAddr := g.Cfg().MustGet(context.Background(), "mongo.address").String() opt := options.Client().ApplyURI(mongoAddr) client, err := mongo.Connect(opt) if err != nil { glog.Error(ctx, "mongodb连接失败", err) } // 从连接串中解析数据库名 dbName := gstr.SubStr(mongoAddr, strings.LastIndex(mongoAddr, "/")+1, len(mongoAddr)) // 如果连接串带有参数(如 ?retryWrites=true),需要去掉参数部分 if strings.Contains(dbName, "?") { dbName = gstr.SubStr(dbName, 0, strings.Index(dbName, "?")) } db = client.Database(dbName) } func listOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOptions]) (m map[string]interface{}) { // 输出opts参数中的值 m = make(map[string]interface{}) for _, opt := range opts { var findOpts options.FindOptions optFuncs := opt.List() for _, fn := range optFuncs { fn(&findOpts) } if findOpts.Limit != nil { m["limit"] = *findOpts.Limit } if findOpts.Skip != nil { m["skip"] = *findOpts.Skip } if findOpts.Sort != nil { m["sort"] = findOpts.Sort } if findOpts.Projection != nil { m["projection"] = findOpts.Projection } } m = utils.OrderMap(m) return } func oneOptionsToMap(ctx context.Context, opts ...options.Lister[options.FindOneOptions]) (m map[string]interface{}) { // 输出opts参数中的值 m = make(map[string]interface{}) for _, opt := range opts { var findOpts options.FindOneOptions optFuncs := opt.List() for _, fn := range optFuncs { fn(&findOpts) } if findOpts.Skip != nil { m["skip"] = *findOpts.Skip } if findOpts.Sort != nil { m["sort"] = findOpts.Sort } if findOpts.Projection != nil { m["projection"] = findOpts.Projection } } m = utils.OrderMap(m) return } // GetTenantInfo 获取租户信息 // 优先从 token 获取,失败则从请求参数 customerServiceId 查询 customer_service_account 表 func GetTenantInfo(ctx context.Context) (user do.User, err error) { // 1. 优先从 token 获取 user, err = utils.GetUserInfo(ctx) if err == nil { return } // 2. token 获取失败,尝试从请求参数获取 customerServiceId req := g.RequestFromCtx(ctx) if req == nil { return user, gerror.New("无法获取租户信息:无 token 且无 request") } customerServiceId := req.Get("customerServiceId").String() if customerServiceId == "" { customerServiceId = req.Get("customer_service_id").String() } if customerServiceId == "" { return user, gerror.New("无法获取租户信息:无 token 且无 customerServiceId 参数") } // 3. 直接查询 customer_service_account 表获取 tenantId filter := bson.M{"customerServiceId": customerServiceId, "isDeleted": false} var account struct { TenantId interface{} `bson:"tenantId"` } if findErr := db.Collection("customer_service_account").FindOne(ctx, filter).Decode(&account); findErr != nil { return user, gerror.Newf("通过 customerServiceId 查询租户失败: %v", findErr) } user.TenantId = account.TenantId user.UserName = customerServiceId return } // Find 查询多条记录 func Find(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOptions]) (err error) { if err = utils.ValidStructPtr(result); err != nil { return } user, err := GetTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) optsMap := listOptionsToMap(ctx, opts...) redisKey := fmt.Sprintf(consts.List, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap)) resultStr, err := redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { err = gconv.Scan(resultStr, result) if err != nil { return err } return } filter["tenantId"] = user.TenantId cur, err := db.Collection(collection).Find(ctx, filter, opts...) if err != nil { return } err = cur.All(ctx, result) err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err } return } // FindOne 查询1条记录 func FindOne(ctx context.Context, filter bson.M, result interface{}, collection string, opts ...options.Lister[options.FindOneOptions]) (err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } if err = utils.ValidStructPtr(result); err != nil { return } user, err := GetTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) optsMap := oneOptionsToMap(ctx, opts...) redisKey := fmt.Sprintf(consts.One, user.TenantId, collection, gconv.String(filterMap), gconv.String(optsMap)) resultStr, err := redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { err = gconv.Scan(resultStr, result) if err != nil { return err } return } filter["tenantId"] = user.TenantId cur := db.Collection(collection).FindOne(ctx, filter, opts...) err = cur.Decode(result) if errors.Is(err, mongo.ErrNoDocuments) { err = nil } err = redis.RedisClient.SetEX(ctx, redisKey, result, int64(time.Hour)) if err != nil { return err } return } func cleanRedis(ctx context.Context, tenantId interface{}, collection string) (err error) { listKeys := fmt.Sprintf(consts.CleanList, tenantId, collection) keys, err := redis.RedisClient.Keys(ctx, listKeys) if err != nil { return } for _, key := range keys { _, err = redis.RedisClient.Del(ctx, key) if err != nil { return } } countKeys := fmt.Sprintf(consts.CleanCount, tenantId, collection) keys, err = redis.RedisClient.Keys(ctx, countKeys) if err != nil { return } for _, key := range keys { _, err = redis.RedisClient.Del(ctx, key) if err != nil { return } } oneKey := fmt.Sprintf(consts.One, tenantId, collection) _, err = redis.RedisClient.Del(ctx, oneKey) if err != nil { return } return } // Delete 删除记录 func Delete(ctx context.Context, filter bson.M, collection string, opts ...options.Lister[options.DeleteManyOptions]) (count int64, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } user, err := GetTenantInfo(ctx) if err != nil { return } filter["tenantId"] = user.TenantId r, err := db.Collection(collection).DeleteMany(ctx, filter, opts...) if err != nil { return } count = r.DeletedCount err = cleanRedis(ctx, user.TenantId, collection) return } // Update 修改记录 func Update(ctx context.Context, filter bson.M, update bson.M, collection string, opts ...options.Lister[options.UpdateManyOptions]) (result *mongo.UpdateResult, err error) { if len(filter) == 0 { err = gerror.New("缺少查询条件") return } filter["isDeleted"] = false user, err := GetTenantInfo(ctx) if err != nil { return } filter["tenantId"] = user.TenantId setDoc := update["$set"].(bson.M) setDoc["updater"] = user.UserName setDoc["updatedAt"] = gtime.Now().Time update = bson.M{"$set": setDoc} result, err = db.Collection(collection).UpdateMany(ctx, filter, update, opts...) if err != nil { return } err = cleanRedis(ctx, user.TenantId, collection) return } // Insert 插入多条记录 func Insert(ctx context.Context, documents []interface{}, collection string, opts ...options.Lister[options.InsertManyOptions]) (ids []interface{}, err error) { user, err := GetTenantInfo(ctx) if err != nil { return } docs := make([]interface{}, 0, len(documents)) for _, document := range documents { doc := gconv.Map(document) delete(doc, "id") doc["creator"] = user.UserName doc["createdAt"] = gtime.Now().Time doc["updater"] = user.UserName doc["updatedAt"] = gtime.Now().Time doc["tenantId"] = user.TenantId doc["isDeleted"] = false docs = append(docs, doc) } r, err := db.Collection(collection).InsertMany(ctx, docs, opts...) if err != nil { return } ids = r.InsertedIDs err = cleanRedis(ctx, user.TenantId, collection) return } // Count 查询总数 func Count(ctx context.Context, filter bson.M, collection string) (count int64, err error) { user, err := GetTenantInfo(ctx) if err != nil { return } filter["isDeleted"] = false filterMap := utils.OrderMap(filter) redisKey := fmt.Sprintf(consts.Count, user.TenantId, collection, gconv.String(filterMap)) resultStr, err := redis.RedisClient.Get(ctx, redisKey) if err != nil { return } if !g.IsEmpty(resultStr) { count = gconv.Int64(resultStr) return } // 调用驱动的 CountDocuments,在数据库端执行的 count, err = db.Collection(collection).CountDocuments(ctx, filter) err = redis.RedisClient.SetEX(ctx, redisKey, count, int64(time.Hour)) if err != nil { return } return }