138 lines
4.1 KiB
Go
138 lines
4.1 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"cidservice/consts"
|
||
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
)
|
||
|
||
var (
|
||
RateLimit = rateLimitService{}
|
||
)
|
||
|
||
type rateLimitService struct{}
|
||
|
||
// TenantRateLimitConfig 租户限流配置
|
||
type TenantRateLimitConfig struct {
|
||
TenantID int64 // 租户ID
|
||
RequestsPerSecond float64 // 每秒请求数
|
||
Burst int // 突发请求数
|
||
Window time.Duration // 时间窗口
|
||
}
|
||
|
||
// CheckTenantRequestLimit 检查租户请求次数限制
|
||
func (s *rateLimitService) CheckTenantRequestLimit(ctx context.Context, tenantID int64, config *TenantRateLimitConfig) (bool, error) {
|
||
if config == nil {
|
||
// 使用默认配置
|
||
config = s.getDefaultTenantRateLimitConfig(tenantID)
|
||
}
|
||
|
||
// 构建Redis键 - 使用当前小时的键,确保按小时计数
|
||
now := time.Now()
|
||
hourKey := fmt.Sprintf("%s%d:%d", consts.AdRequestLimitKeyPrefix, tenantID, now.Hour())
|
||
|
||
// 获取当前计数
|
||
currentCountVar, err := g.Redis().Get(ctx, hourKey)
|
||
if err != nil && err.Error() != "redis: nil" {
|
||
return false, err
|
||
}
|
||
|
||
currentCount := currentCountVar.Int64()
|
||
|
||
// 如果是第一次请求,设置计数和过期时间(到下一个小时)
|
||
if currentCount == 0 {
|
||
// 设置过期时间为到下一个小时的剩余时间
|
||
nextHour := now.Truncate(time.Hour).Add(time.Hour)
|
||
ttl := nextHour.Sub(now)
|
||
// 使用SetEX一次性设置值和过期时间
|
||
err = g.Redis().SetEX(ctx, hourKey, 1, int64(ttl.Seconds()))
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// 检查是否超过限制
|
||
maxRequests := int64(config.RequestsPerSecond * config.Window.Seconds())
|
||
if currentCount >= maxRequests {
|
||
return false, nil
|
||
}
|
||
|
||
// 增加计数
|
||
_, err = g.Redis().Incr(ctx, hourKey)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return true, nil
|
||
}
|
||
|
||
// GetDefaultTenantRateLimitConfig 获取默认的租户限流配置
|
||
func (s *rateLimitService) getDefaultTenantRateLimitConfig(tenantID int64) *TenantRateLimitConfig {
|
||
// 从配置文件中读取限流参数
|
||
ctx := context.Background()
|
||
|
||
// 检查是否启用租户限流
|
||
enabled := g.Cfg().MustGet(ctx, "tenantRateLimit.enabled", false).Bool()
|
||
if !enabled {
|
||
// 如果未启用,返回一个很大的限制值,相当于不限制
|
||
return &TenantRateLimitConfig{
|
||
TenantID: tenantID,
|
||
RequestsPerSecond: 10000, // 每秒10000个请求,相当于不限制
|
||
Burst: 20000, // 突发20000个请求
|
||
Window: time.Hour,
|
||
}
|
||
}
|
||
|
||
// 从配置文件中获取限流参数
|
||
requestsPerHour := g.Cfg().MustGet(ctx, "tenantRateLimit.requestsPerHour", 3600).Int64()
|
||
windowSeconds := g.Cfg().MustGet(ctx, "tenantRateLimit.window", 3600).Int64()
|
||
burst := g.Cfg().MustGet(ctx, "tenantRateLimit.burst", 100).Int()
|
||
|
||
// 转换为每秒请求数
|
||
requestsPerSecond := float64(requestsPerHour) / float64(windowSeconds)
|
||
|
||
return &TenantRateLimitConfig{
|
||
TenantID: tenantID,
|
||
RequestsPerSecond: requestsPerSecond,
|
||
Burst: burst,
|
||
Window: time.Duration(windowSeconds) * time.Second,
|
||
}
|
||
}
|
||
|
||
// SetTenantRateLimitConfig 设置租户限流配置
|
||
func (s *rateLimitService) SetTenantRateLimitConfig(ctx context.Context, config *TenantRateLimitConfig) error {
|
||
// 注意:实际使用的是config.yml中的全局配置,此方法仅用于兼容旧API
|
||
// 实际限流参数请修改config.yml中的tenantRateLimit部分
|
||
return nil
|
||
}
|
||
|
||
// GetTenantCurrentUsage 获取租户当前请求使用情况
|
||
func (s *rateLimitService) GetTenantCurrentUsage(ctx context.Context, tenantID int64, config *TenantRateLimitConfig) (current int64, max int64, err error) {
|
||
if config == nil {
|
||
config = s.getDefaultTenantRateLimitConfig(tenantID)
|
||
}
|
||
|
||
// 构建当前小时的Redis键
|
||
now := time.Now()
|
||
hourKey := fmt.Sprintf("%s%d:%d", consts.AdRequestLimitKeyPrefix, tenantID, now.Hour())
|
||
|
||
// 获取当前计数
|
||
currentVar, err := g.Redis().Get(ctx, hourKey)
|
||
if err != nil && err.Error() == "redis: nil" {
|
||
current = 0
|
||
err = nil
|
||
} else if err != nil {
|
||
return 0, 0, err
|
||
} else {
|
||
current = currentVar.Int64()
|
||
}
|
||
|
||
max = int64(config.RequestsPerSecond * config.Window.Seconds())
|
||
return current, max, nil
|
||
}
|