From b69e7386e2788d2beddc1d5cc06da228fb22a177 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 10 Jun 2026 14:51:25 +0800 Subject: [PATCH] =?UTF-8?q?refactor(prompts-core):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84=E5=92=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/json.go | 140 ++---------------- common/util/mapping.go | 88 +++-------- consts/public/public.go | 4 + controller/prompt_compose_controller.go | 5 - model/dto/prompt_compose_dto.go | 29 +--- service/prompt/prompt_build_service.go | 52 ++----- service/prompt/prompt_compose_service.go | 128 +++++----------- service/prompt/prompt_files_handle_service.go | 3 + service/prompt/prompt_ir_service.go | 131 ++++++++-------- service/session/prompt_session_service.go | 16 +- 10 files changed, 164 insertions(+), 432 deletions(-) diff --git a/common/util/json.go b/common/util/json.go index 86a97d0..2be1df7 100644 --- a/common/util/json.go +++ b/common/util/json.go @@ -1,197 +1,81 @@ package util import ( - "encoding/json" "fmt" - "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/util/gconv" ) -// ConvertToMessages 将原始数据转换为消息列表 -func ConvertToMessages(raw any) []map[string]any { - if raw == nil { - return nil - } - - j := gjson.New(raw) - messages := j.Get("messages") - if !messages.IsNil() { - return gconv.Maps(messages.Val()) - } - return []map[string]any{j.Map()} -} - -// FormToJSON 将表单数据转换为 JSON 字符串 -func FormToJSON(form []map[string]any) string { - if form == nil { - return "[]" - } - b, _ := json.Marshal(form) - return string(b) -} - -// UserFormToJSON 将用户表单数据转换为 JSON 字符串 -func UserFormToJSON(form []map[string]any) string { - if form == nil { - return "{}" - } - - b, _ := json.Marshal(form) - return string(b) -} - -// MustMarshalToMap 将对象序列化为 map[string]any,失败时返回空 map -func MustMarshalToMap(v any) map[string]any { - b, err := json.Marshal(v) - if err != nil { - return make(map[string]any) - } - var m map[string]any - json.Unmarshal(b, &m) - return m -} - -// JSONPretty 将任意类型转为格式化的 JSON 字符串 -func JSONPretty(v any) string { - if gv, ok := v.(*gvar.Var); ok { - v = gconv.Map(gv.String()) - } - - var tmp map[string]any - if err := gconv.Struct(v, &tmp); err != nil { - return gconv.String(v) - } - - b, _ := json.Marshal(tmp) - return string(b) -} - -// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析 -func ParseJSONFieldFromGvar(source any, target any) { - if source == nil { - return - } - - switch v := source.(type) { - case *gvar.Var: - if v.IsNil() { - return - } - - // 尝试获取 map - if m := v.Map(); len(m) > 0 { - data, _ := json.Marshal(m) - json.Unmarshal(data, target) - return - } - - // 尝试解析 JSON 字符串 - str := v.String() - if str != "" && str != "" { - json.Unmarshal([]byte(str), target) - } - - default: - // 其他类型走原来的逻辑 - data, _ := json.Marshal(source) - json.Unmarshal(data, target) - } -} - // MergeConsult 将 consult 附件合并到模型生成的 messages 结构中 func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any { if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 { return messages } - // 1) 获取 consult 数组 consult := gconv.Interfaces(req["consult"]) if len(consult) == 0 { return messages } - // 2) 获取配置 targetPath := gconv.String(extendMapping["target_content_path"]) - if targetPath == "" { - return messages - } - templates := gconv.Map(extendMapping["attachment_templates"]) - if len(templates) == 0 { + if targetPath == "" || len(templates) == 0 { return messages } - // 3) 转为 gjson 操作 msgJson := gjson.New(messages) - // 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath} - if arr := msgJson.Get("rounds.0").Array(); arr != nil { + // rounds 路径修正 + if !msgJson.Get("rounds.0").IsNil() { targetPath = "rounds.0." + targetPath } - // 4) 遍历 consult,按类型生成附件并追加 + // 遍历追加 for _, item := range consult { itemJson := gjson.New(item) - itemType := itemJson.Get("type").String() - if itemType == "" { - continue - } - - // 查找对应模板 tmpl := gconv.Map(templates[itemType]) - if len(tmpl) == 0 { + if itemType == "" || len(tmpl) == 0 { continue } - // 生成附件对象 attachment := buildAttachment(tmpl, itemJson.Get("url").String()) if attachment == nil { continue } - // 获取当前数组长度,用索引追加 - arr := msgJson.Get(targetPath).Array() - idx := len(arr) - indexPath := fmt.Sprintf("%s.%d", targetPath, idx) - _ = msgJson.Set(indexPath, attachment) + idx := len(msgJson.Get(targetPath).Array()) + _ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment) } return msgJson.Map() } -// buildAttachment 根据模板和用户数据生成附件对象 func buildAttachment(tmpl map[string]any, url string) map[string]any { typ := gconv.String(tmpl["type"]) if typ == "" || url == "" { return nil } - // 深拷贝 body 并填充 url body := gconv.Map(tmpl["body"]) - bodyJson := gjson.New(body) - bodyJson = fillEmpty(bodyJson, url) + fillEmptyInPlace(body, url) return map[string]any{ "type": typ, - typ: bodyJson.Map(), + typ: body, } } -// fillEmpty 递归查找空字符串并替换 -func fillEmpty(j *gjson.Json, value string) *gjson.Json { - m := j.Map() +func fillEmptyInPlace(m map[string]any, value string) { for k, v := range m { switch vv := v.(type) { case string: if vv == "" { - _ = j.Set(k, value) + m[k] = value } case map[string]any: - _ = j.Set(k, fillEmpty(gjson.New(vv), value).Map()) + fillEmptyInPlace(vv, value) } } - return j } diff --git a/common/util/mapping.go b/common/util/mapping.go index e2aea24..b8ade11 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -4,6 +4,7 @@ import ( "strings" "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/util/gconv" ) // ReverseMap 映射 payload 到 mapping @@ -20,80 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { return jsonObj.Map() } -// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构 +// ExtractUserText 从 messages 中提取所有 user 文本 func ExtractUserText(messages map[string]any) map[string]any { - var texts []string + msgJson := gjson.New(messages) - // 1) rounds 结构:遍历每轮 - if rounds, ok := messages["rounds"].([]any); ok { - for _, round := range rounds { - if rm, ok := round.(map[string]any); ok { - if msgs, ok := rm["messages"].([]any); ok { - texts = append(texts, extractTextFromRoleUser(msgs)...) + msgs := msgJson.Get("rounds.0.messages") + if msgs.IsNil() { + msgs = msgJson.Get("messages") + } + var texts []string + for _, m := range msgs.Array() { + msg := gjson.New(m) + if msg.Get("role").String() != "user" { + continue + } + content := msg.Get("content").Val() + switch c := content.(type) { + case string: + texts = append(texts, c) + case []any: + for _, item := range c { + if m, ok := item.(map[string]any); ok { + if t := gconv.String(m["text"]); t != "" { + texts = append(texts, t) + } } } } - } else if msgs, ok := messages["messages"].([]any); ok { - // 2) messages 结构 - texts = extractTextFromRoleUser(msgs) } - // 3) 构建返回结构 return map[string]any{ "role": "user", "content": strings.Join(texts, "\n"), } } - -// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本 -func extractTextFromRoleUser(msgs []any) []string { - var texts []string - for _, msg := range msgs { - m, ok := msg.(map[string]any) - if !ok { - continue - } - if role, _ := m["role"].(string); role != "user" { - continue - } - texts = append(texts, extractAllText(m["content"])...) - } - return texts -} - -// extractAllText 从 content 中提取所有文本(递归,最大兼容) -func extractAllText(content any) []string { - switch c := content.(type) { - case string: - return []string{c} - - case []any: - var texts []string - for _, item := range c { - m, ok := item.(map[string]any) - if !ok { - continue - } - if t, ok := m["text"].(string); ok && t != "" { - texts = append(texts, t) - continue - } - for _, v := range m { - texts = append(texts, extractAllText(v)...) - } - } - return texts - - case map[string]any: - if t, ok := c["text"].(string); ok && t != "" { - return []string{t} - } - var texts []string - for _, v := range c { - texts = append(texts, extractAllText(v)...) - } - return texts - } - - return nil -} diff --git a/consts/public/public.go b/consts/public/public.go index 36bb062..9e380c0 100644 --- a/consts/public/public.go +++ b/consts/public/public.go @@ -11,3 +11,7 @@ const ( BuildTypeNode = 2 //节点构建 BuildTypeStruct = 3 //结构构建 ) + +const ( + ModelTypeInference = 100 // 推理模型 +) diff --git a/controller/prompt_compose_controller.go b/controller/prompt_compose_controller.go index beaa272..96be2e4 100644 --- a/controller/prompt_compose_controller.go +++ b/controller/prompt_compose_controller.go @@ -26,8 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) { return promptService.GetComposeTask(ctx, req.TaskId) } - -// GetPromptText 纯文本prompt调用接口(测试专用) -func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) { - return promptService.GetPromptText(ctx, req) -} diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index b1fb382..bb0a2b8 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -25,12 +25,6 @@ type ComposeMessagesRes struct { TaskId string `json:"taskId" dc:"任务ID"` } -// MultiRoundResult 多轮返回结果 -type MultiRoundResult struct { - TotalRounds int `json:"total_rounds"` // 总轮数 - Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型) -} - type CallbackReq struct { g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"` TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` @@ -51,20 +45,11 @@ type GetComposeTaskReq struct { } type GetComposeTaskRes struct { - TaskId string `json:"taskId" dc:"任务ID"` - Status string `json:"status" dc:"业务状态"` - GatewayState int `json:"gatewayState" dc:"网关状态"` - ErrorMessage string `json:"errorMessage" dc:"错误信息"` - Messages any `json:"messages" dc:"最终消息数组"` - OssFile string `json:"ossFile" dc:"结果文件地址"` - FileType string `json:"fileType" dc:"结果文件类型"` -} - -type GetPromptTextReq struct { - g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"` - Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"` -} - -type GetPromptTextRes struct { - Messages any `json:"messages" dc:"历史消息"` + TaskId string `json:"taskId" dc:"任务ID"` + Status string `json:"status" dc:"业务状态"` + GatewayState int `json:"gatewayState" dc:"网关状态"` + ErrorMessage string `json:"errorMessage" dc:"错误信息"` + Messages map[string]any `json:"messages" dc:"最终消息数组"` + OssFile string `json:"ossFile" dc:"结果文件地址"` + FileType string `json:"fileType" dc:"结果文件类型"` } diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index a4aa396..2180784 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -13,11 +13,10 @@ import ( "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/encoding/gjson" - "github.com/gogf/gf/v2/util/gconv" ) // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) -func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) { +func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) { //1) 构建系统提示词 systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel) ir.AddSystem(systemPrompt) @@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai } // buildNodeTypeRequest 构建节点类型请求(BuildType=2) -func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { +func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) return compileToProviderRequest(ctx, ir, chatModel, req) } // buildStructTypeRequest 构建结构体类型请求(BuildType=3) -func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { - // 提取 userForm 中的 prompt 作为自定义提示词 - var customPrompt string - for _, item := range req.UserForm { - if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" { - customPrompt = gconv.String(prompt) - break - } - } - // 用户消息 +func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) { + customPrompt := gjson.New(req.UserForm).Get("0.prompt").String() ir.AddSystem(customPrompt) ir.AddUser(buildUserPrompt(ctx, req, "")) return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt) } // compileToProviderRequest 编译为 Provider 请求 -func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) { +func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) { protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName) if err != nil || protocol == nil { return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err) @@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate }, nil } +// promptBuildWithRounds 构建提示词 func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string { providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: chatModel.OperatorName, @@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, if err != nil || providerProtocol == nil { return "" } - outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) + outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString() return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, //【输出结构】 %s @@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, } // checkOverallContent 检查整体内容是否超出窗口 -func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool { +func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool { fullContent := ir.String() return util.CountToken(fullContent, model.TokenConfig) } @@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st return b.String() } -// buildUserFormText 构建用户表单内容字符串 func buildUserFormText(form []map[string]any) string { if len(form) == 0 { return "" @@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string { var builder strings.Builder for _, item := range form { for k, v := range item { + builder.WriteString(fmt.Sprintf("%s:\n", k)) switch val := v.(type) { case []any: - // 数组类型:逐条列出 - builder.WriteString(fmt.Sprintf("%s:\n", k)) for i, elem := range val { + builder.WriteString(fmt.Sprintf(" %d. ", i+1)) if m, ok := elem.(map[string]any); ok { - builder.WriteString(fmt.Sprintf(" %d. ", i+1)) for mk, mv := range m { builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv)) } - builder.WriteString("\n") } else { - builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem)) - } - } - case []map[string]any: - builder.WriteString(fmt.Sprintf("%s:\n", k)) - for i, m := range val { - builder.WriteString(fmt.Sprintf(" %d. ", i+1)) - for mk, mv := range m { - builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv)) + builder.WriteString(fmt.Sprint(elem)) } builder.WriteString("\n") } default: - builder.WriteString(fmt.Sprintf("%s:%v\n", k, v)) + builder.WriteString(fmt.Sprintf(" %v\n", v)) } } } @@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string { if promptTpl == "" { return "" } - - formStr := util.FormToJSON(req.Form) - userFormStr := util.UserFormToJSON(req.UserForm) - - return fmt.Sprintf(promptTpl, formStr, userFormStr) + return fmt.Sprintf(promptTpl, + gjson.New(req.Form).MustToJsonString(), + gjson.New(req.UserForm).MustToJsonString(), + ) } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 2a2c6e6..402f8d1 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -2,7 +2,6 @@ package prompt import ( "context" - "encoding/json" "errors" "fmt" "prompts-core/service/session" @@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e // handleBuild 通用构建处理 func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) { // 1) 处理表单分批 - processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel) + processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel) if err != nil { return nil, fmt.Errorf("处理用户表单分批失败: %w", err) } @@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai var taskReq map[string]any switch req.BuildType { case public.BuildTypePrompt: - taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches) + taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir) case public.BuildTypeNode: taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir) case public.BuildTypeStruct: @@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai SkillName: req.SkillName, BuildType: req.BuildType, CallbackUrl: req.CallbackUrl, - RequestPayload: util.MustMarshalToMap(req), + RequestPayload: gconv.Map(req), Status: public.ComposeStatusPending, }); err != nil { return nil, err @@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask return err } +// handleCallbackSuccess 处理回调成功 func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { // 1) 获取模型配置 model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ @@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas Status: 1, }) - // 3) 获取历史消息 + // 3) 获取历史消息 + 保存当前轮 payload := composeTask.RequestPayload sessionId := gconv.String(payload["sessionId"]) nodeId := gconv.String(payload["nodeId"]) var history []dto.FlatMessage - if sessionId != "" && nodeId != "" { + var epicycleId int64 + + if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference { + // 3.1 获取历史 h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ SessionId: sessionId, NodeId: nodeId, @@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if h != nil { history = h.Messages } + + // 3.2 保存当前轮(先存,下次查询就能拿到) + if userMsg := util.ExtractUserText(req.Messages); userMsg != nil { + epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + NodeId: nodeId, + SessionId: sessionId, + RequestContent: userMsg, + }) + } } // 4) 合并附加结构 messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) - // 5) 注入历史到 rounds 中 - if protocol != nil && len(history) > 0 { + // 5) 注入历史 + if len(history) > 0 { messages = InjectHistory(messages, history, protocol) } @@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas return err } - // 7) 存储历史 - var epicycleId int64 - if sessionId != "" && nodeId != "" { - if userMsg := util.ExtractUserText(req.Messages); userMsg != nil { - epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - NodeId: nodeId, - SessionId: sessionId, - RequestContent: userMsg, - }) - } - } - // 8) 回调业务方 if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusSuccess @@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas return nil } -// GetComposeTask 查询任务结果 -func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { - record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ - TaskId: taskID, - }) - if err != nil { - return nil, fmt.Errorf("查询任务失败: %w", err) - } - if record == nil { - return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) - } - - messages := parseMessagesForResponse(record.ResultJson) - - return &dto.GetComposeTaskRes{ - TaskId: record.TaskId, - Status: record.Status, - ErrorMessage: record.ErrorMessage, - Messages: messages, - }, nil -} - -// parseMessagesForResponse 解析用于响应的消息 -func parseMessagesForResponse(messages any) any { - str, ok := messages.(string) - if !ok || str == "" { - return messages - } - - var parsed any - if err := json.Unmarshal([]byte(str), &parsed); err == nil { - return parsed - } - - return messages -} - -func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) { - // 1) 获取协议配置 - protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ - ProviderName: "火山引擎", - Status: 1, - }) - if err != nil { - return nil, err - } - - // 2) 获取历史消息 - history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ - SessionId: "88888888", - NodeId: "node1", - }) - if err != nil { - return nil, err - } - - // 3) 模拟roundsData数据 - task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ - TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc", - }) - if err != nil { - return nil, err - } - fmt.Println("[打印数据]", task.ResultJson) - fmt.Println("[打印历史]", history.Messages) - fmt.Println("[打印协议]", protocol) - return &dto.GetPromptTextRes{ - Messages: InjectHistory(task.ResultJson, history.Messages, protocol), - }, nil -} - +// InjectHistory 插入历史会话 func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any { if protocol == nil || len(history) == 0 { return roundsData @@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco firstRound["messages"] = result return roundsData } + +// GetComposeTask 查询任务结果 +func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { + record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: taskID, + }) + if err != nil { + return nil, fmt.Errorf("查询任务失败: %w", err) + } + return &dto.GetComposeTaskRes{ + TaskId: record.TaskId, + Status: record.Status, + ErrorMessage: record.ErrorMessage, + Messages: record.ResultJson, + }, nil +} diff --git a/service/prompt/prompt_files_handle_service.go b/service/prompt/prompt_files_handle_service.go index 010e99d..f6180e5 100644 --- a/service/prompt/prompt_files_handle_service.go +++ b/service/prompt/prompt_files_handle_service.go @@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str } func SkillMdContent(ctx context.Context, skillName string) string { + if skillName == "" { + return "" + } skillResp, err := gateway.GetSkillUser(ctx, skillName) if err != nil { g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err) diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go index adf3cd0..0db6661 100644 --- a/service/prompt/prompt_ir_service.go +++ b/service/prompt/prompt_ir_service.go @@ -2,9 +2,7 @@ package prompt import ( "context" - "encoding/json" "fmt" - "prompts-core/common/util" "prompts-core/service/gateway" "strings" @@ -14,8 +12,8 @@ import ( "github.com/gogf/gf/v2/util/gconv" ) -// PromptIR 统一 Prompt 中间表示 -type PromptIR struct { +// IR 统一 Prompt 中间表示 +type IR struct { System []Segment `json:"system"` History []Segment `json:"history"` User []Segment `json:"user"` @@ -46,8 +44,8 @@ type ContentMapping struct { } // NewPromptIR 创建空 PromptIR -func NewPromptIR() *PromptIR { - return &PromptIR{ +func NewPromptIR() *IR { + return &IR{ System: make([]Segment, 0), History: make([]Segment, 0), User: make([]Segment, 0), @@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR { } // String 返回 PromptIR 的完整内容字符串(用于 token 计算) -func (ir *PromptIR) String() string { +func (ir *IR) String() string { var builder strings.Builder for _, seg := range ir.System { @@ -81,7 +79,7 @@ func (ir *PromptIR) String() string { } // GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算) -func (ir *PromptIR) GetTotalContent() string { +func (ir *IR) GetTotalContent() string { var builder strings.Builder for _, seg := range ir.System { @@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string { } // AddSystem 添加系统提示 -func (ir *PromptIR) AddSystem(content string) *PromptIR { +func (ir *IR) AddSystem(content string) *IR { if content != "" { ir.System = append(ir.System, Segment{Type: "text", Content: content}) } @@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR { } // AddUser 添加用户消息 -func (ir *PromptIR) AddUser(content string) *PromptIR { +func (ir *IR) AddUser(content string) *IR { if content != "" { ir.User = append(ir.User, Segment{Type: "text", Content: content}) } @@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR { } // AddHistory 添加历史消息 -func (ir *PromptIR) AddHistory(role, content string) *PromptIR { +func (ir *IR) AddHistory(role, content string) *IR { if content != "" { ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) } @@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR { } // ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认) -func (ir *PromptIR) ToMessages() []map[string]any { +func (ir *IR) ToMessages() []map[string]any { var messages []map[string]any for _, seg := range ir.System { @@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP // parseProtocol 将 DB entity 转为编译用协议配置 func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { - p := &ProviderProtocol{ + return &ProviderProtocol{ TargetField: e.TargetField, SystemPromptTemplate: e.SystemPromptTemplate, + MergeOrder: e.MergeOrder, + RoleMapping: gconv.MapStrStr(e.RoleMapping), + ContentMapping: ContentMapping{ + Type: gconv.String(e.ContentMapping["type"]), + Field: gconv.String(e.ContentMapping["field"]), + }, + RequestTemplate: e.RequestTemplate, + Capabilities: e.Capabilities, } - - // 使用通用解析方法处理各个字段 - util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder) - util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping) - util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping) - util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate) - util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities) - return p } // Compile 将 PromptIR 按协议配置编译为 Provider Request -func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) { +func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } @@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) } // mergeByOrder 按协议配置顺序拼接消息 -func mergeByOrder(ir *PromptIR, order []string) []map[string]any { - var messages []map[string]any - - for _, part := range order { - switch part { - case "system": - for _, seg := range ir.System { - messages = append(messages, map[string]any{ - "role": "system", - "content": seg.Content, - }) - } - case "history": - for _, seg := range ir.History { - messages = append(messages, map[string]any{ - "role": seg.Role, - "content": seg.Content, - }) - } - case "user": - for _, seg := range ir.User { - messages = append(messages, map[string]any{ - "role": "user", - "content": seg.Content, - }) - } - } +func mergeByOrder(ir *IR, order []string) []map[string]any { + roleMap := map[string][]Segment{ + "system": ir.System, + "history": ir.History, + "user": ir.User, } + var messages []map[string]any + for _, part := range order { + for _, seg := range roleMap[part] { + msg := map[string]any{"content": seg.Content} + if part == "history" { + msg["role"] = seg.Role + } else { + msg["role"] = part + } + messages = append(messages, msg) + } + } return messages } @@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string return messages } -// mapContent 内容字段映射 func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { - for _, msg := range messages { - content := msg["content"] - delete(msg, "content") - - switch cm.Type { - case "parts": - msg["parts"] = []map[string]any{ - {cm.Field: content}, - } - default: - msg[cm.Field] = content - } + if cm.Field == "" || cm.Field == "content" { + return messages } + for i, msg := range messages { + if content, ok := msg["content"]; ok { + delete(msg, "content") + switch cm.Type { + case "parts": + messages[i]["parts"] = []map[string]any{{cm.Field: content}} + default: + messages[i][cm.Field] = content + } + } + } return messages } @@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat } } -// renderTemplate 简单的 {{key}} 模板替换 +// renderTemplate 模板渲染 func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { - b, _ := json.Marshal(p.RequestTemplate) - str := string(b) - - if chatModel != nil { - str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`) + result := make(map[string]any, len(p.RequestTemplate)+1) + for k, v := range p.RequestTemplate { + result[k] = v } - msgBytes, _ := json.Marshal(messages) - str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) - - var result map[string]any - _ = json.Unmarshal([]byte(str), &result) + if chatModel != nil { + result["model"] = chatModel.ModelName + } + result["messages"] = messages if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 { result["max_tokens"] = maxTokens diff --git a/service/session/prompt_session_service.go b/service/session/prompt_session_service.go index f60bf3b..3d44239 100644 --- a/service/session/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -21,8 +21,8 @@ import ( // Callback 会话回调 func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { + fmt.Println("打印会话回调", req) req.Messages["role"] = "assistant" - // 1) 更新 DB _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, @@ -163,23 +163,15 @@ func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteS // entityToHistoryRound entity → HistoryRound func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound { - reqMsgs := util.ConvertToMessages(s.RequestContent) - respMsgs := util.ConvertToMessages(s.ResponseContent) - - round := &dto.HistoryRound{ + return &dto.HistoryRound{ Id: s.Id, SessionId: s.SessionId, NodeId: s.NodeId, CreatedAt: gconv.String(s.CreatedAt), UpdatedAt: gconv.String(s.UpdatedAt), + User: s.RequestContent, + Assistant: s.ResponseContent, } - if len(reqMsgs) > 0 { - round.User = reqMsgs[0] - } - if len(respMsgs) > 0 { - round.Assistant = respMsgs[0] - } - return round } // sessionsToHistoryRounds 批量转换