Files
common/middleware/rate_limiter.go

300 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"fmt"
"strings"
"gitee.com/red-future---jilin-g/common/redis"
"gitee.com/red-future---jilin-g/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
// GlobalLimiter 全局限流中间件使用Redis分布式控制
func GlobalLimiter(r *ghttp.Request) {
// 从配置文件读取全局限流参数
globalLimit := g.Cfg().MustGet(r.GetCtx(), "rate.limit", 800).Int64()
key := redis.RateLimitKeyGlobal
// 使用Redis计数器进行全局限流
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "全局限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > globalLimit {
g.Log().Warningf(r.GetCtx(), "全局限流触发: count: %d, limit: %d", count, globalLimit)
r.Response.WriteStatusExit(429, "系统当前繁忙,请稍后再试")
return
}
r.Middleware.Next()
}
// IPLimiter IP限流中间件防DDoS
func IPLimiter(r *ghttp.Request) {
ip := r.GetClientIp()
key := fmt.Sprintf(redis.RateLimitKeyIP, ip)
// 从配置文件读取IP限流参数
ipLimit := g.Cfg().MustGet(r.GetCtx(), "rate.ip.limit", 100).Int64()
// 使用Redis计数器
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "IP限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > ipLimit {
g.Log().Warningf(r.GetCtx(), "IP限流触发: %s, count: %d, limit: %d", ip, count, ipLimit)
r.Response.WriteStatusExit(429, "请求过于频繁,请稍后再试")
return
}
r.Middleware.Next()
}
// UserLimiter 用户维度限流中间件(防止单用户滥用)
func UserLimiter(r *ghttp.Request) {
// 从JWT获取用户ID(如果已登录)
var userId string
var isAuth bool = false
if token := r.Header.Get("Authorization"); token != "" && gstr.HasPrefix(token, "Bearer ") {
// 这里应该解析JWT获取用户ID简化示例中直接使用token
tokenStr := gstr.SubStrFrom(token, "7")
if tokenStr != "" && validateToken(tokenStr) {
userId = tokenStr
isAuth = true
}
}
// 如果没有userId使用IP作为标识
if userId == "" {
userId = "anon:" + r.GetClientIp()
}
// 从配置文件读取用户限流参数
var userLimit int64
if isAuth {
userLimit = g.Cfg().MustGet(r.GetCtx(), "rate.user.authenticated.limit", 50).Int64()
} else {
userLimit = g.Cfg().MustGet(r.GetCtx(), "rate.user.anonymous.limit", 20).Int64()
}
key := fmt.Sprintf(redis.RateLimitKeyUser, userId)
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1)
if err != nil {
g.Log().Errorf(r.GetCtx(), "用户限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > userLimit {
userType := "已登录"
if !isAuth {
userType = "未登录"
}
g.Log().Warningf(r.GetCtx(), "用户限流触发: %s, count: %d, limit: %d, type: %s", userId, count, userLimit, userType)
r.Response.WriteStatusExit(429, "您的请求过于频繁,请稍后再试")
return
}
r.Middleware.Next()
}
// ServiceLimiter 服务维度限流中间件(保护微服务)
func ServiceLimiter(r *ghttp.Request) {
// 从URL路径提取服务名: /customerService/xxx -> customerService
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) == 0 {
r.Middleware.Next()
return
}
serverName := pathParts[0]
// 从配置文件读取服务限流参数
serviceLimitKey := fmt.Sprintf("rate.services.%s.limit", serverName)
limit := g.Cfg().MustGet(r.GetCtx(), serviceLimitKey, 0).Int64()
// 如果配置为0说明该服务没有限流配置跳过限流
if limit == 0 {
r.Middleware.Next()
return
}
key := fmt.Sprintf(redis.RateLimitKeyService, serverName)
count, err := redis.IncrRateLimit(r.GetCtx(), key, 1)
if err != nil {
g.Log().Errorf(r.GetCtx(), "服务限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > limit {
g.Log().Warningf(r.GetCtx(), "服务限流触发: %s, count: %d, limit: %d", serverName, count, limit)
r.Response.WriteStatusExit(429, fmt.Sprintf("服务 '%s' 当前繁忙,请稍后再试", serverName))
return
}
r.Middleware.Next()
}
// OrderCreateLimiter 订单创建限流中间件
// 限制: 每个用户每分钟最多创建10个订单
func OrderCreateLimiter(r *ghttp.Request) {
userId := getUserIdFromContext(r) // 从context获取用户ID
if userId == "" {
// 如果无法获取用户信息,跳过限流检查
r.Middleware.Next()
return
}
key := fmt.Sprintf(redis.RateLimitKeyOrder, userId)
// 限制: 每个用户每分钟最多创建10个订单
count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "订单创建限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > 10 {
g.Log().Warningf(r.GetCtx(), "订单创建限流触发: %s, count: %d", userId, count)
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{
Code: 429,
Message: "下单过于频繁,请稍后再试",
})
return
}
r.Middleware.Next()
}
// WalletTransferLimiter 钱包转账限流中间件
// 限制: 每个用户每分钟最多转账5次
func WalletTransferLimiter(r *ghttp.Request) {
userId := getUserIdFromContext(r) // 从context获取用户ID
if userId == "" {
r.Middleware.Next()
return
}
key := fmt.Sprintf(redis.RateLimitKeyTransfer, userId)
// 限制: 每个用户每分钟最多转账5次
count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "钱包转账限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > 5 {
g.Log().Warningf(r.GetCtx(), "钱包转账限流触发: %s, count: %d", userId, count)
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{
Code: 429,
Message: "转账操作过于频繁,请稍后再试",
})
return
}
r.Middleware.Next()
}
// CSMessageLimiter 客服消息限流中间件
// 限制: 每个用户每分钟最多发送30条消息
func CSMessageLimiter(r *ghttp.Request) {
userId := getUserIdFromContext(r) // 从context获取用户ID
if userId == "" {
r.Middleware.Next()
return
}
key := fmt.Sprintf(redis.RateLimitKeyMessage, userId)
// 限制: 每个用户每分钟最多发送30条消息
count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "客服消息限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > 30 {
g.Log().Warningf(r.GetCtx(), "客服消息限流触发: %s, count: %d", userId, count)
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{
Code: 429,
Message: "消息发送过于频繁,请稍后再试",
})
return
}
r.Middleware.Next()
}
// OSSUploadLimiter 文件上传限流中间件
// 限制: 每个用户每分钟最多上传10个文件
func OSSUploadLimiter(r *ghttp.Request) {
userId := getUserIdFromContext(r) // 从context获取用户ID
if userId == "" {
r.Middleware.Next()
return
}
key := fmt.Sprintf(redis.RateLimitKeyUpload, userId)
// 限制: 每个用户每分钟最多上传10个文件
count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口
if err != nil {
g.Log().Errorf(r.GetCtx(), "文件上传限流Redis错误: %v", err)
r.Middleware.Next()
return
}
if count > 10 {
g.Log().Warningf(r.GetCtx(), "文件上传限流触发: %s, count: %d", userId, count)
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{
Code: 429,
Message: "文件上传过于频繁,请稍后再试",
})
return
}
r.Middleware.Next()
}
// getUserIdFromContext 从请求上下文中获取用户ID
// 使用项目中已有的utils.GetUserInfo方法
func getUserIdFromContext(r *ghttp.Request) string {
// 使用项目中已有的utils.GetUserInfo方法获取用户信息
user, err := utils.GetUserInfo(r.GetCtx())
if err != nil {
// 如果获取用户信息失败,返回空字符串
return ""
}
// 在这个项目中UserName就是用来标识用户的ID
// 转换为字符串类型
if user.UserName != nil {
return gconv.String(user.UserName)
}
return ""
}
// validateToken 验证token有效性
func validateToken(token string) bool {
// 实现 token 验证逻辑
return token == "valid-token"
}