refactor(prompt): 优化任务等待机制并改进数据结构
This commit is contained in:
@@ -61,7 +61,6 @@ jaeger:
|
||||
|
||||
task:
|
||||
waitTimeoutSeconds: 300 # /composeMessages 同步等待最终结果的最长时间(秒)
|
||||
pollIntervalMillis: 500 # 同步等待期间,轮询本地任务表 / 网关状态的时间间隔(毫秒)
|
||||
|
||||
session:
|
||||
maxRounds: 10 # 最大轮数
|
||||
|
||||
@@ -21,8 +21,8 @@ type ComposeMessagesRes struct {
|
||||
|
||||
// MultiRoundResult 多轮返回结果
|
||||
type MultiRoundResult struct {
|
||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||
Rounds []any `json:"rounds"` // 每轮详情(动态类型)
|
||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
|
||||
@@ -242,87 +242,43 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
|
||||
}
|
||||
|
||||
// waitForResult 等待结果
|
||||
// waitForResult 等待结果(优先channel通知,兜底网关查询)
|
||||
func waitForResult(ctx context.Context, taskID string) (*entity.ComposeTask, error) {
|
||||
timeout := time.Duration(g.Cfg().MustGet(ctx, "task.waitTimeoutSeconds", 300).Int()) * time.Second
|
||||
pollInterval := time.Duration(g.Cfg().MustGet(ctx, "task.pollIntervalMillis", 500).Int()) * time.Millisecond
|
||||
deadline := time.Now().Add(timeout)
|
||||
// 设置超时context
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
// ===================== 修复点 1:检查上下文是否取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 请求已被取消,直接返回,不继续查库
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
// 优先等待channel通知(来自回调)
|
||||
result, err := TaskWaiter.Wait(ctx, taskID)
|
||||
if err == nil {
|
||||
// 成功收到回调通知
|
||||
return result.(*entity.ComposeTask), nil
|
||||
}
|
||||
// channel等待失败(超时/取消),从数据库读取最终状态作为兜底
|
||||
g.Log().Warningf(ctx, "[waitForResult] channel等待失败,从DB获取最终状态 taskId=%s err=%v", taskID, err)
|
||||
record, dbErr := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if dbErr != nil {
|
||||
return nil, fmt.Errorf("查询数据库失败: %w", dbErr)
|
||||
}
|
||||
|
||||
// 1. 查数据库
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
// ===================== 修复点 2:如果是上下文取消,直接返回 =====================
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if record != nil {
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
}
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("任务不存在(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// 2. 查网关状态
|
||||
state, err := gateway.QueryGatewayTaskState(ctx, taskID)
|
||||
if err != nil {
|
||||
// 网关不可达不终止,继续轮询
|
||||
g.Log().Warningf(ctx, "[waitForResult] 查询网关失败 taskId=%s err=%v", taskID, err)
|
||||
} else {
|
||||
switch state {
|
||||
case 2: // 网关成功
|
||||
// 网关已成功,主动更新数据库
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
case 3: // 网关失败
|
||||
if record != nil {
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: "model-gateway 任务执行失败",
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[waitForResult] 更新任务状态失败 taskId=%s err=%v", taskID, err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("model-gateway 任务执行失败(taskId=%s)", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 超时检查
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
// ===================== 修复点3:sleep 也要监听 ctx 取消 =====================
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(pollInterval):
|
||||
switch record.Status {
|
||||
case public.ComposeStatusSuccess:
|
||||
return record, nil
|
||||
case public.ComposeStatusFailed:
|
||||
if strings.TrimSpace(record.ErrorMessage) == "" {
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||||
}
|
||||
return nil, fmt.Errorf("任务失败(taskId=%s): %s", taskID, record.ErrorMessage)
|
||||
default:
|
||||
// 还在处理中,但已超时
|
||||
return nil, fmt.Errorf("等待任务回调超时(taskId=%s)", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -331,6 +287,7 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
||||
if taskRecord == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mapped := parseTaskMessages(taskRecord.Messages)
|
||||
if mapped == nil {
|
||||
return createDefaultResult(nil)
|
||||
@@ -342,23 +299,50 @@ func parsePromptBuild(taskRecord *entity.ComposeTask, model *entity.AsynchModel)
|
||||
return createDefaultResult(mapped)
|
||||
}
|
||||
|
||||
if roundsArray := tryParseAsArray(contentStr); roundsArray != nil {
|
||||
// 尝试解析为数组
|
||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: len(roundsArray),
|
||||
Rounds: roundsArray,
|
||||
}
|
||||
}
|
||||
|
||||
if singleRound := tryParseAsObject(contentStr); singleRound != nil {
|
||||
// 尝试解析为单个对象
|
||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{singleRound},
|
||||
Rounds: []map[string]any{singleRound},
|
||||
}
|
||||
}
|
||||
|
||||
// 纯文本,包装为默认格式
|
||||
return createDefaultResult(map[string]any{"content": contentStr})
|
||||
}
|
||||
|
||||
// tryParseAsMapArray 尝试解析JSON字符串为 []map[string]any
|
||||
func tryParseAsMapArray(jsonStr string) []map[string]any {
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &arr); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(arr) == 0 {
|
||||
return nil
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// tryParseAsMap 尝试解析JSON字符串为 map[string]any
|
||||
func tryParseAsMap(jsonStr string) map[string]any {
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil {
|
||||
return nil
|
||||
}
|
||||
if len(obj) == 0 {
|
||||
return nil
|
||||
}
|
||||
return obj
|
||||
}
|
||||
|
||||
// parseTaskMessages 解析任务消息
|
||||
func parseTaskMessages(messages any) map[string]any {
|
||||
var mapped map[string]any
|
||||
@@ -399,13 +383,13 @@ func tryParseAsObject(contentStr string) any {
|
||||
}
|
||||
|
||||
// createDefaultResult 创建默认结果
|
||||
func createDefaultResult(data any) *dto.MultiRoundResult {
|
||||
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{data},
|
||||
Rounds: []map[string]any{data},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,7 +444,7 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []any{result},
|
||||
Rounds: []map[string]any{result},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,7 +452,6 @@ func parseNodeBuild(taskRecord *entity.ComposeTask) *dto.MultiRoundResult {
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
@@ -478,47 +461,48 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||
}
|
||||
|
||||
//处理失败
|
||||
if req.State == 3 {
|
||||
return handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: req.ErrorMsg,
|
||||
})
|
||||
// 通知等待者:任务失败
|
||||
notifyWaiter(req.TaskId, nil, fmt.Errorf("任务失败: %s", req.ErrorMsg))
|
||||
return err
|
||||
}
|
||||
//处理成功
|
||||
if req.State == 2 {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
notifyWaiter(req.TaskId, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
}, err)
|
||||
}
|
||||
|
||||
return handleCallbackSuccess(ctx, req)
|
||||
}
|
||||
|
||||
// handleCallbackFailure 处理回调失败
|
||||
func handleCallbackFailure(ctx context.Context, taskID, errorMsg string) error {
|
||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
Status: public.ComposeStatusFailed,
|
||||
ErrorMessage: errorMsg,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq) error {
|
||||
result, err := util.ParseOutput(req.Text)
|
||||
if err != nil {
|
||||
handleCallbackFailure(ctx, req.TaskId, req.ErrorMsg)
|
||||
return fmt.Errorf("解析模型输出失败: %w", err)
|
||||
// notifyWaiter 通知等待者(不影响主流程)
|
||||
func notifyWaiter(taskID string, result interface{}, err error) {
|
||||
notifyErr := TaskWaiter.Notify(taskID, result, err)
|
||||
if notifyErr != nil {
|
||||
// 只记录日志,不影响回调处理结果
|
||||
g.Log().Infof(context.Background(), "[Callback] 通知等待者失败 taskId=%s err=%v", taskID, notifyErr)
|
||||
}
|
||||
|
||||
var messages any
|
||||
if result != nil {
|
||||
messages = result
|
||||
}
|
||||
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
Messages: messages,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[Callback] 更新任务失败 taskId=%s err=%v", req.TaskId, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
|
||||
125
service/prompt/prompt_task_waiter.go
Normal file
125
service/prompt/prompt_task_waiter.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTaskNotFound = errors.New("task not found")
|
||||
ErrAlreadyNotified = errors.New("task already notified")
|
||||
TaskWaiter = NewManager()
|
||||
)
|
||||
|
||||
// Result 任务结果
|
||||
type Result struct {
|
||||
Data interface{}
|
||||
Error error
|
||||
}
|
||||
|
||||
// Manager 管理异步任务等待
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
waiters map[string]*waiter
|
||||
}
|
||||
|
||||
// waiter 单个等待者
|
||||
type waiter struct {
|
||||
result chan Result
|
||||
closed chan struct{}
|
||||
notifyOnce sync.Once
|
||||
}
|
||||
|
||||
// NewManager 创建管理器
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
waiters: make(map[string]*waiter),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait 等待任务结果
|
||||
func (m *Manager) Wait(ctx context.Context, taskID string) (interface{}, error) {
|
||||
w := m.getOrCreate(taskID)
|
||||
defer m.remove(taskID)
|
||||
|
||||
select {
|
||||
case result := <-w.result:
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return result.Data, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-w.closed:
|
||||
// context取消后notify才到达的边缘情况
|
||||
select {
|
||||
case result := <-w.result:
|
||||
if result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
return result.Data, nil
|
||||
default:
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify 通知任务完成(安全,无阻塞)
|
||||
func (m *Manager) Notify(taskID string, data interface{}, err error) error {
|
||||
m.mu.Lock()
|
||||
w, exists := m.waiters[taskID]
|
||||
if !exists {
|
||||
m.mu.Unlock()
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
|
||||
var notified bool
|
||||
w.notifyOnce.Do(func() {
|
||||
notified = true
|
||||
close(w.closed) // 先关闭信号channel
|
||||
// 根据err构造Result
|
||||
if err != nil {
|
||||
w.result <- Result{Error: err}
|
||||
} else {
|
||||
w.result <- Result{Data: data}
|
||||
}
|
||||
})
|
||||
m.mu.Unlock()
|
||||
|
||||
if !notified {
|
||||
return ErrAlreadyNotified
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getOrCreate 获取或创建等待者
|
||||
func (m *Manager) getOrCreate(taskID string) *waiter {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if w, exists := m.waiters[taskID]; exists {
|
||||
return w
|
||||
}
|
||||
|
||||
w := &waiter{
|
||||
result: make(chan Result, 1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
m.waiters[taskID] = w
|
||||
return w
|
||||
}
|
||||
|
||||
// remove 安全移除等待者
|
||||
func (m *Manager) remove(taskID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.waiters, taskID)
|
||||
}
|
||||
|
||||
// ActiveCount 当前活跃等待数量
|
||||
func (m *Manager) ActiveCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.waiters)
|
||||
}
|
||||
Reference in New Issue
Block a user