From e487b4bb5e1fec53351e8cfd1f84c0fb308c75a8 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 27 May 2026 09:36:25 +0800 Subject: [PATCH] =?UTF-8?q?refactor(task):=20=E9=87=8D=E6=9E=84=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E4=BB=BB=E5=8A=A1=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/headers.go | 22 +- common/util/json.go | 69 ++++++ dao/task_dao_bg.go | 65 +++--- model/dto/task_dto.go | 13 +- model/entity/asynch_task.go | 48 ++--- service/cleaner.go | 3 +- service/gateway/gateway_http_service.go | 45 ++-- service/task_service.go | 6 +- service/worker.go | 265 +++++++++++++----------- 9 files changed, 305 insertions(+), 231 deletions(-) create mode 100644 common/util/json.go diff --git a/common/util/headers.go b/common/util/headers.go index c723646..581ab3e 100644 --- a/common/util/headers.go +++ b/common/util/headers.go @@ -78,23 +78,13 @@ func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context return ctx } -// ParseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers -// 入库格式:{"payload": , "headers": {"Authorization": "...", "X-User-Info":"..."}} -func ParseStoredPayload(v any) (payload any, headers map[string]string) { +// ParseStoredPayload 解析入库的 request_payload,拆出模型调用核心数据 +func ParseStoredPayload(v map[string]any) map[string]any { if v == nil { - return nil, nil + return nil } - m := gconv.Map(v) - if len(m) == 0 { - return v, nil + if p, ok := v["payload"]; ok { + return gconv.Map(p) } - if h, ok := m["headers"]; ok { - headers = gconv.MapStrStr(h) - } - if p, ok := m["payload"]; ok { - payload = p - } else { - payload = v - } - return + return v } diff --git a/common/util/json.go b/common/util/json.go new file mode 100644 index 0000000..c841e73 --- /dev/null +++ b/common/util/json.go @@ -0,0 +1,69 @@ +package util + +import ( + "encoding/json" + "fmt" +) + +// ValidatePromptResult 完整的校验逻辑 +func ValidatePromptResult(raw map[string]any, requestMapping map[string]any) error { + contentStr, ok := raw["content"].(string) + if !ok || contentStr == "" { + return fmt.Errorf("content 字段为空或不是字符串") + } + + var rounds []map[string]any + if err := json.Unmarshal([]byte(contentStr), &rounds); err != nil { + return fmt.Errorf("解析 content JSON 数组失败: %w", err) + } + if len(rounds) == 0 { + return fmt.Errorf("content 数组为空") + } + + // 对 rounds 中的每一个元素进行结构校验 + for i, round := range rounds { + if err := validateStructure(requestMapping, round); err != nil { + return fmt.Errorf("rounds[%d] 结构校验失败: %w", i, err) + } + } + return nil +} + +// validateStructure 递归校验 actual 是否包含 expected 定义的所有字段路径 +func validateStructure(expected any, actual any) error { + switch exp := expected.(type) { + case map[string]any: + act, ok := actual.(map[string]any) + if !ok { + return fmt.Errorf("期望对象,实际类型 %T", actual) + } + for key, expVal := range exp { + actVal, exists := act[key] + if !exists { + return fmt.Errorf("缺少字段: %s", key) + } + if err := validateStructure(expVal, actVal); err != nil { + return fmt.Errorf("%s: %w", key, err) + } + } + return nil + case []any: + act, ok := actual.([]any) + if !ok { + return fmt.Errorf("期望数组,实际类型 %T", actual) + } + if len(exp) == 0 { + return nil // 空数组模板,只校验类型 + } + // 用第一个元素的结构去校验每个实际元素 + for i, actItem := range act { + if err := validateStructure(exp[0], actItem); err != nil { + return fmt.Errorf("[%d]: %w", i, err) + } + } + return nil + default: + // 基本类型,不校验具体值,只检查存在 + return nil + } +} diff --git a/dao/task_dao_bg.go b/dao/task_dao_bg.go index fc0a2fe..dfab6d6 100644 --- a/dao/task_dao_bg.go +++ b/dao/task_dao_bg.go @@ -20,7 +20,7 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( - `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file + `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC @@ -63,7 +63,7 @@ func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( - `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file + `SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file FROM %s WHERE deleted_at IS NULL AND state = 0 AND task_id = ? LIMIT 1 @@ -91,43 +91,40 @@ func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) return } -func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType, textResult string, fileSize int64, expireAt *gtime.Time, expendTokens int) error { +// UpdateSuccessGlobal 更新任务成功 +func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error { now := gtime.Now() - _, err := gfdb.DB(ctx).Exec(ctx, - fmt.Sprintf(`UPDATE %s -SET state=2, - oss_file=?, - file_type=?, - text_result=?, - expend_tokens=?, - file_size=?, - error_msg='', - finished_at=?, - duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT, - expire_at=NULL, - phase=0, - tmp_file='', - updated_at=? -WHERE id=?`, public.TableNameTask), - ossFile, fileType, textResult, expendTokens, fileSize, now, now, now, id, - ) + _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty(). + Where(entity.AsynchTaskCol.Id, t.Id). + Data(entity.AsynchTask{ + State: 2, + OssFile: t.OssFile, + FileType: t.FileType, + TextResult: t.TextResult, + FileSize: t.FileSize, + ErrorMsg: "", + FinishedAt: now, + Phase: 0, + TmpFile: "", + ExpendTokens: t.ExpendTokens, + }). + Update() return err } -func (d *taskDao) UpdateFailedGlobal(ctx context.Context, id int64, errorMsg string) error { +// UpdateFailedGlobal 模型调用失败 +func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error { now := gtime.Now() - _, err := gfdb.DB(ctx).Exec(ctx, - fmt.Sprintf(`UPDATE %s -SET state=3, - error_msg=?, - finished_at=?, - duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT, - phase=0, - tmp_file='', - updated_at=? -WHERE id=?`, public.TableNameTask), - errorMsg, now, now, now, id, - ) + _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty(). + Where(entity.AsynchTaskCol.Id, t.Id). + Data(entity.AsynchTask{ + State: 3, + ErrorMsg: t.ErrorMsg, + FinishedAt: now, + Phase: 0, + TmpFile: "", + }). + Update() return err } diff --git a/model/dto/task_dto.go b/model/dto/task_dto.go index 45f7ba7..0efa7d7 100644 --- a/model/dto/task_dto.go +++ b/model/dto/task_dto.go @@ -5,12 +5,13 @@ import "github.com/gogf/gf/v2/frame/g" // CreateTaskReq 创建异步任务 type CreateTaskReq struct { g.Meta `path:"/createTask" method:"post" tags:"任务管理" summary:"创建异步任务" dc:"创建异步任务并返回任务ID;创建成功后会立即异步尝试执行当前任务,执行成功后按回调配置触发钩子"` - ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"` - BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"` - CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"` - InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用(如OSS/文件引用等)"` - RequestPayload any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"` + BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"` + CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"` + InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用(如OSS/文件引用等)"` + RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` + EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"` } type CreateTaskRes struct { diff --git a/model/entity/asynch_task.go b/model/entity/asynch_task.go index 711751e..8b14321 100644 --- a/model/entity/asynch_task.go +++ b/model/entity/asynch_task.go @@ -62,28 +62,28 @@ var AsynchTaskCol = asynchTaskCol{ // AsynchTask 异步任务 type AsynchTask struct { beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - TaskID string `orm:"task_id" json:"taskId"` - BizName string `orm:"biz_name" json:"bizName"` - CallbackURL string `orm:"callback_url" json:"callbackUrl"` - ModelKey string `orm:"model_key" json:"modelKey"` - State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载 - OssFile string `orm:"oss_file" json:"ossFile"` - FileType string `orm:"file_type" json:"fileType"` - FileSize int64 `orm:"file_size" json:"fileSize"` - ErrorMsg string `orm:"error_msg" json:"errorMsg"` - StartedAt *gtime.Time `orm:"started_at" json:"startedAt"` - FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"` - DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"` - ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间 - RetryCount int `orm:"retry_count" json:"retryCount"` - EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"` - Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段 - TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径 - InputRef string `orm:"input_ref" json:"inputRef"` - RequestPayload any `orm:"request_payload" json:"requestPayload"` - TextResult string `orm:"text_result" json:"text"` - EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID(用于标识同一轮次的任务) - ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数 - RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"` + ModelName string `orm:"model_name" json:"modelName"` + TaskID string `orm:"task_id" json:"taskId"` + BizName string `orm:"biz_name" json:"bizName"` + CallbackURL string `orm:"callback_url" json:"callbackUrl"` + ModelKey string `orm:"model_key" json:"modelKey"` + State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载 + OssFile string `orm:"oss_file" json:"ossFile"` + FileType string `orm:"file_type" json:"fileType"` + FileSize int64 `orm:"file_size" json:"fileSize"` + ErrorMsg string `orm:"error_msg" json:"errorMsg"` + StartedAt *gtime.Time `orm:"started_at" json:"startedAt"` + FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"` + DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"` + ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间 + RetryCount int `orm:"retry_count" json:"retryCount"` + EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"` + Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段 + TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径 + InputRef string `orm:"input_ref" json:"inputRef"` + RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"` + TextResult map[string]any `orm:"text_result" json:"text"` + EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID(用于标识同一轮次的任务) + ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数 + RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"` } diff --git a/service/cleaner.go b/service/cleaner.go index ad80d54..39bac53 100644 --- a/service/cleaner.go +++ b/service/cleaner.go @@ -35,7 +35,8 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err) } else { for _, t := range list { - _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败") + t.ErrorMsg = "任务超时自动失败" + _ = dao.Task.UpdateFailedGlobal(ctx, t) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) } g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list)) diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 0c9e24f..5f8c254 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "mime/multipart" "model-gateway/common/util" @@ -15,7 +16,7 @@ import ( "github.com/gogf/gf/v2/util/guid" ) -type uploadFileResponse struct { +type UploadFileResponse struct { FileURL string `json:"fileURL"` // 文件 URL FileSize int `json:"fileSize"` // 文件大小(字节) FileName string `json:"fileName"` // 文件名 @@ -23,7 +24,7 @@ type uploadFileResponse struct { FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀 } -func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) { +func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) { // multipart body := &bytes.Buffer{} writer := multipart.NewWriter(body) @@ -39,41 +40,43 @@ func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileEx filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext) part, err := writer.CreateFormFile("file", filename) if err != nil { - return "", err + return nil, err } if _, err := part.Write(data); err != nil { - return "", err + return nil, err } contentType := writer.FormDataContentType() if err = writer.Close(); err != nil { - return "", err + return nil, err } headers := util.ForwardHeaders(ctx) headers["Content-Type"] = contentType - //fullURL := "oss/file/uploadFile" fullURL := "oss/file/uploadFile" g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) - var resp uploadFileResponse + var resp UploadFileResponse if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil { - return "", err + return nil, err } - g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat) - return resp.FileURL, nil + if &resp == nil { + return nil, errors.New("[OSS] 上传文件失败") + } + g.Log().Infof(ctx, "[OSS] 上传成功 url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat) + return &resp, nil } // CallbackPayload 回调请求体 type CallbackPayload struct { - TaskId string `json:"task_id"` - State int `json:"state"` - OssFile string `json:"oss_file"` - FileType string `json:"file_type"` - Text string `json:"text"` - ErrorMsg string `json:"error_msg"` + TaskId string `json:"task_id"` + State int `json:"state"` + OssFile string `json:"oss_file"` + FileType string `json:"file_type"` + Messages map[string]any `json:"messages"` + ErrorMsg string `json:"error_msg"` } -// TriggerCallback 任务成功后的回调 +// TriggerCallback 任务的回调 func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { headers := util.ForwardHeaders(ctx) var resp struct{} @@ -82,7 +85,7 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { State: t.State, OssFile: t.OssFile, FileType: t.FileType, - Text: t.TextResult, + Messages: t.TextResult, ErrorMsg: t.ErrorMsg, } jsonData, err := json.Marshal(payload) @@ -103,8 +106,8 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { // PromptsCallbackPayload 提示词回调请求体 type PromptsCallbackPayload struct { - EpicycleId int64 `json:"epicycleId"` - Text string `json:"text"` + EpicycleId int64 `json:"epicycleId"` + Messages map[string]any `json:"messages"` } // TriggerPromptsCallback 任务成功后的提示词回调 @@ -114,7 +117,7 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleI var resp struct{} payload := PromptsCallbackPayload{ EpicycleId: epicycleId, - Text: t.TextResult, + Messages: t.TextResult, } jsonData, err := json.Marshal(payload) if err != nil { diff --git a/service/task_service.go b/service/task_service.go index f656be2..9390588 100644 --- a/service/task_service.go +++ b/service/task_service.go @@ -103,7 +103,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * // 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。 // 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。 - go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId) + go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req) return &dto.CreateTaskRes{TaskID: taskID}, nil } @@ -112,7 +112,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * // - 只在任务仍为 pending(state=0) 时继续尝试抢占 // - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止 // - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务 -func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) { +func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) { if taskID == "" { return } @@ -139,7 +139,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, } switch t.State { case 0: - if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil { + if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil { g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err) } else { g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID) diff --git a/service/worker.go b/service/worker.go index 43f0214..331d114 100644 --- a/service/worker.go +++ b/service/worker.go @@ -16,9 +16,9 @@ import ( "model-gateway/dao" "model-gateway/model/entity" + "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" - "github.com/tidwall/gjson" ) var AsyncWorker = &asyncWorker{} @@ -50,11 +50,12 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt for _, t := range tasks { task := t _ = pool.AddWithRecover(ctx, func(ctx context.Context) { - w.handleOne(ctx, task, 0) + w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0}) done <- struct{}{} }, func(ctx context.Context, e error) { if e != nil { - _ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e)) + task.ErrorMsg = fmt.Sprintf("worker panic: %v", e) + _ = dao.Task.UpdateFailedGlobal(ctx, task) ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} @@ -71,7 +72,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt // RunByTaskID 创建任务后立即异步尝试执行当前任务: // - 只定向抢占当前 taskId 对应的 pending 任务 // - 若任务已被其它 worker 抢走/已不在 pending,则直接返回 -func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error { +func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error { task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) if err != nil { return err @@ -79,163 +80,175 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId if task == nil { return nil } - w.handleOne(ctx, task, epicycleId) + w.handleOne(ctx, task, req) return nil } -func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { - // 从任务入库的 request_payload 里恢复 payload + headers - payload, headers := util.ParseStoredPayload(t.RequestPayload) - if len(headers) > 0 { - ctx = util.SetTaskHeadersToCtx(ctx, headers) - } +// handleOne 执行一次完整的任务 +func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) { + payload := util.ParseStoredPayload(t.RequestPayload) + maxRetry := 0 // 后面从 model 取 + g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName) - // 1) 拉取模型配置 - m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) - if err != nil { - _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - // ============ 失败回调 ============ - t.State = 3 - t.ErrorMsg = err.Error() - go gateway.TriggerCallback(context.WithoutCancel(ctx), t) - // ================================ - return - } - if m == nil || (m.Enabled != nil && *m.Enabled != 1) { - errMsg := "模型不存在或未启用" - _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - // ============ 失败回调 ============ - t.State = 3 - t.ErrorMsg = errMsg - go gateway.TriggerCallback(context.WithoutCancel(ctx), t) - // ================================ + // 1) 获取模型配置 + model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) + if err != nil || model == nil { + w.failTask(ctx, t, "模型不存在或未启用") return } + maxRetry = model.RetryTimes - // 2) 分布式并发限制 + // 2) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName) - leaseSeconds := int64(3600) - maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency) - acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds) + maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency) + acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { - _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - // ============ 失败回调 ============ - t.State = 3 - t.ErrorMsg = err.Error() - go gateway.TriggerCallback(context.WithoutCancel(ctx), t) - // ================================ + w.failTask(ctx, t, err.Error()) return } if !acquired { - // 并发满了:放回排队,不回调(不是失败) + g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID) _ = w.rollbackToPending(ctx, t.Id) return } - defer func() { - _ = releaseSemaphore(ctx, semKey) - }() + defer func() { _ = releaseSemaphore(ctx, semKey) }() - // 3) 调用模型服务 + // 3) request_payload 校验 if payload == nil { - payload = map[string]any{ - "taskId": t.TaskID, - "inputRef": t.InputRef, + w.failTask(ctx, t, "request_payload 为空") + return + } + + // 4) 调用模型(不重试,失败直接回调) + textResult, err := w.callModel(ctx, t, model, payload) + if err != nil { + w.failTask(ctx, t, err.Error()) + return + } + + // 5) 上传 OSS(可重试) + var oss *gateway.UploadFileResponse + for attempt := 0; attempt <= maxRetry; attempt++ { + if attempt > 0 { + g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) + } + oss, err = w.uploadOSS(ctx, t) + if err == nil { + break + } + g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", + t.TaskID, attempt, maxRetry, err) + if attempt == maxRetry { + _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) + w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err)) + return } } - var ( - data []byte - contentType string - ext string - textResult string - ) - // phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载 + // 6) 解析校验(可重试,失败重新调模型) + if req.BuildType == 1 { + for attempt := 0; attempt <= maxRetry; attempt++ { + if attempt > 0 { + g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) + } + err = util.ValidatePromptResult(textResult, model.RequestMapping) + if err == nil { + break + } + g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", + t.TaskID, attempt, maxRetry, err) + if attempt == maxRetry { + w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) + return + } + // 重新调模型 + newResult, modelErr := w.callModel(ctx, t, model, payload) + if modelErr != nil { + g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", + t.TaskID, attempt, maxRetry, modelErr) + continue + } + textResult = newResult + } + } + + // 7) 成功回调 + t.State = 2 + t.OssFile = oss.FileAddressPrefix + oss.FileURL + t.FileType = oss.FileFormat + t.TextResult = textResult + t.FileSize = int64(oss.FileSize) + t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult)) + + if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil { + g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err) + return + } + + ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) + if req.EpicycleId != 0 { + go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId) + } + + g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s", + t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL) + _ = os.Remove(t.TmpFile) +} + +// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) +// callModel 调用模型 + 检测文件类型 + 保存临时文件 +func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) { + var data []byte + var contentType, ext, textResult string + var err error + if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" { data, err = os.ReadFile(t.TmpFile) - if err == nil && len(data) > 0 { - contentType, ext = util.DetectFileType(data) - } else { + if err != nil || len(data) == 0 { data = nil } } + if data == nil { - // 统计 _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName) - // 核心调用 data, err = InvokeModel(ctx, m, payload, t.ModelKey) if err != nil { - _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - // ============ 失败回调 ============ - t.State = 3 - t.ErrorMsg = err.Error() - go gateway.TriggerCallback(context.WithoutCancel(ctx), t) - // ================================ - return + return nil, err } - contentType, ext = util.DetectFileType(data) - if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { - textResult = string(data) - } - tmpPath, err := saveTmpResult(t.TaskID, data, ext) - if err == nil && tmpPath != "" { + tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext) + if tmpErr == nil && tmpPath != "" { t.TmpFile = tmpPath t.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) } } - // 4) 存储 OSS - ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType) + contentType, ext = util.DetectFileType(data) + if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { + textResult = string(data) + } + + return gjson.New(textResult).Map(), nil +} + +// uploadOSS 从临时文件上传 OSS +func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { + data, err := os.ReadFile(t.TmpFile) if err != nil { - // OSS 阶段失败:保留临时文件,下一轮仅重试 OSS - _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - // ============ OSS失败不回调(还会重试) ============ - // 注意:OSS失败保留临时文件,下次重试,所以这里不触发最终回调 - // 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败 - return + return nil, fmt.Errorf("读取临时文件失败: %w", err) } + _, ext := util.DetectFileType(data) + return gateway.UploadByTask(ctx, data, ext) +} - // 5) 更新任务状态成功 - fileType := strings.TrimPrefix(ext, ".") - if fileType == "" { - fileType = contentType - } - if err = dao.Task.UpdateSuccessGlobal( - ctx, - t.Id, - ossURL, - fileType, - textResult, - int64(len(data)), - nil, - GetExpendTokens(m.ResponseTokenField, textResult), - ); err != nil { - g.Log().Errorf(ctx, "[worker] update success failed: %v", err) - return - } - - // 成功/失败均不再占用 queue_limit +// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 +func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) { + t.State = 3 + t.ErrorMsg = errMsg + _ = dao.Task.UpdateFailedGlobal(ctx, t) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - - // 6) 成功回调 - t.State = 2 - t.OssFile = ossURL - t.FileType = fileType - t.TextResult = textResult - g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) - // ============ 如果有 epicycleId,也触发业务回调 ============ - if epicycleId != 0 { - go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId) - } - - // 成功后清理临时文件 - _ = os.Remove(t.TmpFile) } // saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 @@ -261,11 +274,11 @@ func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) } -// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值 -func GetExpendTokens(responseTokenField string, textResult string) int { - value := gjson.Get(textResult, responseTokenField) - if value.Exists() { - return int(value.Int()) +// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值 +func GetExpendTokens(responseTokenField string, result map[string]any) int { + val := gjson.New(result).Get(responseTokenField) + if val.IsNil() { + return 0 } - return len(textResult) + return val.Int() }