diff --git a/middleware/circuit_breaker.go b/middleware/circuit_breaker.go index 58a2ee9..703e236 100644 --- a/middleware/circuit_breaker.go +++ b/middleware/circuit_breaker.go @@ -37,6 +37,7 @@ type CircuitBreakerConfig struct { FallbackMessage string // 降级提示消息 RequestTimeout int // 请求超时时间(毫秒) DistributedTTL int // 分布式熔断状态TTL(秒) + AdminIPs []string // 允许重置熔断器的管理员IP列表 } // CircuitBreakerMetrics 熔断器指标 @@ -46,6 +47,7 @@ type CircuitBreakerMetrics struct { BlockRequests atomic.Int64 // 阻塞请求数 FailureRequests atomic.Int64 // 失败请求数 OpenCount atomic.Int64 // 熔断开启次数 + LastResetTime atomic.Int64 // 上次重置时间(Unix时间戳) } // CircuitBreakerInfo 熔断器信息 @@ -158,6 +160,7 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { fallbackMessage := g.Cfg().MustGet(ctx, key+".fallbackMessage", "").String() requestTimeout := g.Cfg().MustGet(ctx, key+".requestTimeout", 30000).Int() distributedTTL := g.Cfg().MustGet(ctx, key+".distributedTTL", 300).Int() + adminIPs := g.Cfg().MustGet(ctx, key+".adminIPs", "").String() // 解析成功状态码 successCodes := g.Cfg().MustGet(ctx, key+".successStatusCodes", "200,201,204").String() @@ -176,6 +179,9 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { g.Log().Warningf(ctx, "服务 %s 的 slowRequestThreshold 解析失败,使用默认值: %v", serviceName, err) } + // 解析管理员IP列表 + adminIPList := parseAdminIPs(adminIPs) + return &CircuitBreakerConfig{ Enabled: enabled, MaxFailures: maxFailures, @@ -190,6 +196,7 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { FallbackMessage: fallbackMessage, RequestTimeout: requestTimeout, DistributedTTL: distributedTTL, + AdminIPs: adminIPList, } } @@ -206,6 +213,22 @@ func parseStatusCodes(str string) []int { return codes } +// parseAdminIPs 解析管理员IP列表 +func parseAdminIPs(str string) []string { + if str == "" { + return nil + } + parts := strings.Split(str, ",") + ips := make([]string, 0, len(parts)) + for _, part := range parts { + ip := strings.TrimSpace(part) + if ip != "" { + ips = append(ips, ip) + } + } + return ips +} + // filterServiceNames 过滤服务名(排除非服务配置的key) func filterServiceNames(services map[string]interface{}) []string { excludeKeys := map[string]bool{ @@ -282,6 +305,7 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) Config: config, Metrics: &CircuitBreakerMetrics{}, } + cbInfo.Metrics.LastResetTime.Store(time.Now().Unix()) circuitBreakers.Store(serviceName, cbInfo) strategy := "error_count" @@ -300,15 +324,13 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { startTime := time.Now() ctx := r.GetCtx() - // 从URL路径提取服务名 - pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") - if len(pathParts) == 0 { + // 从URL路径提取服务名(改进提取逻辑) + serviceName := extractServiceName(r.URL.Path) + if serviceName == "" { r.Middleware.Next() return } - serviceName := pathParts[0] - // 获取熔断器配置 val, ok := circuitBreakerConfigs.Load(serviceName) if !ok { @@ -391,7 +413,8 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { statusCode := r.Response.Status duration := time.Since(startTime) - if !isSuccessStatusCode(resourceName, statusCode) { + // 使用提前获取的config判断状态码(性能优化) + if !isSuccessStatusCode(config, statusCode) { // 记录异常 cbInfo.Metrics.FailureRequests.Add(1) api.TraceError(entry, fmt.Errorf("request failed with status: %d", statusCode)) @@ -431,32 +454,39 @@ func sendFallbackResponse(r *ghttp.Request, serviceName string, config *CircuitB } // isSuccessStatusCode 判断HTTP状态码是否成功 -func isSuccessStatusCode(resourceName string, statusCode int) bool { - serviceName := strings.TrimPrefix(resourceName, "service:") - if serviceName == "" { - // 默认只认为2xx是成功 - return statusCode >= 200 && statusCode < 300 - } - - // 从配置中获取成功状态码列表 - var serviceConfig *CircuitBreakerConfig - if val, ok := circuitBreakerConfigs.Load(serviceName); ok { - serviceConfig = val.(*CircuitBreakerConfig) - } - - if serviceConfig != nil && len(serviceConfig.SuccessStatusCodes) > 0 { - for _, code := range serviceConfig.SuccessStatusCodes { +func isSuccessStatusCode(config *CircuitBreakerConfig, statusCode int) bool { + if len(config.SuccessStatusCodes) > 0 { + for _, code := range config.SuccessStatusCodes { if statusCode == code { return true } } return false } - // 默认:2xx状态码为成功 return statusCode >= 200 && statusCode < 300 } +// extractServiceName 从URL路径提取服务名(改进提取逻辑) +func extractServiceName(path string) string { + // 去除首尾斜杠并分割 + path = strings.Trim(path, "/") + if path == "" { + return "" + } + parts := strings.Split(path, "/") + if len(parts) == 0 { + return "" + } + serviceName := parts[0] + + // 验证服务名是否在已配置的熔断器中 + if _, ok := circuitBreakerConfigs.Load(serviceName); ok { + return serviceName + } + return "" +} + // isCircuitBreakerOpenInDistributed 检查分布式熔断状态 func isCircuitBreakerOpenInDistributed(ctx context.Context, resourceName string) bool { key := fmt.Sprintf("circuit_breaker:%s:state", resourceName) @@ -536,10 +566,16 @@ func registerStateChangeListeners() { return } - // 注册默认监听器 + // 注册默认监听器(区分日志级别) RegisterStateChangeListener("default", func(serviceName string, fromState, toState CircuitBreakerState) { - g.Log().Infof(context.Background(), "熔断器状态变化: service=%s, %s -> %s", - serviceName, fromState, toState) + // Open状态使用Warning级别,Closed状态使用Info级别 + if toState == StateOpen { + g.Log().Warningf(context.Background(), "熔断器状态变化: service=%s, %s -> %s", + serviceName, fromState, toState) + } else { + g.Log().Infof(context.Background(), "熔断器状态变化: service=%s, %s -> %s", + serviceName, fromState, toState) + } }) } @@ -585,6 +621,12 @@ func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { } // 从Metrics中读取数据(修复数据准确性问题) + lastResetTime := cbInfo.Metrics.LastResetTime.Load() + var lastResetTimeStr string + if lastResetTime > 0 { + lastResetTimeStr = time.Unix(lastResetTime, 0).Format("2006-01-02 15:04:05") + } + status[serviceName] = map[string]interface{}{ "resource": cbInfo.ResourceName, "state": string(cbInfo.State), @@ -595,6 +637,7 @@ func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { "blockRequests": cbInfo.Metrics.BlockRequests.Load(), "failureRequests": cbInfo.Metrics.FailureRequests.Load(), "openCount": cbInfo.Metrics.OpenCount.Load(), + "lastResetTime": lastResetTimeStr, } cbInfo.mu.RUnlock() @@ -617,6 +660,39 @@ func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { }) } +// isAdminIP 检查请求IP是否在管理员白名单中 +func isAdminIP(r *ghttp.Request) bool { + clientIP := r.GetClientIp() + if clientIP == "" { + return false + } + + // 检查所有服务的adminIPs配置 + var allowedIPs []string + circuitBreakerConfigs.Range(func(key, value interface{}) bool { + config := value.(*CircuitBreakerConfig) + if len(config.AdminIPs) > 0 { + allowedIPs = append(allowedIPs, config.AdminIPs...) + } + return true + }) + + // 如果没有配置白名单,允许所有IP(向后兼容) + if len(allowedIPs) == 0 { + return true + } + + // 检查IP是否在白名单中 + for _, allowedIP := range allowedIPs { + if clientIP == allowedIP { + return true + } + } + + g.Log().Warningf(r.GetCtx(), "熔断器重置请求被拒绝,IP不在白名单中: %s", clientIP) + return false +} + // CircuitBreakerResetHandler 熔断器手动重置接口(仅限管理后台调用) func CircuitBreakerResetHandler(r *ghttp.Request) { serviceName := r.Get("service").String() @@ -628,6 +704,15 @@ func CircuitBreakerResetHandler(r *ghttp.Request) { return } + // 权限验证:检查IP是否在白名单中 + if !isAdminIP(r) { + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 403, + Message: "权限不足,禁止访问", + }) + return + } + resourceName := fmt.Sprintf("service:%s", serviceName) // 获取当前服务的所有规则 @@ -658,13 +743,20 @@ func CircuitBreakerResetHandler(r *ghttp.Request) { } } - // 更新内存状态 + // 更新内存状态并重置指标 if val, ok := circuitBreakers.Load(serviceName); ok { cbInfo := val.(*CircuitBreakerInfo) cbInfo.mu.Lock() cbInfo.State = StateClosed cbInfo.LastOpenTime = time.Time{} cbInfo.NextRetryTime = time.Time{} + // 重置指标 + cbInfo.Metrics.TotalRequests.Store(0) + cbInfo.Metrics.PassRequests.Store(0) + cbInfo.Metrics.BlockRequests.Store(0) + cbInfo.Metrics.FailureRequests.Store(0) + cbInfo.Metrics.OpenCount.Store(0) + cbInfo.Metrics.LastResetTime.Store(time.Now().Unix()) cbInfo.mu.Unlock() }