From e1461cf0f01d1c48e88fa0eb800c760fc8cb8104 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Mon, 8 Jun 2026 18:01:54 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=AD=97=E6=AE=B5=E5=B9=B6=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/json.go | 178 +++++++----------- common/util/mapping.go | 140 +++++++------- controller/prompt_compose_controller.go | 5 + dao/compose_session_dao.go | 1 + go.mod | 6 +- go.sum | 13 +- model/dto/prompt_compose_dto.go | 9 + service/gateway/gateway_http_service.go | 3 +- service/prompt/prompt_build_service.go | 23 +-- service/prompt/prompt_compose_service.go | 146 ++++---------- .../session/prompt_session_redis_service.go | 4 - service/session/prompt_session_service.go | 26 ++- 12 files changed, 219 insertions(+), 335 deletions(-) diff --git a/common/util/json.go b/common/util/json.go index 3894ec7..86a97d0 100644 --- a/common/util/json.go +++ b/common/util/json.go @@ -2,13 +2,11 @@ package util import ( "encoding/json" - "strconv" + "fmt" "github.com/gogf/gf/v2/container/gvar" - gfgjson "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/util/gconv" - tGjson "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) // ConvertToMessages 将原始数据转换为消息列表 @@ -17,7 +15,7 @@ func ConvertToMessages(raw any) []map[string]any { return nil } - j := gfgjson.New(raw) + j := gjson.New(raw) messages := j.Get("messages") if !messages.IsNil() { return gconv.Maps(messages.Val()) @@ -66,7 +64,7 @@ func JSONPretty(v any) string { return gconv.String(v) } - b, _ := json.MarshalIndent(tmp, "", " ") + b, _ := json.Marshal(tmp) return string(b) } @@ -102,132 +100,98 @@ func ParseJSONFieldFromGvar(source any, target any) { } } -// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中。 -// -// 参数说明: -// - req: 请求参数 map,需包含 "consult" 字段,值为 []any,每个元素是 {"type":"xxx","url":"..."} -// - messages: 模型生成的返回结构(如 rounds[...].messages[...].content 数组) -// - extendMapping: 附加映射配置,格式: -// {"attachments": {"image": {"template": {...}, "target_path": "...", "field_mapping": {...}}, ...}} -// -// 返回值:合并后的完整 map。 +// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中 func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any { if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 { return messages } - reqJSON, _ := json.Marshal(req) - msgJSON, _ := json.Marshal(messages) - extJSON, _ := json.Marshal(extendMapping) - - reqStr := string(reqJSON) - msgStr := string(msgJSON) - extStr := string(extJSON) - - // 获取 consult 数组 - consultResult := tGjson.Get(reqStr, "consult") - if !consultResult.Exists() || !consultResult.IsArray() { + // 1) 获取 consult 数组 + consult := gconv.Interfaces(req["consult"]) + if len(consult) == 0 { return messages } - // 获取 attachments 配置 - attachmentsResult := tGjson.Get(extStr, "attachments") - if !attachmentsResult.Exists() || !attachmentsResult.IsObject() { + // 2) 获取配置 + targetPath := gconv.String(extendMapping["target_content_path"]) + if targetPath == "" { return messages } - consultArr := consultResult.Array() - attachmentsMap := attachmentsResult.Map() + templates := gconv.Map(extendMapping["attachment_templates"]) + if len(templates) == 0 { + return messages + } - for _, consultItem := range consultArr { - if !consultItem.IsObject() { - continue - } + // 3) 转为 gjson 操作 + msgJson := gjson.New(messages) - itemType := consultItem.Get("type").String() + // 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath} + if arr := msgJson.Get("rounds.0").Array(); arr != nil { + targetPath = "rounds.0." + targetPath + } + + // 4) 遍历 consult,按类型生成附件并追加 + for _, item := range consult { + itemJson := gjson.New(item) + + itemType := itemJson.Get("type").String() if itemType == "" { continue } - // 查找对应类型的附件配置 - attachResult, ok := attachmentsMap[itemType] - if !ok || !attachResult.IsObject() { + // 查找对应模板 + tmpl := gconv.Map(templates[itemType]) + if len(tmpl) == 0 { continue } - // 获取模板 - templateResult := attachResult.Get("template") - if !templateResult.Exists() || !templateResult.IsObject() { + // 生成附件对象 + attachment := buildAttachment(tmpl, itemJson.Get("url").String()) + if attachment == nil { continue } - // 深拷贝模板 - filledTemplateStr := templateResult.Raw + // 获取当前数组长度,用索引追加 + arr := msgJson.Get(targetPath).Array() + idx := len(arr) + indexPath := fmt.Sprintf("%s.%d", targetPath, idx) + _ = msgJson.Set(indexPath, attachment) + } - // 应用字段映射 - fieldMappingResult := attachResult.Get("field_mapping") - if fieldMappingResult.Exists() && fieldMappingResult.IsObject() { - fieldMapping := fieldMappingResult.Map() - for fieldPath, valueSource := range fieldMapping { - sourceKey := valueSource.String() - valueResult := consultItem.Get(sourceKey) - if valueResult.Exists() { - var err error - filledTemplateStr, err = sjson.SetRaw(filledTemplateStr, fieldPath, valueResult.Raw) - if err != nil { - continue - } - } + return msgJson.Map() +} + +// buildAttachment 根据模板和用户数据生成附件对象 +func buildAttachment(tmpl map[string]any, url string) map[string]any { + typ := gconv.String(tmpl["type"]) + if typ == "" || url == "" { + return nil + } + + // 深拷贝 body 并填充 url + body := gconv.Map(tmpl["body"]) + bodyJson := gjson.New(body) + bodyJson = fillEmpty(bodyJson, url) + + return map[string]any{ + "type": typ, + typ: bodyJson.Map(), + } +} + +// fillEmpty 递归查找空字符串并替换 +func fillEmpty(j *gjson.Json, value string) *gjson.Json { + m := j.Map() + for k, v := range m { + switch vv := v.(type) { + case string: + if vv == "" { + _ = j.Set(k, value) } - } - - // 获取目标路径 - targetPath := attachResult.Get("target_path").String() - if targetPath == "" { - continue - } - - // 检查目标路径是否存在且为数组 - targetResult := tGjson.Get(msgStr, targetPath) - if !targetResult.Exists() || !targetResult.IsArray() { - continue - } - - // 追加到数组末尾 - arrLen := len(targetResult.Array()) - appendPath := targetPath + "." + strconv.Itoa(arrLen) - var err error - msgStr, err = sjson.SetRaw(msgStr, appendPath, filledTemplateStr) - if err != nil { - continue + case map[string]any: + _ = j.Set(k, fillEmpty(gjson.New(vv), value).Map()) } } - - // 转回 map[string]any - var result map[string]any - if err := json.Unmarshal([]byte(msgStr), &result); err != nil { - return messages - } - return result -} - -// GetUserMessage 获取用户消息 -func GetUserMessage(taskReq map[string]any) map[string]any { - // 先取 requestPayload - rp, ok := taskReq["requestPayload"].(map[string]any) - if !ok { - return nil - } - // 再取 messages - messages, ok := rp["messages"].([]any) - if !ok { - return nil - } - for _, msg := range messages { - m, ok := msg.(map[string]any) - if ok && m["role"] == "user" { - return m - } - } - return nil + return j } diff --git a/common/util/mapping.go b/common/util/mapping.go index 60ee815..e2aea24 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -1,11 +1,9 @@ package util import ( - "net/url" "strings" "github.com/gogf/gf/v2/encoding/gjson" - "github.com/gogf/gf/v2/util/gconv" ) // ReverseMap 映射 payload 到 mapping @@ -22,86 +20,80 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { return jsonObj.Map() } -// MapResponsePayload 映射模型响应为标准格式 -func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) { - if len(mapping) == 0 { - return responseBytes, nil - } +// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构 +func ExtractUserText(messages map[string]any) map[string]any { + var texts []string - responseJson := gjson.New(responseBytes) - resultJson := gjson.New("{}") - - for standardField, modelPath := range mapping { - path := gconv.String(modelPath) - if path == "" { - continue - } - val := responseJson.Get(path) - if val.IsNil() { - continue - } - resultJson.Set(standardField, val.Val()) - } - - return []byte(resultJson.String()), nil -} - -// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: -// 示例: -// - X-API-Key:qwen3-tts-key,operation:true,count:123 -// - X-API-Key:"qwen3-tts-key",operation:"true" -// -// 说明: -// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 -// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 -func ParseHeadMsgHeaders(headMsg string) map[string]string { - headMsg = strings.TrimSpace(headMsg) - if headMsg == "" { - return nil - } - out := map[string]string{} - parts := strings.Split(headMsg, ",") - for _, p := range parts { - p = strings.TrimSpace(p) - if p == "" { - continue - } - // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) - if strings.Contains(p, ":") { - kv := strings.SplitN(p, ":", 2) - k := strings.TrimSpace(kv[0]) - v := strings.TrimSpace(kv[1]) - v = strings.Trim(v, "\"") - if k != "" && v != "" { - out[k] = v + // 1) rounds 结构:遍历每轮 + if rounds, ok := messages["rounds"].([]any); ok { + for _, round := range rounds { + if rm, ok := round.(map[string]any); ok { + if msgs, ok := rm["messages"].([]any); ok { + texts = append(texts, extractTextFromRoleUser(msgs)...) + } } - continue - } - if strings.Contains(p, "=") { - kv := strings.SplitN(p, "=", 2) - k := strings.TrimSpace(kv[0]) - v := strings.TrimSpace(kv[1]) - v = strings.Trim(v, "\"") - if k != "" && v != "" { - out[k] = v - } - continue } + } else if msgs, ok := messages["messages"].([]any); ok { + // 2) messages 结构 + texts = extractTextFromRoleUser(msgs) } - if len(out) == 0 { - return nil + + // 3) 构建返回结构 + return map[string]any{ + "role": "user", + "content": strings.Join(texts, "\n"), } - return out } -// PayloadToQuery 将 payload 转为 url.Values -func PayloadToQuery(payload map[string]any) (url.Values, error) { - q := url.Values{} - for k, v := range payload { - if v == nil { +// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本 +func extractTextFromRoleUser(msgs []any) []string { + var texts []string + for _, msg := range msgs { + m, ok := msg.(map[string]any) + if !ok { continue } - q.Set(k, gconv.String(v)) + if role, _ := m["role"].(string); role != "user" { + continue + } + texts = append(texts, extractAllText(m["content"])...) } - return q, nil + return texts +} + +// extractAllText 从 content 中提取所有文本(递归,最大兼容) +func extractAllText(content any) []string { + switch c := content.(type) { + case string: + return []string{c} + + case []any: + var texts []string + for _, item := range c { + m, ok := item.(map[string]any) + if !ok { + continue + } + if t, ok := m["text"].(string); ok && t != "" { + texts = append(texts, t) + continue + } + for _, v := range m { + texts = append(texts, extractAllText(v)...) + } + } + return texts + + case map[string]any: + if t, ok := c["text"].(string); ok && t != "" { + return []string{t} + } + var texts []string + for _, v := range c { + texts = append(texts, extractAllText(v)...) + } + return texts + } + + return nil } diff --git a/controller/prompt_compose_controller.go b/controller/prompt_compose_controller.go index 96be2e4..beaa272 100644 --- a/controller/prompt_compose_controller.go +++ b/controller/prompt_compose_controller.go @@ -26,3 +26,8 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) { return promptService.GetComposeTask(ctx, req.TaskId) } + +// GetPromptText 纯文本prompt调用接口(测试专用) +func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) { + return promptService.GetPromptText(ctx, req) +} diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index cdd4caf..d58aaba 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -54,6 +54,7 @@ func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession OmitEmpty() model.Where(entity.ComposeSessionCol.Creator, req.Creator) model.Where(entity.ComposeSessionCol.SessionId, req.SessionId) + model.Where(entity.ComposeSessionCol.NodeId, req.NodeId) model.OrderDesc(entity.ComposeSessionCol.CreatedAt) model.Page(page, size) r, total, err := model.AllAndCount(false) diff --git a/go.mod b/go.mod index b2e668e..e9f4a1f 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,10 @@ module prompts-core go 1.26.1 require ( - gitea.com/red-future/common v0.0.20 + gitea.com/red-future/common v0.0.21 github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 github.com/gogf/gf/v2 v2.10.2 - github.com/tidwall/gjson v1.19.0 - github.com/tidwall/sjson v1.2.5 ) require ( @@ -65,8 +63,6 @@ require ( github.com/r3labs/diff/v2 v2.15.1 // indirect github.com/redis/go-redis/v9 v9.12.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect github.com/tiger1103/gfast-token v1.0.10 // indirect github.com/vcaesar/cedar v0.30.0 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect diff --git a/go.sum b/go.sum index a24f1f7..edd9261 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -gitea.com/red-future/common v0.0.20 h1:KlKINnJFmOVkDzgkptEAFsdpMUZb0zK9BTdiXRxVfAo= -gitea.com/red-future/common v0.0.20/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo= +gitea.com/red-future/common v0.0.21 h1:8w30HmCVmFG/hphH3ODJs1KxDEGmRpq+/PXI0pQjJKc= +gitea.com/red-future/common v0.0.21/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= @@ -288,15 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU= -github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s= github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index fd41d55..cfc6580 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -59,3 +59,12 @@ type GetComposeTaskRes struct { OssFile string `json:"ossFile" dc:"结果文件地址"` FileType string `json:"fileType" dc:"结果文件类型"` } + +type GetPromptTextReq struct { + g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"` + Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"` +} + +type GetPromptTextRes struct { + Messages any `json:"messages" dc:"最终消息数组"` +} diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 0f3c7fd..e466ba6 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -58,8 +58,7 @@ type AsynchModel struct { ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` IsPrivate *int `orm:"is_private" json:"isPrivate"` IsChatModel int `orm:"is_chat_model" json:"isChatModel"` - IsAsync *int `orm:"is_async" json:"isAsync"` - IsStream *int `orm:"is_stream" json:"isStream"` + CallModel int `orm:"call_model" json:"callModel"` ApiKey string `orm:"api_key" json:"apiKey"` Enabled *int `orm:"enabled" json:"enabled"` MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 65a10b8..a4aa396 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -28,13 +28,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) } - return compileToProviderRequest(ctx, ir, chatModel) + return compileToProviderRequest(ctx, ir, chatModel, req) } // buildNodeTypeRequest 构建节点类型请求(BuildType=2) func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) - return compileToProviderRequest(ctx, ir, chatModel) + return compileToProviderRequest(ctx, ir, chatModel, req) } // buildStructTypeRequest 构建结构体类型请求(BuildType=3) @@ -50,18 +50,20 @@ func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ch // 用户消息 ir.AddSystem(customPrompt) ir.AddUser(buildUserPrompt(ctx, req, "")) - return compileToProviderRequest(ctx, ir, chatModel, customPrompt) + return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt) } // compileToProviderRequest 编译为 Provider 请求 -func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) { +func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) { protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName) if err != nil || protocol == nil { return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err) } // 如果传了自定义提示词,替换掉协议模板 if len(customPrompt) > 0 && customPrompt[0] != "" { - protocol.SystemPromptTemplate = customPrompt[0] + protocol.SystemPromptTemplate = customPrompt[0] + + "【核心铁律】" + + "1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。" } providerReq, err := Compile(ir, protocol, chatModel) if err != nil { @@ -72,6 +74,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate "bizName": util.GetServerName(ctx), "callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"), "requestPayload": providerReq, + "buildType": req.BuildType, }, nil } @@ -84,20 +87,12 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, return "" } outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) + return fmt.Sprintf(providerProtocol.SystemPromptTemplate, outputJSON, //【输出结构】 %s ) } -// buildUserFormContent 构建用户表单内容字符串 -func buildUserFormContent(userForm []map[string]any) string { - var builder strings.Builder - for _, item := range userForm { - builder.WriteString(fmt.Sprintf("%v\n", item)) - } - return builder.String() -} - // checkOverallContent 检查整体内容是否超出窗口 func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool { fullContent := ir.String() diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 2a6ac9d..929a4cc 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "prompts-core/service/session" "prompts-core/common/util" "prompts-core/consts/public" @@ -173,24 +174,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if err != nil { return fmt.Errorf("查询模型失败: %w", err) } - // 2) 根据运营商获取协议配置 - //protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ - // ProviderName: model.OperatorName, - //}) - // 2) 解析结果 - var messages map[string]any - switch composeTask.BuildType { - case public.BuildTypePrompt, public.BuildTypeNode: - messages = ParseResult(req.Messages, model.ResponseBody) - case public.BuildTypeStruct: - messages = ParseStructResult(req.Messages, model.ResponseBody) - default: - messages = req.Messages - } - // 3) 合并附加结构 - messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping) - // 4) 更新数据库 + // 2) 合并附加结构 + messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) + // 3) 更新数据库 _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ TaskId: req.TaskId, Status: public.ComposeStatusSuccess, @@ -203,21 +190,31 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas if err != nil { return err } - // 5) 存储提示词结果作为历史请求 + //var userHistoryMsg map[string]any var epicycleId int64 payload := composeTask.RequestPayload sessionId := gconv.String(payload["sessionId"]) nodeId := gconv.String(payload["nodeId"]) buildType := gconv.Int(payload["buildType"]) if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" { - epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - NodeId: nodeId, - SessionId: sessionId, - RequestContent: messages, - }) + // 4) 获取历史内容并拼接 + history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId) + for _, msg := range history { + role := gconv.String(msg["role"]) + if role != "user" && role != "assistant" { + continue + } + } + // 5) 存储提示词结果作为历史请求 + if userMsg := util.ExtractUserText(messages); userMsg != nil { + epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + NodeId: nodeId, + SessionId: sessionId, + RequestContent: userMsg, + }) + } } - // 6) 拼接历史内容 - // 7) 回调业务方 + // 6) 回调业务方 if composeTask.CallbackUrl != "" { composeTask.Status = public.ComposeStatusSuccess composeTask.Messages = messages @@ -226,95 +223,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas return nil } -// ParseResult 解析结果 -func ParseResult(raw map[string]any, responseBody string) map[string]any { - if responseBody == "" { - return raw - } - - contentVal := raw[responseBody] - if contentVal == nil { - return raw - } - - // 已经是数组 - if arr, ok := contentVal.([]any); ok { - rounds := gconv.Maps(arr) - if len(rounds) > 0 { - return map[string]any{"total_rounds": len(rounds), "rounds": rounds} - } - return raw - } - - // 是字符串 - contentStr := gconv.String(contentVal) - if contentStr == "" { - return raw - } - - // 尝试解析为数组 - var arr []map[string]any - if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 { - return map[string]any{"total_rounds": len(arr), "rounds": arr} - } - - // 尝试解析为单对象 - var obj map[string]any - if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 { - return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}} - } - - return map[string]any{"content": contentStr} -} - -func ParseStructResult(raw map[string]any, responseBody string) map[string]any { - // 如果外层已有 rounds,直接返回 - if _, ok := raw["rounds"]; ok { - return raw - } - - contentVal := raw[responseBody] - - var rounds []map[string]any - - // 是字符串,尝试解析 - contentStr := gconv.String(contentVal) - if contentStr == "" || contentStr == "0" { - rounds = append(rounds, map[string]any{responseBody: raw}) - return map[string]any{ - "total_rounds": 1, - "rounds": rounds, - } - } - - // 尝试解析为数组 - var arr []any - if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 { - rounds = append(rounds, map[string]any{responseBody: arr}) - return map[string]any{ - "total_rounds": len(rounds), - "rounds": rounds, - } - } - - // 尝试解析为单个对象 - var parsed any - if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil { - rounds = append(rounds, map[string]any{responseBody: parsed}) - return map[string]any{ - "total_rounds": 1, - "rounds": rounds, - } - } - - // 兜底:原始字符串作为内容 - rounds = append(rounds, map[string]any{responseBody: contentStr}) - return map[string]any{ - "total_rounds": 1, - "rounds": rounds, - } -} - // GetComposeTask 查询任务结果 func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ @@ -351,3 +259,13 @@ func parseMessagesForResponse(messages any) any { return messages } + +func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) { + // 1) 获取基础数据 + + // 4) 模拟历史拼接 + history, _ := session.GetHistoryMessages(ctx, "88888888", "node1") + return &dto.GetPromptTextRes{ + Messages: history, + }, nil +} diff --git a/service/session/prompt_session_redis_service.go b/service/session/prompt_session_redis_service.go index 0599e42..bec6553 100644 --- a/service/session/prompt_session_redis_service.go +++ b/service/session/prompt_session_redis_service.go @@ -34,7 +34,6 @@ func saveToRedis(ctx context.Context, session *entity.ComposeSession) error { if err != nil { return fmt.Errorf("序列化会话数据失败: %w", err) } - if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { return err } @@ -50,18 +49,15 @@ func executeRedisCommands(ctx context.Context, key string, value string, maxRoun if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil { return fmt.Errorf("裁剪Redis列表失败: %w", err) } - if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { return fmt.Errorf("设置过期时间失败: %w", err) } - return nil } // getFromRedis 从Redis获取会话历史 func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { key := formatRedisKey(sessionId) - result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) if err != nil { return nil, fmt.Errorf("从Redis获取数据失败: %w", err) diff --git a/service/session/prompt_session_service.go b/service/session/prompt_session_service.go index 9933457..4e2c471 100644 --- a/service/session/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -47,15 +47,35 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal } // GetHistoryMessages 获取历史信息 -func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) { +func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) { + // 1) 获取最大轮次 maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() + // 2) 从 Redis 获取历史记录 redisHistory, err := GetSessionHistoryForInference(ctx, sessionId) if err == nil && len(redisHistory) > 0 { return redisHistory, nil } - return getHistoryFromDatabase(ctx, sessionId, maxRounds) + // 3) Redis 没有,从数据库查最新 maxRounds 条 + sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ + SessionId: sessionId, + NodeId: nodeId, + }, 1, maxRounds) + if err != nil { + return nil, fmt.Errorf("DB获取历史失败: %w", err) + } + // 4) 为空返回报错 + if len(sessions) == 0 { + return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId) + } + // 5) 提取为统一格式 + messages := extractMessagesFromSessions(sessions) + + // 6) 缓存 Redis 半小时 + //_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute) + + return messages, nil } // getHistoryFromDatabase 从数据库获取历史记录 @@ -77,12 +97,10 @@ func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int // extractMessagesFromSessions 从会话列表中提取消息 func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any { var messages []map[string]any - for _, session := range sessions { appendRequestMessages(session.RequestContent, &messages) appendResponseMessages(session.ResponseContent, &messages) } - return messages }