diff --git a/middleware/circuit_breaker.go b/middleware/circuit_breaker.go index 913c581..8c615a0 100644 --- a/middleware/circuit_breaker.go +++ b/middleware/circuit_breaker.go @@ -10,6 +10,7 @@ import ( "sync/atomic" "time" + "gitee.com/red-future---jilin-g/common/redis" "github.com/alibaba/sentinel-golang/api" "github.com/alibaba/sentinel-golang/core/circuitbreaker" "github.com/gogf/gf/v2/frame/g" @@ -156,54 +157,70 @@ func (cb *CircuitBreakerInfo) setState(state CircuitBreakerState) CircuitBreaker // setStateWithMetrics 设置熔断器状态并更新指标 func (cb *CircuitBreakerInfo) setStateWithMetrics(state CircuitBreakerState, updateMetrics bool) CircuitBreakerState { - var newState int64 - switch state { - case StateOpen: - newState = stateOpen - case StateHalfOpen: - newState = stateHalfOpen - default: - newState = stateClosed - } - + newState := cb.stateToInt64(state) oldState := cb.State.Swap(newState) - var oldStateEnum CircuitBreakerState - - switch oldState { - case stateOpen: - oldStateEnum = StateOpen - case stateHalfOpen: - oldStateEnum = StateHalfOpen - default: - oldStateEnum = StateClosed - } + oldStateEnum := cb.int64ToState(oldState) // 如果状态发生了变化且需要更新指标 if oldStateEnum != state && updateMetrics { - now := time.Now().Unix() - - // 根据新状态更新计数器 - switch state { - case StateOpen: - cb.Metrics.OpenCount.Add(1) - cb.Metrics.LastOpenTime.Store(now) - case StateClosed: - cb.Metrics.ClosedCount.Add(1) - cb.Metrics.LastCloseTime.Store(now) - case StateHalfOpen: - cb.Metrics.HalfOpenCount.Add(1) - cb.Metrics.LastHalfOpenTime.Store(now) - } - - // 设置下一次重试时间(如果是打开状态) - if state == StateOpen { - cb.Metrics.NextRetryTime.Store(time.Now().Add(cb.Config.TimeoutParsed).Unix()) - } + cb.updateStateMetrics(state) } return oldStateEnum } +// init 初始化熔断器信息 +func (cb *CircuitBreakerInfo) init() { + cb.State.Store(stateClosed) + cb.Metrics.LastResetTime.Store(time.Now().Unix()) + cb.Metrics.LastCloseTime.Store(time.Now().Unix()) + cb.Metrics.WindowStartTime.Store(time.Now().Unix()) +} + +// stateToInt64 将CircuitBreakerState转换为int64状态 +func (cb *CircuitBreakerInfo) stateToInt64(state CircuitBreakerState) int64 { + switch state { + case StateOpen: + return stateOpen + case StateHalfOpen: + return stateHalfOpen + default: + return stateClosed + } +} + +// int64ToState 将int64状态转换为CircuitBreakerState +func (cb *CircuitBreakerInfo) int64ToState(state int64) CircuitBreakerState { + switch state { + case stateOpen: + return StateOpen + case stateHalfOpen: + return StateHalfOpen + default: + return StateClosed + } +} + +// updateStateMetrics 更新状态相关的指标 +func (cb *CircuitBreakerInfo) updateStateMetrics(state CircuitBreakerState) { + now := time.Now().Unix() + + // 根据新状态更新计数器 + switch state { + case StateOpen: + cb.Metrics.OpenCount.Add(1) + cb.Metrics.LastOpenTime.Store(now) + // 设置下一次重试时间 + cb.Metrics.NextRetryTime.Store(time.Now().Add(cb.Config.TimeoutParsed).Unix()) + case StateClosed: + cb.Metrics.ClosedCount.Add(1) + cb.Metrics.LastCloseTime.Store(now) + case StateHalfOpen: + cb.Metrics.HalfOpenCount.Add(1) + cb.Metrics.LastHalfOpenTime.Store(now) + } +} + // InitCircuitBreaker 初始化Sentinel熔断器 func InitCircuitBreaker() error { ctx := context.Background() @@ -301,29 +318,9 @@ func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig { } // 解析时间 - 使用默认值处理解析错误 - timeoutParsed, err := time.ParseDuration(config.Timeout) - if err != nil { - g.Log().Warningf(ctx, "解析timeout失败: %s, 使用默认值 %s, error: %v", config.Timeout, defaultTimeout, err) - timeoutParsed, _ = time.ParseDuration(defaultTimeout) - config.Timeout = defaultTimeout - } - config.TimeoutParsed = timeoutParsed - - slowThresholdParsed, err := time.ParseDuration(config.SlowRequestThreshold) - if err != nil { - g.Log().Warningf(ctx, "解析slowRequestThreshold失败: %s, 使用默认值 %s, error: %v", config.SlowRequestThreshold, defaultSlowRequestThreshold, err) - slowThresholdParsed, _ = time.ParseDuration(defaultSlowRequestThreshold) - config.SlowRequestThreshold = defaultSlowRequestThreshold - } - config.SlowRequestThresholdParsed = slowThresholdParsed - - warmupParsed, err := time.ParseDuration(config.WarmupDuration) - if err != nil { - g.Log().Warningf(ctx, "解析warmupDuration失败: %s, 使用默认值 %s, error: %v", config.WarmupDuration, defaultWarmupDuration, err) - warmupParsed, _ = time.ParseDuration(defaultWarmupDuration) - config.WarmupDuration = defaultWarmupDuration - } - config.WarmupDurationParsed = warmupParsed + config.TimeoutParsed, config.Timeout = parseDurationWithDefault(ctx, config.Timeout, defaultTimeout, "timeout") + config.SlowRequestThresholdParsed, config.SlowRequestThreshold = parseDurationWithDefault(ctx, config.SlowRequestThreshold, defaultSlowRequestThreshold, "slowRequestThreshold") + config.WarmupDurationParsed, config.WarmupDuration = parseDurationWithDefault(ctx, config.WarmupDuration, defaultWarmupDuration, "warmupDuration") // 解析状态码 successCodes := g.Cfg().MustGet(ctx, key+".successStatusCodes", "200,201,204").String() @@ -364,6 +361,76 @@ func parseStrings(str string) []string { return result } +// parseDurationWithDefault 解析持续时间,失败时使用默认值 +func parseDurationWithDefault(ctx context.Context, durationStr, defaultStr, fieldName string) (time.Duration, string) { + durationParsed, err := time.ParseDuration(durationStr) + if err != nil { + g.Log().Warningf(ctx, "解析%s失败: %s, 使用默认值 %s, error: %v", fieldName, durationStr, defaultStr, err) + durationParsed, _ = time.ParseDuration(defaultStr) + return durationParsed, defaultStr + } + return durationParsed, durationStr +} + +// atomicUpdateMin 原子更新最小值 +func atomicUpdateMin(minValue *atomic.Int64, newValue int64) { + for { + currentMin := minValue.Load() + if newValue >= currentMin { + break + } + if minValue.CompareAndSwap(currentMin, newValue) { + break + } + } +} + +// atomicUpdateMax 原子更新最大值 +func atomicUpdateMax(maxValue *atomic.Int64, newValue int64) { + for { + currentMax := maxValue.Load() + if newValue <= currentMax { + break + } + if maxValue.CompareAndSwap(currentMax, newValue) { + break + } + } +} + +// getAllowedIPs 获取允许的IP列表(带锁保护) +func getAllowedIPs() map[string]bool { + allowedAdminIPsMutex.RLock() + defer allowedAdminIPsMutex.RUnlock() + return allowedAdminIPsMap +} + +// getAllowedCIDRs 获取允许的CIDR列表(带锁保护) +func getAllowedCIDRs() []*net.IPNet { + allowedAdminCIDRsMutex.RLock() + defer allowedAdminCIDRsMutex.RUnlock() + return allowedAdminCIDRs +} + +// reset 重置所有指标到初始状态 +func (m *CircuitBreakerMetrics) reset() { + m.TotalRequests.Store(0) + m.PassRequests.Store(0) + m.BlockRequests.Store(0) + m.FailureRequests.Store(0) + m.SlowRequests.Store(0) + m.OpenCount.Store(0) + m.HalfOpenRequests.Store(0) + m.HalfOpenPassed.Store(0) + m.HalfOpenFailed.Store(0) + m.TotalResponseTime.Store(0) + m.MinResponseTime.Store(1<<63 - 1) // 最大int64值作为初始最小值 + m.MaxResponseTime.Store(0) + m.WindowRequests.Store(0) + m.WindowFailures.Store(0) + // 时间戳相关字段不重置,LastResetTime在调用时单独设置 +} + // parseCIDRs 解析CIDR列表 func parseCIDRs(strs []string) ([]*net.IPNet, error) { nets := make([]*net.IPNet, 0, len(strs)) @@ -386,11 +453,8 @@ func parseCIDRs(strs []string) ([]*net.IPNet, error) { // newCircuitBreakerMetrics 创建并初始化熔断器指标 func newCircuitBreakerMetrics() *CircuitBreakerMetrics { - metrics := &CircuitBreakerMetrics{ - MinResponseTime: atomic.Int64{}, - MaxResponseTime: atomic.Int64{}, - } - metrics.MinResponseTime.Store(1<<63 - 1) // 最大int64值作为初始最小值 + metrics := &CircuitBreakerMetrics{} + metrics.reset() return metrics } @@ -530,13 +594,7 @@ func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) AdaptiveThreshold: threshold, WarmupEndTime: time.Now().Add(config.WarmupDurationParsed).Unix(), } - cbInfo.State.Store(stateClosed) - cbInfo.Metrics.LastResetTime.Store(time.Now().Unix()) - cbInfo.Metrics.LastCloseTime.Store(time.Now().Unix()) - cbInfo.Metrics.WindowStartTime.Store(time.Now().Unix()) - - // 初始化响应时间统计 - cbInfo.Metrics.MinResponseTime.Store(1<<63 - 1) // 最大int64值作为初始最小值 + cbInfo.init() circuitBreakers.Store(serviceName, cbInfo) strategy := "error_count" @@ -663,27 +721,9 @@ func CircuitBreakerMiddleware(r *ghttp.Request) { durationNs := duration.Nanoseconds() cbInfo.Metrics.TotalResponseTime.Add(durationNs) - // 原子更新最小响应时间 - for { - currentMin := cbInfo.Metrics.MinResponseTime.Load() - if durationNs >= currentMin { - break - } - if cbInfo.Metrics.MinResponseTime.CompareAndSwap(currentMin, durationNs) { - break - } - } - - // 原子更新最大响应时间 - for { - currentMax := cbInfo.Metrics.MaxResponseTime.Load() - if durationNs <= currentMax { - break - } - if cbInfo.Metrics.MaxResponseTime.CompareAndSwap(currentMax, durationNs) { - break - } - } + // 原子更新最小和最大响应时间 + atomicUpdateMin(&cbInfo.Metrics.MinResponseTime, durationNs) + atomicUpdateMax(&cbInfo.Metrics.MaxResponseTime, durationNs) if duration > config.SlowRequestThresholdParsed { cbInfo.Metrics.SlowRequests.Add(1) @@ -908,77 +948,32 @@ func isCircuitBreakerOpenInDistributed(ctx context.Context, resourceName string) // syncCircuitBreakerStateToDistributed 同步熔断器状态到Redis func syncCircuitBreakerStateToDistributed(ctx context.Context, resourceName, state string, ttl int) { - key := "circuit_breaker:" + resourceName + ":state" + stateKey := "circuit_breaker:" + resourceName + ":state" lockKey := "circuit_breaker:" + resourceName + ":lock" - redis := g.Redis() - if redis == nil { + redisClient := g.Redis() + if redisClient == nil { g.Log().Warningf(ctx, "Redis未初始化,无法同步分布式熔断状态: %s", resourceName) return } - lockValue := fmt.Sprintf("%d", time.Now().UnixNano()) + // 使用common/redis中的Lock方法获取分布式锁 + success, err := redis.Lock(ctx, lockKey, 10, func(ctx context.Context) error { + // 设置熔断器状态 + _, err := redisClient.Do(ctx, "SETEX", stateKey, ttl, state) + if err != nil { + g.Log().Errorf(ctx, "设置分布式熔断状态失败: %s=%s, error: %v", stateKey, state, err) + } else { + g.Log().Debugf(ctx, "分布式熔断状态已同步: %s=%s (TTL: %d)", stateKey, state, ttl) + } + return nil + }) - // 获取分布式锁 - locked, err := redis.Do(ctx, "SET", lockKey, lockValue, "NX", "EX", 10) if err != nil { g.Log().Errorf(ctx, "获取分布式锁失败: %s, error: %v", lockKey, err) return } - // 检查是否获取到锁 - var isLocked bool - if locked != nil && !locked.IsNil() { - isLocked = true - } else { - // 检查锁是否已过期 - currentLock, err := redis.Get(ctx, lockKey) - if err == nil && !currentLock.IsNil() { - lockTime, _ := strconv.ParseInt(currentLock.String(), 10, 64) - // 如果锁已经存在超过10秒(超时),强制获取 - if time.Now().UnixNano()-lockTime > 10*1e9 { - // 使用SETNX方式获取锁,使用Lua脚本保证原子性 - luaAcquire := ` -local current = redis.call("get", KEYS[1]) -if current and tonumber(current) then - local lockTime = tonumber(current) - if redis.call("TIME")[1] * 1000000000 + redis.call("TIME")[2] - lockTime > 10000000000 then - redis.call("del", KEYS[1]) - return redis.call("set", KEYS[1], ARGV[1], "EX", 10) - end - return nil -else - return redis.call("set", KEYS[1], ARGV[1], "EX", 10) -end` - locked, err = redis.Do(ctx, "EVAL", luaAcquire, 1, lockKey, lockValue) - if err != nil { - g.Log().Errorf(ctx, "强制获取分布式锁失败: %s, error: %v", lockKey, err) - return - } - if locked != nil && !locked.IsNil() { - isLocked = true - } - } - } - } - - if isLocked { - defer func() { - // 使用Lua脚本原子性地删除锁,只删除自己创建的锁 - luaScript := `if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end` - _, err := redis.Do(ctx, "EVAL", luaScript, 1, lockKey, lockValue) - if err != nil { - g.Log().Warningf(ctx, "释放分布式锁失败: %s, error: %v", lockKey, err) - } - }() - - // 设置状态 - _, err := redis.Do(ctx, "SETEX", key, ttl, state) - if err != nil { - g.Log().Errorf(ctx, "设置分布式熔断状态失败: %s=%s, error: %v", key, state, err) - } else { - g.Log().Debugf(ctx, "分布式熔断状态已同步: %s=%s (TTL: %d)", key, state, ttl) - } - } else { + if !success { g.Log().Debugf(ctx, "未获取到分布式锁,跳过状态同步: %s", lockKey) } } @@ -1089,34 +1084,25 @@ func isAdminIP(r *ghttp.Request) bool { return false } - allowedAdminIPsMutex.RLock() - allowedIPs := allowedAdminIPsMap - allowedAdminIPsMutex.RUnlock() + allowedIPs := getAllowedIPs() + allowedCIDRs := getAllowedCIDRs() - if len(allowedIPs) == 0 { - allowedAdminCIDRsMutex.RLock() - hasCIDRs := len(allowedAdminCIDRs) > 0 - allowedAdminCIDRsMutex.RUnlock() - if !hasCIDRs { - return true - } + // 如果没有任何限制,允许访问 + if len(allowedIPs) == 0 && len(allowedCIDRs) == 0 { + return true } + // 检查IP白名单 if allowedIPs[clientIP] { return true } - allowedAdminCIDRsMutex.RLock() - cidrNets := allowedAdminCIDRs - allowedAdminCIDRsMutex.RUnlock() - - if len(cidrNets) > 0 { - clientNetIP := net.ParseIP(clientIP) - if clientNetIP != nil { - for _, cidrNet := range cidrNets { - if cidrNet.Contains(clientNetIP) { - return true - } + // 检查CIDR白名单 + clientNetIP := net.ParseIP(clientIP) + if clientNetIP != nil { + for _, cidrNet := range allowedCIDRs { + if cidrNet.Contains(clientNetIP) { + return true } } } @@ -1125,6 +1111,29 @@ func isAdminIP(r *ghttp.Request) bool { return false } +// batchProcessServices 批量处理服务 +func batchProcessServices(r *ghttp.Request, processFunc func(serviceName string) error) (int, int, map[string]string) { + successCount := 0 + failCount := 0 + failures := make(map[string]string) + + serviceNamesMutex.RLock() + slice := serviceNamesSlice + serviceNamesMutex.RUnlock() + + for _, serviceName := range slice { + if err := processFunc(serviceName); err != nil { + g.Log().Errorf(r.GetCtx(), "服务 %s 处理失败: %v", serviceName, err) + failCount++ + failures[serviceName] = err.Error() + } else { + successCount++ + } + } + + return successCount, failCount, failures +} + // CircuitBreakerResetHandler 重置熔断器 func CircuitBreakerResetHandler(r *ghttp.Request) { serviceName := r.Get("service").String() @@ -1135,25 +1144,13 @@ func CircuitBreakerResetHandler(r *ghttp.Request) { } if serviceName == "" || serviceName == "*" { - serviceNamesMutex.RLock() - slice := serviceNamesSlice - serviceNamesMutex.RUnlock() - - successCount := 0 - failCount := 0 - - for _, name := range slice { - if err := resetSingleService(r, name); err != nil { - g.Log().Errorf(r.GetCtx(), "服务 %s 熔断器重置失败: %v", name, err) - failCount++ - } else { - successCount++ - } - } + successCount, failCount, failures := batchProcessServices(r, func(name string) error { + return resetSingleService(r, name) + }) g.Log().Infof(r.GetCtx(), "批量重置熔断器完成: 成功 %d, 失败 %d", successCount, failCount) r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("批量重置完成: 成功 %d, 失败 %d", successCount, failCount), - Data: map[string]interface{}{"success": successCount, "failed": failCount}}) + Data: map[string]interface{}{"success": successCount, "failed": failCount, "failures": failures}}) return } @@ -1184,16 +1181,8 @@ func resetSingleService(r *ghttp.Request, serviceName string) error { if cbInfoVal, ok := circuitBreakers.Load(serviceName); ok { cbInfo := cbInfoVal.(*CircuitBreakerInfo) cbInfo.State.Store(stateClosed) - cbInfo.Metrics.LastOpenTime.Store(0) - cbInfo.Metrics.NextRetryTime.Store(0) - cbInfo.Metrics.TotalRequests.Store(0) - cbInfo.Metrics.PassRequests.Store(0) - cbInfo.Metrics.BlockRequests.Store(0) - cbInfo.Metrics.FailureRequests.Store(0) - cbInfo.Metrics.SlowRequests.Store(0) - cbInfo.Metrics.OpenCount.Store(0) - cbInfo.Metrics.HalfOpenRequests.Store(0) - cbInfo.Metrics.HalfOpenPassed.Store(0) + // 重置指标 + cbInfo.Metrics.reset() cbInfo.WarmupEndTime = time.Now().Add(cbInfo.Config.WarmupDurationParsed).Unix() cbInfo.Metrics.LastResetTime.Store(time.Now().Unix()) } @@ -1224,23 +1213,9 @@ func CircuitBreakerReloadHandler(r *ghttp.Request) { } if serviceName == "" || serviceName == "*" { - serviceNamesMutex.RLock() - slice := serviceNamesSlice - serviceNamesMutex.RUnlock() - - successCount := 0 - failCount := 0 - failures := make(map[string]string) - - for _, service := range slice { - if err := ReloadCircuitBreakerConfig(service); err != nil { - g.Log().Errorf(r.GetCtx(), "服务 %s 配置重载失败: %v", service, err) - failCount++ - failures[service] = err.Error() - } else { - successCount++ - } - } + successCount, failCount, failures := batchProcessServices(r, func(serviceName string) error { + return ReloadCircuitBreakerConfig(serviceName) + }) updateAdminIPsCache() r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("配置重载完成: 成功 %d, 失败 %d", successCount, failCount),