package service import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "prompts-core/consts/public" "prompts-core/dao" "prompts-core/model/dto" "prompts-core/model/entity" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" ) // ============================================ // 核心业务流程 // ============================================ // ComposeMessages 拼接提示词主流程 func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { var ( epicycleId int64 err error historyMessages []Message // 用来存放历史会话 ) // 1. 如果不需要构建返回记录id if req.IsBuilder == false { epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ SessionId: req.SessionId, Remark: req.Cause, }) return &dto.ComposeMessagesRes{ EpicycleId: epicycleId, }, nil } // 2. 获取当前用户模型信息 sessionModel, err := dao.Model.GetByIsChatModel(ctx) //获取会话模型 if err != nil { return nil, err } if sessionModel == nil { return nil, errors.New("当前没有对话模型,请添加") } model, err := dao.Model.GetByModelName(ctx, req.ModelName) //获取模型信息 if err != nil { return nil, err } if model == nil { return nil, fmt.Errorf("模型 %s 不存在", sessionModel.ModelName) } // 3 获取历史会话 historyMessages, err = Session.GetSessionHistoryForInference(ctx, req.SessionId) if err != nil { g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) historyMessages = nil // 出错就用空的,不影响主流程 } // 4. 调用推理模型 taskID, err := s.callInferenceModel(ctx, req, sessionModel, model, historyMessages) if err != nil { return nil, err } // 5. 保存相关记录 _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, RequestPayload: mustMarshal(req), Status: public.ComposeStatusPending, }) if err != nil { return nil, err } // 6. 等待结果 taskRecord, err := s.waitForResult(ctx, taskID) if err != nil { return nil, err } // 7. 处理返回结果 messages := s.processResult(taskRecord) //8.1 数据库查询当前会话是否存在 session, err := dao.ComposeSession.GetBySessionId(ctx, req.SessionId) if err != nil { return nil, err } if session == nil { //8.2 不存在则创建新会话记录 epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ SessionId: req.SessionId, RequestContent: messages, }) if err != nil { return nil, err } } // 9. 更新历史会话 _, err = dao.ComposeSession.UpdateById(ctx, epicycleId, map[string]any{ entity.ComposeSessionCol.RequestContent: messages, }) return &dto.ComposeMessagesRes{ Messages: messages, EpicycleId: epicycleId, }, nil } func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) // ============ 先查任务是否存在 ============ task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId) if err != nil { return err } if task == nil { return fmt.Errorf("任务不存在: %s", req.TaskId) } // ============ 根据状态区分处理 ============ if req.State == 3 { // 失败:直接更新状态 _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg, }) return err } // ====================================== // 成功:解析模型输出 result, err := parseModelOutput(req.Text) if err != nil { _, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: err.Error(), }) if updateErr != nil { g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr) } return err } // ============ result 可能为 nil ============ var messages any if result != nil { messages = result } // ======================================= _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, entity.ComposeTaskCol.Messages: messages, }) if err != nil { g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) } return err } // GetComposeTask 查询任务结果 func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) if err != nil { return nil, err } if record == nil { return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) } // 如果 Messages 是字符串,反序列化为 JSON 数组 messages := record.Messages if str, ok := messages.(string); ok && str != "" { var parsed any if err := json.Unmarshal([]byte(str), &parsed); err == nil { messages = parsed } } return &dto.GetComposeTaskRes{ TaskId: record.TaskId, Status: record.Status, ErrorMessage: record.ErrorMessage, Messages: messages, }, nil } // ============================================ // 步骤4:调用推理模型 // ============================================ func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, sessionModel *entity.AsynchModel, model *entity.AsynchModel, historyMessages []Message) (string, error) { // 构建推理模型请求 taskReq, err := buildInferenceRequest(ctx, req, sessionModel, model, historyMessages) if err != nil { return "", fmt.Errorf("构建推理请求失败: %w", err) } // 创建网关任务 taskID, err := createGatewayTask(ctx, taskReq) if err != nil { return "", fmt.Errorf("创建网关任务失败: %w", err) } if taskID == "" { return "", errors.New("网关未返回taskId") } return taskID, nil } // ============================================ // 步骤6:等待结果 // ============================================ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(getIntConfig(ctx, "task.waitTimeoutSeconds", 30)) * time.Second pollInterval := time.Duration(getIntConfig(ctx, "task.pollIntervalMillis", 500)) * time.Millisecond deadline := time.Now().Add(timeout) for { // 1. 查数据库 record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) if err != nil { return nil, err } if record != nil { switch record.Status { case public.ComposeStatusSuccess: return record, nil case public.ComposeStatusFailed: return nil, formatTaskError(taskID, record.ErrorMessage) } } // 2. 查网关状态 state, err := queryGatewayTaskState(ctx, taskID) if err != nil { // ============ 网关不可达不终止,继续轮询 ============ g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) } else { switch state { case 2: // 网关成功 // ============ 网关已成功,主动更新数据库 ============ if record != nil { dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, }) } case 3: // 网关失败 if record != nil { dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败", }) } return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) } } // 3. 超时检查 if time.Now().After(deadline) { return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) } time.Sleep(pollInterval) } } // ============================================ // 步骤6:处理结果 // ============================================ func (s *promptService) processResult(taskRecord *entity.ComposeTask) map[string]any { if taskRecord == nil { return nil } // 1. 解析 Messages 获取 content var contentStr string switch v := taskRecord.Messages.(type) { case *gvar.Var: if v != nil { var mapped map[string]any json.Unmarshal([]byte(v.String()), &mapped) if c, ok := mapped["content"].(string); ok { contentStr = c } } case string: var mapped map[string]any json.Unmarshal([]byte(v), &mapped) if c, ok := mapped["content"].(string); ok { contentStr = c } case map[string]any: if c, ok := v["content"].(string); ok { contentStr = c } } // 2. 清理并解析 contentStr = cleanJSONString(contentStr) var innerData map[string]any json.Unmarshal([]byte(contentStr), &innerData) return innerData } // ============================================ // 消息处理管道 // ============================================ // parseStoredMessages 从数据库存储的数据中解析消息列表 // 处理多层 JSON 嵌套的情况 func parseStoredMessages(data any) []dto.Message { if data == nil { return nil } // 统一序列化为 JSON jsonBytes, err := json.Marshal(data) if err != nil { return nil } // 第一层解析:尝试直接解析为消息数组 var messages []dto.Message if err := json.Unmarshal(jsonBytes, &messages); err == nil { // 成功解析,但需要处理 content 可能是 JSON 字符串的情况 return deepNormalizeMessages(messages) } // 第二层解析:可能是 JSON 字符串包裹的数组 var rawStr string if err := json.Unmarshal(jsonBytes, &rawStr); err != nil { return nil } // 尝试解析字符串为消息数组 if err := json.Unmarshal([]byte(rawStr), &messages); err == nil { return deepNormalizeMessages(messages) } return nil } // deepNormalizeMessages 深度规范化消息,处理 content 为 JSON 字符串的情况 func deepNormalizeMessages(messages []dto.Message) []dto.Message { for i, msg := range messages { messages[i].Content = deepNormalizeContent(msg.Content) } return messages } // deepNormalizeContent 递归处理 content,支持多层 JSON 嵌套 func deepNormalizeContent(content any) any { switch v := content.(type) { case string: // 尝试解析 JSON 字符串 v = strings.TrimSpace(v) if v == "" { return v } // 如果看起来像 JSON,尝试解析 if looksLikeJSON(v) { var parsed any if err := json.Unmarshal([]byte(v), &parsed); err == nil { // 递归处理解析后的内容 return deepNormalizeContent(parsed) } } return v case []any: // 递归处理数组中的每个元素 result := make([]any, len(v)) for i, item := range v { result[i] = deepNormalizeContent(item) } return result case map[string]any: // 递归处理 map 中的每个值 result := make(map[string]any, len(v)) for k, val := range v { result[k] = deepNormalizeContent(val) } return result default: return content } } func NormalizeToTwoPart(messages []dto.Message, req *dto.ComposeMessagesReq) []dto.Message { var result []dto.Message // 1. 提取 system sysContent := extractByRole(messages, "system") if sysContent == nil { sysContent = renderFormText(req.Form, false) } result = append(result, dto.Message{Role: "system", Content: sysContent}) // 2. 提取 form formContent := extractByRole(messages, "form") if formContent != nil { result = append(result, dto.Message{Role: "form", Content: formContent}) } else if req != nil { result = append(result, dto.Message{Role: "form", Content: renderFormJSON(req.Form)}) } // 3. 提取 skill skillContent := extractByRole(messages, "skill") if skillContent != nil { result = append(result, dto.Message{Role: "skill", Content: skillContent}) } else if req != nil && req.SkillName != "" { result = append(result, dto.Message{Role: "skill", Content: req.SkillName}) } // 4. 提取 history(如果模型返回了压缩后的历史) historyContent := extractByRole(messages, "history") if historyContent != nil { result = append(result, dto.Message{Role: "history", Content: historyContent}) } // 5. 提取 user usrContent := extractByRole(messages, "user") if usrContent == nil { usrContent = renderUserText(req.UserForm, req.Form) } result = append(result, dto.Message{Role: "user", Content: usrContent}) return result } // ============================================ // 辅助函数:按 role 提取第一个非空 content // ============================================ func extractByRole(messages []dto.Message, role string) any { for _, msg := range messages { if msg.Role == role && !isEmptyValue(msg.Content) { return msg.Content } } return nil } // ============================================ // 辅助函数:将 form 渲染为 JSON 对象 // ============================================ func renderFormJSON(form map[string]any) map[string]any { if form == nil { return nil } result := make(map[string]any) for k, v := range form { result[k] = v } return result } func enrichSystemMessages(messages []dto.Message, req *dto.ComposeMessagesReq) []dto.Message { if len(messages) == 0 { return messages } // 获取系统字段的值映射 systemValues := extractSystemValues(req) for i, msg := range messages { if msg.Role != "system" { continue } // 为 schema 数组补充 value switch content := msg.Content.(type) { case []any: messages[i].Content = enrichSchemaWithValues(content, systemValues) case []map[string]any: arr := make([]any, len(content)) for j, item := range content { arr[j] = item } messages[i].Content = enrichSchemaWithValues(arr, systemValues) case map[string]any: // 合并但不覆盖已有值 for k, v := range systemValues { if _, exists := content[k]; !exists { content[k] = v } } messages[i].Content = content } } return messages }