package task import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "model-gateway/common/util" "model-gateway/model/dto" "model-gateway/service/gateway" "model-gateway/service/queue" "net/http" "os" "strings" "sync" "time" "unicode/utf8" "model-gateway/dao" "model-gateway/model/entity" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { } // RunOnce 由上层定时任务触发:一次性抢占并处理一批任务 // - batchSize: 本次抢占数量 // - goroutines: 本次并发数(协程池大小) func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { if req.BatchSize <= 0 { req.BatchSize = 10 } if req.Goroutines <= 0 { req.Goroutines = 1 } tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize) if err != nil { return nil, err } if len(tasks) == 0 { return nil, errors.New("no task to run") } pool := grpool.New(req.Goroutines) defer pool.Close() claimed := len(tasks) done := make(chan struct{}, claimed) for _, t := range tasks { task := t _ = pool.AddWithRecover(ctx, func(ctx context.Context) { //w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0}) done <- struct{}{} }, func(ctx context.Context, e error) { if e != nil { task.ErrorMsg = fmt.Sprintf("worker panic: %v", e) _ = dao.Task.UpdateFailedGlobal(ctx, task) queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} }) } for i := 0; i < claimed; i++ { <-done } return &dto.RunWorkRes{ Claimed: claimed, }, nil } // handleOne 执行一次完整的任务 func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) { body := util.GetModelBody(task.RequestPayload) //核心请求参数 maxRetry := model.RetryTimes //重试次数 g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) // 1) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName) maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency) acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { w.failTask(ctx, task, err.Error()) return } if !acquired { g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) _ = w.rollbackToPending(ctx, task.Id) return } defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() // 2) request_payload 校验 if body == nil { w.failTask(ctx, task, "请求模型为空") return } // 3) 调用模型 switch { case model.IsStream != nil && *model.IsStream == 1: // 流式调用 rawBytes, err := w.callModelStream(ctx, task, model, body) if err != nil { w.failTask(ctx, task, err.Error()) return } // 解析流式结果 body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) if err != nil { w.failTask(ctx, task, err.Error()) return } case model.IsAsync != nil && *model.IsAsync == 1: // 异步调用:注入回调地址后提交,拿到 task_id 轮询 // 异步调用:提交任务 body, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, err.Error()) return } // 拿到 task_id,启动轮询 taskID := gjson.New(body).Get(model.ResponseBody).String() body, err = util.PullTaskResult(ctx, taskID, model.QueryConfig) if err != nil { w.failTask(ctx, task, err.Error()) return } default: // 同步调用 body, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, err.Error()) return } } // 5) 解析响应映射 body, err = util.MapResponsePayload(model.ResponseMapping, body) if err != nil { w.failTask(ctx, task, err.Error()) return } // 5) 保存临时文件(通用工具方法) tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, body, task.TmpFile) if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } // 6) 上传 OSS(可重试) var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } oss, err = w.uploadOSS(ctx, task) if err == nil { break } g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error()) w.failTask(ctx, task, fmt.Sprintf("OSS上传重试耗尽: %v", err)) return } } // 7) 解析校验(可重试,失败重新调模型) if req.BuildType == 1 { for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) } // 6.1) 校验数据 err = util.ValidatePromptResult(body, model) if err == nil { break } g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) if attempt == maxRetry { w.failTask(ctx, task, fmt.Sprintf("JSON解析重试耗尽: %v", err)) return } // 6.2) 重新调模型 newResult, modelErr := w.callModel(ctx, task, model, body) if modelErr != nil { g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, modelErr) continue } body = newResult } } // 8) 成功回调 task.State = 2 task.OssFile = oss.FileAddressPrefix + oss.FileURL task.FileType = oss.FileFormat task.TextResult = body task.FileSize = int64(oss.FileSize) task.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, body)) if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) return } queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), task) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) } g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s", task.TaskID, oss.FileFormat, len(body), task.CallbackURL) // 9) 删除临时文件 _ = os.Remove(task.TmpFile) } // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) { var data []byte var err error if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body, task.ModelKey) if err != nil { return nil, err } tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "") if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } return data, nil } // asyncResult 异步任务结果 type asyncResult struct { result map[string]any err error } // asyncTaskChan 全局异步任务等待通道 var asyncTaskChan = sync.Map{} // taskID → chan asyncResult func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { // 1. 提交异步任务 result, err := w.callModel(ctx, task, model, body) if err != nil { return nil, err } // 2. 拿到 task_id taskID := gjson.New(result).Get(model.ResponseBody).String() // 3. 创建等待通道 ch := make(chan asyncResult, 1) asyncTaskChan.Store(taskID, ch) defer func() { asyncTaskChan.Delete(taskID) close(ch) }() // 4. 阻塞等待回调或超时 timeout := time.Duration(model.TimeoutSeconds) * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout) select { case res, ok := <-ch: if !ok { return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID) } g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID) return res.result, res.err case <-ctx.Done(): return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID) } } // NotifyAsyncResult 回调接口调用此方法通知结果 func NotifyAsyncResult(taskID string, result map[string]any, err error) { if ch, ok := asyncTaskChan.Load(taskID); ok { ch.(chan asyncResult) <- asyncResult{result: result, err: err} } } // 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) // callModel 调用模型 + 检测文件类型 + 保存临时文件 func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { var data []byte var contentType, ext, textResult string var err error if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body, task.ModelKey) if err != nil { return nil, err } tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext) if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } contentType, ext = util.DetectFileType(data) if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } return gjson.New(textResult).Map(), nil } // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) { // 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 //mappedPayload := util.ReverseMap(model.RequestMapping, payload) // 2)构建请求 URL 和超时 baseURL := strings.TrimRight(model.BaseURL, "/") timeout := time.Duration(model.TimeoutSeconds) * time.Second client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) // 3)构建 HTTP 请求 var req *http.Request switch method { case http.MethodGet: q, err := util.BodyToQuery(body) if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(baseURL, "?") { baseURL = baseURL + "&" + q.Encode() } else { baseURL = baseURL + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) default: bodyBytes, err := json.Marshal(body) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) } // 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { req.Header.Set(hk, hv) } if modelKey != "" { req.Header.Set("Authorization", "Bearer "+modelKey) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } // 5)发送请求 resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() // 6)读取响应体 b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 7)检查 HTTP 状态码 if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg := string(b) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } return b, nil } // // InvokeModel 调用模型服务,返回二进制结果 // // func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { // if m == nil || m.BaseURL == "" { // return nil, fmt.Errorf("模型配置不完整") // } // // 请求参数映射 // mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) // if err != nil { // return nil, fmt.Errorf("请求参数映射失败: %w", err) // } // // 合并请求头 // headers := util.ForwardHeaders(ctx) // for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { // headers[hk] = hv // } // for hk, hv := range parseHeadMsgHeaders(modelKey) { // headers[hk] = hv // } // // // 设置超时 // timeout := time.Duration(m.TimeoutSeconds) * time.Second // if timeout <= 0 { // timeout = 600 * time.Second // } // ctx, cancel := context.WithTimeout(ctx, timeout) // defer cancel() // // invokeUrl := strings.TrimRight(m.BaseURL, "/") // method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) // if method == "" { // method = http.MethodPost // } // // var respBytes []byte // // switch method { // case http.MethodGet: // err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) // default: // err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) // } // if err != nil { // return nil, err // } // // 响应参数映射 // mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) // if err != nil { // g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) // return respBytes, nil // } // return mappedResponse, nil // } // uploadOSS 从临时文件上传 OSS func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) if err != nil { return nil, fmt.Errorf("读取临时文件失败: %w", err) } _, ext := util.DetectFileType(data) return gateway.UploadByTask(ctx, data, ext) } // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) { t.State = 3 t.ErrorMsg = errMsg _ = dao.Task.UpdateFailedGlobal(ctx, t) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) } // rollbackToPending 恢复任务状态为 PENDING func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) } // GetExpendTokens 根据映射路径从 result 中提取消耗 token 值 func GetExpendTokens(responseTokenField string, result map[string]any) int { val := gjson.New(result).Get(responseTokenField) if val.IsNil() { return 0 } return val.Int() }