refactor(model): 优化模型网关的数据解析和任务处理逻辑

This commit is contained in:
2026-06-17 14:34:48 +08:00
parent b3b111995e
commit fddaf36f48
7 changed files with 231 additions and 166 deletions

View File

@@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
// ============================================
// 2) 调用模型
// ============================================
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error())
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
err = streamErr
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") {
w.failTask(ctx, task, startTime, err.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================
@@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String()
taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue
}
// 2) 存 token 到数据库,防止后续失败丢失
if _, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
// 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil {
return parsed, nil
}
lastErr = err
case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody)
return parsed, nil
return util.ParseStructResult(mapped, entity.ResponseBody), nil
default:
return mapped, nil
}
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
}
// 4) 重新调模型(直接调,不走缓存)
// 4) 拼接错误信息到请求体,重调模型
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil {
if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil
}
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {