package middleware import ( "context" "fmt" "strings" "gitea.redpowerfuture.com/red-future/common/utils" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/ghttp" "github.com/gogf/gf/v2/util/gconv" ) // 限流 Redis Key 常量 const ( RateLimitKeyPrefix = "ragflow:ratelimit:" // 限流Key前缀 RateLimitKeyIP = "ip:%s" // IP限流: ip:192.168.1.1 RateLimitKeyUser = "user:%s" // 用户限流: user:123 或 user:anon:192.168.1.1 RateLimitKeyService = "service:%s" // 服务限流: service:customerService RateLimitKeyGlobal = "global:requests" // 全局限流: global:requests ) func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) { fullKey := RateLimitKeyPrefix + key count, err = g.Redis().Incr(ctx, fullKey) if err != nil { return } // 首次设置过期时间 if count == 1 { g.Redis().Expire(ctx, fullKey, windowSeconds) } return } // GlobalLimiter 全局限流中间件(使用Redis分布式控制) func GlobalLimiter(r *ghttp.Request) { // 从配置文件读取全局限流参数 globalLimit := g.Cfg().MustGet(r.GetCtx(), "rate.limit", 800).Int64() key := RateLimitKeyGlobal // 使用Redis计数器进行全局限流 count, err := IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口 if err != nil { g.Log().Errorf(r.GetCtx(), "全局限流Redis错误: %v", err) r.Middleware.Next() return } if count > globalLimit { g.Log().Warningf(r.GetCtx(), "全局限流触发: count: %d, limit: %d", count, globalLimit) r.Response.WriteStatusExit(429, "系统当前繁忙,请稍后再试") return } r.Middleware.Next() } // IPLimiter IP限流中间件(防DDoS) func IPLimiter(r *ghttp.Request) { ip := r.GetClientIp() key := fmt.Sprintf(RateLimitKeyIP, ip) // 从配置文件读取IP限流参数 ipLimit := g.Cfg().MustGet(r.GetCtx(), "rate.ip.limit", 100).Int64() // 使用Redis计数器 count, err := IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口 if err != nil { g.Log().Errorf(r.GetCtx(), "IP限流Redis错误: %v", err) r.Middleware.Next() return } if count > ipLimit { g.Log().Warningf(r.GetCtx(), "IP限流触发: %s, count: %d, limit: %d", ip, count, ipLimit) r.Response.WriteStatusExit(429, "请求过于频繁,请稍后再试") return } r.Middleware.Next() } // UserLimiter 用户维度限流中间件(防止单用户滥用) func UserLimiter(r *ghttp.Request) { if strings.Contains(r.RequestURI, "/swagger") || strings.Contains(r.RequestURI, "/pub/captcha/get") || strings.Contains(r.RequestURI, "/login") || strings.Contains(r.RequestURI, "/web/socket/") { r.Middleware.Next() return } var userName string user, err := utils.GetUserInfo(r.GetCtx()) if err != nil { r.Response.WriteStatusExit(401, err.Error()) return } userName = gconv.String(user.UserName) // 从配置文件读取用户限流参数 userLimit := g.Cfg().MustGet(r.GetCtx(), "rate.user.limit", 50).Int64() key := fmt.Sprintf(RateLimitKeyUser, userName) count, err := IncrRateLimit(r.GetCtx(), key, 1) if err != nil { g.Log().Errorf(r.GetCtx(), "用户限流Redis错误: %v", err) return } if count > userLimit { r.Response.WriteStatusExit(429, "您的请求过于频繁,请稍后再试") return } r.Middleware.Next() } // ServiceLimiter 服务维度限流中间件(保护微服务) func ServiceLimiter(r *ghttp.Request) { // 从URL路径提取服务名: /customerService/xxx -> customerService pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") if len(pathParts) == 0 { r.Middleware.Next() return } serverName := pathParts[0] // 从配置文件读取服务限流参数 serviceLimitKey := fmt.Sprintf("rate.services.%s.limit", serverName) limit := g.Cfg().MustGet(r.GetCtx(), serviceLimitKey, 0).Int64() // 如果配置为0,说明该服务没有限流配置,跳过限流 if limit == 0 { r.Middleware.Next() return } key := fmt.Sprintf(RateLimitKeyService, serverName) count, err := IncrRateLimit(r.GetCtx(), key, 1) if err != nil { g.Log().Errorf(r.GetCtx(), "服务限流Redis错误: %v", err) r.Middleware.Next() return } if count > limit { g.Log().Warningf(r.GetCtx(), "服务限流触发: %s, count: %d, limit: %d", serverName, count, limit) r.Response.WriteStatusExit(429, fmt.Sprintf("服务 '%s' 当前繁忙,请稍后再试", serverName)) return } r.Middleware.Next() }