Files
prompts-core/common/util/token.go

230 lines
5.1 KiB
Go

package util
import (
"encoding/json"
"fmt"
"regexp"
"strings"
"unicode"
"github.com/gogf/gf/v2/container/gvar"
)
var (
enWordRegex = regexp.MustCompile(`[A-Za-z]+`)
punctRegex = regexp.MustCompile(`[[:punct:]]`)
)
// TokenConfig Token计算配置
type TokenConfig struct {
ZhRatio float64 `json:"zh_ratio"`
EnRatio float64 `json:"en_ratio"`
SpaceRatio float64 `json:"space_ratio"`
PunctuationRatio float64 `json:"punctuation_ratio"`
MaxWindowSize int `json:"max_window_size"`
ReserveRatio float64 `json:"reserve_ratio"`
MinReserve int `json:"min_reserve"`
}
// CalculateTokens 计算文本token数
func CalculateTokens(text string, tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 0
}
if text == "" {
return 0
}
zhCount := countChineseChars(text)
enCount := countEnglishWords(text)
spaceCount := strings.Count(text, " ")
punctCount := countPunctuation(text)
totalTokens := int(
float64(zhCount)*config.ZhRatio +
float64(enCount)*config.EnRatio +
float64(spaceCount)*config.SpaceRatio +
float64(punctCount)*config.PunctuationRatio,
)
return totalTokens
}
// CountToken 计算token是否超出窗口限制
// 返回: true - 未超出(可用), false - 已超出(不可用)
func CountToken(text string, tokenConfig any) bool {
config := parseConfig(tokenConfig)
if config == nil {
return false
}
estimatedTokens := CalculateTokens(text, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
return estimatedTokens <= availableWindow
}
// GetAvailableWindow 获取可用窗口大小
func GetAvailableWindow(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio)
reserve := reserveByRatio
if config.MinReserve > reserve {
reserve = config.MinReserve
}
available := config.MaxWindowSize - reserve
if available < 0 {
available = 0
}
return available
}
// GetMaxWindowSize 获取模型最大窗口大小
func GetMaxWindowSize(tokenConfig any) int {
config := parseConfig(tokenConfig)
if config == nil {
return 4096
}
return config.MaxWindowSize
}
// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内
// 返回: isValid, exceedTokens, error
func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return true, 0, nil
}
totalTokens := calculateUserFormTokens(userForm, tokenConfig)
availableWindow := GetAvailableWindow(tokenConfig)
if totalTokens > availableWindow {
return false, totalTokens - availableWindow, nil
}
return true, 0, nil
}
// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内
// 返回: 需要拆分的批次数, 每批的token数, 错误
func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) {
config := parseConfig(tokenConfig)
if config == nil || len(userForm) == 0 {
return 1, nil, nil
}
availableWindow := GetAvailableWindow(tokenConfig)
batches := 1
currentTokens := 0
batchTokens := make([]int, 0)
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
itemTokens := CalculateTokens(itemStr, tokenConfig)
if currentTokens+itemTokens > availableWindow {
batchTokens = append(batchTokens, currentTokens)
batches++
currentTokens = itemTokens
} else {
currentTokens += itemTokens
}
}
if currentTokens > 0 {
batchTokens = append(batchTokens, currentTokens)
}
return batches, batchTokens, nil
}
// parseConfig 解析配置
func parseConfig(tokenConfig any) *TokenConfig {
if tokenConfig == nil {
return nil
}
switch v := tokenConfig.(type) {
case *gvar.Var:
return parseGVarConfig(v)
case map[string]any:
return parseMapConfig(v)
case *TokenConfig:
return v
case TokenConfig:
return &v
default:
return nil
}
}
// parseGVarConfig 解析 GVar 配置
func parseGVarConfig(v *gvar.Var) *TokenConfig {
if v.IsNil() {
return nil
}
mapVal := v.Map()
if mapVal == nil {
return nil
}
config := &TokenConfig{}
data, _ := json.Marshal(mapVal)
json.Unmarshal(data, config)
return config
}
// parseMapConfig 解析 Map 配置
func parseMapConfig(v map[string]any) *TokenConfig {
config := &TokenConfig{}
data, _ := json.Marshal(v)
json.Unmarshal(data, config)
return config
}
// countChineseChars 统计中文字符数量
func countChineseChars(text string) int {
count := 0
for _, r := range text {
if unicode.Is(unicode.Han, r) {
count++
}
}
return count
}
// countEnglishWords 统计英文单词数量
func countEnglishWords(text string) int {
return len(enWordRegex.FindAllString(text, -1))
}
// countPunctuation 统计标点符号数量
func countPunctuation(text string) int {
return len(punctRegex.FindAllString(text, -1))
}
// calculateUserFormTokens 计算 UserForm 总 token 数
func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int {
totalTokens := 0
for _, item := range userForm {
itemStr := fmt.Sprintf("%v", item)
totalTokens += CalculateTokens(itemStr, tokenConfig)
}
return totalTokens
}