refactor(prompt): 优化任务等待机制并改进数据结构

This commit is contained in:
2026-05-21 14:23:34 +08:00
parent 15f5761000
commit a34eb4ea61
4 changed files with 228 additions and 120 deletions

View File

@@ -61,7 +61,6 @@ jaeger:
task: task:
waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒) waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒)
pollIntervalMillis: 500 # 同步等待期间,轮询本地任务表 / 网关状态的时间间隔(毫秒)
session: session:
maxRounds: 10 # 最大轮数 maxRounds: 10 # 最大轮数

View File

@@ -21,8 +21,8 @@ type ComposeMessagesRes struct {
// MultiRoundResult 多轮返回结果 // MultiRoundResult 多轮返回结果
type MultiRoundResult struct { type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数 TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []any `json:"rounds"` // 每轮详情(动态类型) Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
} }
type CallbackReq struct { type CallbackReq struct {

View File

@@ -242,87 +242,43 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
} }
// waitForResult 等待结果 // waitForResult 等待结果
// waitForResult 等待结果优先channel通知兜底网关查询
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) { func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond // 设置超时context
deadline := time.Now().Add(timeout) ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for { // 优先等待channel通知来自回调
// ===================== 修复点 1检查上下文是否取消 ===================== result, err := TaskWaiter.Wait(ctx, taskID)
select { if err == nil {
case <-ctx.Done(): // 成功收到回调通知
// 请求已被取消,直接返回,不继续查库 return result.(*entity.ComposeTask), nil
return nil, ctx.Err() }
default: // channel等待失败超时/取消),从数据库读取最终状态作为兜底
} g.Log().Warningf(ctx, "[waitForResult] channel等待失败从DB获取最终状态 taskId=%s err=%v", taskID, err)
record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if dbErr != nil {
return nil, fmt.Errorf("查询数据库失败: %w", dbErr)
}
// 1. 查数据库 if record == nil {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID)
TaskId: taskID, }
})
if err != nil {
// ===================== 修复点 2如果是上下文取消直接返回 =====================
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil, err
}
return nil, err
}
if record != nil {
switch record.Status {
case public.ComposeStatusSuccess:
return record, nil
case public.ComposeStatusFailed:
if strings.TrimSpace(record.ErrorMessage) == "" {
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
}
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
}
}
// 2. 查网关状态 switch record.Status {
state, err := gateway.QueryGatewayTaskState(ctx, taskID) case public.ComposeStatusSuccess:
if err != nil { return record, nil
// 网关不可达不终止,继续轮询 case public.ComposeStatusFailed:
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err) if strings.TrimSpace(record.ErrorMessage) == "" {
} else { return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
switch state {
case 2: // 网关成功
// 网关已成功,主动更新数据库
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusSuccess,
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
case 3: // 网关失败
if record != nil {
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: "model-gateway 任务执行失败",
})
if err != nil {
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
}
}
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
}
}
// 3. 超时检查
if time.Now().After(deadline) {
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
}
// ===================== 修复点3sleep 也要监听 ctx 取消 =====================
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(pollInterval):
} }
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
default:
// 还在处理中,但已超时
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
} }
} }
@@ -331,6 +287,7 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
if taskRecord == nil { if taskRecord == nil {
return nil return nil
} }
mapped := parseTaskMessages(taskRecord.Messages) mapped := parseTaskMessages(taskRecord.Messages)
if mapped == nil { if mapped == nil {
return createDefaultResult(nil) return createDefaultResult(nil)
@@ -342,23 +299,50 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
return createDefaultResult(mapped) return createDefaultResult(mapped)
} }
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil { // 尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return &dto.MultiRoundResult{ return &dto.MultiRoundResult{
TotalRounds: len(roundsArray), TotalRounds: len(roundsArray),
Rounds: roundsArray, Rounds: roundsArray,
} }
} }
if singleRound := tryParseAsObject(contentStr); singleRound != nil { // 尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return &dto.MultiRoundResult{ return &dto.MultiRoundResult{
TotalRounds: 1, TotalRounds: 1,
Rounds: []any{singleRound}, Rounds: []map[string]any{singleRound},
} }
} }
// 纯文本,包装为默认格式
return createDefaultResult(map[string]any{"content": contentStr}) return createDefaultResult(map[string]any{"content": contentStr})
} }
// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any
func tryParseAsMapArray(jsonStr string) []map[string]any {
var arr []map[string]any
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
return nil
}
if len(arr) == 0 {
return nil
}
return arr
}
// tryParseAsMap 尝试解析JSON字符串为 map[string]any
func tryParseAsMap(jsonStr string) map[string]any {
var obj map[string]any
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
return nil
}
if len(obj) == 0 {
return nil
}
return obj
}
// parseTaskMessages 解析任务消息 // parseTaskMessages 解析任务消息
func parseTaskMessages(messages any) map[string]any { func parseTaskMessages(messages any) map[string]any {
var mapped map[string]any var mapped map[string]any
@@ -399,13 +383,13 @@ func tryParseAsObject(contentStr string) any {
} }
// createDefaultResult 创建默认结果 // createDefaultResult 创建默认结果
func createDefaultResult(data any) *dto.MultiRoundResult { func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
if data == nil { if data == nil {
data = make(map[string]any) data = make(map[string]any)
} }
return &dto.MultiRoundResult{ return &dto.MultiRoundResult{
TotalRounds: 1, TotalRounds: 1,
Rounds: []any{data}, Rounds: []map[string]any{data},
} }
} }
@@ -460,7 +444,7 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
return &dto.MultiRoundResult{ return &dto.MultiRoundResult{
TotalRounds: 1, TotalRounds: 1,
Rounds: []any{result}, Rounds: []map[string]any{result},
} }
} }
@@ -468,7 +452,6 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
func Callback(ctx context.Context, req *dto.CallbackReq) error { func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId, TaskId: req.TaskId,
}) })
@@ -478,47 +461,48 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
if task == nil { if task == nil {
return fmt.Errorf("任务不存在: %s", req.TaskId) return fmt.Errorf("任务不存在: %s", req.TaskId)
} }
//处理失败
if req.State == 3 { if req.State == 3 {
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) _, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusFailed,
ErrorMessage: req.ErrorMsg,
})
// 通知等待者:任务失败
notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg))
return err
}
//处理成功
if req.State == 2 {
result, err := util.ParseOutput(req.Text)
var messages any
if result != nil {
messages = result
}
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
notifyWaiter(req.TaskId, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
}, err)
} }
return handleCallbackSuccess(ctx, req)
}
// handleCallbackFailure 处理回调失败
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: taskID,
Status: public.ComposeStatusFailed,
ErrorMessage: errorMsg,
})
return err return err
} }
// handleCallbackSuccess 处理回调成功 // notifyWaiter 通知等待者(不影响主流程)
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error { func notifyWaiter(taskID string, result interface{}, err error) {
result, err := util.ParseOutput(req.Text) notifyErr := TaskWaiter.Notify(taskID, result, err)
if err != nil { if notifyErr != nil {
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg) // 只记录日志,不影响回调处理结果
return fmt.Errorf("解析模型输出失败: %w", err) g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr)
} }
var messages any
if result != nil {
messages = result
}
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
Status: public.ComposeStatusSuccess,
Messages: messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
}
return err
} }
// GetComposeTask 查询任务结果 // GetComposeTask 查询任务结果

View File

@@ -0,0 +1,125 @@
package prompt
import (
"context"
"errors"
"sync"
)
var (
ErrTaskNotFound = errors.New("task not found")
ErrAlreadyNotified = errors.New("task already notified")
TaskWaiter = NewManager()
)
// Result 任务结果
type Result struct {
Data interface{}
Error error
}
// Manager 管理异步任务等待
type Manager struct {
mu sync.Mutex
waiters map[string]*waiter
}
// waiter 单个等待者
type waiter struct {
result chan Result
closed chan struct{}
notifyOnce sync.Once
}
// NewManager 创建管理器
func NewManager() *Manager {
return &Manager{
waiters: make(map[string]*waiter),
}
}
// Wait 等待任务结果
func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) {
w := m.getOrCreate(taskID)
defer m.remove(taskID)
select {
case result := <-w.result:
if result.Error != nil {
return nil, result.Error
}
return result.Data, nil
case <-ctx.Done():
return nil, ctx.Err()
case <-w.closed:
// context取消后notify才到达的边缘情况
select {
case result := <-w.result:
if result.Error != nil {
return nil, result.Error
}
return result.Data, nil
default:
return nil, ctx.Err()
}
}
}
// Notify 通知任务完成(安全,无阻塞)
func (m *Manager) Notify(taskID string, data interface{}, err error) error {
m.mu.Lock()
w, exists := m.waiters[taskID]
if !exists {
m.mu.Unlock()
return ErrTaskNotFound
}
var notified bool
w.notifyOnce.Do(func() {
notified = true
close(w.closed) // 先关闭信号channel
// 根据err构造Result
if err != nil {
w.result <- Result{Error: err}
} else {
w.result <- Result{Data: data}
}
})
m.mu.Unlock()
if !notified {
return ErrAlreadyNotified
}
return nil
}
// getOrCreate 获取或创建等待者
func (m *Manager) getOrCreate(taskID string) *waiter {
m.mu.Lock()
defer m.mu.Unlock()
if w, exists := m.waiters[taskID]; exists {
return w
}
w := &waiter{
result: make(chan Result, 1),
closed: make(chan struct{}),
}
m.waiters[taskID] = w
return w
}
// remove 安全移除等待者
func (m *Manager) remove(taskID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.waiters, taskID)
}
// ActiveCount 当前活跃等待数量
func (m *Manager) ActiveCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.waiters)
}