Files
cid/service/rate_limit_service.go

137 lines
4.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"time"
"cid/consts"
"github.com/gogf/gf/v2/frame/g"
)
type rateLimit struct{}
// RateLimit 限流服务
var RateLimit = new(rateLimit)
// TenantRateLimitConfig 租户限流配置
type TenantRateLimitConfig struct {
TenantID int64 // 租户ID
RequestsPerSecond float64 // 每秒请求数
Burst int // 突发请求数
Window time.Duration // 时间窗口
}
// CheckTenantRequestLimit 检查租户请求次数限制
func (s *rateLimit) CheckTenantRequestLimit(ctx context.Context, tenantID int64, config *TenantRateLimitConfig) (bool, error) {
if config == nil {
// 使用默认配置
config = s.getDefaultTenantRateLimitConfig(tenantID)
}
// 构建Redis键 - 使用当前小时的键,确保按小时计数
now := time.Now()
hourKey := fmt.Sprintf("%s%d:%d", consts.AdRequestLimitKeyPrefix, tenantID, now.Hour())
// 获取当前计数
currentCountVar, err := g.Redis().Get(ctx, hourKey)
if err != nil && err.Error() != "redis: nil" {
return false, err
}
currentCount := currentCountVar.Int64()
// 如果是第一次请求,设置计数和过期时间(到下一个小时)
if currentCount == 0 {
// 设置过期时间为到下一个小时的剩余时间
nextHour := now.Truncate(time.Hour).Add(time.Hour)
ttl := nextHour.Sub(now)
// 使用SetEX一次性设置值和过期时间
err = g.Redis().SetEX(ctx, hourKey, 1, int64(ttl.Seconds()))
if err != nil {
return false, err
}
return true, nil
}
// 检查是否超过限制
maxRequests := int64(config.RequestsPerSecond * config.Window.Seconds())
if currentCount >= maxRequests {
return false, nil
}
// 增加计数
_, err = g.Redis().Incr(ctx, hourKey)
if err != nil {
return false, err
}
return true, nil
}
// GetDefaultTenantRateLimitConfig 获取默认的租户限流配置
func (s *rateLimit) getDefaultTenantRateLimitConfig(tenantID int64) *TenantRateLimitConfig {
// 从配置文件中读取限流参数
ctx := context.Background()
// 检查是否启用租户限流
enabled := g.Cfg().MustGet(ctx, "tenantRateLimit.enabled", false).Bool()
if !enabled {
// 如果未启用,返回一个很大的限制值,相当于不限制
return &TenantRateLimitConfig{
TenantID: tenantID,
RequestsPerSecond: 10000, // 每秒10000个请求相当于不限制
Burst: 20000, // 突发20000个请求
Window: time.Hour,
}
}
// 从配置文件中获取限流参数
requestsPerHour := g.Cfg().MustGet(ctx, "tenantRateLimit.requestsPerHour", 3600).Int64()
windowSeconds := g.Cfg().MustGet(ctx, "tenantRateLimit.window", 3600).Int64()
burst := g.Cfg().MustGet(ctx, "tenantRateLimit.burst", 100).Int()
// 转换为每秒请求数
requestsPerSecond := float64(requestsPerHour) / float64(windowSeconds)
return &TenantRateLimitConfig{
TenantID: tenantID,
RequestsPerSecond: requestsPerSecond,
Burst: burst,
Window: time.Duration(windowSeconds) * time.Second,
}
}
// SetTenantRateLimitConfig 设置租户限流配置
func (s *rateLimit) SetTenantRateLimitConfig(ctx context.Context, config *TenantRateLimitConfig) error {
// 注意实际使用的是config.yml中的全局配置此方法仅用于兼容旧API
// 实际限流参数请修改config.yml中的tenantRateLimit部分
return nil
}
// GetTenantCurrentUsage 获取租户当前请求使用情况
func (s *rateLimit) GetTenantCurrentUsage(ctx context.Context, tenantID int64, config *TenantRateLimitConfig) (current int64, max int64, err error) {
if config == nil {
config = s.getDefaultTenantRateLimitConfig(tenantID)
}
// 构建当前小时的Redis键
now := time.Now()
hourKey := fmt.Sprintf("%s%d:%d", consts.AdRequestLimitKeyPrefix, tenantID, now.Hour())
// 获取当前计数
currentVar, err := g.Redis().Get(ctx, hourKey)
if err != nil && err.Error() == "redis: nil" {
current = 0
err = nil
} else if err != nil {
return 0, 0, err
} else {
current = currentVar.Int64()
}
max = int64(config.RequestsPerSecond * config.Window.Seconds())
return current, max, nil
}