package dao import ( "context" "fmt" "time" "model-gateway/consts/public" "model-gateway/model/entity" "gitea.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) var Task = &taskDao{} type taskDao struct{} func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) { r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert() if err != nil { return 0, err } return r.LastInsertId() } func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) { r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.TaskID, taskID). One() if err != nil { return nil, err } if r.IsEmpty() { return nil, nil } err = r.Struct(&t) return } // ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据) func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) { if len(taskIDs) == 0 { return nil, nil } r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). WhereIn(entity.AsynchTaskCol.TaskID, taskIDs). All() if err != nil { return nil, err } err = r.Structs(&list) return } // MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间 func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error { data := gdb.Map{ entity.AsynchTaskCol.State: 4, entity.AsynchTaskCol.ExpireAt: expireAt, entity.AsynchTaskCol.Updater: "", } _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Where(entity.AsynchTaskCol.State, 2). Data(data). Update() return err } func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error { now := gtime.Now() data := gdb.Map{ entity.AsynchTaskCol.State: 1, entity.AsynchTaskCol.StartedAt: now, entity.AsynchTaskCol.Updater: "", } _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Data(data). Update() return err } func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error { now := gtime.Now() data := gdb.Map{ entity.AsynchTaskCol.State: 2, entity.AsynchTaskCol.OssFile: ossFile, entity.AsynchTaskCol.FileType: fileType, entity.AsynchTaskCol.FileSize: fileSize, entity.AsynchTaskCol.ErrorMsg: "", entity.AsynchTaskCol.FinishedAt: now, entity.AsynchTaskCol.ExpireAt: expireAt, entity.AsynchTaskCol.Updater: "", } _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Data(data). Update() return err } func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error { now := gtime.Now() data := gdb.Map{ entity.AsynchTaskCol.State: 3, entity.AsynchTaskCol.ErrorMsg: errorMsg, entity.AsynchTaskCol.FinishedAt: now, entity.AsynchTaskCol.Updater: "", } _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Data(data). Update() return err } func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) { r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.TaskID, taskID). Delete() if err != nil { return 0, err } return r.RowsAffected() } // CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值) func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) { n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.ModelName, modelName). WhereIn(entity.AsynchTaskCol.State, []int{0, 1}). Count() return int64(n), err } // List 任务分页查询(受 gfdb 租户 Hook 影响) func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) { m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL") if modelNameLike != "" { m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%") } if taskIDLike != "" { m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%") } if state != nil { m = m.Where(entity.AsynchTaskCol.State, *state) } m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt) if pageNum > 0 && pageSize > 0 { m = m.Page(pageNum, pageSize) } r, totalInt, err := m.AllAndCount(false) if err != nil { return nil, 0, err } total = gconv.Int64(totalInt) err = r.Structs(&list) return } // ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1) // 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费 func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) { if batchSize <= 0 { batchSize = 1 } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { sql := fmt.Sprintf( `SELECT id, tenant_id, model_name, task_id, input_ref, request_payload FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY created_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, public.TableNameTask, batchSize, ) r, err := tx.GetAll(sql) if err != nil { return err } if r.IsEmpty() { tasks = nil return nil } if err := r.Structs(&tasks); err != nil { return err } // 更新为 running now := time.Now() for _, t := range tasks { // tx.Model 不走 gfdb Hook,这里手动更新必要字段 _, err = tx.Exec( fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id, ) if err != nil { return err } } return nil }) return } // ListExpiredSuccess 获取已成功且过期的任务 func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 100 } r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.State, 2). Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL"). Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()). Limit(limit). All() if err != nil { return nil, err } err = r.Structs(&list) return } // ListTimeoutTasks 获取超时的排队/执行中任务 func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) { if limit <= 0 { limit = 100 } deadline := gtime.New(time.Now().Add(-timeout)) r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). WhereIn(entity.AsynchTaskCol.State, []int{0, 1}). Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline). Limit(limit). All() if err != nil { return nil, err } err = r.Structs(&list) return } // DebugPing 用于启动时检测数据库连通性(可选) func (d *taskDao) DebugPing(ctx context.Context) error { _, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1") return err }