From 35bc3bd6ec50c00d78ec5d7ef5e46d914255ce50 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 20 May 2026 11:36:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(prompt):=20=E9=87=8D=E6=9E=84=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E6=9E=84=E5=BB=BA=E6=9C=8D=E5=8A=A1=E4=B8=8E?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 72 +- common/util/config.go | 11 +- common/util/files.go | 86 +-- common/util/headers.go | 53 +- common/util/json.go | 120 +++- common/util/token.go | 229 ++++++ config.yml | 100 ++- consts/public/public.go | 5 + .../{prompt => }/prompt_compose_controller.go | 10 +- .../{prompt => }/prompt_session_controller.go | 6 +- dao/compose_session_dao.go | 2 +- dao/provider_protocol_dao.go | 9 +- main.go | 6 +- model/dto/{prompt => }/prompt_compose_dto.go | 28 +- model/dto/{prompt => }/prompt_session_dto.go | 2 +- model/entity/asynch_model.go | 15 +- service/prompt/prompt_build_service.go | 203 ++++-- service/prompt/prompt_compose_service.go | 662 +++++++++++------- service/prompt/prompt_files_handle_service.go | 226 +++--- .../prompt_files_handle_service.markdown | 75 ++ service/prompt/prompt_ir_service.go | 141 ++-- .../prompt/prompt_session_redis_service.go | 93 ++- service/prompt/prompt_session_service.go | 152 ++-- service/prompt/prompt_user_form_batches.go | 135 ++++ 24 files changed, 1682 insertions(+), 759 deletions(-) create mode 100644 common/util/token.go rename controller/{prompt => }/prompt_compose_controller.go (51%) rename controller/{prompt => }/prompt_session_controller.go (54%) rename model/dto/{prompt => }/prompt_compose_dto.go (57%) rename model/dto/{prompt => }/prompt_session_dto.go (95%) create mode 100644 service/prompt/prompt_files_handle_service.markdown create mode 100644 service/prompt/prompt_user_form_batches.go diff --git a/README.md b/README.md index 4c0d610..90d2286 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,54 @@ -# prompts-core(提示词服务)[2026.5.12前,暂时弃置] +# Prompts-Core 提示词核心服务 +## 项目简介 +Prompts-Core 是基于 Go 语言开发的**多模态 AI 提示词构建与管理系统**,专注于统一管理各类 AI 模型的提示词模板、维护智能会话上下文、适配主流模型协议,并支持文件解析与外部技能集成,为 AI 应用提供标准化、高效的提示词服务。 -## 1. 功能范围(当前阶段) -- 仅做提示词配置的基础 CRUD(最小可用版本) -- 表:`prompts_model_prompt` +## 核心功能 +1. **提示词构建引擎** + 支持文字/图片/音频/向量化/全模态 5 类任务提示词生成,提供完整流程、分步节点两种构建模式,支持超大内容按 Token 自动分批处理。 +2. **智能会话管理** + 基于缓存实现高效会话存储,自动控制会话轮数与过期时间,保障上下文连贯性。 +3. **多模型协议适配** + 动态适配 OpenAI、DeepSeek、Qwen、Gemini 等主流 AI 模型协议,支持角色、字段、消息顺序灵活映射。 +4. **文件与技能集成** + 自动提取文本、ZIP 压缩包内容,支持加载外部 Markdown 技能配置,扩展服务能力。 +5. **异步任务调度** + 支持异步任务处理、状态轮询与回调通知,自带可配置重试机制。 -## 2. 接口 -> 路由注册方式与参考项目一致:使用 `common/http.RouteRegister` 注册 controller。 +## 技术架构 +- 开发语言:Go 1.26.0 +- Web 框架:GoFrame v2.10.0 +- 核心存储:Redis(会话缓存) +- 服务组件:Consul(服务注册)、Jaeger(链路追踪) +- 调用链路:客户端 → Prompts-Core → 模型网关 → AI 模型 -- `POST /composeMessages`:按 `modelTypeId` 读取 `prompt_info + response_json_schema`,`modelName` 作为实际调用的网关模型;结合前端 `form(role/value)` 与 `userfiles` 调用 `model-gateway /task/createTask`,同步等待回调后直接返回最终 `messages` -- `GET /composeMessagesCallback/prompts-core`:`model-gateway` 成功回调接口(真实地址由 `callbackUrl + /bizName` 组成) -- `GET /getComposeTask`:按 `taskId` 查询拼接任务状态和结果 -- `POST /createPrompt`:创建(默认启用) -- `PUT /updatePrompt`:更新 -- `DELETE /deletePrompt`:删除 -- `GET /getPrompt`:详情 -- `POST /listPrompt`:列表分页 +## 快速开始 +### 环境要求 +Go 1.26+、Redis、已部署模型网关服务 -## 3. 数据库初始化 -执行根目录 `update.sql`。 +### 启动步骤 +1. 克隆项目代码 +2. 完成项目配置文件修改 +3. 执行命令启动服务: +```bash +go run main.go +``` -## 4. 运行配置 -配置文件:`config.yml` +## API 接口 +### 基础信息 +- 服务地址:`http://{host}:3009` +- 请求类型:`application/json` +- 认证方式:请求头携带 `Authorization`、`X-User` -### 新增说明 -- `prompts_model_prompt` 去除了 `limit_length` -- 新增 `response_json_schema` -- 新增任务记录表 `prompts_compose_task` -- `callbackUrl` 必须填写 prompts-core 的绝对地址基路径,例如:`http://127.0.0.1:8002/composeMessagesCallback` -- `model-gateway` 实际回调地址为:`callbackUrl/{bizName}`,本项目固定为:`/composeMessagesCallback/prompts-core` +### 核心接口 +1. **提示词拼接接口** + - 地址:`POST /composeMessages` + - 功能:构建提示词并调用模型服务,同步返回结果 +2. **任务状态查询** + - 地址:`GET /getComposeTask` + - 功能:根据任务 ID 查询处理状态与结果 +3. **任务回调接口** + - 地址:`GET /composeMessagesCallback/prompts-core` + - 功能:接收模型服务处理完成回调 +4. **会话同步接口** + - 地址:`POST /sessionCallback` + - 功能:同步更新会话上下文历史 \ No newline at end of file diff --git a/common/util/config.go b/common/util/config.go index 55bfbf3..417de8a 100644 --- a/common/util/config.go +++ b/common/util/config.go @@ -8,11 +8,12 @@ import ( ) // GetModelPrompt 获取请求模型的提示词 -func GetModelPrompt(ctx context.Context, Type int) string { - return g.Cfg().MustGet(ctx, "modelPrompts.types."+gconv.String(Type), "").String() +func GetModelPrompt(ctx context.Context, modelType int) string { + key := "modelPrompts.types." + gconv.String(modelType) + return g.Cfg().MustGet(ctx, key, "").String() } -// GetBuildPrompt 获取构建提示词 -func GetBuildPrompt(ctx context.Context, Type int) string { - return g.Cfg().MustGet(ctx, "buildProject.types."+gconv.String(Type), "").String() +// GetBuildPrompt 获取节点构建提示词 +func GetBuildPrompt(ctx context.Context) string { + return g.Cfg().MustGet(ctx, "nodePrompts", "").String() } diff --git a/common/util/files.go b/common/util/files.go index f59d410..1e7325f 100644 --- a/common/util/files.go +++ b/common/util/files.go @@ -6,38 +6,41 @@ import ( "strings" ) -// AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀 -var AllowedMIMEPrefixes = []string{ - "text/", - "application/json", - "application/xml", - "application/javascript", - "application/x-yaml", - "application/yaml", - "application/toml", - "application/x-httpd-php", - "application/x-sh", - "application/x-python", - "application/x-perl", - "application/x-ruby", -} +var ( + // AllowedMIMEPrefixes 允许的文本类 MIME 类型前缀 + AllowedMIMEPrefixes = []string{ + "text/", + "application/json", + "application/xml", + "application/javascript", + "application/x-yaml", + "application/yaml", + "application/toml", + "application/x-httpd-php", + "application/x-sh", + "application/x-python", + "application/x-perl", + "application/x-ruby", + } -// BannedExtensions 禁止的文件扩展名 -var BannedExtensions = map[string]bool{ - ".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true, - ".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true, - ".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true, - ".wma": true, ".m4a": true, - ".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true, - ".flv": true, ".webm": true, - ".tar": true, ".gz": true, ".rar": true, ".7z": true, - ".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true, - ".class": true, ".pyc": true, - ".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true, - ".ppt": true, ".pptx": true, -} + // BannedExtensions 禁止的文件扩展名 + BannedExtensions = map[string]bool{ + ".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true, + ".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true, + ".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true, + ".wma": true, ".m4a": true, + ".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true, + ".flv": true, ".webm": true, + ".tar": true, ".gz": true, ".rar": true, ".7z": true, + ".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true, + ".class": true, ".pyc": true, + ".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true, + ".ppt": true, ".pptx": true, + } -var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`) + symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`) + multiNewlines = regexp.MustCompile(`\n{3,}`) +) // SanitizeURL 清洗 URL 字符串 func SanitizeURL(raw string) string { @@ -51,25 +54,19 @@ func CleanSymbols(text string) string { text = symbolCleaner.ReplaceAllString(text, "") text = strings.ReplaceAll(text, "\r\n", "\n") text = strings.ReplaceAll(text, "\r", "\n") - text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n") + text = multiNewlines.ReplaceAllString(text, "\n\n") return strings.TrimSpace(text) } // IsBannedExtension 判断是否为禁止的文件扩展名 func IsBannedExtension(url string) bool { - ext := strings.ToLower(filepath.Ext(url)) - if idx := strings.Index(ext, "?"); idx != -1 { - ext = ext[:idx] - } + ext := extractExtension(url) return BannedExtensions[ext] } // IsZipExtension 判断是否为 zip 文件 func IsZipExtension(url string) bool { - ext := strings.ToLower(filepath.Ext(url)) - if idx := strings.Index(ext, "?"); idx != -1 { - ext = ext[:idx] - } + ext := extractExtension(url) return ext == ".zip" } @@ -78,11 +75,22 @@ func IsReadableContentType(contentType string) bool { if contentType == "" { return false } + ct := strings.ToLower(contentType) for _, prefix := range AllowedMIMEPrefixes { if strings.HasPrefix(ct, prefix) { return true } } + return false } + +// extractExtension 提取文件扩展名并清理查询参数 +func extractExtension(url string) string { + ext := strings.ToLower(filepath.Ext(url)) + if idx := strings.Index(ext, "?"); idx != -1 { + ext = ext[:idx] + } + return ext +} diff --git a/common/util/headers.go b/common/util/headers.go index 5615d45..6eedffa 100644 --- a/common/util/headers.go +++ b/common/util/headers.go @@ -10,6 +10,7 @@ import ( // AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失 func AsyncCtx(ctx context.Context) context.Context { asyncCtx := context.WithoutCancel(ctx) + if r := g.RequestFromCtx(ctx); r != nil { if token := r.Header.Get("Authorization"); token != "" { asyncCtx = context.WithValue(asyncCtx, "token", token) @@ -18,9 +19,11 @@ func AsyncCtx(ctx context.Context) context.Context { asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo) } } + if user, err := utils.GetUserInfo(ctx); err == nil && user != nil { asyncCtx = context.WithValue(asyncCtx, "user", user) } + return asyncCtx } @@ -28,25 +31,37 @@ func AsyncCtx(ctx context.Context) context.Context { func ForwardHeaders(ctx context.Context) map[string]string { headers := make(map[string]string) - if token, ok := ctx.Value("token").(string); ok && token != "" { - headers["Authorization"] = token - } - if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" { - headers["X-User-Info"] = x - } + setHeaderFromContext(headers, ctx, "Authorization", "token") + setHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo") + + fallbackToRequestHeaders(headers, ctx) - // 兜底:从请求头获取 - if r := g.RequestFromCtx(ctx); r != nil { - if headers["Authorization"] == "" { - if token := r.Header.Get("Authorization"); token != "" { - headers["Authorization"] = token - } - } - if headers["X-User-Info"] == "" { - if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { - headers["X-User-Info"] = userInfo - } - } - } return headers } + +// setHeaderFromContext 从上下文中设置 header +func setHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) { + if value, ok := ctx.Value(ctxKey).(string); ok && value != "" { + headers[headerKey] = value + } +} + +// fallbackToRequestHeaders 从请求头中获取作为兜底 +func fallbackToRequestHeaders(headers map[string]string, ctx context.Context) { + r := g.RequestFromCtx(ctx) + if r == nil { + return + } + + if headers["Authorization"] == "" { + if token := r.Header.Get("Authorization"); token != "" { + headers["Authorization"] = token + } + } + + if headers["X-User-Info"] == "" { + if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { + headers["X-User-Info"] = userInfo + } + } +} diff --git a/common/util/json.go b/common/util/json.go index c31d600..ed16af0 100644 --- a/common/util/json.go +++ b/common/util/json.go @@ -15,6 +15,7 @@ func ParseOutput(text string) (map[string]any, error) { if err != nil { return nil, fmt.Errorf("解析模型输出失败: %w", err) } + return j.Map(), nil } @@ -23,26 +24,17 @@ func ConvertToMessages(raw any) []map[string]any { if raw == nil { return nil } + j, err := gjson.LoadJson(gconv.Bytes(raw)) if err != nil { return nil } - // 如果有 messages 字段,直接返回 + if j.Contains("messages") { return gconv.Maps(j.Get("messages").Array()) } - // 否则当成单条 message - return []map[string]any{ - j.Map(), - } -} -// IsMessageValid 校验推理结果是否合法 -func IsMessageValid(message map[string]any) bool { - if message == nil { - return false - } - return true + return []map[string]any{j.Map()} } // FormToJSON 将表单数据转换为 JSON 字符串 @@ -50,6 +42,17 @@ func FormToJSON(form map[string]any) string { if form == nil { return "{}" } + + b, _ := json.Marshal(form) + return string(b) +} + +// UserFormToJSON 将用户表单数据转换为 JSON 字符串 +func UserFormToJSON(form []map[string]any) string { + if form == nil { + return "{}" + } + b, _ := json.Marshal(form) return string(b) } @@ -60,39 +63,16 @@ func MustMarshal(v any) string { if err != nil { return "{}" } + return string(b) } -// ParseJSONField 解析 JSON 字段 -func ParseJSONField(field any) any { - var v *gvar.Var - switch val := field.(type) { - case *gvar.Var: - v = val - default: - return field - } - - if v == nil || v.IsNil() || v.IsEmpty() { - return nil - } - - str := v.String() - var result any - if json.Unmarshal([]byte(str), &result) == nil { - return result - } - return str -} - // JSONPretty 将任意类型转为格式化的 JSON 字符串 func JSONPretty(v any) string { - // 处理 *gvar.Var 类型 if gv, ok := v.(*gvar.Var); ok { v = gconv.Map(gv.String()) } - // 统一转 map 再美化 var tmp map[string]any if err := gconv.Struct(v, &tmp); err != nil { return gconv.String(v) @@ -101,3 +81,71 @@ func JSONPretty(v any) string { b, _ := json.MarshalIndent(tmp, "", " ") return string(b) } + +// GvarToMap 将 *gvar.Var 类型转换为 map[string]any +func GvarToMap(v *gvar.Var) map[string]any { + if v == nil || v.IsNil() { + return nil + } + + result := make(map[string]any) + + // 方法1:尝试获取 map 值 + if m := v.Map(); len(m) > 0 { + return m + } + + // 方法2:尝试解析 JSON 字符串 + str := v.String() + if str != "" && str != "" { + json.Unmarshal([]byte(str), &result) + if len(result) > 0 { + return result + } + } + + // 方法3:尝试获取 interface 再转换 + if val := v.Val(); val != nil { + switch val.(type) { + case map[string]any: + return val.(map[string]any) + default: + data, _ := json.Marshal(val) + json.Unmarshal(data, &result) + } + } + + return result +} + +// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析 +func ParseJSONFieldFromGvar(source any, target any) { + if source == nil { + return + } + + switch v := source.(type) { + case *gvar.Var: + if v.IsNil() { + return + } + + // 尝试获取 map + if m := v.Map(); len(m) > 0 { + data, _ := json.Marshal(m) + json.Unmarshal(data, target) + return + } + + // 尝试解析 JSON 字符串 + str := v.String() + if str != "" && str != "" { + json.Unmarshal([]byte(str), target) + } + + default: + // 其他类型走原来的逻辑 + data, _ := json.Marshal(source) + json.Unmarshal(data, target) + } +} diff --git a/common/util/token.go b/common/util/token.go new file mode 100644 index 0000000..167ff73 --- /dev/null +++ b/common/util/token.go @@ -0,0 +1,229 @@ +package util + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + "unicode" + + "github.com/gogf/gf/v2/container/gvar" +) + +var ( + enWordRegex = regexp.MustCompile(`[A-Za-z]+`) + punctRegex = regexp.MustCompile(`[[:punct:]]`) +) + +// TokenConfig Token计算配置 +type TokenConfig struct { + ZhRatio float64 `json:"zh_ratio"` + EnRatio float64 `json:"en_ratio"` + SpaceRatio float64 `json:"space_ratio"` + PunctuationRatio float64 `json:"punctuation_ratio"` + MaxWindowSize int `json:"max_window_size"` + ReserveRatio float64 `json:"reserve_ratio"` + MinReserve int `json:"min_reserve"` +} + +// CalculateTokens 计算文本token数 +func CalculateTokens(text string, tokenConfig any) int { + config := parseConfig(tokenConfig) + if config == nil { + return 0 + } + + if text == "" { + return 0 + } + + zhCount := countChineseChars(text) + enCount := countEnglishWords(text) + spaceCount := strings.Count(text, " ") + punctCount := countPunctuation(text) + + totalTokens := int( + float64(zhCount)*config.ZhRatio + + float64(enCount)*config.EnRatio + + float64(spaceCount)*config.SpaceRatio + + float64(punctCount)*config.PunctuationRatio, + ) + + return totalTokens +} + +// CountToken 计算token是否超出窗口限制 +// 返回: true - 未超出(可用), false - 已超出(不可用) +func CountToken(text string, tokenConfig any) bool { + config := parseConfig(tokenConfig) + if config == nil { + return false + } + + estimatedTokens := CalculateTokens(text, tokenConfig) + availableWindow := GetAvailableWindow(tokenConfig) + + return estimatedTokens <= availableWindow +} + +// GetAvailableWindow 获取可用窗口大小 +func GetAvailableWindow(tokenConfig any) int { + config := parseConfig(tokenConfig) + if config == nil { + return 4096 + } + + reserveByRatio := int(float64(config.MaxWindowSize) * config.ReserveRatio) + reserve := reserveByRatio + + if config.MinReserve > reserve { + reserve = config.MinReserve + } + + available := config.MaxWindowSize - reserve + if available < 0 { + available = 0 + } + + return available +} + +// GetMaxWindowSize 获取模型最大窗口大小 +func GetMaxWindowSize(tokenConfig any) int { + config := parseConfig(tokenConfig) + if config == nil { + return 4096 + } + + return config.MaxWindowSize +} + +// CheckUserFormWithinWindow 校验 UserForm 是否在窗口大小内 +// 返回: isValid, exceedTokens, error +func CheckUserFormWithinWindow(userForm []map[string]any, tokenConfig any) (bool, int, error) { + config := parseConfig(tokenConfig) + if config == nil || len(userForm) == 0 { + return true, 0, nil + } + + totalTokens := calculateUserFormTokens(userForm, tokenConfig) + availableWindow := GetAvailableWindow(tokenConfig) + + if totalTokens > availableWindow { + return false, totalTokens - availableWindow, nil + } + + return true, 0, nil +} + +// CheckUserFormBatchWithinWindow 检查 UserForm 分批是否在窗口内 +// 返回: 需要拆分的批次数, 每批的token数, 错误 +func CheckUserFormBatchWithinWindow(userForm []map[string]any, tokenConfig any) (int, []int, error) { + config := parseConfig(tokenConfig) + if config == nil || len(userForm) == 0 { + return 1, nil, nil + } + + availableWindow := GetAvailableWindow(tokenConfig) + + batches := 1 + currentTokens := 0 + batchTokens := make([]int, 0) + + for _, item := range userForm { + itemStr := fmt.Sprintf("%v", item) + itemTokens := CalculateTokens(itemStr, tokenConfig) + + if currentTokens+itemTokens > availableWindow { + batchTokens = append(batchTokens, currentTokens) + batches++ + currentTokens = itemTokens + } else { + currentTokens += itemTokens + } + } + + if currentTokens > 0 { + batchTokens = append(batchTokens, currentTokens) + } + + return batches, batchTokens, nil +} + +// parseConfig 解析配置 +func parseConfig(tokenConfig any) *TokenConfig { + if tokenConfig == nil { + return nil + } + + switch v := tokenConfig.(type) { + case *gvar.Var: + return parseGVarConfig(v) + case map[string]any: + return parseMapConfig(v) + case *TokenConfig: + return v + case TokenConfig: + return &v + default: + return nil + } +} + +// parseGVarConfig 解析 GVar 配置 +func parseGVarConfig(v *gvar.Var) *TokenConfig { + if v.IsNil() { + return nil + } + + mapVal := v.Map() + if mapVal == nil { + return nil + } + + config := &TokenConfig{} + data, _ := json.Marshal(mapVal) + json.Unmarshal(data, config) + + return config +} + +// parseMapConfig 解析 Map 配置 +func parseMapConfig(v map[string]any) *TokenConfig { + config := &TokenConfig{} + data, _ := json.Marshal(v) + json.Unmarshal(data, config) + + return config +} + +// countChineseChars 统计中文字符数量 +func countChineseChars(text string) int { + count := 0 + for _, r := range text { + if unicode.Is(unicode.Han, r) { + count++ + } + } + return count +} + +// countEnglishWords 统计英文单词数量 +func countEnglishWords(text string) int { + return len(enWordRegex.FindAllString(text, -1)) +} + +// countPunctuation 统计标点符号数量 +func countPunctuation(text string) int { + return len(punctRegex.FindAllString(text, -1)) +} + +// calculateUserFormTokens 计算 UserForm 总 token 数 +func calculateUserFormTokens(userForm []map[string]any, tokenConfig any) int { + totalTokens := 0 + for _, item := range userForm { + itemStr := fmt.Sprintf("%v", item) + totalTokens += CalculateTokens(itemStr, tokenConfig) + } + return totalTokens +} diff --git a/config.yml b/config.yml index 27a9507..a7e28ef 100644 --- a/config.yml +++ b/config.yml @@ -103,55 +103,51 @@ modelPrompts: 在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。 当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。 -buildProject: - types: - 1: | - 你是专业的JSON结构生成专家,必须严格遵守以下全部规则。 - 【强制规则】 - 必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致; - 完整阅读所有文本、规则、表单内容,禁止跳读、漏读; - 完整读取UserForm所有字段,不得忽略任何字段; - 如果有skill相关内容必须完整的将内容拼接到system角色描述中; - 理解全部语义后再输出,禁止断章取义; - UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。 - 【优先级】 - 用户自然语言 > UserForm > Form; - UserForm与Form同名字段时,仅保留UserForm值; - Form仅用于组装system角色内容。 - 【表单处理】 - Form:系统提示词、默认参数、基础配置 → 专属填充system角色; - UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content; - 自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和(示例:10条文案各配5张图 → 总50张,parameters.n=50),用户没有相关数量必须默认1; - 图片尺寸为空时自动填充size=1024*1024。 - 【结构铁律】 - 严格沿用固定输出结构,不增删字段或修改层级; - messages元素必须按结构返回; - 禁止将role对象转为字符串、禁止嵌套错乱; - 输出纯净JSON:无多余转义符、无换行符、无额外字符; - 所有括号、引号必须成对闭合,保证JSON合法。 - 【参数赋值】 - model固定沿用传入值; - 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值; - history历史信息必须结合UserForm里的内容对用户描述部分进行补充; - 从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。 - 【输出要求】 - 仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号; - 完整合UserForm全部字段语义到user描述; - 生成后自检JSON语法、结构、数量;错误则自动重新生成。 - 【输出结构】 - %s - 【字段映射】 - %s - 【完整输入信息】 - %s - 直接输出最终JSON: - 2: | - 你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。 - 规则: - 1. 只允许从下面的可选节点ID列表中选择一个返回 - 2. 不要返回任何多余文字、标点、解释、标题 - 3. 只返回纯节点ID - 可选节点ID(ID: 节点描述): - %s - 上下文内容: - %s \ No newline at end of file +nodePrompts: | + 你是流程路由助手,你的任务是根据上下文,选择一个正确的节点ID返回。 + 规则: + 1. 只允许从下面的可选节点ID列表中选择一个返回 + 2. 不要返回任何多余文字、标点、解释、标题 + 3. 只返回纯节点ID + 可选节点ID(ID: 节点描述): + %s + 上下文内容: + %s + +#你是专业的JSON结构生成专家,必须严格遵守以下全部规则。 +# 【强制规则】 +# 必须根据【输出结构】里面返回的JSON结构进行生成,不得任何更改,最终内容与输出结构返回一致; +# 完整阅读所有文本、规则、表单内容,禁止跳读、漏读; +# 完整读取UserForm所有字段,不得忽略任何字段; +# 如果有skill相关内容必须完整的将内容拼接到system角色描述中; +# 理解全部语义后再输出,禁止断章取义; +# UserForm所有字段内容必须完整拼接赋值到user角色描述中,不得有任何遗漏。 +# 【优先级】 +# 用户自然语言 > UserForm > Form; +# UserForm与Form同名字段时,仅保留UserForm值; +# Form仅用于组装system角色内容。 +# 【表单处理】 +# Form:系统提示词、默认参数、基础配置 → 专属填充system角色; +# UserForm:用户业务输入、文案、配图数量、比例、prompt等 → 全部解析后拼接进user角色content; +# 自动提取UserForm中每条文案的配图数量,总图片数 = 各文案配图数累加求和,用户没有相关数量必须默认1; +# 图片尺寸为空时自动填充size=1024*1024。 +# 【结构铁律】 +# 严格沿用固定输出结构,不增删字段或修改层级; +# messages元素必须按结构返回; +# 禁止将role对象转为字符串、禁止嵌套错乱; +# 输出纯净JSON:无多余转义符、无换行符、无额外字符; +# 所有括号、引号必须成对闭合,保证JSON合法。 +# 【参数赋值】 +# model固定沿用传入值; +# 返回结构里面的参数,需要根据语意进行赋值,缺失补默认值; +# history历史信息必须结合UserForm里的内容对用户描述部分进行补充; +# 从UserForm提取信息整合进user描述,确保数量、尺寸、文案语义无遗漏。 +# 【输出要求】 +# 仅输出单行纯净JSON,无任何解释、备注、Markdown或多余符号; +# 完整合UserForm全部字段语义到user描述; +# 生成后自检JSON语法、结构、数量;错误则自动重新生成。 +# 【输出结构】 +# %s +# 【完整输入信息】 +# %s +# 直接输出最终JSON: \ No newline at end of file diff --git a/consts/public/public.go b/consts/public/public.go index b0936bf..acf3ee1 100644 --- a/consts/public/public.go +++ b/consts/public/public.go @@ -5,3 +5,8 @@ const ( ComposeStatusSuccess = "success" ComposeStatusFailed = "failed" ) + +const ( + BuildTypePrompt = 1 //提示词构建 + BuildTypeNode = 2 //节点构建 +) diff --git a/controller/prompt/prompt_compose_controller.go b/controller/prompt_compose_controller.go similarity index 51% rename from controller/prompt/prompt_compose_controller.go rename to controller/prompt_compose_controller.go index 638a0d8..3d3e200 100644 --- a/controller/prompt/prompt_compose_controller.go +++ b/controller/prompt_compose_controller.go @@ -1,9 +1,9 @@ -package prompt +package controller import ( "context" + "prompts-core/model/dto" - promptDto "prompts-core/model/dto/prompt" promptService "prompts-core/service/prompt" ) @@ -13,17 +13,17 @@ type prompt struct{} var Prompt = new(prompt) // ComposeMessages 调用 model-gateway 异步任务并同步等待结果, -func (c *prompt) ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (res *promptDto.ComposeMessagesRes, err error) { +func (c *prompt) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (res *dto.ComposeMessagesRes, err error) { return promptService.ComposeMessages(ctx, req) } // Callback model-gateway 提示词回调 -func (c *prompt) Callback(ctx context.Context, req *promptDto.CallbackReq) (res *promptDto.CallbackRes, err error) { +func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.CallbackRes, err error) { err = promptService.Callback(ctx, req) return } // GetComposeTask 查询拼接任务结果 -func (c *prompt) GetComposeTask(ctx context.Context, req *promptDto.GetComposeTaskReq) (res *promptDto.GetComposeTaskRes, err error) { +func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) { return promptService.GetComposeTask(ctx, req.TaskId) } diff --git a/controller/prompt/prompt_session_controller.go b/controller/prompt_session_controller.go similarity index 54% rename from controller/prompt/prompt_session_controller.go rename to controller/prompt_session_controller.go index d08b0ce..4bc7613 100644 --- a/controller/prompt/prompt_session_controller.go +++ b/controller/prompt_session_controller.go @@ -1,9 +1,9 @@ -package prompt +package controller import ( "context" + "prompts-core/model/dto" - promptDto "prompts-core/model/dto/prompt" promptService "prompts-core/service/prompt" ) @@ -13,6 +13,6 @@ type session struct{} var Session = new(session) // SessionCallback 会话回调 -func (c *session) SessionCallback(ctx context.Context, req *promptDto.SessionCallbackReq) (res *promptDto.SessionCallbackRes, err error) { +func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { return promptService.SessionCallback(ctx, req) } diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index c7a9787..4477e29 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -15,7 +15,7 @@ type composeSessionDao struct{} // Insert 插入 func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) { - var m = new(entity.ComposeTask) + var m = new(entity.ComposeSession) err = gconv.Struct(req, &m) if err != nil { return diff --git a/dao/provider_protocol_dao.go b/dao/provider_protocol_dao.go index d4808db..e0b970d 100644 --- a/dao/provider_protocol_dao.go +++ b/dao/provider_protocol_dao.go @@ -6,6 +6,7 @@ import ( "prompts-core/model/entity" "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/util/gconv" ) var ProviderProtocol = &providerProtocolDao{} @@ -14,7 +15,13 @@ type providerProtocolDao struct{} // Insert 新增协议配置 func (d *providerProtocolDao) Insert(ctx context.Context, req *entity.ProviderProtocol) (id int64, err error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol).OmitEmpty().Data(req).Insert() + var m = new(entity.ProviderProtocol) + err = gconv.Struct(req, &m) + if err != nil { + return + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameProviderProtocol). + Insert(m) if err != nil { return 0, err } diff --git a/main.go b/main.go index 4c5ac6b..7faf891 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,7 @@ import ( "context" "os" "os/signal" - "prompts-core/controller/prompt" + "prompts-core/controller" "syscall" "gitea.com/red-future/common/http" @@ -21,8 +21,8 @@ func main() { defer jaeger.ShutDown(ctx) // 注册路由 http.RouteRegister([]interface{}{ - prompt.Prompt, - prompt.Session, + controller.Prompt, + controller.Session, }) // 监听退出信号,确保 Ctrl+C 能完整退出并关闭 gateway server diff --git a/model/dto/prompt/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go similarity index 57% rename from model/dto/prompt/prompt_compose_dto.go rename to model/dto/prompt_compose_dto.go index 2e0b211..06d0ae1 100644 --- a/model/dto/prompt/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -1,22 +1,28 @@ -package prompt +package dto import "github.com/gogf/gf/v2/frame/g" type ComposeMessagesReq struct { g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"` - ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` - BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点 - SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"` - Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"` - Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` - UserForm map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"` - SkillName string `p:"skillName" json:"skillName" dc:"技能名称"` - UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"` + ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` + BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点 + SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"` + Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"` + Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` + UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"` + SkillName string `p:"skillName" json:"skillName" dc:"技能名称"` + UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"` } type ComposeMessagesRes struct { - Messages any `json:"messages,omitempty" dc:"最终消息数组"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"` + EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` +} + +// MultiRoundResult 多轮返回结果 +type MultiRoundResult struct { + TotalRounds int `json:"total_rounds"` // 总轮数 + Rounds []any `json:"rounds"` // 每轮详情(动态类型) } type CallbackReq struct { diff --git a/model/dto/prompt/prompt_session_dto.go b/model/dto/prompt_session_dto.go similarity index 95% rename from model/dto/prompt/prompt_session_dto.go rename to model/dto/prompt_session_dto.go index 5901ed7..c0974c6 100644 --- a/model/dto/prompt/prompt_session_dto.go +++ b/model/dto/prompt_session_dto.go @@ -1,4 +1,4 @@ -package prompt +package dto import "github.com/gogf/gf/v2/frame/g" diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index c704254..3232cbf 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -16,10 +16,10 @@ type AsynchModel struct { ResponseBody any `orm:"response_body" json:"responseBody"` TokenMapping string `orm:"token_mapping" json:"tokenMapping"` Prompt string `orm:"prompt" json:"prompt"` - IsPrivate int `orm:"is_private" json:"isPrivate"` - IsChatModel int `orm:"is_chat_model" json:"isChatModel"` + IsPrivate *int `orm:"is_private" json:"isPrivate"` + IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` ApiKey string `orm:"api_key" json:"apiKey"` - Enabled int `orm:"enabled" json:"enabled"` + Enabled *int `orm:"enabled" json:"enabled"` MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` QueueLimit int `orm:"queue_limit" json:"queueLimit"` TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` @@ -28,6 +28,9 @@ type AsynchModel struct { RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` Remark string `orm:"remark" json:"remark"` + IsOwner *int `json:"isOwner" orm:"is_owner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig any `orm:"token_config" json:"tokenConfig"` } type asynchModelCol struct { @@ -55,6 +58,9 @@ type asynchModelCol struct { RetryQueueMaxSecs string AutoCleanSeconds string Remark string + IsOwner string + OperatorName string + TokenConfig string } var AsynchModelCol = asynchModelCol{ @@ -82,4 +88,7 @@ var AsynchModelCol = asynchModelCol{ RetryQueueMaxSecs: "retry_queue_max_seconds", AutoCleanSeconds: "auto_clean_seconds", Remark: "remark", + IsOwner: "is_owner", + OperatorName: "operator_name", + TokenConfig: "token_config", } diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 412556d..bd62a98 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -4,65 +4,113 @@ import ( "context" "errors" "fmt" + "prompts-core/consts/public" "strings" "prompts-core/common/util" "prompts-core/dao" - "prompts-core/model/dto/prompt" + "prompts-core/model/dto" "prompts-core/model/entity" "github.com/gogf/gf/v2/util/gconv" ) -// buildInferenceRequest 构建返回请求 -func buildInferenceRequest(ctx context.Context, req *prompt.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (map[string]any, error) { +// buildInferenceRequest 构建推理请求 +func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) { + processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel) + if err != nil { + return nil, fmt.Errorf("处理用户表单分批失败: %w", err) + } + ir := NewPromptIR() - // 1. 统一 Prompt IR + switch req.BuildType { - case 1: //构建提示词请求 - ir.AddSystem(promptBuild(ctx, req, model)) - for _, msg := range history { - role := gconv.String(msg["role"]) - if role != "user" && role != "assistant" { - continue - } - ir.AddHistory(role, gconv.String(msg["content"])) - } - ir.AddUser(buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, model.ModelType))) - case 2: //构建节点请求 - ir.AddUser(NodeBuild(ctx, req)) + case public.BuildTypePrompt: + return buildPromptTypeRequest(ctx, processedReq, targetModel, history, ir, totalBatches) + case public.BuildTypeNode: + return buildNodeTypeRequest(ctx, req, ir) default: return nil, errors.New("不支持的构建类型") } +} - // 2. 获取协议配置 - protocol, err := GetProtocolByProvider(ctx, "qwen") +// buildPromptTypeRequest 构建提示词类型请求(BuildType=1) +func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { + systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches) + ir.AddSystem(systemPrompt) + + for _, msg := range history { + role := gconv.String(msg["role"]) + if role != "user" && role != "assistant" { + continue + } + ir.AddHistory(role, gconv.String(msg["content"])) + } + + userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType)) + ir.AddUser(userPrompt) + + if !checkOverallContent(ir, targetModel) { + availableWindow := util.GetAvailableWindow(targetModel.TokenConfig) + return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) + } + + return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel) +} + +// buildNodeTypeRequest 构建节点类型请求(BuildType=2) +func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ir *PromptIR) (map[string]any, error) { + ir.AddUser(NodeBuild(ctx, req)) + + protocol, err := GetProtocolByProvider(ctx, req.ModelName) if err != nil { - return nil, err + return nil, fmt.Errorf("获取协议配置失败: %w", err) } if protocol == nil { return nil, errors.New("协议配置不存在") } - // 3. 编译为 Provider Request - providerReq, err := Compile(ir, protocol, chatModel) + providerReq, err := Compile(ir, protocol, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("编译请求失败: %w", err) } - // 4. 构建请求体 return map[string]any{ - "modelName": chatModel.ModelName, + "modelName": req.ModelName, "bizName": "prompts-core", "callbackUrl": "/prompt/callback", "requestPayload": providerReq, }, nil } -// promptBuild 构建系统提示词 -func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *entity.AsynchModel) string { +// compileToProviderRequest 编译为 Provider 请求 +func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, model *entity.AsynchModel) (map[string]any, error) { + protocol, err := GetProtocolByProvider(ctx, providerName) + if err != nil { + return nil, fmt.Errorf("获取协议配置失败: %w", err) + } + if protocol == nil { + return nil, errors.New("协议配置不存在") + } + + providerReq, err := Compile(ir, protocol, model) + if err != nil { + return nil, fmt.Errorf("编译请求失败: %w", err) + } + + fmt.Println("providerReq打印:", util.MustMarshal(providerReq)) + return map[string]any{ + "modelName": model.ModelName, + "bizName": "prompts-core", + "callbackUrl": "/prompt/callback", + "requestPayload": providerReq, + }, nil +} + +// promptBuildWithRounds 构建系统提示词(包含轮次信息) +func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string { providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ - ProviderName: "qwen", + ProviderName: model.OperatorName, Status: 1, }) if err != nil || providerProtocol == nil { @@ -70,43 +118,104 @@ func promptBuild(ctx context.Context, req *prompt.ComposeMessagesReq, model *ent } outputJSON := util.JSONPretty(model.RequestMapping) - var userFormContent strings.Builder - for k, v := range req.UserForm { - userFormContent.WriteString(fmt.Sprintf("%s=%v;", k, v)) - } - userFormFullText := strings.TrimSuffix(userFormContent.String(), ";") + maxWindowSize := util.GetMaxWindowSize(model.TokenConfig) + availableWindow := util.GetAvailableWindow(model.TokenConfig) + userFormContent := buildUserFormContent(req.UserForm) formInfo := fmt.Sprintf(` 【系统表单(系统提示词/参数)】 %s 【用户表单全文(必须完整阅读,全部作为用户提示词来源)】 %s -`, util.FormToJSON(req.Form), userFormFullText) +`, util.FormToJSON(req.Form), userFormContent) - return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, formInfo) + inputInfo := fmt.Sprintf(` +目标模型: %s +%s +技能名称: %s +用户文件: %v +`, req.ModelName, formInfo, req.SkillName, req.UserFiles) + + return fmt.Sprintf(providerProtocol.SystemPromptTemplate, + req.ModelName, + maxWindowSize, + availableWindow, + totalRounds, + totalRounds, + totalRounds, + outputJSON, + inputInfo, + totalRounds, + ) } -// 构建用户提示词 -func buildUserPrompt(ctx context.Context, req *prompt.ComposeMessagesReq, prompt string) string { - payload := map[string]any{ - "model": req.ModelName, // 请求模型名称 - "promptInfo": prompt, // 数据库提示信息 - "form": req.Form, // 系统表单 - "userForm": req.UserForm, // 用户表单 - "userFiles": req.UserFiles, //文件url - "userFilesText": FetchFileTexts(ctx, req.UserFiles), //解读文件(只支持可读类型 如:xml,json,yaml) - "skills": SkillMdContent(ctx, req.SkillName), //skill 相关(根据传入的 skillName 获取 zip 内所有 md 文件拼接内容) +// buildUserFormContent 构建用户表单内容字符串 +func buildUserFormContent(userForm []map[string]any) string { + var builder strings.Builder + for _, item := range userForm { + builder.WriteString(fmt.Sprintf("%v\n", item)) } + return builder.String() +} + +// checkOverallContent 检查整体内容是否超出窗口 +func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool { + fullContent := ir.String() + return util.CountToken(fullContent, model.TokenConfig) +} + +// buildUserPrompt 构建用户提示词 +func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string { + userFormForPayload := prepareUserFormPayload(req.UserForm) + + payload := map[string]any{ + "model": req.ModelName, + "promptInfo": prompt, + "form": req.Form, + "userForm": userFormForPayload, + "userFiles": req.UserFiles, + "userFilesText": FetchFileTexts(ctx, req.UserFiles), + "skills": SkillMdContent(ctx, req.SkillName), + } + return util.MustMarshal(payload) } +// prepareUserFormPayload 准备用户表单载荷 +func prepareUserFormPayload(userForm []map[string]any) any { + if len(userForm) == 0 { + return nil + } + + if _, ok := userForm[0]["batch_index"]; ok { + return userForm + } + + return mergeUserFormTexts(userForm) +} + +// mergeUserFormTexts 合并 UserForm 中的所有文本内容 +func mergeUserFormTexts(userForm []map[string]any) string { + var builder strings.Builder + for i, item := range userForm { + text := getItemText(item) + if i > 0 { + builder.WriteString("\n\n") + } + builder.WriteString(text) + } + return builder.String() +} + // NodeBuild 节点构建 -func NodeBuild(ctx context.Context, req *prompt.ComposeMessagesReq) string { - promptTpl := util.GetBuildPrompt(ctx, req.BuildType) +func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string { + promptTpl := util.GetBuildPrompt(ctx) if promptTpl == "" { return "" } + formStr := util.FormToJSON(req.Form) - userFormStr := util.FormToJSON(req.UserForm) + userFormStr := util.UserFormToJSON(req.UserForm) + return fmt.Sprintf(promptTpl, formStr, userFormStr) } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 8f8a354..9711b33 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -5,171 +5,229 @@ import ( "encoding/json" "errors" "fmt" - "prompts-core/dao" - "prompts-core/model/entity" "strings" "time" - "prompts-core/common/util" - "prompts-core/consts/public" - promptDto "prompts-core/model/dto/prompt" - "prompts-core/service/gateway" - "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" + + "prompts-core/common/util" + "prompts-core/consts/public" + "prompts-core/dao" + "prompts-core/model/dto" + "prompts-core/model/entity" + "prompts-core/service/gateway" ) // ComposeMessages 核心拼接提示词主流程 -func ComposeMessages(ctx context.Context, req *promptDto.ComposeMessagesReq) (*promptDto.ComposeMessagesRes, error) { - var ( - epicycleId int64 - taskID string - history []map[string]any - message map[string]any - err error - taskRecord *entity.ComposeTask - ) - // 获取模型信息 +func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { chatModel, aiModel, err := GetModelMessage(ctx, req) if err != nil { return nil, err } - // 根据构建类型进行判断处理 - switch req.BuildType { - //提示词构建 - case 1: - maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() - //1. 获取历史会话 - history, err = GetHistoryMessages(ctx, req.SessionId) - if err != nil { - g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) - history = nil // 出错就用空的,不影响主流程 - } - // 重试循环 - for attempt := 0; attempt <= 0; attempt++ { - if attempt > 0 { - g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) - } - // 2. 调用推理模型 - taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, history) - if err != nil { - g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) - continue - } - - // 3. 保存记录 - _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ - TaskId: taskID, - ModelName: req.ModelName, - SkillName: req.SkillName, - RequestPayload: util.MustMarshal(req), - Status: public.ComposeStatusPending, - }) - if err != nil { - g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) - continue - } - - // 4. 等待结果 - taskRecord, err = waitForResult(ctx, taskID) - if err != nil { - g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) - continue - } - // 校验结果 - message = parsePromptBuild(taskRecord, chatModel) - if message != nil && util.IsMessageValid(message) { - break - } - g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) - message = nil - } - if message == nil { - return nil, errors.New("推理模型调用失败,请稍后再试") - } - //5.创建会话记录 - epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - RequestContent: message, - }) - //节点构建 - case 2: - //1. 调用推理模型 - taskID, err = callInferenceModel(ctx, req, chatModel, aiModel, nil) - if err != nil { - return nil, err - } - //2. 保存相关记录 - _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ - TaskId: taskID, - ModelName: req.ModelName, - SkillName: req.SkillName, - RequestPayload: util.MustMarshal(req), - Status: public.ComposeStatusPending, - }) - //5. 等待结果 - taskRecord, err := waitForResult(ctx, taskID) - if err != nil { - return nil, err - } - message = parseNodeBuild(taskRecord) - default: - epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - Remark: req.Cause, - }) - return &promptDto.ComposeMessagesRes{ - EpicycleId: epicycleId, - }, nil + if err = validateUserForm(ctx, req, aiModel); err != nil { + return nil, err } - return &promptDto.ComposeMessagesRes{ + switch req.BuildType { + case public.BuildTypePrompt: + return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建 + case public.BuildTypeNode: + return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建 + default: + return handleDefaultCase(ctx, req) + } +} + +// validateUserForm 校验用户表单 +func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error { + if len(req.UserForm) == 0 { + return nil + } + isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig) + if err != nil { + return fmt.Errorf("校验用户表单失败: %w", err) + } + + if !isValid { + availableWindow := util.GetAvailableWindow(model.TokenConfig) + return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试", + exceedTokens, availableWindow) + } + + return nil +} + +// handlePromptBuild 处理提示词构建(BuildType=1) +func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { + maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() + history, err := GetHistoryMessages(ctx, req.SessionId) + if err != nil { + g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) + history = nil + } + + var message *dto.MultiRoundResult + var taskRecord *entity.ComposeTask + for attempt := 0; attempt <= 0; attempt++ { + if attempt > 0 { + g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) + } + + taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history) + if err != nil { + g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) + continue + } + + if err = saveComposeTask(ctx, taskID, req); err != nil { + g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) + continue + } + + taskRecord, err = waitForResult(ctx, taskID) + if err != nil { + g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) + continue + } + + message = parsePromptBuild(taskRecord, chatModel) + if message != nil { + break + } + + g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) + } + + if message == nil { + return nil, errors.New("推理模型调用失败,请稍后再试") + } + epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + SessionId: req.SessionId, + RequestContent: message, + }) + if err != nil { + g.Log().Errorf(ctx, "创建会话记录失败: %v", err) + } + return &dto.ComposeMessagesRes{ Messages: message, EpicycleId: epicycleId, }, nil } +// handleNodeBuild 处理节点构建(BuildType=2) +func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { + taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) + if err != nil { + return nil, fmt.Errorf("调用推理模型失败: %w", err) + } + + if err := saveComposeTask(ctx, taskID, req); err != nil { + return nil, fmt.Errorf("保存任务记录失败: %w", err) + } + + taskRecord, err := waitForResult(ctx, taskID) + if err != nil { + return nil, fmt.Errorf("等待结果失败: %w", err) + } + + message := parseNodeBuild(taskRecord) + + return &dto.ComposeMessagesRes{ + Messages: message, + EpicycleId: 0, + }, nil +} + +// handleDefaultCase 处理默认情况 +func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { + epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + SessionId: req.SessionId, + Remark: req.Cause, + }) + if err != nil { + return nil, fmt.Errorf("创建会话记录失败: %w", err) + } + + return &dto.ComposeMessagesRes{ + EpicycleId: epicycleId, + }, nil +} + +// saveComposeTask 保存组合任务 +func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error { + _, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ + TaskId: taskID, + ModelName: req.ModelName, + SkillName: req.SkillName, + RequestPayload: util.MustMarshal(req), + Status: public.ComposeStatusPending, + }) + return err +} + // GetModelMessage 获取模型信息 -func GetModelMessage(ctx context.Context, req *promptDto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { +func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { userInfo, err := utils.GetUserInfo(ctx) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) } - // 1. 获取当前用户的会话模型 - chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, - IsChatModel: 1, - }) + + chatModel, err := getChatModel(ctx, userInfo.UserName) if err != nil { return nil, nil, err } - if chatModel == nil { - return nil, nil, errors.New("当前没有对话模型,请添加") - } - // 2. 获取要构建的模型信息 - aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, - ModelName: req.ModelName, - }) + + aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName) if err != nil { return nil, nil, err } - if aiModel == nil { - return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName) - } + return chatModel, aiModel, nil } +// getChatModel 获取聊天模型 +func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) { + chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userName}, + IsChatModel: new(1), + }) + if err != nil { + return nil, fmt.Errorf("查询聊天模型失败: %w", err) + } + + if chatModel == nil { + return nil, errors.New("当前没有对话模型,请添加") + } + + return chatModel, nil +} + +// getAIModel 获取AI模型 +func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) { + aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userName}, + ModelName: modelName, + }) + if err != nil { + return nil, fmt.Errorf("查询AI模型失败: %w", err) + } + + if aiModel == nil { + return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName) + } + + return aiModel, nil +} + // callInferenceModel 调用推理模型 -func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { - // 构建推理模型请求 +func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history) if err != nil { return "", fmt.Errorf("构建推理请求失败: %w", err) } - // 创建网关任务 taskID, err := gateway.CreateGatewayTask(ctx, taskReq) if err != nil { return "", fmt.Errorf("创建网关任务失败: %w", err) @@ -186,96 +244,131 @@ func callInferenceModel(ctx context.Context, req *promptDto.ComposeMessagesReq, func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() for { - // ===================== 修复点 1:检查上下文是否取消 ===================== - select { - case <-ctx.Done(): - // 请求已被取消,直接返回,不继续查库 - return nil, ctx.Err() - default: - } - - // 1. 查数据库 record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: taskID, }) if err != nil { - // ===================== 修复点 2:如果是上下文取消,直接返回 ===================== if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - return nil, err + return nil, fmt.Errorf("查询任务失败: %w", err) } + if record != nil { - switch record.Status { - case public.ComposeStatusSuccess: - return record, nil - case public.ComposeStatusFailed: - if strings.TrimSpace(record.ErrorMessage) == "" { - return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) - } - return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) + if completed, result := checkTaskCompletion(record); completed { + return result, nil } } - // 2. 查网关状态 - state, err := gateway.QueryGatewayTaskState(ctx, taskID) - if err != nil { - // 网关不可达不终止,继续轮询 - g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) - } else { - switch state { - case 2: // 网关成功 - // 网关已成功,主动更新数据库 - if record != nil { - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: taskID, - Status: public.ComposeStatusSuccess, - }) - if err != nil { - g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) - } - } - case 3: // 网关失败 - if record != nil { - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: taskID, - Status: public.ComposeStatusFailed, - ErrorMessage: "model-gateway 任务执行失败", - }) - if err != nil { - g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) - } - } - return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) - } + if err = syncGatewayTaskState(ctx, taskID, record); err != nil { + g.Log().Warningf(ctx, "[waitForResult] 同步网关状态失败 taskId=%s err=%v", taskID, err) } - // 3. 超时检查 if time.Now().After(deadline) { return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) } - // ===================== 修复点3:sleep 也要监听 ctx 取消 ===================== select { case <-ctx.Done(): return nil, ctx.Err() - case <-time.After(pollInterval): + case <-ticker.C: } } } +// checkTaskCompletion 检查任务是否完成 +func checkTaskCompletion(record *entity.ComposeTask) (bool, *entity.ComposeTask) { + if record == nil { + return false, nil + } + switch record.Status { + case public.ComposeStatusSuccess: + return true, record + case public.ComposeStatusFailed: + errMsg := strings.TrimSpace(record.ErrorMessage) + if errMsg == "" { + return true, nil + } + return true, nil + default: + return false, nil + } +} + +// syncGatewayTaskState 同步网关任务状态 +func syncGatewayTaskState(ctx context.Context, taskID string, record *entity.ComposeTask) error { + state, err := gateway.QueryGatewayTaskState(ctx, taskID) + if err != nil { + return fmt.Errorf("查询网关状态失败: %w", err) + } + switch state { + case 2: + return updateTaskStatus(ctx, taskID, public.ComposeStatusSuccess, "") + case 3: + updateTaskStatus(ctx, taskID, public.ComposeStatusFailed, "model-gateway 任务执行失败") + return fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) + } + return nil +} + +// updateTaskStatus 更新任务状态 +func updateTaskStatus(ctx context.Context, taskID string, status string, errorMsg string) error { + task := &entity.ComposeTask{ + TaskId: taskID, + Status: status, + } + if errorMsg != "" { + task.ErrorMessage = errorMsg + } + + _, err := dao.ComposeTask.Update(ctx, task) + return err +} + // parsePromptBuild 解析提示词构建结果(BuildType == 1) -func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any { +func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult { if taskRecord == nil { return nil } + mapped := parseTaskMessages(taskRecord.Messages) + if mapped == nil { + return createDefaultResult(nil) + } - // 1. 解析 Messages + contentField := getContentField(model) + contentStr, ok := mapped[contentField].(string) + if !ok || contentStr == "" { + return createDefaultResult(mapped) + } + + if roundsArray := tryParseAsArray(contentStr); roundsArray != nil { + return &dto.MultiRoundResult{ + TotalRounds: len(roundsArray), + Rounds: roundsArray, + } + } + + if singleRound := tryParseAsObject(contentStr); singleRound != nil { + return &dto.MultiRoundResult{ + TotalRounds: 1, + Rounds: []any{singleRound}, + } + } + + return createDefaultResult(map[string]any{"content": contentStr}) +} + +// parseTaskMessages 解析任务消息 +func parseTaskMessages(messages any) map[string]any { var mapped map[string]any - switch v := taskRecord.Messages.(type) { + + switch v := messages.(type) { case *gvar.Var: if v != nil { json.Unmarshal([]byte(v.String()), &mapped) @@ -289,115 +382,137 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) json.Unmarshal(b, &mapped) } - // 2. 解析模型 ResponseMapping 获取 content 字段名 - contentField := "content" // 默认值 - if model != nil { - var respMapping map[string]string - switch v := model.ResponseMapping.(type) { - case *gvar.Var: - if v != nil { - json.Unmarshal([]byte(v.String()), &respMapping) - } - case string: - json.Unmarshal([]byte(v), &respMapping) - case map[string]interface{}: - respMapping = make(map[string]string) - for k, val := range v { - if s, ok := val.(string); ok { - respMapping[k] = s - } - } - } - // 从映射中找到 content 对应的字段名 - for k, v := range respMapping { - if strings.Contains(v, "content") { - contentField = k - break - } - } - } - - // 3. 提取 content 的值 - contentStr, ok := mapped[contentField].(string) - if !ok || contentStr == "" { - return mapped - } - - // 4. 解析 content 内的 JSON - var innerData map[string]any - json.Unmarshal([]byte(contentStr), &innerData) - - return innerData + return mapped } -// parseNodeBuild 解析节点构建结果(BuildType == 2) -func parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any { - if taskRecord == nil { +// tryParseAsArray 尝试将字符串解析为数组 +func tryParseAsArray(contentStr string) []any { + var roundsArray []any + if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil { return nil } - var result map[string]any - switch v := taskRecord.Messages.(type) { + return roundsArray +} + +// tryParseAsObject 尝试将字符串解析为对象 +func tryParseAsObject(contentStr string) any { + var singleRound any + if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil { + return nil + } + return singleRound +} + +// createDefaultResult 创建默认结果 +func createDefaultResult(data any) *dto.MultiRoundResult { + if data == nil { + data = make(map[string]any) + } + return &dto.MultiRoundResult{ + TotalRounds: 1, + Rounds: []any{data}, + } +} + +// getContentField 从模型 ResponseMapping 中获取 content 字段名 +func getContentField(model *entity.AsynchModel) string { + if model == nil { + return "content" + } + + respMapping := parseResponseMapping(model.ResponseMapping) + for k, v := range respMapping { + if strings.Contains(v, "content") { + return k + } + } + + return "content" +} + +// parseResponseMapping 解析响应映射 +func parseResponseMapping(mapping any) map[string]string { + result := make(map[string]string) + + switch v := mapping.(type) { case *gvar.Var: if v != nil { json.Unmarshal([]byte(v.String()), &result) } case string: json.Unmarshal([]byte(v), &result) - case map[string]any: - result = v - default: - b, _ := json.Marshal(v) - json.Unmarshal(b, &result) + case map[string]interface{}: + for k, val := range v { + if s, ok := val.(string); ok { + result[k] = s + } + } } + return result } +// parseNodeBuild 解析节点构建结果(BuildType == 2) +func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult { + if taskRecord == nil { + return nil + } + + result := parseTaskMessages(taskRecord.Messages) + if result == nil { + result = make(map[string]any) + } + + return &dto.MultiRoundResult{ + TotalRounds: 1, + Rounds: []any{result}, + } +} + // Callback 回调处理 -func Callback(ctx context.Context, req *promptDto.CallbackReq) error { +func Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) - // ============ 先查任务是否存在 ============ task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: req.TaskId, }) if err != nil { - return err + return fmt.Errorf("查询任务失败: %w", err) } if task == nil { return fmt.Errorf("任务不存在: %s", req.TaskId) } - // ============ 根据状态区分处理 ============ + if req.State == 3 { - // 失败:直接更新状态 - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: req.TaskId, - Status: public.ComposeStatusFailed, - ErrorMessage: req.ErrorMsg, - }) - return err - } - // ====================================== - // 成功:解析模型输出 - result, err := util.ParseOutput(req.Text) - if err != nil { - _, updateErr := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: req.TaskId, - Status: public.ComposeStatusFailed, - ErrorMessage: req.ErrorMsg, - }) - if updateErr != nil { - g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr) - } - return err + return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) + } + + return handleCallbackSuccess(ctx, req) +} + +// handleCallbackFailure 处理回调失败 +func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error { + _, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: taskID, + Status: public.ComposeStatusFailed, + ErrorMessage: errorMsg, + }) + return err +} + +// handleCallbackSuccess 处理回调成功 +func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error { + result, err := util.ParseOutput(req.Text) + if err != nil { + handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) + return fmt.Errorf("解析模型输出失败: %w", err) } - // ============ result 可能为 nil ============ var messages any if result != nil { messages = result } - // ======================================= _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, @@ -407,34 +522,43 @@ func Callback(ctx context.Context, req *promptDto.CallbackReq) error { if err != nil { g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) } + return err } // GetComposeTask 查询任务结果 -func GetComposeTask(ctx context.Context, taskID string) (*promptDto.GetComposeTaskRes, error) { +func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: taskID, }) if err != nil { - return nil, err + return nil, fmt.Errorf("查询任务失败: %w", err) } if record == nil { return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) } - // 如果 Messages 是字符串,反序列化为 JSON 数组 - messages := record.Messages - if str, ok := messages.(string); ok && str != "" { - var parsed any - if err := json.Unmarshal([]byte(str), &parsed); err == nil { - messages = parsed - } - } + messages := parseMessagesForResponse(record.Messages) - return &promptDto.GetComposeTaskRes{ + return &dto.GetComposeTaskRes{ TaskId: record.TaskId, Status: record.Status, ErrorMessage: record.ErrorMessage, Messages: messages, }, nil } + +// parseMessagesForResponse 解析用于响应的消息 +func parseMessagesForResponse(messages any) any { + str, ok := messages.(string) + if !ok || str == "" { + return messages + } + + var parsed any + if err := json.Unmarshal([]byte(str), &parsed); err == nil { + return parsed + } + + return messages +} diff --git a/service/prompt/prompt_files_handle_service.go b/service/prompt/prompt_files_handle_service.go index d8aef2f..8f749de 100644 --- a/service/prompt/prompt_files_handle_service.go +++ b/service/prompt/prompt_files_handle_service.go @@ -10,10 +10,15 @@ import ( "strings" "time" + "github.com/gogf/gf/v2/frame/g" + "prompts-core/common/util" "prompts-core/service/gateway" +) - "github.com/gogf/gf/v2/frame/g" +const ( + bytesPerKB = 1024 + bytesPerMB = 1024 * 1024 ) // FetchFileTexts 从 URL 列表获取文件内容,支持 zip 内文件 @@ -24,51 +29,49 @@ func FetchFileTexts(ctx context.Context, urls []string) map[string]string { return result } - client := &http.Client{ - Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second, - } + client := createHTTPClient(ctx, "userFiles.httpTimeoutSec", 8) for _, rawURL := range urls { url := util.SanitizeURL(rawURL) - if url == "" { - continue - } - - if util.IsBannedExtension(url) { + if url == "" || util.IsBannedExtension(url) { continue } if util.IsZipExtension(url) { - zipTexts := fetchZipFileTexts(ctx, client, url) - for k, v := range zipTexts { - result[k] = v - } + mergeMap(result, fetchZipFileTexts(ctx, client, url)) continue } - text, err := fetchFileContent(ctx, client, url) - if err != nil { - continue + if text := fetchAndCleanFileContent(ctx, client, url); text != "" { + result[url] = text } - - if text == "" { - continue - } - - text = util.CleanSymbols(text) - result[url] = text } return result } +// mergeMap 合并 map +func mergeMap(dst, src map[string]string) { + for k, v := range src { + dst[k] = v + } +} + +// fetchAndCleanFileContent 获取并清理文件内容 +func fetchAndCleanFileContent(ctx context.Context, client *http.Client, url string) string { + text, err := fetchFileContent(ctx, client, url) + if err != nil || text == "" { + return "" + } + return util.CleanSymbols(text) +} + // fetchZipFileTexts 下载并解压 zip 文件,提取可读文本内容 func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string { result := make(map[string]string) - zipBytes, err := downloadFile(client, url, - int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024, - ) + maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB + zipBytes, err := downloadFile(client, url, maxSize) if err != nil { return result } @@ -78,61 +81,61 @@ func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map return result } - entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024 + entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * bytesPerKB for _, file := range reader.File { - if file.FileInfo().IsDir() { + if shouldSkipZipEntry(file.Name) { continue } - fileName := file.Name - - if util.IsBannedExtension(fileName) { - continue + if text := extractZipEntryContent(file, entryMaxSize); text != "" { + result[url+"::"+file.Name] = text } - - if util.IsZipExtension(fileName) { - continue - } - - rc, err := file.Open() - if err != nil { - continue - } - - content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize)) - rc.Close() - if err != nil { - continue - } - - contentType := http.DetectContentType(content) - if !util.IsReadableContentType(contentType) { - continue - } - - text := util.CleanSymbols(string(content)) - if text == "" { - continue - } - - key := url + "::" + fileName - result[key] = text } return result } +// shouldSkipZipEntry 判断是否应该跳过 zip 条目 +func shouldSkipZipEntry(fileName string) bool { + return util.IsBannedExtension(fileName) || util.IsZipExtension(fileName) +} + +// extractZipEntryContent 提取 zip 条目内容 +func extractZipEntryContent(file *zip.File, maxSize int64) string { + rc, err := file.Open() + if err != nil { + return "" + } + defer rc.Close() + + content, err := io.ReadAll(io.LimitReader(rc, maxSize)) + if err != nil { + return "" + } + + if !util.IsReadableContentType(http.DetectContentType(content)) { + return "" + } + + text := util.CleanSymbols(string(content)) + if text == "" { + return "" + } + + return text +} + // downloadFile 下载文件,限制最大大小 func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("创建请求失败: %w", err) } resp, err := client.Do(req) if err != nil { - return nil, err + return nil, fmt.Errorf("执行请求失败: %w", err) } defer resp.Body.Close() @@ -140,19 +143,24 @@ func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error return nil, fmt.Errorf("HTTP %d", resp.StatusCode) } - return io.ReadAll(io.LimitReader(resp.Body, maxSize)) + body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + return body, nil } // fetchFileContent 获取单个文本文件内容 func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return "", err + return "", fmt.Errorf("创建请求失败: %w", err) } resp, err := client.Do(req) if err != nil { - return "", err + return "", fmt.Errorf("执行请求失败: %w", err) } defer resp.Body.Close() @@ -162,16 +170,13 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str contentType := resp.Header.Get("Content-Type") if !util.IsReadableContentType(contentType) { - return "", fmt.Errorf("unreadable content-type: %s", contentType) + return "", fmt.Errorf("不可读的内容类型: %s", contentType) } - body, err := io.ReadAll( - io.LimitReader(resp.Body, - int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024, - ), - ) + maxSize := int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int()) * bytesPerKB + body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize)) if err != nil { - return "", err + return "", fmt.Errorf("读取响应失败: %w", err) } return strings.TrimSpace(string(body)), nil @@ -186,27 +191,26 @@ func SkillMdContent(ctx context.Context, skillName string) string { fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl - client := &http.Client{ - Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second, - } + client := createHTTPClient(ctx, "skillFiles.httpTimeoutSec", 30) + maxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int()) * bytesPerMB - zipBytes, err := downloadFile(client, fullUrl, - int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024, - ) + zipBytes, err := downloadFile(client, fullUrl, maxSize) if err != nil { return "" } mdContents, err := extractMdFiles(ctx, zipBytes) - if err != nil { + if err != nil || len(mdContents) == 0 { return "" } - if len(mdContents) == 0 { - return "" - } + return buildSkillMarkdown(skillResp, mdContents) +} +// buildSkillMarkdown 构建技能 Markdown 内容 +func buildSkillMarkdown(skillResp *gateway.SkillUserVO, mdContents map[string]string) string { var builder strings.Builder + builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name)) if skillResp.Description != "" { builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description)) @@ -227,35 +231,53 @@ func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, er reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes))) if err != nil { - return nil, err + return nil, fmt.Errorf("创建 zip 阅读器失败: %w", err) } - entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024 + entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * bytesPerKB for _, file := range reader.File { - if file.FileInfo().IsDir() { + if file.FileInfo().IsDir() || !isMarkdownFile(file.Name) { continue } - if !strings.HasSuffix(strings.ToLower(file.Name), ".md") { - continue - } - - rc, err := file.Open() - if err != nil { - continue - } - - content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize)) - rc.Close() - if err != nil { - continue - } - - if len(content) > 0 { - result[file.Name] = strings.TrimSpace(string(content)) + if content := readMarkdownFileContent(file, entryMaxSize); content != "" { + result[file.Name] = content } } return result, nil } + +// isMarkdownFile 判断是否为 Markdown 文件 +func isMarkdownFile(fileName string) bool { + return strings.HasSuffix(strings.ToLower(fileName), ".md") +} + +// readMarkdownFileContent 读取 Markdown 文件内容 +func readMarkdownFileContent(file *zip.File, maxSize int64) string { + rc, err := file.Open() + if err != nil { + return "" + } + defer rc.Close() + + content, err := io.ReadAll(io.LimitReader(rc, maxSize)) + if err != nil { + return "" + } + + if len(content) == 0 { + return "" + } + + return strings.TrimSpace(string(content)) +} + +// createHTTPClient 创建 HTTP 客户端 +func createHTTPClient(ctx context.Context, configKey string, defaultSeconds int) *http.Client { + timeout := time.Duration(g.Cfg().MustGet(ctx, configKey, defaultSeconds).Int()) * time.Second + return &http.Client{ + Timeout: timeout, + } +} diff --git a/service/prompt/prompt_files_handle_service.markdown b/service/prompt/prompt_files_handle_service.markdown new file mode 100644 index 0000000..bd6cbaf --- /dev/null +++ b/service/prompt/prompt_files_handle_service.markdown @@ -0,0 +1,75 @@ +# Prompts-Core(提示词核心服务) + +> 智能提示词构建与管理系统,支持多模态 AI 模型的提示词组装、会话管理和协议适配。 + +--- + +## 项目简介 + +**Prompts-Core** 是一个基于 Go 语言开发的提示词核心服务,作为 AI 应用层与模型网关之间的桥梁,负责将业务需求转换为标准化的模型请求。 + +### 核心价值 +- **统一提示词管理**:集中化管理不同模型类型的提示词模板 +- **智能会话维护**:基于 Redis + PostgreSQL 的双层会话存储 +- **多协议适配**:支持 OpenAI、DeepSeek、Qwen、Gemini 等多种模型协议 +- **文件处理能力**:自动提取文本文件和 ZIP 压缩包内容 +- **技能系统集成**:支持从外部加载 Markdown 格式的技能描述 + +--- + +## 核心功能 + +### 1. 提示词构建引擎 + +#### 多模态支持 +| 类型 | 说明 | 适用场景 | +|------|------|----------| +| Type 1 | 文字处理助手 | 文章撰写、文案优化、翻译等 | +| Type 2 | 图片处理助手 | 图像生成、风格迁移等 | +| Type 3 | 音频处理助手 | 语音合成、识别、降噪等 | +| Type 4 | 向量化处理助手 | 语义检索、知识索引等 | +| Type 5 | 全模态助手 | 跨模态转换、多模态融合等 | + +#### 构建模式 +- **BuildType 1(提示词构建)**:完整流程,包含系统提示词、历史会话、用户输入的智能组装 +- **BuildType 2(节点构建)**:工作流路由决策,根据上下文选择节点 ID + +#### 分批处理 +当用户表单内容超出模型窗口限制时,自动按 Token 大小分批处理。 + +### 2. 会话管理系统 + +- **双层存储**:Redis 缓存(最近 N 轮)+ PostgreSQL 持久化 +- **自动管理**:最大轮数控制(默认 10 轮)、自动过期(默认 30 分钟) + +### 3. 协议适配器 + +通过配置动态支持多种模型协议: +- 角色映射:system/user/assistant → 目标协议角色 +- 内容字段映射:content → parts.text 等 +- 消息顺序控制:灵活配置拼接顺序 +- 请求模板渲染:支持占位符替换 + +### 4. 任务调度 + +- **异步流程**:创建网关任务 → 轮询等待 → 接收回调 → 返回结果 +- **重试机制**:可配置最大重试次数(默认 3 次) +- **超时保护**:默认 300 秒超时 + +--- + +## 技术架构 + +### 技术栈 + +| 组件 | 版本 | 用途 | +|------|------|------| +| Go | 1.26.0 | 编程语言 | +| GoFrame | v2.10.0 | Web 框架 | +| PostgreSQL | - | 关系型数据库 | +| Redis | - | 缓存与会话存储 | +| Consul | - | 服务注册与发现 | +| Jaeger | - | 分布式链路追踪 | + +### 架构图 + diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go index cdca0ea..33d22b3 100644 --- a/service/prompt/prompt_ir_service.go +++ b/service/prompt/prompt_ir_service.go @@ -20,11 +20,27 @@ type PromptIR struct { // Segment 消息片段 type Segment struct { - Type string `json:"type"` // text/image + Type string `json:"type"` Content string `json:"content"` Role string `json:"role,omitempty"` } +// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) +type ProviderProtocol struct { + TargetField string `json:"target_field"` + MergeOrder []string `json:"merge_order"` + RoleMapping map[string]string `json:"role_mapping"` + ContentMapping ContentMapping `json:"content_mapping"` + RequestTemplate map[string]any `json:"request_template"` + SystemPromptTemplate string `json:"system_prompt_template"` +} + +// ContentMapping 内容字段映射 +type ContentMapping struct { + Type string `json:"type"` + Field string `json:"field"` +} + // NewPromptIR 创建空 PromptIR func NewPromptIR() *PromptIR { return &PromptIR{ @@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR { } } +// String 返回 PromptIR 的完整内容字符串(用于 token 计算) +func (ir *PromptIR) String() string { + var builder strings.Builder + + for _, seg := range ir.System { + builder.WriteString("System: ") + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + for _, seg := range ir.History { + builder.WriteString(seg.Role) + builder.WriteString(": ") + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + for _, seg := range ir.User { + builder.WriteString("User: ") + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + return builder.String() +} + +// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算) +func (ir *PromptIR) GetTotalContent() string { + var builder strings.Builder + + for _, seg := range ir.System { + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + for _, seg := range ir.History { + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + for _, seg := range ir.User { + builder.WriteString(seg.Content) + builder.WriteString("\n") + } + + return builder.String() +} + // AddSystem 添加系统提示 func (ir *PromptIR) AddSystem(content string) *PromptIR { if content != "" { @@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR { func (ir *PromptIR) ToMessages() []map[string]any { var messages []map[string]any - // 1. 系统消息 for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", @@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any { }) } - // 2. 历史消息 for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, @@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any { }) } - // 3. 用户消息 for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } + return messages } @@ -97,74 +159,35 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP if err != nil || entity == nil { return nil, err } - entity.MergeOrder = util.ParseJSONField(entity.MergeOrder) - entity.RoleMapping = util.ParseJSONField(entity.RoleMapping) - entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) - entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate) - entity.ContentMapping = util.ParseJSONField(entity.ContentMapping) + fmt.Println("entity打印", entity) return parseProtocol(entity), nil } // parseProtocol 将 DB entity 转为编译用协议配置 func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { p := &ProviderProtocol{ - TargetField: e.TargetField, + TargetField: e.TargetField, + SystemPromptTemplate: e.SystemPromptTemplate, } - // MergeOrder: any → []string - if e.MergeOrder != nil { - b, _ := json.Marshal(e.MergeOrder) - json.Unmarshal(b, &p.MergeOrder) - } - - // RoleMapping: any → map[string]string - if e.RoleMapping != nil { - b, _ := json.Marshal(e.RoleMapping) - json.Unmarshal(b, &p.RoleMapping) - } - - // ContentMapping: any → ContentMapping - if e.ContentMapping != nil { - b, _ := json.Marshal(e.ContentMapping) - json.Unmarshal(b, &p.ContentMapping) - } - - // RequestTemplate: any → map[string]any - if e.RequestTemplate != nil { - b, _ := json.Marshal(e.RequestTemplate) - json.Unmarshal(b, &p.RequestTemplate) - } - fmt.Printf("parseProtocol: %+v\n", p) + // 使用通用解析方法处理各个字段 + util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder) + util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping) + util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping) + util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate) return p } -// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) -type ProviderProtocol struct { - TargetField string `json:"target_field"` - MergeOrder []string `json:"merge_order"` - RoleMapping map[string]string `json:"role_mapping"` - ContentMapping ContentMapping `json:"content_mapping"` - RequestTemplate map[string]any `json:"request_template"` -} - -// ContentMapping 内容字段映射 -type ContentMapping struct { - Type string `json:"type"` // direct/parts - Field string `json:"field"` // content/text -} - // Compile 将 PromptIR 按协议配置编译为 Provider Request func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } - // 1. 按 merge_order 拼接消息 + messages := mergeByOrder(ir, p.MergeOrder) - // 2. 角色映射 messages = mapRoles(messages, p.RoleMapping) - // 3. 内容字段映射 messages = mapContent(messages, p.ContentMapping) - // 4. 按 target_field + request_template 构建请求体 + return buildRequest(messages, p, chatModel), nil } @@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any { } } } + return messages } @@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string if len(mapping) == 0 { return messages } + for i, msg := range messages { role, ok := msg["role"].(string) if !ok { continue } + if mapped, exists := mapping[role]; exists { messages[i]["role"] = mapped } } + return messages } @@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { switch cm.Type { case "parts": - // Gemini 格式: {"parts": [{"text": "..."}]} msg["parts"] = []map[string]any{ {cm.Field: content}, } default: - // direct: {"content": "..."} msg[cm.Field] = content } } + return messages } @@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent if len(p.RequestTemplate) > 0 { return renderTemplate(p.RequestTemplate, messages, chatModel) } + return map[string]any{ p.TargetField: messages, } @@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e b, _ := json.Marshal(tmpl) str := string(b) - // 替换 {{model}} str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) - // 替换 {{messages}} + msgBytes, _ := json.Marshal(messages) str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) var result map[string]any json.Unmarshal([]byte(str), &result) + return result } diff --git a/service/prompt/prompt_session_redis_service.go b/service/prompt/prompt_session_redis_service.go index 16ebb17..4024628 100644 --- a/service/prompt/prompt_session_redis_service.go +++ b/service/prompt/prompt_session_redis_service.go @@ -9,15 +9,16 @@ import ( "github.com/gogf/gf/v2/frame/g" ) -// ==================== Redis 操作 ==================== +const ( + redisKeyPrefix = "chat:session:%s" +) // saveToRedis 保存会话数据到Redis func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { - key := fmt.Sprintf("chat:session:%s", sessionId) + key := formatRedisKey(sessionId) maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64() - expireTime := time.Duration(expireSeconds) * time.Second data := map[string]any{ "sessionId": sessionId, @@ -31,18 +32,29 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st return fmt.Errorf("序列化会话数据失败: %w", err) } - _, err = g.Redis().Do(ctx, "LPUSH", key, string(b)) - if err != nil { + if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { + return err + } + + return nil +} + +// formatRedisKey 格式化Redis键 +func formatRedisKey(sessionId string) string { + return fmt.Sprintf(redisKeyPrefix, sessionId) +} + +// executeRedisCommands 执行Redis命令 +func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error { + if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil { return fmt.Errorf("写入Redis失败: %w", err) } - _, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1) - if err != nil { + if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil { return fmt.Errorf("裁剪Redis列表失败: %w", err) } - _, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds())) - if err != nil { + if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { return fmt.Errorf("设置过期时间失败: %w", err) } @@ -51,7 +63,7 @@ func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[st // getFromRedis 从Redis获取会话历史 func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { - key := fmt.Sprintf("chat:session:%s", sessionId) + key := formatRedisKey(sessionId) result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) if err != nil { @@ -62,8 +74,17 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro return []map[string]any{}, nil } + sessions := parseRedisSessions(ctx, result.Strings()) + + reverseSlice(sessions) + + return sessions, nil +} + +// parseRedisSessions 解析Redis会话数据 +func parseRedisSessions(ctx context.Context, values []string) []map[string]any { var sessions []map[string]any - values := result.Strings() + for _, str := range values { var data map[string]any if err := json.Unmarshal([]byte(str), &data); err != nil { @@ -73,12 +94,14 @@ func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, erro sessions = append(sessions, data) } - // 反转(Redis 最新在前 → 时间正序) - for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 { - sessions[i], sessions[j] = sessions[j], sessions[i] - } + return sessions +} - return sessions, nil +// reverseSlice 反转切片 +func reverseSlice(s []map[string]any) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } } // GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用) @@ -92,23 +115,31 @@ func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map return []map[string]any{}, nil } + return flattenHistoryMessages(historyData), nil +} + +// flattenHistoryMessages 扁平化历史消息 +func flattenHistoryMessages(historyData []map[string]any) []map[string]any { var messages []map[string]any + for _, round := range historyData { - if reqMsgs, ok := round["requestContent"].([]interface{}); ok { - for _, m := range reqMsgs { - if msg, ok := m.(map[string]interface{}); ok { - messages = append(messages, msg) - } - } - } - if respMsgs, ok := round["responseContent"].([]interface{}); ok { - for _, m := range respMsgs { - if msg, ok := m.(map[string]interface{}); ok { - messages = append(messages, msg) - } - } - } + appendMessagesFromField(round, "requestContent", &messages) + appendMessagesFromField(round, "responseContent", &messages) } - return messages, nil + return messages +} + +// appendMessagesFromField 从指定字段追加消息 +func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) { + msgs, ok := data[field].([]interface{}) + if !ok { + return + } + + for _, m := range msgs { + if msg, ok := m.(map[string]interface{}); ok { + *messages = append(*messages, msg) + } + } } diff --git a/service/prompt/prompt_session_service.go b/service/prompt/prompt_session_service.go index b0d9fce..7434013 100644 --- a/service/prompt/prompt_session_service.go +++ b/service/prompt/prompt_session_service.go @@ -3,112 +3,164 @@ package prompt import ( "context" "fmt" - sessionDao "prompts-core/dao" - "prompts-core/model/entity" - - "prompts-core/common/util" - sessionDto "prompts-core/model/dto/prompt" "gitea.com/red-future/common/beans" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" + + "prompts-core/common/util" + "prompts-core/dao" + "prompts-core/model/dto" + "prompts-core/model/entity" ) -func SessionCallback(ctx context.Context, req *sessionDto.SessionCallbackReq) (res *sessionDto.SessionCallbackRes, err error) { - // 1. 解析AI返回的文本 +// SessionCallback 会话回调 +func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { result, err := util.ParseOutput(req.Text) if err != nil { g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err) - return nil, err + return nil, fmt.Errorf("解析模型输出失败: %w", err) } - // 2. 更新数据库 result["role"] = "assistant" - _, err = sessionDao.ComposeSession.Update(ctx, &entity.ComposeSession{ - SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, - ResponseContent: result, - }) - if err != nil { - g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err) + + if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil { return nil, err } - // 3. 获取当前轮次完整数据 - session, err := sessionDao.ComposeSession.Get(ctx, &entity.ComposeSession{ - SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, - }) + session, err := getSessionById(ctx, req.EpicycleId) if err != nil { - g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, err } - // 4. 转换 json 并存入 Redis + if err := saveSessionToRedis(ctx, session); err != nil { + return nil, err + } + requestMessages := util.ConvertToMessages(session.RequestContent) responseMessages := util.ConvertToMessages(session.ResponseContent) - if err = saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { - g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", - session.SessionId, session.Id, err) - return nil, err - } - g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", session.SessionId, session.Id, len(requestMessages), len(responseMessages)) - return &sessionDto.SessionCallbackRes{}, nil + + return &dto.SessionCallbackRes{}, nil +} + +// updateSessionResponse 更新会话响应 +func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error { + _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ + SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, + ResponseContent: response, + }) + if err != nil { + g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err) + return fmt.Errorf("更新数据库失败: %w", err) + } + return nil +} + +// getSessionById 根据ID获取会话 +func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) { + session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{ + SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, + }) + if err != nil { + g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err) + return nil, fmt.Errorf("获取会话数据失败: %w", err) + } + return session, nil +} + +// saveSessionToRedis 保存会话到Redis +func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error { + requestMessages := util.ConvertToMessages(session.RequestContent) + responseMessages := util.ConvertToMessages(session.ResponseContent) + + if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { + g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", + session.SessionId, session.Id, err) + return fmt.Errorf("Redis存储失败: %w", err) + } + + return nil } // GetHistoryMessages 获取历史信息 func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() - // 1. 先从 Redis 拿 redisHistory, err := GetSessionHistoryForInference(ctx, sessionId) if err == nil && len(redisHistory) > 0 { return redisHistory, nil } - // 2. Redis 没有 → fallback DB - sessions, _, err := sessionDao.ComposeSession.List(ctx, &entity.ComposeSession{ + return getHistoryFromDatabase(ctx, sessionId, maxRounds) +} + +// getHistoryFromDatabase 从数据库获取历史记录 +func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) { + sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ SessionId: sessionId, }, 1, maxRounds) if err != nil { return nil, fmt.Errorf("DB获取历史失败: %w", err) } + messages := extractMessagesFromSessions(sessions) + + cacheSessionsToRedis(ctx, sessions) + + return messages, nil +} + +// extractMessagesFromSessions 从会话列表中提取消息 +func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any { var messages []map[string]any for _, session := range sessions { - // request - reqMsgs := util.ConvertToMessages(session.RequestContent) - for _, m := range reqMsgs { - role := gconv.String(m["role"]) - if role == "user" || role == "assistant" { - messages = append(messages, m) - } - } - - // response - respMsgs := util.ConvertToMessages(session.ResponseContent) - for _, m := range respMsgs { - if m["role"] == nil { - m["role"] = "assistant" - } - messages = append(messages, m) - } + appendRequestMessages(session.RequestContent, &messages) + appendResponseMessages(session.ResponseContent, &messages) } - // 3. 回写 Redis + return messages +} + +// appendRequestMessages 追加请求消息 +func appendRequestMessages(requestContent any, messages *[]map[string]any) { + reqMsgs := util.ConvertToMessages(requestContent) + for _, m := range reqMsgs { + role := gconv.String(m["role"]) + if role == "user" || role == "assistant" { + *messages = append(*messages, m) + } + } +} + +// appendResponseMessages 追加响应消息 +func appendResponseMessages(responseContent any, messages *[]map[string]any) { + respMsgs := util.ConvertToMessages(responseContent) + for _, m := range respMsgs { + if m["role"] == nil { + m["role"] = "assistant" + } + *messages = append(*messages, m) + } +} + +// cacheSessionsToRedis 将会话缓存到Redis +func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) { for _, session := range sessions { reqMsgs := util.ConvertToMessages(session.RequestContent) respMsgs := util.ConvertToMessages(session.ResponseContent) + for i := range respMsgs { if respMsgs[i]["role"] == nil { respMsgs[i]["role"] = "assistant" } } + if len(reqMsgs) > 0 || len(respMsgs) > 0 { _ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) } } - return messages, nil } diff --git a/service/prompt/prompt_user_form_batches.go b/service/prompt/prompt_user_form_batches.go new file mode 100644 index 0000000..d5feb17 --- /dev/null +++ b/service/prompt/prompt_user_form_batches.go @@ -0,0 +1,135 @@ +package prompt + +import ( + "context" + "fmt" + "strings" + + "github.com/gogf/gf/v2/frame/g" + + "prompts-core/common/util" + "prompts-core/model/dto" + "prompts-core/model/entity" +) + +// ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容) +func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) { + if model.TokenConfig == nil || len(req.UserForm) == 0 { + return req, 1, nil + } + + availableWindow := util.GetAvailableWindow(model.TokenConfig) + batches := splitUserFormByTokenSize(req.UserForm, availableWindow, model.TokenConfig) + + if len(batches) <= 1 { + return req, 1, nil + } + + newUserForm := buildBatchedUserForm(batches) + + newReq := *req + newReq.UserForm = newUserForm + + g.Log().Infof(ctx, "[ProcessUserFormBatches] UserForm分批完成: 原始%d条 -> %d批 (按token大小拼接)", + len(req.UserForm), len(batches)) + + return &newReq, len(batches), nil +} + +// buildBatchedUserForm 构建分批后的用户表单 +func buildBatchedUserForm(batches [][]map[string]any) []map[string]any { + newUserForm := make([]map[string]any, 0, len(batches)) + + for i, batch := range batches { + combinedText := combineBatchText(batch) + newUserForm = append(newUserForm, map[string]any{ + "batch_index": i + 1, + "total_batches": len(batches), + "text": combinedText, + "item_count": len(batch), + }) + } + + return newUserForm +} + +// combineBatchText 合并批次中的所有文本(合并所有字段的值) +func combineBatchText(batch []map[string]any) string { + var builder strings.Builder + + for j, item := range batch { + itemText := getItemText(item) + if itemText == "" { + continue + } + + if j > 0 { + builder.WriteString("\n\n") + } + builder.WriteString(itemText) + } + + return builder.String() +} + +// splitUserFormByTokenSize 按 token 大小将 UserForm 内容拼接后分批 +func splitUserFormByTokenSize(userForm []map[string]any, maxTokens int, tokenConfig any) [][]map[string]any { + if len(userForm) == 0 { + return [][]map[string]any{} + } + + batches := make([][]map[string]any, 0) + currentBatch := make([]map[string]any, 0) + currentTokens := 0 + + for i, item := range userForm { + itemText := getItemText(item) + itemTokens := util.CalculateTokens(itemText, tokenConfig) + + // 单个元素超过窗口,单独成一批 + if itemTokens > maxTokens { + if len(currentBatch) > 0 { + batches = append(batches, currentBatch) + currentBatch = make([]map[string]any, 0) + currentTokens = 0 + } + batches = append(batches, []map[string]any{item}) + continue + } + + // 判断是否需要新开一批 + if currentTokens+itemTokens > maxTokens && len(currentBatch) > 0 { + batches = append(batches, currentBatch) + currentBatch = make([]map[string]any, 0) + currentTokens = 0 + } + + currentBatch = append(currentBatch, item) + currentTokens += itemTokens + + // 最后一批 + if i == len(userForm)-1 && len(currentBatch) > 0 { + batches = append(batches, currentBatch) + } + } + + return batches +} + +// getItemText 获取 item 中的所有文本内容(合并所有字段的值) +func getItemText(item map[string]any) string { + if len(item) == 0 { + return "" + } + + var parts []string + for key, value := range item { + // 跳过分批时添加的元数据字段 + if key == "batch_index" || key == "total_batches" || key == "item_count" { + continue + } + parts = append(parts, fmt.Sprintf("%v", value)) + } + + return strings.Join(parts, "\n") +}