package util import ( "encoding/json" "fmt" "regexp" "strings" "unicode" "github.com/gogf/gf/v2/container/gvar" ) var ( enWordRegex = regexp.MustCompile(`[A-Za-z]+`) punctRegex = regexp.MustCompile(`[[:punct:]]`) ) // TokenConfig Token计算配置 type TokenConfig struct { ZhRatio float64 `json:"zh_ratio"` EnRatio float64 `json:"en_ratio"` SpaceRatio float64 `json:"space_ratio"` PunctuationRatio float64 `json:"punctuation_ratio"` MaxWindowSize int `json:"max_window_size"` ReserveRatio float64 `json:"reserve_ratio"` MinReserve int `json:"min_reserve"` } // CalculateTokens 计算文本token数 func CalculateTokens(text string, tokenConfig any) int { config := parseConfig(tokenConfig) if config == nil { return 0 } if text == "" { return 0 } zhCount := countChineseChars(text) enCount := countEnglishWords(text) spaceCount := strings.Count(text, " ") punctCount := countPunctuation(text) totalTokens := int( float64(zhCount)*config.ZhRatio + float64(enCount)*config.EnRatio + float64(spaceCount)*config.SpaceRatio + float64(punctCount)*config.PunctuationRatio, ) return totalTokens } // CountToken 计算token是否超出窗口限制 // 返回: true - 未超出(可用), false - 已超出(不可用) func CountToken(text string, tokenConfig any) bool { config := parseConfig(tokenConfig) if config == nil { return false } estimatedTokens := CalculateTokens(text, tokenConfig) availableWindow := GetAvailableWindow(tokenConfig) return estimatedTokens <= availableWindow } // GetAvailableWindow 获取可用窗口大小 func GetAvailableWindow(tokenConfig any) int { config := parseConfig(tokenConfig) if config == nil { return 4096 } reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio) reserve := reserveByRatio if config.MinReserve > reserve { reserve = config.MinReserve } available := config.MaxWindowSize - reserve if available < 0 { available = 0 } return available } // GetMaxWindowSize 获取模型最大窗口大小 func GetMaxWindowSize(tokenConfig any) int { config := parseConfig(tokenConfig) if config == nil { return 4096 } return config.MaxWindowSize } // CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内 // 返回: isValid, exceedTokens, error func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) { config := parseConfig(tokenConfig) if config == nil || len(userForm) == 0 { return true, 0, nil } totalTokens := calculateUserFormTokens(userForm, tokenConfig) availableWindow := GetAvailableWindow(tokenConfig) if totalTokens > availableWindow { return false, totalTokens - availableWindow, nil } return true, 0, nil } // CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内 // 返回: 需要拆分的批次数, 每批的token数, 错误 func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) { config := parseConfig(tokenConfig) if config == nil || len(userForm) == 0 { return 1, nil, nil } availableWindow := GetAvailableWindow(tokenConfig) batches := 1 currentTokens := 0 batchTokens := make([]int, 0) for _, item := range userForm { itemStr := fmt.Sprintf("%v", item) itemTokens := CalculateTokens(itemStr, tokenConfig) if currentTokens+itemTokens > availableWindow { batchTokens = append(batchTokens, currentTokens) batches++ currentTokens = itemTokens } else { currentTokens += itemTokens } } if currentTokens > 0 { batchTokens = append(batchTokens, currentTokens) } return batches, batchTokens, nil } // parseConfig 解析配置 func parseConfig(tokenConfig any) *TokenConfig { if tokenConfig == nil { return nil } switch v := tokenConfig.(type) { case *gvar.Var: return parseGVarConfig(v) case map[string]any: return parseMapConfig(v) case *TokenConfig: return v case TokenConfig: return &v default: return nil } } // parseGVarConfig 解析 GVar 配置 func parseGVarConfig(v *gvar.Var) *TokenConfig { if v.IsNil() { return nil } mapVal := v.Map() if mapVal == nil { return nil } config := &TokenConfig{} data, _ := json.Marshal(mapVal) json.Unmarshal(data, config) return config } // parseMapConfig 解析 Map 配置 func parseMapConfig(v map[string]any) *TokenConfig { config := &TokenConfig{} data, _ := json.Marshal(v) json.Unmarshal(data, config) return config } // countChineseChars 统计中文字符数量 func countChineseChars(text string) int { count := 0 for _, r := range text { if unicode.Is(unicode.Han, r) { count++ } } return count } // countEnglishWords 统计英文单词数量 func countEnglishWords(text string) int { return len(enWordRegex.FindAllString(text, -1)) } // countPunctuation 统计标点符号数量 func countPunctuation(text string) int { return len(punctRegex.FindAllString(text, -1)) } // calculateUserFormTokens 计算 UserForm 总 token 数 func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int { totalTokens := 0 for _, item := range userForm { itemStr := fmt.Sprintf("%v", item) totalTokens += CalculateTokens(itemStr, tokenConfig) } return totalTokens }