feat: 重构异步模型字段并更新依赖
This commit is contained in:
@@ -13,28 +13,53 @@ import (
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
|
||||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||||
// ======================== 查询辅助 ========================
|
||||
|
||||
// taskColumns 查询用的公共字段
|
||||
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
|
||||
|
||||
// ======================== 事务抢占 ========================
|
||||
|
||||
// claimTasks 事务内抢占任务并更新 state=1
|
||||
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
|
||||
r, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var task entity.AsynchTask
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tasks = []*entity.AsynchTask{&task}
|
||||
return nil
|
||||
})
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingGlobal 批量抢占 pending 任务
|
||||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0
|
||||
ORDER BY enqueue_at ASC
|
||||
LIMIT %d
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
batchSize,
|
||||
)
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
|
||||
r, err := tx.GetAll(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
tasks = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Structs(&tasks); err != nil {
|
||||
@@ -42,234 +67,148 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks
|
||||
}
|
||||
now := time.Now()
|
||||
for _, t := range tasks {
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, t.Id,
|
||||
)
|
||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
|
||||
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) {
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
|
||||
if taskID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
)
|
||||
r, err := tx.GetOne(sql, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
task = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, task.Id,
|
||||
)
|
||||
return err
|
||||
})
|
||||
return
|
||||
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
|
||||
if err != nil || len(tasks) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return tasks[0], nil
|
||||
}
|
||||
|
||||
// ======================== 更新辅助 ========================
|
||||
|
||||
func execSQL(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// updateTask 通用更新
|
||||
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSuccessGlobal 更新任务成功
|
||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.Id, t.Id).
|
||||
Data(entity.AsynchTask{
|
||||
State: 2,
|
||||
OssFile: t.OssFile,
|
||||
FileType: t.FileType,
|
||||
TextResult: t.TextResult,
|
||||
FileSize: t.FileSize,
|
||||
ErrorMsg: "",
|
||||
FinishedAt: now,
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
ExpendTokens: t.ExpendTokens,
|
||||
}).
|
||||
Update()
|
||||
return err
|
||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
||||
State: 2,
|
||||
OssFile: t.OssFile,
|
||||
FileType: t.FileType,
|
||||
TextResult: t.TextResult,
|
||||
FileSize: t.FileSize,
|
||||
ErrorMsg: "",
|
||||
FinishedAt: gtime.Now(),
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
ExpendTokens: t.ExpendTokens,
|
||||
DurationSeconds: t.DurationSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFailedGlobal 模型调用失败
|
||||
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
||||
now := gtime.Now()
|
||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
||||
State: 3,
|
||||
ErrorMsg: t.ErrorMsg,
|
||||
FinishedAt: gtime.Now(),
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
TextResult: t.TextResult,
|
||||
DurationSeconds: t.DurationSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateFailedKeepTmpGlobal OSS 上传失败
|
||||
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
|
||||
}
|
||||
|
||||
// UpdateTmpAfterModelGlobal 写临时文件
|
||||
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
|
||||
}
|
||||
|
||||
// RollbackToPendingGlobal 回滚
|
||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// IncRetryCountGlobal 重试计数+1
|
||||
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// RequeueForRetryGlobal 重新入队
|
||||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
|
||||
}
|
||||
|
||||
// ======================== 列表查询 ========================
|
||||
|
||||
// ListExpiredDownloadedGlobal
|
||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// ListFailedRetryableGlobal
|
||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// ListFailedExhaustedGlobal
|
||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// ListTimeoutTasksGlobal
|
||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// HardDeleteByIDGlobal
|
||||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// ======================== 内部辅助 ========================
|
||||
|
||||
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var list []*entity.AsynchTask
|
||||
err = r.Structs(&list)
|
||||
return list, err
|
||||
}
|
||||
|
||||
func clampLimit(limit, defaultVal int) int {
|
||||
if limit <= 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// UpdateColumns 更新指定字段(结构体版)
|
||||
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.Id, t.Id).
|
||||
Data(entity.AsynchTask{
|
||||
State: 3,
|
||||
ErrorMsg: t.ErrorMsg,
|
||||
FinishedAt: now,
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
}).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file,下一轮仅重试 OSS 上传
|
||||
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
errorMsg, now, now, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
|
||||
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask),
|
||||
tmpFile, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
|
||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
|
||||
gtime.Now(), limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
|
||||
// retry_count 不含首次执行;retry_times 表示失败后最多再重试 N 次
|
||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*,
|
||||
m.retry_queue_max_seconds AS retry_queue_max_seconds
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state = 3
|
||||
AND t.retry_count < m.retry_times
|
||||
ORDER BY t.updated_at ASC
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// RequeueForRetryGlobal 将任务重新入队(state=0),并将 retry_count +1
|
||||
// enqueueAt 用于控制重试任务在队列中的位置:
|
||||
// - enqueueAt 越早,越靠前(ClaimPendingGlobal 按 enqueue_at ASC 抢占)
|
||||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
|
||||
enqueueAt, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
|
||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state = 3
|
||||
AND t.retry_count >= m.retry_times
|
||||
ORDER BY t.updated_at ASC
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// HardDeleteByIDGlobal 硬删除任务记录
|
||||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务:
|
||||
// - state in (0,1)
|
||||
// - 模型 expected_seconds > 0
|
||||
// - now - created_at >= expected_seconds
|
||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state IN (0,1)
|
||||
AND m.expected_seconds > 0
|
||||
AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval)
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user