package service import ( "context" "errors" "fmt" "time" "model-gateway/dao" "model-gateway/model/dto" "model-gateway/model/entity" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/google/uuid" ) var Task = &taskService{} type taskService struct{} func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { fmt.Printf("打印请求:%+v", req) startAt := time.Now() // 固化 token/user 等信息 ctx = asyncCtx(ctx) // 1) 检查模型配置 m, err := dao.Model.GetByModelName(ctx, req.ModelName) if err != nil { return nil, err } if m == nil || (m.Enabled != nil && *m.Enabled != 1) { return nil, errors.New("模型不存在或未启用") } taskID := uuid.NewString() // 2) 排队上限(严格控制:Redis 原子闸门) limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit) if limit > 0 { ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds) if err != nil { return nil, err } if !ok { return nil, errors.New("任务排队已满,请稍后再试") } } // 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用 storedPayload := map[string]any{ "payload": req.RequestPayload, "headers": forwardHeaders(ctx), } t := &entity.AsynchTask{ ModelName: req.ModelName, TaskID: taskID, State: 0, BizName: req.BizName, CallbackURL: req.CallbackUrl, ModelKey: m.ApiKey, InputRef: req.InputRef, RequestPayload: storedPayload, EpicycleId: req.EpicycleId, } _, err = dao.Task.Insert(ctx, t) if err != nil { // 入库失败:回滚闸门占位 ReleaseQueueSlot(ctx, req.ModelName, taskID) return nil, err } // 3) 写操作日志(尽量不影响主流程,失败忽略) ip := "" ua := "" apiPath := "/task/createTask" httpMethod := "POST" if r := g.RequestFromCtx(ctx); r != nil { ip = r.GetClientIp() ua = r.UserAgent() apiPath = r.URL.Path httpMethod = r.Method } _, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{ IP: ip, UserAgent: ua, APIPath: apiPath, HttpMethod: httpMethod, BizName: req.BizName, ModelName: req.ModelName, TaskID: taskID, OpType: "createTask", Success: 1, ErrorMsg: "", CostMs: time.Since(startAt).Milliseconds(), RequestPayload: storedPayload, ResponsePayload: gdb.Map{ "taskId": taskID, }, }) // 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。 // 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。 go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId) return &dto.CreateTaskRes{TaskID: taskID}, nil } // pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”: // - 目标:尽快把刚创建的任务拉起来执行 // - 只在任务仍为 pending(state=0) 时继续尝试抢占 // - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止 // - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务 func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) { if taskID == "" { return } interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int() if interval <= 0 { interval = 5 } g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval) ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() tryRun := func() bool { t, err := dao.Task.GetByTaskID(ctx, taskID) if err != nil { g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err) return true } if t == nil { g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID) return true } switch t.State { case 0: if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil { g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err) } else { g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID) } return false case 1: g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID) return true case 2, 3, 4: g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State) return true default: g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State) return true } } // 先立即尝试一次 if stop := tryRun(); stop { return } for { select { case <-ctx.Done(): g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID) return case <-ticker.C: if stop := tryRun(); stop { return } } } } func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { t, err := dao.Task.GetByTaskID(ctx, taskID) if err != nil { return nil, err } if t == nil { return nil, errors.New("任务不存在") } return &dto.GetTaskResultRes{ OssFile: t.OssFile, State: t.State, }, nil } // GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间 func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { if req == nil || len(req.TaskIDs) == 0 { return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil } // 1) 先查当前租户下的任务列表 list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs) if err != nil { return nil, err } // 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at now := time.Now() for _, t := range list { if t == nil { continue } if t.State != 2 { continue } // 按模型配置决定保留时间 m, err := dao.Model.GetByModelName(ctx, t.ModelName) if err != nil { return nil, err } retainSeconds := 86400 if m != nil && m.AutoCleanSeconds > 0 { retainSeconds = m.AutoCleanSeconds } expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second)) _ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt) // 为了本次返回一致性,内存里也更新 t.State = 4 t.ExpireAt = expireAt } // 3) 组装返回 items := make([]dto.GetTaskBatchItem, 0, len(list)) for _, t := range list { if t == nil { continue } items = append(items, dto.GetTaskBatchItem{ TaskID: t.TaskID, State: t.State, OssFile: t.OssFile, }) } return &dto.GetTaskBatchRes{List: items}, nil } func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { pageNum, pageSize := 1, 10 if req != nil { if req.PageNum > 0 { pageNum = req.PageNum } if req.PageSize > 0 { pageSize = req.PageSize } } modelName := "" taskID := "" var state *int if req != nil { modelName = req.ModelName taskID = req.TaskID state = req.State } list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state) if err != nil { return nil, err } return &dto.ListTaskRes{List: list, Total: total}, nil }