package service import ( "context" "errors" "fmt" "model-gateway/common/util" "model-gateway/model/dto" "model-gateway/service/gateway" "os" "path/filepath" "strings" "time" "unicode/utf8" "model-gateway/dao" "model-gateway/model/entity" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { } // RunOnce 由上层定时任务触发:一次性抢占并处理一批任务 // - batchSize: 本次抢占数量 // - goroutines: 本次并发数(协程池大小) func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { if req.BatchSize <= 0 { req.BatchSize = 10 } if req.Goroutines <= 0 { req.Goroutines = 1 } tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize) if err != nil { return nil, err } if len(tasks) == 0 { return nil, errors.New("no task to run") } pool := grpool.New(req.Goroutines) defer pool.Close() claimed := len(tasks) done := make(chan struct{}, claimed) for _, t := range tasks { task := t _ = pool.AddWithRecover(ctx, func(ctx context.Context) { w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0}) done <- struct{}{} }, func(ctx context.Context, e error) { if e != nil { task.ErrorMsg = fmt.Sprintf("worker panic: %v", e) _ = dao.Task.UpdateFailedGlobal(ctx, task) ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} }) } for i := 0; i < claimed; i++ { <-done } return &dto.RunWorkRes{ Claimed: claimed, }, nil } // RunByTaskID 创建任务后立即异步尝试执行当前任务: // - 只定向抢占当前 taskId 对应的 pending 任务 // - 若任务已被其它 worker 抢走/已不在 pending,则直接返回 func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error { task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) if err != nil { return err } if task == nil { return nil } w.handleOne(ctx, task, req) return nil } // handleOne 执行一次完整的任务 func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) { payload := util.ParseStoredPayload(t.RequestPayload) maxRetry := 0 // 后面从 model 取 g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName) // 1) 获取模型配置 model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil || model == nil { w.failTask(ctx, t, "模型不存在或未启用") return } maxRetry = model.RetryTimes // 2) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName) maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency) acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { w.failTask(ctx, t, err.Error()) return } if !acquired { g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID) _ = w.rollbackToPending(ctx, t.Id) return } defer func() { _ = releaseSemaphore(ctx, semKey) }() // 3) request_payload 校验 if payload == nil { w.failTask(ctx, t, "request_payload 为空") return } // 4) 调用模型(不重试,失败直接回调) textResult, err := w.callModel(ctx, t, model, payload) if err != nil { w.failTask(ctx, t, err.Error()) return } // 5) 上传 OSS(可重试) var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) } oss, err = w.uploadOSS(ctx, t) if err == nil { break } g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", t.TaskID, attempt, maxRetry, err) if attempt == maxRetry { _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } } // 6) 解析校验(可重试,失败重新调模型) if req.BuildType == 1 { for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) } err = util.ValidatePromptResult(textResult, model.RequestMapping) if err == nil { break } g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", t.TaskID, attempt, maxRetry, err) if attempt == maxRetry { w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) return } // 重新调模型 newResult, modelErr := w.callModel(ctx, t, model, payload) if modelErr != nil { g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", t.TaskID, attempt, maxRetry, modelErr) continue } textResult = newResult } } // 7) 成功回调 t.State = 2 t.OssFile = oss.FileAddressPrefix + oss.FileURL t.FileType = oss.FileFormat t.TextResult = textResult t.FileSize = int64(oss.FileSize) t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult)) if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err) return } ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId) } g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s", t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL) _ = os.Remove(t.TmpFile) } // 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) // callModel 调用模型 + 检测文件类型 + 保存临时文件 func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) { var data []byte var contentType, ext, textResult string var err error if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" { data, err = os.ReadFile(t.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName) data, err = InvokeModel(ctx, m, payload, t.ModelKey) if err != nil { return nil, err } tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext) if tmpErr == nil && tmpPath != "" { t.TmpFile = tmpPath t.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) } } contentType, ext = util.DetectFileType(data) if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } return gjson.New(textResult).Map(), nil } // uploadOSS 从临时文件上传 OSS func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) if err != nil { return nil, fmt.Errorf("读取临时文件失败: %w", err) } _, ext := util.DetectFileType(data) return gateway.UploadByTask(ctx, data, ext) } // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) { t.State = 3 t.ErrorMsg = errMsg _ = dao.Task.UpdateFailedGlobal(ctx, t) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) } // saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 func saveTmpResult(taskID string, data []byte, ext string) (string, error) { dir := filepath.Join(os.TempDir(), "model-asynch") if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } if ext == "" { ext = ".bin" } if ext[0] != '.' { ext = "." + ext } path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) if err := os.WriteFile(path, data, 0o644); err != nil { return "", err } return path, nil } func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) } // GetExpendTokens 根据映射路径从 result 中提取消耗 token 值 func GetExpendTokens(responseTokenField string, result map[string]any) int { val := gjson.New(result).Get(responseTokenField) if val.IsNil() { return 0 } return val.Int() }