package service import ( "context" "encoding/json" "errors" "fmt" "strings" "time" "prompts-core/consts/public" "prompts-core/dao" "prompts-core/model/dto" "prompts-core/model/entity" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" ) // ============================================ // 核心业务流程 // ============================================ // ComposeMessages 拼接提示词主流程 func (s *promptService) ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) { var ( epicycleId int64 taskID string history []map[string]any message map[string]any err error taskRecord *entity.ComposeTask ) // 获取模型信息 chatModel, model, err := s.GetModelMessage(ctx, req) if err != nil { return nil, err } // 根据构建类型进行判断处理 switch req.BuildType { //提示词构建 case 1: maxRetryTimes := g.Cfg().MustGet(ctx, "promptsRetry.maxRetryTimes", 3).Int() //1. 获取历史会话 history, err = Session.GetHistoryMessages(ctx, req.SessionId) if err != nil { g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) history = nil // 出错就用空的,不影响主流程 } // 重试循环 for attempt := 0; attempt <= maxRetryTimes; attempt++ { if attempt > 0 { g.Log().Warningf(ctx, "[重试]第 %d/%d 次调用推理模型", attempt, maxRetryTimes) } // 2. 调用推理模型 taskID, err = s.callInferenceModel(ctx, req, chatModel, model, history) if err != nil { g.Log().Errorf(ctx, "调用推理模型失败(第%d次): %v", attempt+1, err) continue } // 3. 保存记录 _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, RequestPayload: mustMarshal(req), Status: public.ComposeStatusPending, }) if err != nil { g.Log().Errorf(ctx, "保存任务记录失败(第%d次): %v", attempt+1, err) continue } // 4. 等待结果 taskRecord, err = s.waitForResult(ctx, taskID) if err != nil { g.Log().Errorf(ctx, "等待结果失败(第%d次): %v", attempt+1, err) continue } // 校验结果 message = s.parsePromptBuild(taskRecord, chatModel) if message != nil && isMessageValid(message) { break } g.Log().Warningf(ctx, "[重试] 推理结果不合法(第%d次),准备重新请求", attempt+1) message = nil } if message == nil { return nil, errors.New("推理模型调用失败,请稍后再试") } //5.创建会话记录 epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ SessionId: req.SessionId, RequestContent: message, }) //节点构建 case 2: //1. 调用推理模型 taskID, err = s.callInferenceModel(ctx, req, chatModel, model, nil) if err != nil { return nil, err } //2. 保存相关记录 _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, RequestPayload: mustMarshal(req), Status: public.ComposeStatusPending, }) //5. 等待结果 taskRecord, err := s.waitForResult(ctx, taskID) if err != nil { return nil, err } message = s.parseNodeBuild(taskRecord) default: epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ SessionId: req.SessionId, Remark: req.Cause, }) return &dto.ComposeMessagesRes{ EpicycleId: epicycleId, }, nil } return &dto.ComposeMessagesRes{ Messages: message, EpicycleId: epicycleId, }, nil } func (s *promptService) Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) // ============ 先查任务是否存在 ============ task, err := dao.ComposeTask.GetByTaskId(ctx, req.TaskId) if err != nil { return err } if task == nil { return fmt.Errorf("任务不存在: %s", req.TaskId) } // ============ 根据状态区分处理 ============ if req.State == 3 { // 失败:直接更新状态 _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: req.ErrorMsg, }) return err } // ====================================== // 成功:解析模型输出 result, err := parseOutput(req.Text) if err != nil { _, updateErr := dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: err.Error(), }) if updateErr != nil { g.Log().Warningf(ctx, "[Callback] 更新失败状态出错 taskId=%s err=%v", req.TaskId, updateErr) } return err } // ============ result 可能为 nil ============ var messages any if result != nil { messages = result } // ======================================= _, err = dao.ComposeTask.UpdateByTaskId(ctx, req.TaskId, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, entity.ComposeTaskCol.Messages: messages, }) if err != nil { g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) } return err } // GetComposeTask 查询任务结果 func (s *promptService) GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) { record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) if err != nil { return nil, err } if record == nil { return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID) } // 如果 Messages 是字符串,反序列化为 JSON 数组 messages := record.Messages if str, ok := messages.(string); ok && str != "" { var parsed any if err := json.Unmarshal([]byte(str), &parsed); err == nil { messages = parsed } } return &dto.GetComposeTaskRes{ TaskId: record.TaskId, Status: record.Status, ErrorMessage: record.ErrorMessage, Messages: messages, }, nil } // GetModelMessage 获取模型信息 func (s *promptService) GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { // 1. 获取当前用户的会话模型 chatModel, err := dao.Model.GetByIsChatModel(ctx) if err != nil { return nil, nil, err } if chatModel == nil { return nil, nil, errors.New("当前没有对话模型,请添加") } // 2. 获取要构建的模型信息 model, err := dao.Model.GetByModelName(ctx, req.ModelName) if err != nil { return nil, nil, err } if model == nil { return nil, nil, fmt.Errorf("需要构建的模型 %s 不存在", req.ModelName) } return chatModel, model, nil } // callInferenceModel 调用推理模型 func (s *promptService) callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, model *entity.AsynchModel, history []map[string]any) (string, error) { // 构建推理模型请求 taskReq, err := buildInferenceRequest(ctx, req, chatModel, model, history) if err != nil { return "", fmt.Errorf("构建推理请求失败: %w", err) } // 创建网关任务 taskID, err := createGatewayTask(ctx, taskReq) if err != nil { return "", fmt.Errorf("创建网关任务失败: %w", err) } if taskID == "" { return "", errors.New("网关未返回taskId") } return taskID, nil } // ============================================ // 步骤6:等待结果 // ============================================ func (s *promptService) waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond deadline := time.Now().Add(timeout) for { // ===================== 修复点 1:检查上下文是否取消 ===================== select { case <-ctx.Done(): // 请求已被取消,直接返回,不继续查库 return nil, ctx.Err() default: } // 1. 查数据库 record, err := dao.ComposeTask.GetByTaskId(ctx, taskID) if err != nil { // ===================== 修复点 2:如果是上下文取消,直接返回 ===================== if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } return nil, err } if record != nil { switch record.Status { case public.ComposeStatusSuccess: return record, nil case public.ComposeStatusFailed: if strings.TrimSpace(record.ErrorMessage) == "" { return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) } return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) } } // 2. 查网关状态 state, err := queryGatewayTaskState(ctx, taskID) if err != nil { // 网关不可达不终止,继续轮询 g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) } else { switch state { case 2: // 网关成功 // 网关已成功,主动更新数据库 if record != nil { dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusSuccess, }) } case 3: // 网关失败 if record != nil { dao.ComposeTask.UpdateByTaskId(ctx, taskID, map[string]any{ entity.ComposeTaskCol.Status: public.ComposeStatusFailed, entity.ComposeTaskCol.ErrorMessage: "model-gateway 任务执行失败", }) } return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) } } // 3. 超时检查 if time.Now().After(deadline) { return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) } // ===================== 修复点3:sleep 也要监听 ctx 取消 ===================== select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(pollInterval): } } } // parsePromptBuild 解析提示词构建结果(BuildType == 1) func (s *promptService) parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) map[string]any { if taskRecord == nil { return nil } // 1. 解析 Messages var mapped map[string]any switch v := taskRecord.Messages.(type) { case *gvar.Var: if v != nil { json.Unmarshal([]byte(v.String()), &mapped) } case string: json.Unmarshal([]byte(v), &mapped) case map[string]any: mapped = v default: b, _ := json.Marshal(v) json.Unmarshal(b, &mapped) } // 2. 解析模型 ResponseMapping 获取 content 字段名 contentField := "content" // 默认值 if model != nil { var respMapping map[string]string switch v := model.ResponseMapping.(type) { case *gvar.Var: if v != nil { json.Unmarshal([]byte(v.String()), &respMapping) } case string: json.Unmarshal([]byte(v), &respMapping) case map[string]interface{}: respMapping = make(map[string]string) for k, val := range v { if s, ok := val.(string); ok { respMapping[k] = s } } } // 从映射中找到 content 对应的字段名 for k, v := range respMapping { if strings.Contains(v, "content") { contentField = k break } } } // 3. 提取 content 的值 contentStr, ok := mapped[contentField].(string) if !ok || contentStr == "" { return mapped } // 4. 解析 content 内的 JSON var innerData map[string]any json.Unmarshal([]byte(contentStr), &innerData) return innerData } // parseNodeBuild 解析节点构建结果(BuildType == 2) func (s *promptService) parseNodeBuild(taskRecord *entity.ComposeTask) map[string]any { if taskRecord == nil { return nil } var result map[string]any switch v := taskRecord.Messages.(type) { case *gvar.Var: if v != nil { json.Unmarshal([]byte(v.String()), &result) } case string: json.Unmarshal([]byte(v), &result) case map[string]any: result = v default: b, _ := json.Marshal(v) json.Unmarshal(b, &result) } return result }