diff --git a/config.yml b/config.yml index a7e28ef..b28d081 100644 --- a/config.yml +++ b/config.yml @@ -61,7 +61,6 @@ jaeger: task: waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒) - pollIntervalMillis: 500 # 同步等待期间,轮询本地任务表 / 网关状态的时间间隔(毫秒) session: maxRounds: 10 # 最大轮数 diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index 06d0ae1..f03cd37 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -21,8 +21,8 @@ type ComposeMessagesRes struct { // MultiRoundResult 多轮返回结果 type MultiRoundResult struct { - TotalRounds int `json:"total_rounds"` // 总轮数 - Rounds []any `json:"rounds"` // 每轮详情(动态类型) + TotalRounds int `json:"total_rounds"` // 总轮数 + Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型) } type CallbackReq struct { diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index ef9a19f..21c20cf 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -242,87 +242,43 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo } // waitForResult 等待结果 +// waitForResult 等待结果(优先channel通知,兜底网关查询) func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second - pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond - deadline := time.Now().Add(timeout) + // 设置超时context + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() - for { - // ===================== 修复点 1:检查上下文是否取消 ===================== - select { - case <-ctx.Done(): - // 请求已被取消,直接返回,不继续查库 - return nil, ctx.Err() - default: - } + // 优先等待channel通知(来自回调) + result, err := TaskWaiter.Wait(ctx, taskID) + if err == nil { + // 成功收到回调通知 + return result.(*entity.ComposeTask), nil + } + // channel等待失败(超时/取消),从数据库读取最终状态作为兜底 + g.Log().Warningf(ctx, "[waitForResult] channel等待失败,从DB获取最终状态 taskId=%s err=%v", taskID, err) + record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ + TaskId: taskID, + }) + if dbErr != nil { + return nil, fmt.Errorf("查询数据库失败: %w", dbErr) + } - // 1. 查数据库 - record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ - TaskId: taskID, - }) - if err != nil { - // ===================== 修复点 2:如果是上下文取消,直接返回 ===================== - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return nil, err - } - return nil, err - } - if record != nil { - switch record.Status { - case public.ComposeStatusSuccess: - return record, nil - case public.ComposeStatusFailed: - if strings.TrimSpace(record.ErrorMessage) == "" { - return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) - } - return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) - } - } + if record == nil { + return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID) + } - // 2. 查网关状态 - state, err := gateway.QueryGatewayTaskState(ctx, taskID) - if err != nil { - // 网关不可达不终止,继续轮询 - g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) - } else { - switch state { - case 2: // 网关成功 - // 网关已成功,主动更新数据库 - if record != nil { - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: taskID, - Status: public.ComposeStatusSuccess, - }) - if err != nil { - g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) - } - } - case 3: // 网关失败 - if record != nil { - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: taskID, - Status: public.ComposeStatusFailed, - ErrorMessage: "model-gateway 任务执行失败", - }) - if err != nil { - g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err) - } - } - return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID) - } - } - - // 3. 超时检查 - if time.Now().After(deadline) { - return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) - } - - // ===================== 修复点3:sleep 也要监听 ctx 取消 ===================== - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(pollInterval): + switch record.Status { + case public.ComposeStatusSuccess: + return record, nil + case public.ComposeStatusFailed: + if strings.TrimSpace(record.ErrorMessage) == "" { + return nil, fmt.Errorf("任务失败(taskId=%s)", taskID) } + return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage) + default: + // 还在处理中,但已超时 + return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID) } } @@ -331,6 +287,7 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) if taskRecord == nil { return nil } + mapped := parseTaskMessages(taskRecord.Messages) if mapped == nil { return createDefaultResult(nil) @@ -342,23 +299,50 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel) return createDefaultResult(mapped) } - if roundsArray := tryParseAsArray(contentStr); roundsArray != nil { + // 尝试解析为数组 + if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { return &dto.MultiRoundResult{ TotalRounds: len(roundsArray), Rounds: roundsArray, } } - if singleRound := tryParseAsObject(contentStr); singleRound != nil { + // 尝试解析为单个对象 + if singleRound := tryParseAsMap(contentStr); singleRound != nil { return &dto.MultiRoundResult{ TotalRounds: 1, - Rounds: []any{singleRound}, + Rounds: []map[string]any{singleRound}, } } + // 纯文本,包装为默认格式 return createDefaultResult(map[string]any{"content": contentStr}) } +// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any +func tryParseAsMapArray(jsonStr string) []map[string]any { + var arr []map[string]any + if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil { + return nil + } + if len(arr) == 0 { + return nil + } + return arr +} + +// tryParseAsMap 尝试解析JSON字符串为 map[string]any +func tryParseAsMap(jsonStr string) map[string]any { + var obj map[string]any + if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { + return nil + } + if len(obj) == 0 { + return nil + } + return obj +} + // parseTaskMessages 解析任务消息 func parseTaskMessages(messages any) map[string]any { var mapped map[string]any @@ -399,13 +383,13 @@ func tryParseAsObject(contentStr string) any { } // createDefaultResult 创建默认结果 -func createDefaultResult(data any) *dto.MultiRoundResult { +func createDefaultResult(data map[string]any) *dto.MultiRoundResult { if data == nil { data = make(map[string]any) } return &dto.MultiRoundResult{ TotalRounds: 1, - Rounds: []any{data}, + Rounds: []map[string]any{data}, } } @@ -460,7 +444,7 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult { return &dto.MultiRoundResult{ TotalRounds: 1, - Rounds: []any{result}, + Rounds: []map[string]any{result}, } } @@ -468,7 +452,6 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult { func Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) - task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: req.TaskId, }) @@ -478,47 +461,48 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { if task == nil { return fmt.Errorf("任务不存在: %s", req.TaskId) } - + //处理失败 if req.State == 3 { - return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusFailed, + ErrorMessage: req.ErrorMsg, + }) + // 通知等待者:任务失败 + notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg)) + return err + } + //处理成功 + if req.State == 2 { + result, err := util.ParseOutput(req.Text) + var messages any + if result != nil { + messages = result + } + _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusSuccess, + Messages: messages, + }) + if err != nil { + g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) + } + notifyWaiter(req.TaskId, &entity.ComposeTask{ + TaskId: req.TaskId, + Status: public.ComposeStatusSuccess, + Messages: messages, + }, err) } - - return handleCallbackSuccess(ctx, req) -} - -// handleCallbackFailure 处理回调失败 -func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error { - _, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: taskID, - Status: public.ComposeStatusFailed, - ErrorMessage: errorMsg, - }) return err } -// handleCallbackSuccess 处理回调成功 -func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error { - result, err := util.ParseOutput(req.Text) - if err != nil { - handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) - return fmt.Errorf("解析模型输出失败: %w", err) +// notifyWaiter 通知等待者(不影响主流程) +func notifyWaiter(taskID string, result interface{}, err error) { + notifyErr := TaskWaiter.Notify(taskID, result, err) + if notifyErr != nil { + // 只记录日志,不影响回调处理结果 + g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr) } - - var messages any - if result != nil { - messages = result - } - - _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{ - TaskId: req.TaskId, - Status: public.ComposeStatusSuccess, - Messages: messages, - }) - if err != nil { - g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err) - } - - return err } // GetComposeTask 查询任务结果 diff --git a/service/prompt/prompt_task_waiter.go b/service/prompt/prompt_task_waiter.go new file mode 100644 index 0000000..575c300 --- /dev/null +++ b/service/prompt/prompt_task_waiter.go @@ -0,0 +1,125 @@ +package prompt + +import ( + "context" + "errors" + "sync" +) + +var ( + ErrTaskNotFound = errors.New("task not found") + ErrAlreadyNotified = errors.New("task already notified") + TaskWaiter = NewManager() +) + +// Result 任务结果 +type Result struct { + Data interface{} + Error error +} + +// Manager 管理异步任务等待 +type Manager struct { + mu sync.Mutex + waiters map[string]*waiter +} + +// waiter 单个等待者 +type waiter struct { + result chan Result + closed chan struct{} + notifyOnce sync.Once +} + +// NewManager 创建管理器 +func NewManager() *Manager { + return &Manager{ + waiters: make(map[string]*waiter), + } +} + +// Wait 等待任务结果 +func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) { + w := m.getOrCreate(taskID) + defer m.remove(taskID) + + select { + case result := <-w.result: + if result.Error != nil { + return nil, result.Error + } + return result.Data, nil + case <-ctx.Done(): + return nil, ctx.Err() + case <-w.closed: + // context取消后notify才到达的边缘情况 + select { + case result := <-w.result: + if result.Error != nil { + return nil, result.Error + } + return result.Data, nil + default: + return nil, ctx.Err() + } + } +} + +// Notify 通知任务完成(安全,无阻塞) +func (m *Manager) Notify(taskID string, data interface{}, err error) error { + m.mu.Lock() + w, exists := m.waiters[taskID] + if !exists { + m.mu.Unlock() + return ErrTaskNotFound + } + + var notified bool + w.notifyOnce.Do(func() { + notified = true + close(w.closed) // 先关闭信号channel + // 根据err构造Result + if err != nil { + w.result <- Result{Error: err} + } else { + w.result <- Result{Data: data} + } + }) + m.mu.Unlock() + + if !notified { + return ErrAlreadyNotified + } + return nil +} + +// getOrCreate 获取或创建等待者 +func (m *Manager) getOrCreate(taskID string) *waiter { + m.mu.Lock() + defer m.mu.Unlock() + + if w, exists := m.waiters[taskID]; exists { + return w + } + + w := &waiter{ + result: make(chan Result, 1), + closed: make(chan struct{}), + } + m.waiters[taskID] = w + return w +} + +// remove 安全移除等待者 +func (m *Manager) remove(taskID string) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.waiters, taskID) +} + +// ActiveCount 当前活跃等待数量 +func (m *Manager) ActiveCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.waiters) +}