diff --git a/middleware/circuit_breaker.go b/middleware/circuit_breaker.go index 703e236..36c1832 100644 --- a/middleware/circuit_breaker.go +++ b/middleware/circuit_breaker.go @@ -38,6 +38,8 @@ type CircuitBreakerConfig struct { RequestTimeout int // 请求超时时间(毫秒) DistributedTTL int // 分布式熔断状态TTL(秒) AdminIPs []string // 允许重置熔断器的管理员IP列表 + StatIntervalMs int // 统计窗口时长(毫秒),默认1000ms + MinRequestAmount int // 最小请求数量,默认与MaxFailures相同 } // CircuitBreakerMetrics 熔断器指标 @@ -72,6 +74,10 @@ var ( stateChangeListeners sync.Map // stateChangeListenersRegistered 默认监听器是否已注册 stateChangeListenersRegistered sync.Map + // allowedAdminIPsCache 缓存的所有管理员IP白名单(性能优化) + allowedAdminIPsCache []string + // allowedAdminIPsCacheMutex 保护白名单缓存的并发访问 + allowedAdminIPsCacheMutex sync.RWMutex ) // InitCircuitBreaker 初始化Sentinel熔断器 @@ -89,6 +95,9 @@ func InitCircuitBreaker() error { g.Log().Infof(ctx, "Sentinel熔断器初始化成功") + // 更新管理员IP白名单缓存 + updateAdminIPsCache() + // 扫描配置文件中所有配置了熔断器的服务 services := g.Cfg().MustGet(ctx, "circuitBreaker").Map() @@ -161,6 +170,8 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { requestTimeout := g.Cfg().MustGet(ctx, key+".requestTimeout", 30000).Int() distributedTTL := g.Cfg().MustGet(ctx, key+".distributedTTL", 300).Int() adminIPs := g.Cfg().MustGet(ctx, key+".adminIPs", "").String() + statIntervalMs := g.Cfg().MustGet(ctx, key+".statIntervalMs", 1000).Int() + minRequestAmount := g.Cfg().MustGet(ctx, key+".minRequestAmount", 0).Int() // 解析成功状态码 successCodes := g.Cfg().MustGet(ctx, key+".successStatusCodes", "200,201,204").String() @@ -182,6 +193,11 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { // 解析管理员IP列表 adminIPList := parseAdminIPs(adminIPs) + // 如果minRequestAmount未配置,则使用maxFailures作为默认值 + if minRequestAmount == 0 { + minRequestAmount = maxFailures + } + return &CircuitBreakerConfig{ Enabled: enabled, MaxFailures: maxFailures, @@ -197,6 +213,8 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { RequestTimeout: requestTimeout, DistributedTTL: distributedTTL, AdminIPs: adminIPList, + StatIntervalMs: statIntervalMs, + MinRequestAmount: minRequestAmount, } } @@ -213,6 +231,31 @@ func parseStatusCodes(str string) []int { return codes } +// updateAdminIPsCache 更新管理员IP白名单缓存(性能优化) +func updateAdminIPsCache() { + var ipList []string + ipSet := make(map[string]bool) + + // 收集所有服务的adminIPs配置 + circuitBreakerConfigs.Range(func(key, value interface{}) bool { + config := value.(*CircuitBreakerConfig) + if len(config.AdminIPs) > 0 { + for _, ip := range config.AdminIPs { + if !ipSet[ip] { + ipSet[ip] = true + ipList = append(ipList, ip) + } + } + } + return true + }) + + // 更新缓存 + allowedAdminIPsCacheMutex.Lock() + allowedAdminIPsCache = ipList + allowedAdminIPsCacheMutex.Unlock() +} + // parseAdminIPs 解析管理员IP列表 func parseAdminIPs(str string) []string { if str == "" { @@ -268,8 +311,8 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) Resource: resourceName, Strategy: circuitbreaker.SlowRequestRatio, RetryTimeoutMs: uint32(timeout.Milliseconds()), - MinRequestAmount: uint64(config.MaxFailures), - StatIntervalMs: 1000, + MinRequestAmount: uint64(config.MinRequestAmount), + StatIntervalMs: uint32(config.StatIntervalMs), StatSlidingWindowBucketCount: 10, MaxAllowedRtMs: uint64(slowRequestThreshold.Milliseconds()), Threshold: config.FailureRateThreshold, @@ -282,8 +325,8 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) Resource: resourceName, Strategy: circuitbreaker.ErrorCount, RetryTimeoutMs: uint32(timeout.Milliseconds()), - MinRequestAmount: uint64(config.MaxFailures), - StatIntervalMs: 1000, // 1秒统计窗口 + MinRequestAmount: uint64(config.MinRequestAmount), + StatIntervalMs: uint32(config.StatIntervalMs), Threshold: float64(config.MaxFailures), }, } @@ -355,6 +398,9 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { cbInfo := cbInfoVal.(*CircuitBreakerInfo) cbInfo.Metrics.TotalRequests.Add(1) + // 提前构造resourceName(性能优化) + resourceName := fmt.Sprintf("service:%s", serviceName) + // 设置请求超时(使用服务独立配置) if config.RequestTimeout > 0 { ctx, cancel := context.WithTimeout(ctx, time.Duration(config.RequestTimeout)*time.Millisecond) @@ -362,14 +408,11 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { defer cancel() } - resourceName := fmt.Sprintf("service:%s", serviceName) - // 检查是否启用分布式熔断 if config.DistributedTTL > 0 { if isCircuitBreakerOpenInDistributed(ctx, resourceName) { cbInfo.Metrics.BlockRequests.Add(1) g.Log().Warningf(ctx, "分布式熔断触发: %s", resourceName) - notifyStateChange(serviceName, StateOpen, StateOpen) sendFallbackResponse(r, serviceName, config, "distributed") return } @@ -535,26 +578,32 @@ func syncCircuitBreakerStateToDistributed(ctx context.Context, resourceName, sta // validateCircuitBreakerConfig 验证熔断器配置 func validateCircuitBreakerConfig(config *CircuitBreakerConfig) error { if config.MaxFailures <= 0 { - return fmt.Errorf("maxFailures必须大于0") + return fmt.Errorf("maxFailures必须大于0,当前值: %d", config.MaxFailures) } if config.FailureRateThreshold < 0 || config.FailureRateThreshold > 1 { - return fmt.Errorf("failureRateThreshold必须在0.0-1.0之间") + return fmt.Errorf("failureRateThreshold必须在0.0-1.0之间,当前值: %.2f", config.FailureRateThreshold) } if len(config.SuccessStatusCodes) == 0 { return fmt.Errorf("successStatusCodes不能为空") } if config.RequestTimeout < 0 || config.RequestTimeout > 300000 { - return fmt.Errorf("requestTimeout必须在0-300000毫秒之间") + return fmt.Errorf("requestTimeout必须在0-300000毫秒之间,当前值: %d", config.RequestTimeout) } if config.DistributedTTL < 0 || config.DistributedTTL > 3600 { - return fmt.Errorf("distributedTTL必须在0-3600秒之间") + return fmt.Errorf("distributedTTL必须在0-3600秒之间,当前值: %d", config.DistributedTTL) + } + if config.StatIntervalMs < 100 || config.StatIntervalMs > 60000 { + return fmt.Errorf("statIntervalMs必须在100-60000毫秒之间,当前值: %d", config.StatIntervalMs) + } + if config.MinRequestAmount < 1 || config.MinRequestAmount > 10000 { + return fmt.Errorf("minRequestAmount必须在1-10000之间,当前值: %d", config.MinRequestAmount) } // 验证时间字符串格式(如果缓存为空,说明解析失败) if config.TimeoutParsed == 0 { - return fmt.Errorf("timeout格式错误,应为有效的时间字符串(如30s, 1m)") + return fmt.Errorf("timeout格式错误,应为有效的时间字符串(如30s, 1m),当前值: %s", config.Timeout) } if config.SlowRequestThresholdParsed == 0 { - return fmt.Errorf("slowRequestThreshold格式错误,应为有效的时间字符串(如3s, 1m)") + return fmt.Errorf("slowRequestThreshold格式错误,应为有效的时间字符串(如3s, 1m),当前值: %s", config.SlowRequestThreshold) } return nil } @@ -660,22 +709,17 @@ func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { }) } -// isAdminIP 检查请求IP是否在管理员白名单中 +// isAdminIP 检查请求IP是否在管理员白名单中(使用缓存优化性能) func isAdminIP(r *ghttp.Request) bool { clientIP := r.GetClientIp() if clientIP == "" { return false } - // 检查所有服务的adminIPs配置 - var allowedIPs []string - circuitBreakerConfigs.Range(func(key, value interface{}) bool { - config := value.(*CircuitBreakerConfig) - if len(config.AdminIPs) > 0 { - allowedIPs = append(allowedIPs, config.AdminIPs...) - } - return true - }) + // 读取缓存的白名单(性能优化) + allowedAdminIPsCacheMutex.RLock() + allowedIPs := allowedAdminIPsCache + allowedAdminIPsCacheMutex.RUnlock() // 如果没有配置白名单,允许所有IP(向后兼容) if len(allowedIPs) == 0 { @@ -689,7 +733,7 @@ func isAdminIP(r *ghttp.Request) bool { } } - g.Log().Warningf(r.GetCtx(), "熔断器重置请求被拒绝,IP不在白名单中: %s", clientIP) + g.Log().Warningf(r.GetCtx(), "熔断器操作请求被拒绝,IP不在白名单中: %s", clientIP) return false } @@ -784,6 +828,15 @@ func CircuitBreakerResetHandler(r *ghttp.Request) { func CircuitBreakerReloadHandler(r *ghttp.Request) { serviceName := r.Get("service").String() + // 权限验证:检查IP是否在白名单中(P0级别安全问题) + if !isAdminIP(r) { + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 403, + Message: "权限不足,禁止访问", + }) + return + } + if serviceName == "" { // 重载所有服务 - 扫描配置文件中所有服务 services := g.Cfg().MustGet(r.GetCtx(), "circuitBreaker").Map() @@ -804,6 +857,9 @@ func CircuitBreakerReloadHandler(r *ghttp.Request) { } } + // 更新管理员IP白名单缓存 + updateAdminIPsCache() + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ Code: 200, Message: fmt.Sprintf("配置重载完成: 成功 %d, 失败 %d", successCount, failCount),