refactor(prompt): 优化任务等待机制并改进数据结构
This commit is contained in:
@@ -242,87 +242,43 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
|
||||
}
|
||||
|
||||
// waitForResult 等待结果
|
||||
// waitForResult 等待结果(优先channel通知,兜底网关查询)
|
||||
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
||||
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
||||
deadline := time.Now().Add(timeout)
|
||||
// 设置超时context
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
// ===================== 修复点 1:检查上下文是否取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 请求已被取消,直接返回,不继续查库
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
// 优先等待channel通知(来自回调)
|
||||
result, err := TaskWaiter.Wait(ctx, taskID)
|
||||
if err == nil {
|
||||
// 成功收到回调通知
|
||||
return result.(*entity.ComposeTask), nil
|
||||
}
|
||||
// channel等待失败(超时/取消),从数据库读取最终状态作为兜底
|
||||
g.Log().Warningf(ctx, "[waitForResult] channel等待失败,从DB获取最终状态 taskId=%s err=%v", taskID, err)
|
||||
record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return nil, fmt.Errorf("查询数据库失败: %w", dbErr)
|
||||
}
|
||||
|
||||
// 1. 查数据库
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
// ===================== 修复点 2:如果是上下文取消,直接返回 =====================
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if record != nil {
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
}
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// 2. 查网关状态
|
||||
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
||||
if err != nil {
|
||||
// 网关不可达不终止,继续轮询
|
||||
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
|
||||
} else {
|
||||
switch state {
|
||||
case 2: // 网关成功
|
||||
// 网关已成功,主动更新数据库
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
case 3: // 网关失败
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: "model-gateway 任务执行失败",
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 超时检查
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// ===================== 修复点3:sleep 也要监听 ctx 取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
default:
|
||||
// 还在处理中,但已超时
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,6 +287,7 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapped := parseTaskMessages(taskRecord.Messages)
|
||||
if mapped == nil {
|
||||
return createDefaultResult(nil)
|
||||
@@ -342,23 +299,50 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
||||
return createDefaultResult(mapped)
|
||||
}
|
||||
|
||||
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
|
||||
// 尝试解析为数组
|
||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: len(roundsArray),
|
||||
Rounds: roundsArray,
|
||||
}
|
||||
}
|
||||
|
||||
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
|
||||
// 尝试解析为单个对象
|
||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{singleRound},
|
||||
Rounds: []map[string]any{singleRound},
|
||||
}
|
||||
}
|
||||
|
||||
// 纯文本,包装为默认格式
|
||||
return createDefaultResult(map[string]any{"content": contentStr})
|
||||
}
|
||||
|
||||
// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any
|
||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(arr) == 0 {
|
||||
return nil
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// tryParseAsMap 尝试解析JSON字符串为 map[string]any
|
||||
func tryParseAsMap(jsonStr string) map[string]any {
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(obj) == 0 {
|
||||
return nil
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// parseTaskMessages 解析任务消息
|
||||
func parseTaskMessages(messages any) map[string]any {
|
||||
var mapped map[string]any
|
||||
@@ -399,13 +383,13 @@ func tryParseAsObject(contentStr string) any {
|
||||
}
|
||||
|
||||
// createDefaultResult 创建默认结果
|
||||
func createDefaultResult(data any) *dto.MultiRoundResult {
|
||||
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{data},
|
||||
Rounds: []map[string]any{data},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +444,7 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{result},
|
||||
Rounds: []map[string]any{result},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,7 +452,6 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
@@ -478,47 +461,48 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||
}
|
||||
|
||||
//处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
})
|
||||
// 通知等待者:任务失败
|
||||
notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg))
|
||||
return err
|
||||
}
|
||||
//处理成功
|
||||
if req.State == 2 {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
notifyWaiter(req.TaskId, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
}, err)
|
||||
}
|
||||
|
||||
return handleCallbackSuccess(ctx, req)
|
||||
}
|
||||
|
||||
// handleCallbackFailure 处理回调失败
|
||||
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: errorMsg,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
return fmt.Errorf("解析模型输出失败: %w", err)
|
||||
// notifyWaiter 通知等待者(不影响主流程)
|
||||
func notifyWaiter(taskID string, result interface{}, err error) {
|
||||
notifyErr := TaskWaiter.Notify(taskID, result, err)
|
||||
if notifyErr != nil {
|
||||
// 只记录日志,不影响回调处理结果
|
||||
g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr)
|
||||
}
|
||||
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
|
||||
Reference in New Issue
Block a user