package service import ( "context" "fmt" "strings" "sync" "time" "model-asynch/dao" "model-asynch/model/entity" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" ) var AsyncWorker = &asyncWorker{} type asyncWorker struct { mu sync.Mutex pool *grpool.Pool closed bool } func (w *asyncWorker) Start(ctx context.Context) { if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() { g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=false,worker 未启动") return } w.mu.Lock() defer w.mu.Unlock() if w.pool != nil && !w.pool.IsClosed() { return } limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int() if limit <= 0 { limit = 1 } w.pool = grpool.New(limit) w.closed = false go w.pollLoop(ctx) g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit) } // Stop 关闭协程池,确保 Ctrl+C 能完整退出。 func (w *asyncWorker) Stop(ctx context.Context) { w.mu.Lock() defer w.mu.Unlock() if w.pool != nil && !w.pool.IsClosed() { w.pool.Close() w.closed = true g.Log().Infof(ctx, "[worker] stopped") } } func (w *asyncWorker) pollLoop(ctx context.Context) { pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String() pollInterval, _ := time.ParseDuration(pollIntervalStr) if pollInterval <= 0 { pollInterval = time.Second } batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int() if batchSize <= 0 { batchSize = 1 } g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize) ticker := time.NewTicker(pollInterval) defer ticker.Stop() for { select { case <-ctx.Done(): w.Stop(ctx) return case <-ticker.C: tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize) if err != nil { g.Log().Errorf(ctx, "[worker] claim pending error: %v", err) continue } if len(tasks) == 0 { continue } for _, t := range tasks { task := t // 防止闭包捕获循环变量 w.mu.Lock() p := w.pool w.mu.Unlock() if p == nil || p.IsClosed() { // 池已关闭,回滚任务 _ = w.rollbackToPending(ctx, task.Id) continue } _ = p.AddWithRecover(ctx, func(ctx context.Context) { w.handleOne(ctx, task) }, func(ctx context.Context, err error) { if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err)) } }) } } } } func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) { // 从任务入库的 request_payload 里恢复 payload + headers,给 OSS 上传透传鉴权用 payload, headers := parseStoredPayload(t.RequestPayload) if len(headers) > 0 { ctx = setTaskHeadersToCtx(ctx, headers) } // 1) 拉取模型配置 m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) return } if m == nil || m.Enabled != 1 { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用") return } // 2) 分布式并发限制(按 tenant+model) semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName) leaseSeconds := int64(3600) // 兜底1小时 acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) return } if !acquired { // 并发满了:放回排队(重新置回 state=0),下一轮再抢占 _ = w.rollbackToPending(ctx, t.Id) return } defer func() { _ = releaseSemaphore(ctx, semKey) }() // 3) 调用模型服务 if payload == nil { payload = map[string]any{ "taskId": t.TaskID, "inputRef": t.InputRef, } } data, err := InvokeModel(ctx, m, payload) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) return } // 4) 存储 OSS/MinIO contentType, ext := DetectFileType(data) ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType) if err != nil { _ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error()) return } // 5) 更新任务状态成功 // 注意:expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。 fileType := strings.TrimPrefix(ext, ".") if fileType == "" { fileType = contentType } if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil { g.Log().Errorf(ctx, "[worker] update success failed: %v", err) return } } func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { return dao.Task.RollbackToPendingGlobal(ctx, id) }