From 92092575bccc7abfdcfe8c8de7928a0da47bd516 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Fri, 22 May 2026 09:49:46 +0800 Subject: [PATCH] =?UTF-8?q?feat(prompt):=20=E9=87=8D=E6=9E=84=E6=8F=90?= =?UTF-8?q?=E7=A4=BA=E8=AF=8D=E6=9C=8D=E5=8A=A1=E5=B9=B6=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=B1=BB=E5=9E=8B=E5=AD=90=E5=88=86=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.yml | 12 +- model/dto/prompt_compose_dto.go | 25 +- model/entity/prompts_compose_task.go | 57 ++- service/gateway/gateway_http_service.go | 80 ++++ service/prompt/prompt_build_service.go | 32 +- service/prompt/prompt_compose_service.go | 543 ++++++++--------------- service/prompt/prompt_session_service.go | 3 +- service/prompt/prompt_task_waiter.go | 125 ------ update.sql | 16 +- 9 files changed, 353 insertions(+), 540 deletions(-) delete mode 100644 service/prompt/prompt_task_waiter.go diff --git a/config.yml b/config.yml index b28d081..9f501b9 100644 --- a/config.yml +++ b/config.yml @@ -60,7 +60,7 @@ jaeger: addr: 192.168.3.30:4318 task: - waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒) + waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒) session: maxRounds: 10 # 最大轮数 @@ -81,23 +81,23 @@ promptsRetry: modelPrompts: types: - 1: | + 100: | 你是一个智能文字处理助手,专注于文本理解、文本创作、文本优化与语言表达任务,能够根据不同场景完成文章撰写、商业文案、报告总结、邮件通知、脚本创作、内容改写、信息提炼、语言翻译等多种文字处理工作,并能够理解上下文语义关系,保持内容逻辑完整、结构清晰、表达自然。 在执行文本任务时,你需要以专业内容创作者、编辑顾问、语言优化专家的身份完成输出,严格保证语言准确性、逻辑连贯性、表达一致性与阅读体验,根据不同用户场景自动适配正式、口语化、专业化、营销化等表达风格,同时避免空洞表达、重复描述与机械化生成内容。 当用户提供具体需求时,需要结合用户输入、上下文信息、参数条件与目标场景生成最终文本结果;若涉及改写、扩写、摘要、总结、标题、营销内容等任务,需要保证核心语义不偏离,并根据用户真实目的完成结构化输出。 - 2: | + 200: | 你是一个智能图片处理助手,专注于视觉内容生成、图像编辑、画面分析与风格控制任务,能够根据文字描述生成不同风格的图片内容,包括写实、插画、动漫、水彩、电影感、商业海报等多种视觉形式,并支持图片局部修改、风格迁移、画面扩展、背景处理与视觉增强等操作。 在执行图片相关任务时,你需要以专业视觉设计师、插画师、摄影指导、美术导演的身份进行画面构建,重点关注主体构图、色彩关系、光影氛围、镜头语言、视觉层次与整体风格统一性,确保生成结果具备明确视觉主题与稳定审美表现,而不是简单关键词堆砌。 当用户提供图片需求时,需要结合用户描述、场景用途、风格方向、尺寸比例、主体元素、氛围要求等信息生成完整视觉方案;若存在图片编辑任务,则必须保留原图核心特征,仅对用户指定区域或效果进行修改。 - 3: | + 300: | 你是一个智能音频处理助手,专注于语音生成、语音识别、音频分析与声音编辑任务,能够完成文字转语音、语音转文本、多语言识别、音频降噪、音色处理、混音剪辑、情绪识别与声音特征分析等多种音频相关工作,并能够根据不同场景匹配对应语音风格与声音表现形式。 在执行音频任务时,你需要以专业配音导演、声音工程师、语音分析专家、后期音频制作人员的身份进行处理,重点保证语音自然度、情绪一致性、识别准确率、音频清晰度与输出稳定性,同时确保不同格式、采样率与播放场景下具备良好兼容性。 当用户提供具体音频需求时,需要结合音色、语速、语言类型、情绪风格、背景环境、输出格式等参数完成对应处理;若涉及语音识别或音频分析,则需要尽可能保留原始语义与声音特征,并明确标注不确定内容。 - 4: | + 400: | 你是一个智能向量化处理助手,专注于文本向量化、语义检索、知识索引、相似度计算与语义聚类任务,能够将文本内容转换为高维语义向量,并基于向量相似度完成语义搜索、知识召回、内容聚类、文档匹配与知识库构建等处理流程。 在执行向量化任务时,你需要以语义检索工程师、知识库架构师、AI检索系统专家的身份进行处理,重点保证语义表达准确性、向量一致性、检索稳定性与召回有效性,同时确保不同文本之间的语义关系能够被正确表达与计算。 当用户提供文本集合、知识内容或检索需求时,需要结合文本上下文、主题方向、检索目标、相似度要求与业务场景生成最终结果;若涉及聚类或知识库构建,则必须明确类别关系、索引结构与召回逻辑。 - 5: | + 500: | 你是一个全模态智能处理助手,能够同时理解、分析与生成文本、图片、音频、视频等多种模态内容,并支持跨模态转换、多模态融合推理、联合内容生成与复杂场景交互,能够根据不同输入形式自动匹配最合理的处理策略与输出方式。 在执行多模态任务时,你需要以全链路AI内容架构师、多模态交互专家、综合内容生成系统的身份完成处理,重点保证不同模态之间的语义一致性、风格统一性、信息完整性与交互连贯性,避免出现跨模态语义断裂或输出不一致的问题。 当用户提供混合输入内容时,需要结合文本、图片、音频、视频等多种信息共同分析用户真实目标,并根据任务场景自动决定最终输出形式;若涉及跨模态生成,则必须保证生成结果能够准确映射原始语义与核心信息。 diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index f03cd37..ffe076f 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -3,21 +3,26 @@ package dto import "github.com/gogf/gf/v2/frame/g" type ComposeMessagesReq struct { - g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"` - ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` - BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点 - SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"` - Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"` - Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` - UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"` - SkillName string `p:"skillName" json:"skillName" dc:"技能名称"` - UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"` + g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"` + ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"` + BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点 + SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"` + Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"` + CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"` + Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"` + UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"` + SkillName string `p:"skillName" json:"skillName" dc:"技能名称"` + UserFiles []string `p:"userFiles" json:"userFiles" dc:"用户附件地址列表"` } type ComposeMessagesRes struct { + TaskId string `json:"taskId" dc:"任务ID"` +} + +/* Messages *MultiRoundResult `json:"messages,omitempty" dc:"最终消息数组"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` -} +*/ // MultiRoundResult 多轮返回结果 type MultiRoundResult struct { diff --git a/model/entity/prompts_compose_task.go b/model/entity/prompts_compose_task.go index b81715a..2645045 100644 --- a/model/entity/prompts_compose_task.go +++ b/model/entity/prompts_compose_task.go @@ -7,39 +7,48 @@ type ComposeTask struct { TaskId string `orm:"task_id" json:"taskId"` ModelName string `orm:"model_name" json:"modelName"` SkillName string `orm:"skill_name" json:"skillName"` - LimitWords int `orm:"limit_words" json:"limitWords"` + BuildType int `orm:"build_type" json:"buildType"` + CallbackUrl string `orm:"callback_url" json:"callbackUrl"` + GatewayState int `orm:"gateway_state" json:"gatewayState"` RequestPayload any `orm:"request_payload" json:"requestPayload"` - CallbackPayload any `orm:"callback_payload" json:"callbackPayload"` - ModelResult any `orm:"model_result" json:"modelResult"` + ResultText string `orm:"result_text" json:"resultText"` Messages any `orm:"messages" json:"messages"` Status string `orm:"status" json:"status"` ErrorMessage string `orm:"error_message" json:"errorMessage"` + OssFile string `orm:"oss_file" json:"ossFile"` + FileType string `orm:"file_type" json:"fileType"` } type composeTaskCol struct { beans.SQLBaseCol - TaskId string - ModelName string - SkillName string - LimitWords string - RequestPayload string - CallbackPayload string - ModelResult string - Messages string - Status string - ErrorMessage string + TaskId string + ModelName string + SkillName string + BuildType string + CallbackUrl string + GatewayState string + RequestPayload string + ResultText string + Messages string + Status string + ErrorMessage string + OssFile string + FileType string } var ComposeTaskCol = composeTaskCol{ - SQLBaseCol: beans.DefSQLBaseCol, - TaskId: "task_id", - ModelName: "model_name", - SkillName: "skill_name", - LimitWords: "limit_words", - RequestPayload: "request_payload", - CallbackPayload: "callback_payload", - ModelResult: "model_result", - Messages: "messages", - Status: "status", - ErrorMessage: "error_message", + SQLBaseCol: beans.DefSQLBaseCol, + TaskId: "task_id", + ModelName: "model_name", + SkillName: "skill_name", + BuildType: "build_type", + CallbackUrl: "callback_url", + GatewayState: "gateway_state", + RequestPayload: "request_payload", + ResultText: "result_text", + Messages: "messages", + Status: "status", + ErrorMessage: "error_message", + OssFile: "oss_file", + FileType: "file_type", } diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 9c35885..bdc755c 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "prompts-core/common/util" + "prompts-core/model/entity" commonHttp "gitea.com/red-future/common/http" + "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" ) @@ -75,3 +77,81 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) { } return &resp, nil } + +// SendCallbackReq 发送回调的请求体 +type SendCallbackReq struct { + TaskId string `json:"taskId"` + Status string `json:"status"` + Messages *MultiRoundResult `json:"messages,omitempty"` + EpicycleId int64 `json:"epicycleId"` + ErrorMsg string `json:"errorMsg,omitempty"` +} +type MultiRoundResult struct { + TotalRounds int `json:"total_rounds"` // 总轮数 + Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型) +} + +// SendCallback 向业务方发送回调 +func SendCallback(ctx context.Context, composeTask *entity.ComposeTask) error { + // 1. 检查回调地址 + if composeTask.CallbackUrl == "" { + return fmt.Errorf("回调地址为空,taskId=%s", composeTask.TaskId) + } + + // 2. 构造请求体 + req := SendCallbackReq{ + TaskId: composeTask.TaskId, + Status: composeTask.Status, + Messages: parseMessagesToResult(composeTask.Messages), // 需要将 JSON 字符串转为结构体 + ErrorMsg: composeTask.ErrorMessage, + } + // 3. 发送 POST 请求 + headers := util.ForwardHeaders(ctx) + var resp struct{} + g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v", + composeTask.TaskId, composeTask.CallbackUrl, req.Messages) + if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil { + return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err) + } + g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s", composeTask.TaskId, composeTask.CallbackUrl) + return nil +} + +// parseMessagesToResult 将 any 类型的 Messages 转为 *MultiRoundResult +func parseMessagesToResult(messages any) *MultiRoundResult { + if messages == nil { + return nil + } + + var result MultiRoundResult + + switch v := messages.(type) { + case *MultiRoundResult: + return v + case MultiRoundResult: + return &v + case string: + if err := json.Unmarshal([]byte(v), &result); err != nil { + return nil + } + case []byte: + if err := json.Unmarshal(v, &result); err != nil { + return nil + } + case map[string]any: + // 通过 JSON 序列化再反序列化 + data, _ := json.Marshal(v) + if err := json.Unmarshal(data, &result); err != nil { + return nil + } + default: + data, err := json.Marshal(v) + if err != nil { + return nil + } + if err = json.Unmarshal(data, &result); err != nil { + return nil + } + } + return &result +} diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 0859454..d9fed61 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -16,8 +16,8 @@ import ( ) // buildInferenceRequest 构建推理请求 -func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, targetModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) { - processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, targetModel) +func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) { + processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel) if err != nil { return nil, fmt.Errorf("处理用户表单分批失败: %w", err) } @@ -26,7 +26,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha switch req.BuildType { case public.BuildTypePrompt: - return buildPromptTypeRequest(ctx, processedReq, targetModel, chatModel, history, ir, totalBatches) + return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches) case public.BuildTypeNode: return buildNodeTypeRequest(ctx, req, chatModel, ir) default: @@ -35,8 +35,8 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha } // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) -func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, targetModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { - systemPrompt := promptBuildWithRounds(ctx, req, targetModel, totalBatches) +func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { + systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches) ir.AddSystem(systemPrompt) for _, msg := range history { @@ -47,26 +47,30 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ta ir.AddHistory(role, gconv.String(msg["content"])) } - userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, targetModel.ModelType)) + userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType)) ir.AddUser(userPrompt) - if !checkOverallContent(ir, targetModel) { - availableWindow := util.GetAvailableWindow(targetModel.TokenConfig) + if !checkOverallContent(ir, aiModel) { + availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) } - - return compileToProviderRequest(ctx, ir, targetModel.OperatorName, targetModel.ModelName, chatModel) + // 记录历史会话 + _, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + SessionId: req.SessionId, + RequestContent: ir.User, + }) + return compileToProviderRequest(ctx, ir, chatModel) } // buildNodeTypeRequest 构建节点类型请求(BuildType=2) func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) - return compileToProviderRequest(ctx, ir, req.ModelName, req.ModelName, chatModel) + return compileToProviderRequest(ctx, ir, chatModel) } // compileToProviderRequest 编译为 Provider 请求 -func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName string, modelName string, chatModel *entity.AsynchModel) (map[string]any, error) { - protocol, err := GetProtocolByProvider(ctx, providerName) +func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) { + protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName) if err != nil { return nil, fmt.Errorf("获取协议配置失败: %w", err) } @@ -79,7 +83,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, providerName st } return map[string]any{ - "modelName": modelName, + "modelName": chatModel.ModelName, "bizName": "prompts-core", "callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"), "requestPayload": providerReq, diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 21c20cf..1157164 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -5,12 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "strings" - "time" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" - "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" "prompts-core/common/util" @@ -27,147 +24,19 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com if err != nil { return nil, err } - if err = validateUserForm(ctx, req, aiModel); err != nil { + if err = validateUserForm(req, aiModel); err != nil { return nil, err } - fmt.Printf("req打印%+v", req) switch req.BuildType { case public.BuildTypePrompt: return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建 case public.BuildTypeNode: return handleNodeBuild(ctx, req, chatModel, aiModel) // 节点构建 default: - return handleDefaultCase(ctx, req) + return nil, errors.New("BuildType 不支持") } } -// validateUserForm 校验用户表单 -func validateUserForm(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) error { - if len(req.UserForm) == 0 { - return nil - } - isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig) - if err != nil { - return fmt.Errorf("校验用户表单失败: %w", err) - } - - if !isValid { - availableWindow := util.GetAvailableWindow(model.TokenConfig) - return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试", - exceedTokens, availableWindow) - } - - return nil -} - -// handlePromptBuild 处理提示词构建(BuildType=1) -func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { - maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() - history, err := GetHistoryMessages(ctx, req.SessionId) - if err != nil { - g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) - history = nil - } - - var message *dto.MultiRoundResult - var taskRecord *entity.ComposeTask - for attempt := 0; attempt <= maxRetryTimes; attempt++ { - if attempt > 0 { - g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) - } - - taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history) - if err != nil { - g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) - continue - } - - if err = saveComposeTask(ctx, taskID, req); err != nil { - g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) - continue - } - //等待结果 - taskRecord, err = waitForResult(ctx, taskID) - if err != nil { - g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) - continue - } - //处理结果 - message = parsePromptBuild(taskRecord, chatModel) - if message != nil { - break - } - - g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) - } - - if message == nil { - return nil, errors.New("推理模型调用失败,请稍后再试") - } - epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - RequestContent: message, - }) - if err != nil { - g.Log().Errorf(ctx, "创建会话记录失败: %v", err) - } - return &dto.ComposeMessagesRes{ - Messages: message, - EpicycleId: epicycleId, - }, nil -} - -// handleNodeBuild 处理节点构建(BuildType=2) -func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { - taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) - if err != nil { - return nil, fmt.Errorf("调用推理模型失败: %w", err) - } - - if err := saveComposeTask(ctx, taskID, req); err != nil { - return nil, fmt.Errorf("保存任务记录失败: %w", err) - } - - taskRecord, err := waitForResult(ctx, taskID) - if err != nil { - return nil, fmt.Errorf("等待结果失败: %w", err) - } - - message := parseNodeBuild(taskRecord) - - return &dto.ComposeMessagesRes{ - Messages: message, - EpicycleId: 0, - }, nil -} - -// handleDefaultCase 处理默认情况 -func handleDefaultCase(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { - epicycleId, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - Remark: req.Cause, - }) - if err != nil { - return nil, fmt.Errorf("创建会话记录失败: %w", err) - } - - return &dto.ComposeMessagesRes{ - EpicycleId: epicycleId, - }, nil -} - -// saveComposeTask 保存组合任务 -func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error { - _, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ - TaskId: taskID, - ModelName: req.ModelName, - SkillName: req.SkillName, - RequestPayload: util.MustMarshal(req), - Status: public.ComposeStatusPending, - }) - return err -} - // GetModelMessage 获取模型信息 func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { userInfo, err := utils.GetUserInfo(ctx) @@ -188,6 +57,77 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. return chatModel, aiModel, nil } +// validateUserForm 校验用户表单 +func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error { + if len(req.UserForm) == 0 { + return nil + } + isValid, exceedTokens, err := util.CheckUserFormWithinWindow(req.UserForm, model.TokenConfig) + if err != nil { + return fmt.Errorf("校验用户表单失败: %w", err) + } + + if !isValid { + availableWindow := util.GetAvailableWindow(model.TokenConfig) + return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试", + exceedTokens, availableWindow) + } + + return nil +} + +// handlePromptBuild 处理提示词构建(BuildType=1) +func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { + // 获取历史会话 + history, err := GetHistoryMessages(ctx, req.SessionId) + if err != nil { + g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) + history = nil + } + // 调用推理模型 + taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history) + if err != nil { + return nil, fmt.Errorf("调用推理模型失败: %w", err) + } + // 保存任务记录 + if err = saveComposeTask(ctx, taskID, req); err != nil { + return nil, fmt.Errorf("保存任务记录失败: %w", err) + } + return &dto.ComposeMessagesRes{ + TaskId: taskID, + }, nil +} + +// handleNodeBuild 处理节点构建(BuildType=2) +func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { + taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) + if err != nil { + return nil, fmt.Errorf("调用推理模型失败: %w", err) + } + + if err := saveComposeTask(ctx, taskID, req); err != nil { + return nil, fmt.Errorf("保存任务记录失败: %w", err) + } + + return &dto.ComposeMessagesRes{ + TaskId: taskID, + }, nil +} + +// saveComposeTask 保存组合任务记录 +func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error { + _, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ + TaskId: taskID, + ModelName: req.ModelName, + SkillName: req.SkillName, + BuildType: req.BuildType, + CallbackUrl: req.CallbackUrl, + RequestPayload: util.MustMarshal(req), + Status: public.ComposeStatusPending, + }) + return err +} + // getChatModel 获取聊天模型 func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) { chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ @@ -223,8 +163,8 @@ func getAIModel(ctx context.Context, userName, modelName string) (*entity.Asynch } // callInferenceModel 调用推理模型 -func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { - taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history) +func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) { + taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history) if err != nil { return "", fmt.Errorf("构建推理请求失败: %w", err) } @@ -241,147 +181,6 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo return taskID, nil } -// waitForResult 等待结果 -// waitForResult 等待结果(优先channel通知,兜底网关查询) -func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { - timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second - // 设置超时context - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - // 优先等待channel通知(来自回调) - result, err := TaskWaiter.Wait(ctx, taskID) - if err == nil { - // 成功收到回调通知 - return result.(*entity.ComposeTask), nil - } - // channel等待失败(超时/取消),从数据库读取最终状态作为兜底 - g.Log().Warningf(ctx, "[waitForResult] channel等待失败,从DB获取最终状态 taskId=%s err=%v", taskID, err) - record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ - TaskId: taskID, - }) - if dbErr != nil { - return nil, fmt.Errorf("查询数据库失败: %w", dbErr) - } - - if record == nil { - return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID) - } - - switch record.Status { - case public.ComposeStatusSuccess: - return record, nil - case public.ComposeStatusFailed: - if strings.TrimSpace(record.ErrorMessage) == "" { - return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) - } - return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) - default: - // 还在处理中,但已超时 - return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) - } -} - -// parsePromptBuild 解析提示词构建结果(BuildType == 1) -func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) *dto.MultiRoundResult { - if taskRecord == nil { - return nil - } - - mapped := parseTaskMessages(taskRecord.Messages) - if mapped == nil { - return createDefaultResult(nil) - } - - contentField := getContentField(model) - contentStr, ok := mapped[contentField].(string) - if !ok || contentStr == "" { - return createDefaultResult(mapped) - } - - // 尝试解析为数组 - if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { - return &dto.MultiRoundResult{ - TotalRounds: len(roundsArray), - Rounds: roundsArray, - } - } - - // 尝试解析为单个对象 - if singleRound := tryParseAsMap(contentStr); singleRound != nil { - return &dto.MultiRoundResult{ - TotalRounds: 1, - Rounds: []map[string]any{singleRound}, - } - } - - // 纯文本,包装为默认格式 - return createDefaultResult(map[string]any{"content": contentStr}) -} - -// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any -func tryParseAsMapArray(jsonStr string) []map[string]any { - var arr []map[string]any - if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil { - return nil - } - if len(arr) == 0 { - return nil - } - return arr -} - -// tryParseAsMap 尝试解析JSON字符串为 map[string]any -func tryParseAsMap(jsonStr string) map[string]any { - var obj map[string]any - if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { - return nil - } - if len(obj) == 0 { - return nil - } - return obj -} - -// parseTaskMessages 解析任务消息 -func parseTaskMessages(messages any) map[string]any { - var mapped map[string]any - - switch v := messages.(type) { - case *gvar.Var: - if v != nil { - json.Unmarshal([]byte(v.String()), &mapped) - } - case string: - json.Unmarshal([]byte(v), &mapped) - case map[string]any: - mapped = v - default: - b, _ := json.Marshal(v) - json.Unmarshal(b, &mapped) - } - - return mapped -} - -// tryParseAsArray 尝试将字符串解析为数组 -func tryParseAsArray(contentStr string) []any { - var roundsArray []any - if err := json.Unmarshal([]byte(contentStr), &roundsArray); err != nil { - return nil - } - return roundsArray -} - -// tryParseAsObject 尝试将字符串解析为对象 -func tryParseAsObject(contentStr string) any { - var singleRound any - if err := json.Unmarshal([]byte(contentStr), &singleRound); err != nil { - return nil - } - return singleRound -} - // createDefaultResult 创建默认结果 func createDefaultResult(data map[string]any) *dto.MultiRoundResult { if data == nil { @@ -393,72 +192,17 @@ func createDefaultResult(data map[string]any) *dto.MultiRoundResult { } } -// getContentField 从模型 ResponseMapping 中获取 content 字段名 -func getContentField(model *entity.AsynchModel) string { - if model == nil { - return "content" - } - - respMapping := parseResponseMapping(model.ResponseMapping) - for k, v := range respMapping { - if strings.Contains(v, "content") { - return k - } - } - - return "content" -} - -// parseResponseMapping 解析响应映射 -func parseResponseMapping(mapping any) map[string]string { - result := make(map[string]string) - - switch v := mapping.(type) { - case *gvar.Var: - if v != nil { - json.Unmarshal([]byte(v.String()), &result) - } - case string: - json.Unmarshal([]byte(v), &result) - case map[string]interface{}: - for k, val := range v { - if s, ok := val.(string); ok { - result[k] = s - } - } - } - - return result -} - -// parseNodeBuild 解析节点构建结果(BuildType == 2) -func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult { - if taskRecord == nil { - return nil - } - - result := parseTaskMessages(taskRecord.Messages) - if result == nil { - result = make(map[string]any) - } - - return &dto.MultiRoundResult{ - TotalRounds: 1, - Rounds: []map[string]any{result}, - } -} - // Callback 回调处理 func Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) - task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: req.TaskId, }) if err != nil { return fmt.Errorf("查询任务失败: %w", err) } - if task == nil { + if composeTask == nil { return fmt.Errorf("任务不存在: %s", req.TaskId) } //处理失败 @@ -467,41 +211,134 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { TaskId: req.TaskId, Status: public.ComposeStatusFailed, ErrorMessage: req.ErrorMsg, + GatewayState: req.State, + OssFile: req.OssFile, + FileType: req.FileType, + ResultText: req.Text, }) - // 通知等待者:任务失败 - notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg)) + // 用更新后的值发送回调 + if composeTask.CallbackUrl != "" { + failedTask := &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusFailed, + ErrorMessage: req.ErrorMsg, + CallbackUrl: composeTask.CallbackUrl, + Messages: composeTask.Messages, + } + gateway.SendCallback(ctx, failedTask) + } return err } //处理成功 if req.State == 2 { - result, err := util.ParseOutput(req.Text) + // 1. 根据 BuildType 解析结果 var messages any - if result != nil { - messages = result + switch composeTask.BuildType { + case public.BuildTypePrompt: // 提示词构建解析 + messages = parsePromptResult(req.Text) + case public.BuildTypeNode: // 节点构建解析 + messages = parseNodeResult(req.Text) + default: + messages = req.Text } + // 2. 更新数据库 _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: req.TaskId, - Status: public.ComposeStatusSuccess, - Messages: messages, + TaskId: req.TaskId, + Status: public.ComposeStatusSuccess, + Messages: messages, + GatewayState: req.State, + OssFile: req.OssFile, + FileType: req.FileType, + ResultText: req.Text, }) if err != nil { - g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) + g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err) + return err + } + // 4. 发送回调给业务方 + if composeTask.CallbackUrl != "" { + successTask := &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusSuccess, + Messages: messages, + CallbackUrl: composeTask.CallbackUrl, + } + gateway.SendCallback(ctx, successTask) } - notifyWaiter(req.TaskId, &entity.ComposeTask{ - TaskId: req.TaskId, - Status: public.ComposeStatusSuccess, - Messages: messages, - }, err) } return err } -// notifyWaiter 通知等待者(不影响主流程) -func notifyWaiter(taskID string, result interface{}, err error) { - notifyErr := TaskWaiter.Notify(taskID, result, err) - if notifyErr != nil { - // 只记录日志,不影响回调处理结果 - g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr) +// parsePromptResult 解析提示词构建结果 +func parsePromptResult(raw string) *dto.MultiRoundResult { + var wrapper map[string]any + if err := json.Unmarshal([]byte(raw), &wrapper); err != nil { + return createDefaultResult(map[string]any{"raw": raw}) + } + + contentStr, ok := wrapper["content"].(string) + if !ok || contentStr == "" { + return createDefaultResult(wrapper) + } + + // 先尝试解析为数组 + if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { + return &dto.MultiRoundResult{ + TotalRounds: len(roundsArray), + Rounds: roundsArray, + } + } + + // 再尝试解析为单个对象 + if singleRound := tryParseAsMap(contentStr); singleRound != nil { + return &dto.MultiRoundResult{ + TotalRounds: 1, + Rounds: []map[string]any{singleRound}, + } + } + + return createDefaultResult(map[string]any{"content": contentStr}) +} + +func tryParseAsMapArray(jsonStr string) []map[string]any { + var arr []map[string]any + if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil { + return nil + } + if len(arr) == 0 { + return nil + } + return arr +} + +func tryParseAsMap(jsonStr string) map[string]any { + var obj map[string]any + if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { + return nil + } + if len(obj) == 0 { + return nil + } + return obj +} + +// parseNodeResult 解析节点构建结果 +func parseNodeResult(raw string) *dto.MultiRoundResult { + var result map[string]any + if err := json.Unmarshal([]byte(raw), &result); err != nil { + return createDefaultResult(map[string]any{"raw": raw}) + } + + if contentStr, ok := result["content"].(string); ok && contentStr != "" { + var inner map[string]any + if err := json.Unmarshal([]byte(contentStr), &inner); err == nil { + result = inner + } + } + + return &dto.MultiRoundResult{ + TotalRounds: 1, + Rounds: []map[string]any{result}, } } diff --git a/service/prompt/prompt_session_service.go b/service/prompt/prompt_session_service.go index 7434013..c72c6f3 100644 --- a/service/prompt/prompt_session_service.go +++ b/service/prompt/prompt_session_service.go @@ -23,8 +23,7 @@ func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.Ses } result["role"] = "assistant" - - if err := updateSessionResponse(ctx, req.EpicycleId, result); err != nil { + if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil { return nil, err } diff --git a/service/prompt/prompt_task_waiter.go b/service/prompt/prompt_task_waiter.go deleted file mode 100644 index 575c300..0000000 --- a/service/prompt/prompt_task_waiter.go +++ /dev/null @@ -1,125 +0,0 @@ -package prompt - -import ( - "context" - "errors" - "sync" -) - -var ( - ErrTaskNotFound = errors.New("task not found") - ErrAlreadyNotified = errors.New("task already notified") - TaskWaiter = NewManager() -) - -// Result 任务结果 -type Result struct { - Data interface{} - Error error -} - -// Manager 管理异步任务等待 -type Manager struct { - mu sync.Mutex - waiters map[string]*waiter -} - -// waiter 单个等待者 -type waiter struct { - result chan Result - closed chan struct{} - notifyOnce sync.Once -} - -// NewManager 创建管理器 -func NewManager() *Manager { - return &Manager{ - waiters: make(map[string]*waiter), - } -} - -// Wait 等待任务结果 -func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) { - w := m.getOrCreate(taskID) - defer m.remove(taskID) - - select { - case result := <-w.result: - if result.Error != nil { - return nil, result.Error - } - return result.Data, nil - case <-ctx.Done(): - return nil, ctx.Err() - case <-w.closed: - // context取消后notify才到达的边缘情况 - select { - case result := <-w.result: - if result.Error != nil { - return nil, result.Error - } - return result.Data, nil - default: - return nil, ctx.Err() - } - } -} - -// Notify 通知任务完成(安全,无阻塞) -func (m *Manager) Notify(taskID string, data interface{}, err error) error { - m.mu.Lock() - w, exists := m.waiters[taskID] - if !exists { - m.mu.Unlock() - return ErrTaskNotFound - } - - var notified bool - w.notifyOnce.Do(func() { - notified = true - close(w.closed) // 先关闭信号channel - // 根据err构造Result - if err != nil { - w.result <- Result{Error: err} - } else { - w.result <- Result{Data: data} - } - }) - m.mu.Unlock() - - if !notified { - return ErrAlreadyNotified - } - return nil -} - -// getOrCreate 获取或创建等待者 -func (m *Manager) getOrCreate(taskID string) *waiter { - m.mu.Lock() - defer m.mu.Unlock() - - if w, exists := m.waiters[taskID]; exists { - return w - } - - w := &waiter{ - result: make(chan Result, 1), - closed: make(chan struct{}), - } - m.waiters[taskID] = w - return w -} - -// remove 安全移除等待者 -func (m *Manager) remove(taskID string) { - m.mu.Lock() - defer m.mu.Unlock() - delete(m.waiters, taskID) -} - -// ActiveCount 当前活跃等待数量 -func (m *Manager) ActiveCount() int { - m.mu.Lock() - defer m.mu.Unlock() - return len(m.waiters) -} diff --git a/update.sql b/update.sql index ac9e6b1..a146009 100644 --- a/update.sql +++ b/update.sql @@ -1,4 +1,4 @@ --- prompts_compose_task 拼接提示词任务记录表 +-- prompts_compose_task 提示词任务记录表 CREATE TABLE IF NOT EXISTS prompts_compose_task ( id BIGINT PRIMARY KEY, tenant_id BIGINT NOT NULL DEFAULT 0, @@ -11,8 +11,9 @@ CREATE TABLE IF NOT EXISTS prompts_compose_task ( task_id VARCHAR(64) NOT NULL, model_name VARCHAR(128) NOT NULL DEFAULT '', skill_name VARCHAR(128) NOT NULL DEFAULT '', + build_type INT NOT NULL DEFAULT 0, + callback_url VARCHAR(512) NOT NULL DEFAULT '', gateway_state INT NOT NULL DEFAULT 0, - limit_words INT NOT NULL DEFAULT 0, request_payload JSONB NOT NULL DEFAULT '{}'::jsonb, result_text TEXT NOT NULL DEFAULT '', messages JSONB NOT NULL DEFAULT '{}'::jsonb, @@ -20,13 +21,15 @@ CREATE TABLE IF NOT EXISTS prompts_compose_task ( error_message TEXT NOT NULL DEFAULT '', oss_file VARCHAR(1024) NOT NULL DEFAULT '', file_type VARCHAR(64) NOT NULL DEFAULT '' -); + ); + -- 索引 CREATE UNIQUE INDEX IF NOT EXISTS uk_prompts_compose_task_task_id ON prompts_compose_task(task_id); CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_status ON prompts_compose_task(status); -CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task +CREATE INDEX IF NOT EXISTS idx_prompts_compose_task_deleted_at ON prompts_compose_task(deleted_at); + -- 注释 -COMMENT ON TABLE prompts_compose_task IS '拼接提示词任务记录表'; +COMMENT ON TABLE prompts_compose_task IS '提示词任务记录表'; COMMENT ON COLUMN prompts_compose_task.id IS '主键ID'; COMMENT ON COLUMN prompts_compose_task.tenant_id IS '租户ID'; COMMENT ON COLUMN prompts_compose_task.creator IS '创建人'; @@ -37,8 +40,9 @@ COMMENT ON COLUMN prompts_compose_task.deleted_at IS '删除时间(软删 COMMENT ON COLUMN prompts_compose_task.task_id IS 'model-gateway 任务ID'; COMMENT ON COLUMN prompts_compose_task.model_name IS '业务模型名称'; COMMENT ON COLUMN prompts_compose_task.skill_name IS '技能名称'; +COMMENT ON COLUMN prompts_compose_task.build_type IS '构建类型:0默认/1提示词构建/2节点构建'; +COMMENT ON COLUMN prompts_compose_task.callback_url IS '回调地址'; COMMENT ON COLUMN prompts_compose_task.gateway_state IS 'model-gateway 状态:0排队/1执行/2成功/3失败/4已下载'; -COMMENT ON COLUMN prompts_compose_task.limit_words IS '提示词限制字数'; COMMENT ON COLUMN prompts_compose_task.request_payload IS '发给 model-gateway 的请求内容'; COMMENT ON COLUMN prompts_compose_task.result_text IS '回调返回的文本结果'; COMMENT ON COLUMN prompts_compose_task.messages IS '最终解析后的 messages';