230 lines
5.1 KiB
Go
230 lines
5.1 KiB
Go
package util
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/gogf/gf/v2/container/gvar"
|
|
)
|
|
|
|
var (
|
|
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
|
|
punctRegex = regexp.MustCompile(`[[:punct:]]`)
|
|
)
|
|
|
|
// TokenConfig Token计算配置
|
|
type TokenConfig struct {
|
|
ZhRatio float64 `json:"zh_ratio"`
|
|
EnRatio float64 `json:"en_ratio"`
|
|
SpaceRatio float64 `json:"space_ratio"`
|
|
PunctuationRatio float64 `json:"punctuation_ratio"`
|
|
MaxWindowSize int `json:"max_window_size"`
|
|
ReserveRatio float64 `json:"reserve_ratio"`
|
|
MinReserve int `json:"min_reserve"`
|
|
}
|
|
|
|
// CalculateTokens 计算文本token数
|
|
func CalculateTokens(text string, tokenConfig any) int {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil {
|
|
return 0
|
|
}
|
|
|
|
if text == "" {
|
|
return 0
|
|
}
|
|
|
|
zhCount := countChineseChars(text)
|
|
enCount := countEnglishWords(text)
|
|
spaceCount := strings.Count(text, " ")
|
|
punctCount := countPunctuation(text)
|
|
|
|
totalTokens := int(
|
|
float64(zhCount)*config.ZhRatio +
|
|
float64(enCount)*config.EnRatio +
|
|
float64(spaceCount)*config.SpaceRatio +
|
|
float64(punctCount)*config.PunctuationRatio,
|
|
)
|
|
|
|
return totalTokens
|
|
}
|
|
|
|
// CountToken 计算token是否超出窗口限制
|
|
// 返回: true - 未超出(可用), false - 已超出(不可用)
|
|
func CountToken(text string, tokenConfig any) bool {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil {
|
|
return false
|
|
}
|
|
|
|
estimatedTokens := CalculateTokens(text, tokenConfig)
|
|
availableWindow := GetAvailableWindow(tokenConfig)
|
|
|
|
return estimatedTokens <= availableWindow
|
|
}
|
|
|
|
// GetAvailableWindow 获取可用窗口大小
|
|
func GetAvailableWindow(tokenConfig any) int {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil {
|
|
return 4096
|
|
}
|
|
|
|
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
|
|
reserve := reserveByRatio
|
|
|
|
if config.MinReserve > reserve {
|
|
reserve = config.MinReserve
|
|
}
|
|
|
|
available := config.MaxWindowSize - reserve
|
|
if available < 0 {
|
|
available = 0
|
|
}
|
|
|
|
return available
|
|
}
|
|
|
|
// GetMaxWindowSize 获取模型最大窗口大小
|
|
func GetMaxWindowSize(tokenConfig any) int {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil {
|
|
return 4096
|
|
}
|
|
|
|
return config.MaxWindowSize
|
|
}
|
|
|
|
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
|
|
// 返回: isValid, exceedTokens, error
|
|
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil || len(userForm) == 0 {
|
|
return true, 0, nil
|
|
}
|
|
|
|
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
|
|
availableWindow := GetAvailableWindow(tokenConfig)
|
|
|
|
if totalTokens > availableWindow {
|
|
return false, totalTokens - availableWindow, nil
|
|
}
|
|
|
|
return true, 0, nil
|
|
}
|
|
|
|
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
|
|
// 返回: 需要拆分的批次数, 每批的token数, 错误
|
|
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
|
|
config := parseConfig(tokenConfig)
|
|
if config == nil || len(userForm) == 0 {
|
|
return 1, nil, nil
|
|
}
|
|
|
|
availableWindow := GetAvailableWindow(tokenConfig)
|
|
|
|
batches := 1
|
|
currentTokens := 0
|
|
batchTokens := make([]int, 0)
|
|
|
|
for _, item := range userForm {
|
|
itemStr := fmt.Sprintf("%v", item)
|
|
itemTokens := CalculateTokens(itemStr, tokenConfig)
|
|
|
|
if currentTokens+itemTokens > availableWindow {
|
|
batchTokens = append(batchTokens, currentTokens)
|
|
batches++
|
|
currentTokens = itemTokens
|
|
} else {
|
|
currentTokens += itemTokens
|
|
}
|
|
}
|
|
|
|
if currentTokens > 0 {
|
|
batchTokens = append(batchTokens, currentTokens)
|
|
}
|
|
|
|
return batches, batchTokens, nil
|
|
}
|
|
|
|
// parseConfig 解析配置
|
|
func parseConfig(tokenConfig any) *TokenConfig {
|
|
if tokenConfig == nil {
|
|
return nil
|
|
}
|
|
|
|
switch v := tokenConfig.(type) {
|
|
case *gvar.Var:
|
|
return parseGVarConfig(v)
|
|
case map[string]any:
|
|
return parseMapConfig(v)
|
|
case *TokenConfig:
|
|
return v
|
|
case TokenConfig:
|
|
return &v
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// parseGVarConfig 解析 GVar 配置
|
|
func parseGVarConfig(v *gvar.Var) *TokenConfig {
|
|
if v.IsNil() {
|
|
return nil
|
|
}
|
|
|
|
mapVal := v.Map()
|
|
if mapVal == nil {
|
|
return nil
|
|
}
|
|
|
|
config := &TokenConfig{}
|
|
data, _ := json.Marshal(mapVal)
|
|
json.Unmarshal(data, config)
|
|
|
|
return config
|
|
}
|
|
|
|
// parseMapConfig 解析 Map 配置
|
|
func parseMapConfig(v map[string]any) *TokenConfig {
|
|
config := &TokenConfig{}
|
|
data, _ := json.Marshal(v)
|
|
json.Unmarshal(data, config)
|
|
|
|
return config
|
|
}
|
|
|
|
// countChineseChars 统计中文字符数量
|
|
func countChineseChars(text string) int {
|
|
count := 0
|
|
for _, r := range text {
|
|
if unicode.Is(unicode.Han, r) {
|
|
count++
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
// countEnglishWords 统计英文单词数量
|
|
func countEnglishWords(text string) int {
|
|
return len(enWordRegex.FindAllString(text, -1))
|
|
}
|
|
|
|
// countPunctuation 统计标点符号数量
|
|
func countPunctuation(text string) int {
|
|
return len(punctRegex.FindAllString(text, -1))
|
|
}
|
|
|
|
// calculateUserFormTokens 计算 UserForm 总 token 数
|
|
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
|
|
totalTokens := 0
|
|
for _, item := range userForm {
|
|
itemStr := fmt.Sprintf("%v", item)
|
|
totalTokens += CalculateTokens(itemStr, tokenConfig)
|
|
}
|
|
return totalTokens
|
|
}
|