From 456bee2cae955930dc6739fc129eae6d20e70e63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=96=8C?= <259278618@qq.com> Date: Thu, 1 Jan 2026 10:48:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=86=94=E6=96=AD=E7=AD=96?= =?UTF-8?q?=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/circuit_breaker.go | 168 ++++++++++++++++++++++------------ 1 file changed, 107 insertions(+), 61 deletions(-) diff --git a/middleware/circuit_breaker.go b/middleware/circuit_breaker.go index e025b56..58a2ee9 100644 --- a/middleware/circuit_breaker.go +++ b/middleware/circuit_breaker.go @@ -18,24 +18,25 @@ import ( type CircuitBreakerState string const ( - StateClosed CircuitBreakerState = "closed" // 关闭:正常状态 - StateOpen CircuitBreakerState = "open" // 开启:熔断状态 - StateHalfOpen CircuitBreakerState = "half-open" // 半开:尝试恢复状态 + StateClosed CircuitBreakerState = "closed" // 关闭:正常状态 + StateOpen CircuitBreakerState = "open" // 开启:熔断状态 ) // CircuitBreakerConfig 熔断器配置 type CircuitBreakerConfig struct { - Enabled bool // 是否启用熔断器 - MaxFailures int // 连续失败次数 - Timeout string // 熔断超时时间 - SuccessStatusCodes []int // 视为成功的HTTP状态码 - SlowRequestThreshold string // 慢请求阈值 - EnableSlidingWindow bool // 是否启用滑动窗口 - FailureRateThreshold float64 // 失败率阈值 - EnableFallback bool // 是否启用降级 - FallbackMessage string // 降级提示消息 - RequestTimeout int // 请求超时时间(毫秒) - DistributedTTL int // 分布式熔断状态TTL(秒) + Enabled bool // 是否启用熔断器 + MaxFailures int // 连续失败次数 + Timeout string // 熔断超时时间 + TimeoutParsed time.Duration // 缓存的超时时间(性能优化) + SuccessStatusCodes []int // 视为成功的HTTP状态码 + SlowRequestThreshold string // 慢请求阈值 + SlowRequestThresholdParsed time.Duration // 缓存的慢请求阈值(性能优化) + EnableSlidingWindow bool // 是否启用滑动窗口 + FailureRateThreshold float64 // 失败率阈值 + EnableFallback bool // 是否启用降级 + FallbackMessage string // 降级提示消息 + RequestTimeout int // 请求超时时间(毫秒) + DistributedTTL int // 分布式熔断状态TTL(秒) } // CircuitBreakerMetrics 熔断器指标 @@ -52,8 +53,6 @@ type CircuitBreakerInfo struct { ResourceName string `json:"resourceName"` // 资源名称 State CircuitBreakerState `json:"state"` // 当前状态 Config *CircuitBreakerConfig `json:"config"` // 配置信息 - FailCount int64 `json:"failCount"` // 失败次数 - TotalCount int64 `json:"totalCount"` // 总请求数 LastOpenTime time.Time `json:"lastOpenTime"` // 上次熔断时间 NextRetryTime time.Time `json:"nextRetryTime"` // 下次重试时间 Metrics *CircuitBreakerMetrics `json:"metrics"` // 指标统计 @@ -65,10 +64,12 @@ var ( circuitBreakers sync.Map // circuitBreakerConfigs 熔断器配置缓存 circuitBreakerConfigs sync.Map - // distributedSyncLock 分布式同步锁 - distributedSyncLock sync.Mutex + // distributedSyncLocks 分布式同步锁(按服务名分片) + distributedSyncLocks sync.Map // stateChangeListeners 状态变化监听器 stateChangeListeners sync.Map + // stateChangeListenersRegistered 默认监听器是否已注册 + stateChangeListenersRegistered sync.Map ) // InitCircuitBreaker 初始化Sentinel熔断器 @@ -78,7 +79,7 @@ func InitCircuitBreaker() error { // 初始化Sentinel err := api.InitDefault() if err != nil { - return fmt.Errorf("Sentinel初始化失败: %v", err) + return fmt.Errorf("sentinel初始化失败: %v", err) } // 注册熔断器状态变化监听器 @@ -90,12 +91,7 @@ func InitCircuitBreaker() error { services := g.Cfg().MustGet(ctx, "circuitBreaker").Map() // 过滤掉非服务配置的key - serviceNames := make([]string, 0) - for key := range services { - if key != "services" && key != "enableDistributed" && key != "requestTimeout" && key != "distributedTTL" { - serviceNames = append(serviceNames, key) - } - } + serviceNames := filterServiceNames(services) if len(serviceNames) == 0 { g.Log().Infof(ctx, "未配置任何服务熔断器") @@ -167,18 +163,33 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { successCodes := g.Cfg().MustGet(ctx, key+".successStatusCodes", "200,201,204").String() statusCodes := parseStatusCodes(successCodes) + // 解析时间(缓存结果,性能优化) + timeoutParsed, err := time.ParseDuration(timeout) + if err != nil { + timeoutParsed = 60 * time.Second + g.Log().Warningf(ctx, "服务 %s 的 timeout 解析失败,使用默认值: %v", serviceName, err) + } + + slowRequestThresholdParsed, err := time.ParseDuration(slowRequestThreshold) + if err != nil { + slowRequestThresholdParsed = 3 * time.Second + g.Log().Warningf(ctx, "服务 %s 的 slowRequestThreshold 解析失败,使用默认值: %v", serviceName, err) + } + return &CircuitBreakerConfig{ - Enabled: enabled, - MaxFailures: maxFailures, - Timeout: timeout, - SuccessStatusCodes: statusCodes, - SlowRequestThreshold: slowRequestThreshold, - EnableSlidingWindow: enableSlidingWindow, - FailureRateThreshold: failureRateThreshold, - EnableFallback: enableFallback, - FallbackMessage: fallbackMessage, - RequestTimeout: requestTimeout, - DistributedTTL: distributedTTL, + Enabled: enabled, + MaxFailures: maxFailures, + Timeout: timeout, + TimeoutParsed: timeoutParsed, + SuccessStatusCodes: statusCodes, + SlowRequestThreshold: slowRequestThreshold, + SlowRequestThresholdParsed: slowRequestThresholdParsed, + EnableSlidingWindow: enableSlidingWindow, + FailureRateThreshold: failureRateThreshold, + EnableFallback: enableFallback, + FallbackMessage: fallbackMessage, + RequestTimeout: requestTimeout, + DistributedTTL: distributedTTL, } } @@ -195,6 +206,24 @@ func parseStatusCodes(str string) []int { return codes } +// filterServiceNames 过滤服务名(排除非服务配置的key) +func filterServiceNames(services map[string]interface{}) []string { + excludeKeys := map[string]bool{ + "services": true, + "enableDistributed": true, + "requestTimeout": true, + "distributedTTL": true, + } + + serviceNames := make([]string, 0, len(services)) + for key := range services { + if !excludeKeys[key] { + serviceNames = append(serviceNames, key) + } + } + return serviceNames +} + // initServiceCircuitBreaker 初始化服务熔断器 func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) error { // 验证配置参数 @@ -202,15 +231,9 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) return fmt.Errorf("配置验证失败: %v", err) } - timeout, err := time.ParseDuration(config.Timeout) - if err != nil { - return fmt.Errorf("解析超时时间失败: %v", err) - } - - slowRequestThreshold, err := time.ParseDuration(config.SlowRequestThreshold) - if err != nil { - return fmt.Errorf("解析慢请求阈值失败: %v", err) - } + // 使用缓存的时间值(性能优化) + timeout := config.TimeoutParsed + slowRequestThreshold := config.SlowRequestThresholdParsed resourceName := fmt.Sprintf("service:%s", serviceName) @@ -243,8 +266,11 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) } } - // 加载规则到Sentinel - _, err = circuitbreaker.LoadRules(rule) + // 先清理旧规则(健壮性改进) + _, _ = circuitbreaker.LoadRulesOfResource(resourceName, []*circuitbreaker.Rule{}) + + // 加载新规则到Sentinel + _, err := circuitbreaker.LoadRules(rule) if err != nil { return fmt.Errorf("加载熔断规则失败: %v", err) } @@ -340,9 +366,8 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { oldState := cbInfo.State cbInfo.State = StateOpen cbInfo.LastOpenTime = time.Now() - if timeout, err := time.ParseDuration(config.Timeout); err == nil { - cbInfo.NextRetryTime = time.Now().Add(timeout) - } + // 使用缓存的时间值(性能优化) + cbInfo.NextRetryTime = time.Now().Add(config.TimeoutParsed) cbInfo.mu.Unlock() // 通知状态变化(如果状态改变) @@ -395,8 +420,6 @@ func sendFallbackResponse(r *ghttp.Request, serviceName string, config *CircuitB } else { // 根据原因返回不同的状态码和消息 switch reason { - case "timeout": - r.Response.WriteStatusExit(504, fmt.Sprintf("服务 '%s' 响应超时", serviceName)) case "blocked": r.Response.WriteStatusExit(503, fmt.Sprintf("服务 '%s' 熔断保护中,请稍后再试", serviceName)) case "distributed": @@ -451,10 +474,19 @@ func isCircuitBreakerOpenInDistributed(ctx context.Context, resourceName string) return state == "open" } +// getDistributedLock 获取分布式锁(按服务名分片) +func getDistributedLock(serviceName string) *sync.Mutex { + lock, _ := distributedSyncLocks.LoadOrStore(serviceName, &sync.Mutex{}) + return lock.(*sync.Mutex) +} + // syncCircuitBreakerStateToDistributed 同步熔断器状态到分布式存储 func syncCircuitBreakerStateToDistributed(ctx context.Context, resourceName, state string, ttl int) { - distributedSyncLock.Lock() - defer distributedSyncLock.Unlock() + // 提取服务名用于锁分片 + serviceName := strings.TrimPrefix(resourceName, "service:") + lock := getDistributedLock(serviceName) + lock.Lock() + defer lock.Unlock() key := fmt.Sprintf("circuit_breaker:%s:state", resourceName) @@ -481,12 +513,30 @@ func validateCircuitBreakerConfig(config *CircuitBreakerConfig) error { if len(config.SuccessStatusCodes) == 0 { return fmt.Errorf("successStatusCodes不能为空") } + if config.RequestTimeout < 0 || config.RequestTimeout > 300000 { + return fmt.Errorf("requestTimeout必须在0-300000毫秒之间") + } + if config.DistributedTTL < 0 || config.DistributedTTL > 3600 { + return fmt.Errorf("distributedTTL必须在0-3600秒之间") + } + // 验证时间字符串格式(如果缓存为空,说明解析失败) + if config.TimeoutParsed == 0 { + return fmt.Errorf("timeout格式错误,应为有效的时间字符串(如30s, 1m)") + } + if config.SlowRequestThresholdParsed == 0 { + return fmt.Errorf("slowRequestThreshold格式错误,应为有效的时间字符串(如3s, 1m)") + } return nil } // registerStateChangeListeners 注册状态变化监听器 func registerStateChangeListeners() { - // 示例:注册默认监听器 + // 检查是否已注册,防止重复注册(健壮性改进) + if _, exists := stateChangeListenersRegistered.LoadOrStore("default", true); exists { + return + } + + // 注册默认监听器 RegisterStateChangeListener("default", func(serviceName string, fromState, toState CircuitBreakerState) { g.Log().Infof(context.Background(), "熔断器状态变化: service=%s, %s -> %s", serviceName, fromState, toState) @@ -534,6 +584,7 @@ func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { openServices++ } + // 从Metrics中读取数据(修复数据准确性问题) status[serviceName] = map[string]interface{}{ "resource": cbInfo.ResourceName, "state": string(cbInfo.State), @@ -646,12 +697,7 @@ func CircuitBreakerReloadHandler(r *ghttp.Request) { services := g.Cfg().MustGet(r.GetCtx(), "circuitBreaker").Map() // 过滤出服务名 - serviceNames := make([]string, 0) - for key := range services { - if key != "services" && key != "enableDistributed" && key != "requestTimeout" && key != "distributedTTL" { - serviceNames = append(serviceNames, key) - } - } + serviceNames := filterServiceNames(services) successCount := 0 failCount := 0