package service import ( "context" "errors" "time" "model-asynch/dao" "model-asynch/model/dto" "model-asynch/model/entity" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" "github.com/google/uuid" ) var Task = &taskService{} type taskService struct{} func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { startAt := time.Now() // 固化 token/user 等信息 ctx = asyncCtx(ctx) // 1) 检查模型配置 m, err := dao.Model.GetByModelName(ctx, req.ModelName) if err != nil { return nil, err } if m == nil || m.Enabled != 1 { return nil, errors.New("模型不存在或未启用") } // 2) 排队上限(近似控制) if m.QueueLimit > 0 { cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName) if err != nil { return nil, err } if cnt >= int64(m.QueueLimit) { return nil, errors.New("任务排队已满,请稍后再试") } } taskID := uuid.NewString() // 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用 storedPayload := map[string]any{ "payload": req.RequestPayload, "headers": forwardHeaders(ctx), } t := &entity.AsynchTask{ ModelName: req.ModelName, TaskID: taskID, State: 0, InputRef: req.InputRef, RequestPayload: storedPayload, } _, err = dao.Task.Insert(ctx, t) if err != nil { return nil, err } // 3) 写操作日志(尽量不影响主流程,失败忽略) ip := "" ua := "" apiPath := "/task/createTask" httpMethod := "POST" if r := g.RequestFromCtx(ctx); r != nil { ip = r.GetClientIp() ua = r.UserAgent() apiPath = r.URL.Path httpMethod = r.Method } _, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{ IP: ip, UserAgent: ua, APIPath: apiPath, HttpMethod: httpMethod, BizName: req.BizName, ModelName: req.ModelName, TaskID: taskID, OpType: "createTask", Success: 1, ErrorMsg: "", CostMs: time.Since(startAt).Milliseconds(), RequestPayload: storedPayload, ResponsePayload: gdb.Map{ "taskId": taskID, }, }) return &dto.CreateTaskRes{TaskID: taskID}, nil } func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { t, err := dao.Task.GetByTaskID(ctx, taskID) if err != nil { return nil, err } if t == nil { return nil, errors.New("任务不存在") } return &dto.GetTaskResultRes{ OssFile: t.OssFile, State: t.State, }, nil } // GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间 func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { if req == nil || len(req.TaskIDs) == 0 { return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil } // 1) 先查当前租户下的任务列表 list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs) if err != nil { return nil, err } // 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at now := time.Now() for _, t := range list { if t == nil { continue } if t.State != 2 { continue } // 按模型配置决定保留时间 m, err := dao.Model.GetByModelName(ctx, t.ModelName) if err != nil { return nil, err } retainSeconds := 86400 if m != nil && m.AutoCleanSeconds > 0 { retainSeconds = m.AutoCleanSeconds } expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second)) _ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt) // 为了本次返回一致性,内存里也更新 t.State = 4 t.ExpireAt = expireAt } // 3) 组装返回 items := make([]dto.GetTaskBatchItem, 0, len(list)) for _, t := range list { if t == nil { continue } items = append(items, dto.GetTaskBatchItem{ TaskID: t.TaskID, State: t.State, OssFile: t.OssFile, }) } return &dto.GetTaskBatchRes{List: items}, nil } func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { pageNum, pageSize := 1, 10 if req != nil && req.Page != nil { if req.Page.PageNum > 0 { pageNum = int(req.Page.PageNum) } if req.Page.PageSize > 0 { pageSize = int(req.Page.PageSize) } } modelName := "" taskID := "" var state *int if req != nil { modelName = req.ModelName taskID = req.TaskID state = req.State } list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state) if err != nil { return nil, err } return &dto.ListTaskRes{List: list, Total: total}, nil }