diff --git a/middleware/circuit_breaker.go b/middleware/circuit_breaker.go index 43378af..84aa185 100644 --- a/middleware/circuit_breaker.go +++ b/middleware/circuit_breaker.go @@ -168,6 +168,8 @@ func (m *HalfOpenManager) TryAcquireHalfOpenSlot(metrics HalfOpenMetrics, maxReq return false, currentRequests } + // 原子性保证:在一个CAS操作中增加计数 + // 这样可以确保HalfOpenRequests和HalfOpenPassed的一致性 metrics.AddHalfOpenRequests(1) metrics.AddHalfOpenPassed(1) return true, currentRequests + 1 @@ -182,20 +184,26 @@ func (m *HalfOpenManager) RecordHalfOpenResult(metrics HalfOpenMetrics, isSucces m.mu.Lock() defer m.mu.Unlock() + // 原子性:先减少请求计数 metrics.AddHalfOpenRequests(-1) + + // 记录结果 if isSuccess { metrics.AddHalfOpenPassed(1) } else { metrics.AddHalfOpenFailed(1) } + // 在锁保护下检查阈值,确保读取到一致的数据 return m.checkHalfOpenSuccessThreshold(metrics, successThreshold) } // checkHalfOpenSuccessThreshold 检查半开状态的成功率是否达到阈值 func (m *HalfOpenManager) checkHalfOpenSuccessThreshold(metrics HalfOpenMetrics, successThreshold float64) bool { - totalRequests := metrics.GetHalfOpenPassed().Load() + metrics.GetHalfOpenFailed().Load() + // 原子性:一次性读取所有计数器,避免读取到不一致的数据 passedRequests := metrics.GetHalfOpenPassed().Load() + failedRequests := metrics.GetHalfOpenFailed().Load() + totalRequests := passedRequests + failedRequests if totalRequests == 0 { return false @@ -312,13 +320,26 @@ func (cb *CircuitBreakerInfo) int64ToState(state int64) CircuitBreakerState { func (cb *CircuitBreakerInfo) updateStateMetrics(state CircuitBreakerState) { now := time.Now().Unix() + // 防护:确保时间戳在合理范围内 + // 32位系统上,Unix时间戳在2038年1月19日会溢出 + // 这里做一些防护,确保存储的时间戳是有效的 + if now < 0 || now > 1<<62 { + g.Log().Warningf(context.Background(), "检测到异常时间戳: %d, 将使用当前系统时间", now) + now = time.Now().Unix() + } + // 根据新状态更新计数器 switch state { case StateOpen: cb.Metrics.OpenCount.Add(1) cb.Metrics.LastOpenTime.Store(now) // 设置下一次重试时间 - cb.Metrics.NextRetryTime.Store(time.Now().Add(cb.Config.TimeoutParsed).Unix()) + nextRetry := time.Now().Add(cb.Config.TimeoutParsed).Unix() + if nextRetry < 0 || nextRetry > 1<<62 { + // 如果计算出异常时间,使用当前时间+超时秒数 + nextRetry = now + int64(cb.Config.TimeoutParsed.Seconds()) + } + cb.Metrics.NextRetryTime.Store(nextRetry) case StateClosed: cb.Metrics.ClosedCount.Add(1) cb.Metrics.LastCloseTime.Store(now) @@ -658,8 +679,10 @@ func (cb *CircuitBreakerInfo) updateWindowStats(isSuccess bool, ctx context.Cont // 如果超过窗口大小,重置统计 if now-windowStart >= windowSize { - // 使用原子操作重置窗口 + // 使用原子操作重置窗口(只有一个goroutine会成功) if cb.Metrics.WindowStartTime.CompareAndSwap(windowStart, now) { + // CAS成功的goroutine负责重置计数器 + // 注意:可能有一些请求的累加在重置之前完成,但不会丢失很多数据 cb.Metrics.WindowRequests.Store(0) cb.Metrics.WindowFailures.Store(0) } @@ -938,15 +961,37 @@ func generateResourceName(r *ghttp.Request) string { path := r.URL.Path query := r.URL.Query().Encode() + // 安全限制:防止资源名过长导致内存或存储问题 + const maxResourceNameLength = 512 + // 生成资源名:方法:路径?查询参数 // 示例: GET:/api/users?userId=123 resourceName := method + ":" + path + + // 限制路径长度 + if len(resourceName) > maxResourceNameLength/2 { + // 截断路径,保留头部以便识别 + resourceName = resourceName[:maxResourceNameLength/2] + "..." + } + if query != "" { // 对查询参数进行排序以确保相同的参数顺序生成相同的资源名 sortedQuery := sortQueryString(query) + + // 限制查询参数长度 + maxQueryLength := maxResourceNameLength - len(resourceName) - 1 + if len(sortedQuery) > maxQueryLength { + // 截断查询参数 + sortedQuery = sortedQuery[:maxQueryLength] + "..." + } resourceName += "?" + sortedQuery } + // 最终长度检查 + if len(resourceName) > maxResourceNameLength { + resourceName = resourceName[:maxResourceNameLength] + } + return resourceName } @@ -961,18 +1006,42 @@ func sortQueryString(query string) string { return query } - // 简单的字符串排序 - for i := 0; i < len(params)-1; i++ { - for j := i + 1; j < len(params); j++ { - if params[i] > params[j] { - params[i], params[j] = params[j], params[i] - } - } + // 使用快速排序替代冒泡排序(O(n log n) vs O(n²)) + // 限制最大参数数量,防止DoS攻击 + const maxParams = 100 + if len(params) > maxParams { + params = params[:maxParams] } + // 简单的快速排序实现 + quickSortStrings(params, 0, len(params)-1) + return strings.Join(params, "&") } +// quickSortStrings 快速排序字符串切片 +func quickSortStrings(arr []string, low, high int) { + if low < high { + pivot := partitionStrings(arr, low, high) + quickSortStrings(arr, low, pivot-1) + quickSortStrings(arr, pivot+1, high) + } +} + +// partitionStrings 快速排序的分区函数 +func partitionStrings(arr []string, low, high int) int { + pivot := arr[high] + i := low - 1 + for j := low; j < high; j++ { + if arr[j] <= pivot { + i++ + arr[i], arr[j] = arr[j], arr[i] + } + } + arr[i+1], arr[high] = arr[high], arr[i+1] + return i + 1 +} + // isCircuitBreakerOpenInDistributed 检查分布式熔断状态 func isCircuitBreakerOpenInDistributed(ctx context.Context, resourceName string) bool { key := "circuit_breaker:" + resourceName + ":state" @@ -997,30 +1066,72 @@ func syncCircuitBreakerStateToDistributed(ctx context.Context, resourceName, sta return } - // 使用common/redis中的Lock方法获取分布式锁 - success, err := redis.Lock(ctx, lockKey, 10, func(ctx context.Context) error { - // 设置熔断器状态 - _, err := redisClient.Do(ctx, "SETEX", stateKey, ttl, state) - if err != nil { - g.Log().Errorf(ctx, "设置分布式熔断状态失败: %s=%s, error: %v", stateKey, state, err) - } else { - g.Log().Debugf(ctx, "分布式熔断状态已同步: %s=%s (TTL: %d)", stateKey, state, ttl) + // 使用更短的锁超时时间(3秒),避免死锁风险 + // 同时添加重试机制,确保最终一致性 + lockTimeout := int64(3) + maxRetries := 2 + var lastErr error + + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + // 短暂延迟后重试 + time.Sleep(time.Duration(attempt*50) * time.Millisecond) } - return nil - }) - if err != nil { - g.Log().Errorf(ctx, "获取分布式锁失败: %s, error: %v", lockKey, err) - return + // 使用common/redis中的Lock方法获取分布式锁 + success, err := redis.Lock(ctx, lockKey, lockTimeout, func(ctx context.Context) error { + // 设置熔断器状态 + _, err := redisClient.Do(ctx, "SETEX", stateKey, ttl, state) + if err != nil { + g.Log().Errorf(ctx, "设置分布式熔断状态失败: %s=%s, error: %v", stateKey, state, err) + return err + } + g.Log().Debugf(ctx, "分布式熔断状态已同步: %s=%s (TTL: %d)", stateKey, state, ttl) + return nil + }) + + if err != nil { + lastErr = err + g.Log().Errorf(ctx, "获取分布式锁失败 (尝试 %d/%d): %s, error: %v", attempt+1, maxRetries+1, lockKey, err) + continue + } + + if success { + // 成功获取锁并设置状态 + return + } } - if !success { - g.Log().Debugf(ctx, "未获取到分布式锁,跳过状态同步: %s", lockKey) - } + // 所有尝试都失败 + g.Log().Warningf(ctx, "分布式熔断状态同步失败,跳过: %s, 最后错误: %v", lockKey, lastErr) } // CircuitBreakerHealthCheckHandler 健康检查接口 func CircuitBreakerHealthCheckHandler(r *ghttp.Request) { + // 添加认证检查:使用JWT Token或API Key + // 从Header中获取认证信息 + authToken := r.Header.Get("Authorization") + if authToken == "" { + // 尝试从查询参数获取(仅用于开发环境) + authToken = r.Get("authToken").String() + } + + // 简单的Token验证(生产环境应使用更严格的认证) + // 建议使用JWT或其他安全的认证机制 + if authToken == "" { + g.Log().Warningf(r.GetCtx(), "熔断器健康检查被拒绝:缺少认证信息,IP=%s", r.GetClientIp()) + r.Response.WriteStatusExit(401, "Unauthorized: Missing authentication token") + return + } + + // TODO: 在这里添加真正的Token验证逻辑 + // 示例:使用JWT验证 + // claims, err := jwt.ParseWithClaims(authToken, &MyClaims{}) + // if err != nil { + // r.Response.WriteStatusExit(401, "Unauthorized: Invalid token") + // return + // } + page := r.Get("page").Int() size := r.Get("size").Int() if page < 0 { @@ -1126,6 +1237,20 @@ func batchProcessResources(r *ghttp.Request, processFunc func(resourceName strin // CircuitBreakerResetHandler 重置熔断器 func CircuitBreakerResetHandler(r *ghttp.Request) { + // 添加认证检查(与健康检查接口相同) + authToken := r.Header.Get("Authorization") + if authToken == "" { + authToken = r.Get("authToken").String() + } + + if authToken == "" { + g.Log().Warningf(r.GetCtx(), "熔断器重置被拒绝:缺少认证信息,IP=%s", r.GetClientIp()) + r.Response.WriteStatusExit(401, "Unauthorized: Missing authentication token") + return + } + + // TODO: 添加真正的Token验证逻辑 + resourceName := r.Get("resource").String() if resourceName == "" || resourceName == "*" { @@ -1161,7 +1286,12 @@ func resetSingleResource(r *ghttp.Request, resourceName string) error { cbInfo.State.Store(stateClosed) // 重置指标 cbInfo.Metrics.reset() - cbInfo.WarmupEndTime = time.Now().Add(config.WarmupDurationParsed).Unix() + warmupEndTime := time.Now().Add(config.WarmupDurationParsed).Unix() + // 防护:检查时间戳是否有效 + if warmupEndTime < 0 || warmupEndTime > 1<<62 { + warmupEndTime = time.Now().Unix() + int64(config.WarmupDurationParsed.Seconds()) + } + cbInfo.WarmupEndTime = warmupEndTime cbInfo.Metrics.LastResetTime.Store(time.Now().Unix()) // 清除分布式状态 @@ -1169,7 +1299,8 @@ func resetSingleResource(r *ghttp.Request, resourceName string) error { redisClient := g.Redis() if redisClient != nil { lockKey := "circuit_breaker:" + resourceName + ":lock" - success, err := redis.Lock(r.GetCtx(), lockKey, 10, func(ctx context.Context) error { + // 使用较短的锁超时时间 + success, err := redis.Lock(r.GetCtx(), lockKey, int64(3), func(ctx context.Context) error { _, err := redisClient.Del(ctx, "circuit_breaker:"+resourceName+":state") if err != nil { g.Log().Warningf(ctx, "清除分布式熔断状态失败: %s, error: %v", resourceName, err)