refactor(prompt): 优化任务等待机制并改进数据结构

This commit is contained in:
2026-05-21 14:23:34 +08:00
parent 15f5761000
commit a34eb4ea61
4 changed files with 228 additions and 120 deletions

View File

@@ -242,87 +242,43 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
}
// waitForResult 等待结果
// waitForResult 等待结果优先channel通知兜底网关查询
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
deadline := time.Now().Add(timeout)
// 设置超时context
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
// ===================== 修复点 1检查上下文是否取消 =====================
select {
case <-ctx.Done():
// 请求已被取消,直接返回,不继续查库
return nil, ctx.Err()
default:
}
// 优先等待channel通知来自回调
result, err := TaskWaiter.Wait(ctx, taskID)
if err == nil {
// 成功收到回调通知
return result.(*entity.ComposeTask), nil
}
// channel等待失败超时/取消),从数据库读取最终状态作为兜底
g.Log().Warningf(ctx, "[waitForResult] channel等待失败从DB获取最终状态 taskId=%s err=%v", taskID, err)
record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if dbErr != nil {
return nil, fmt.Errorf("查询数据库失败: %w", dbErr)
}
// 1. 查数据库
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
}
}
if record == nil {
return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID)
}
// 2. 查网关状态
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
if err != nil {
// 网关不可达不终止,继续轮询
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
} else {
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
default:
// 还在处理中,但已超时
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
}
@@ -331,6 +287,7 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
if taskRecord == nil {
return nil
}
mapped := parseTaskMessages(taskRecord.Messages)
if mapped == nil {
return createDefaultResult(nil)
@@ -342,23 +299,50 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
return createDefaultResult(mapped)
}
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
// 尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{
TotalRounds: len(roundsArray),
Rounds: roundsArray,
}
}
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
// 尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{singleRound},
Rounds: []map[string]any{singleRound},
}
}
// 纯文本,包装为默认格式
return createDefaultResult(map[string]any{"content": contentStr})
}
// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
// tryParseAsMap 尝试解析JSON字符串为 map[string]any
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
// parseTaskMessages 解析任务消息
func parseTaskMessages(messages any) map[string]any {
var mapped map[string]any
@@ -399,13 +383,13 @@ func tryParseAsObject(contentStr string) any {
}
// createDefaultResult 创建默认结果
func createDefaultResult(data any) *dto.MultiRoundResult {
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
if data == nil {
data = make(map[string]any)
}
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{data},
Rounds: []map[string]any{data},
}
}
@@ -460,7 +444,7 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
return &dto.MultiRoundResult{
TotalRounds: 1,
Rounds: []any{result},
Rounds: []map[string]any{result},
}
}
@@ -468,7 +452,6 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
})
@@ -478,47 +461,48 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId)
}
//处理失败
if req.State == 3 {
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
// 通知等待者:任务失败
notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg))
return err
}
//处理成功
if req.State == 2 {
result, err := util.ParseOutput(req.Text)
var messages any
if result != nil {
messages = result
}
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
notifyWaiter(req.TaskId, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
}, err)
}
return handleCallbackSuccess(ctx, req)
}
// handleCallbackFailure 处理回调失败
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: errorMsg,
})
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
result, err := util.ParseOutput(req.Text)
if err != nil {
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
return fmt.Errorf("解析模型输出失败: %w", err)
// notifyWaiter 通知等待者(不影响主流程)
func notifyWaiter(taskID string, result interface{}, err error) {
notifyErr := TaskWaiter.Notify(taskID, result, err)
if notifyErr != nil {
// 只记录日志,不影响回调处理结果
g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr)
}
var messages any
if result != nil {
messages = result
}
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
}
// GetComposeTask 查询任务结果