- 新增统计控制器、服务层与数据访问层,提供按天统计接口 - 在 worker 处理任务时原子累加请求计数(仅实际调用模型时计数) - 更新数据库表结构,添加 asynch_model_stat 表及索引 - 更新文档说明统计功能的使用方式与统计口径
216 lines
6.5 KiB
Go
216 lines
6.5 KiB
Go
package dao
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"model-asynch/consts/public"
|
||
"model-asynch/model/entity"
|
||
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/os/gtime"
|
||
)
|
||
|
||
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
|
||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||
if batchSize <= 0 {
|
||
batchSize = 1
|
||
}
|
||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||
sql := fmt.Sprintf(
|
||
`SELECT id, tenant_id, creator, model_name, task_id, input_ref, request_payload, phase, tmp_file
|
||
FROM %s
|
||
WHERE deleted_at IS NULL AND state = 0
|
||
ORDER BY enqueue_at ASC
|
||
LIMIT %d
|
||
FOR UPDATE SKIP LOCKED`,
|
||
public.TableNameTask,
|
||
batchSize,
|
||
)
|
||
r, err := tx.GetAll(sql)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if r.IsEmpty() {
|
||
tasks = nil
|
||
return nil
|
||
}
|
||
if err := r.Structs(&tasks); err != nil {
|
||
return err
|
||
}
|
||
now := time.Now()
|
||
for _, t := range tasks {
|
||
_, err = tx.Exec(
|
||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||
now, now, t.Id,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
return
|
||
}
|
||
|
||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
||
now := gtime.Now()
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET state=2, oss_file=?, file_type=?, file_size=?, error_msg='', finished_at=?, expire_at=NULL, phase=0, tmp_file='', updated_at=? WHERE id=?`, public.TableNameTask),
|
||
ossFile, fileType, fileSize, now, now, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||
now := gtime.Now()
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=0, tmp_file='', updated_at=? WHERE id=?`, public.TableNameTask),
|
||
errorMsg, now, now, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file,下一轮仅重试 OSS 上传
|
||
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||
now := gtime.Now()
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask),
|
||
errorMsg, now, now, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
|
||
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask),
|
||
tmpFile, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
|
||
taskID,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||
id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
|
||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 200
|
||
}
|
||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
|
||
gtime.Now(), limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
|
||
// retry_count 不含首次执行;retry_times 表示失败后最多再重试 N 次
|
||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 200
|
||
}
|
||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||
fmt.Sprintf(`
|
||
SELECT t.*,
|
||
m.retry_queue_max_seconds AS retry_queue_max_seconds
|
||
FROM %s t
|
||
JOIN %s m
|
||
ON t.tenant_id = m.tenant_id
|
||
AND t.model_name = m.model_name
|
||
WHERE t.deleted_at IS NULL
|
||
AND t.state = 3
|
||
AND t.retry_count < m.retry_times
|
||
ORDER BY t.updated_at ASC
|
||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||
limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// RequeueForRetryGlobal 将任务重新入队(state=0),并将 retry_count +1
|
||
// enqueueAt 用于控制重试任务在队列中的位置:
|
||
// - enqueueAt 越早,越靠前(ClaimPendingGlobal 按 enqueue_at ASC 抢占)
|
||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
|
||
enqueueAt, id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
|
||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 200
|
||
}
|
||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||
fmt.Sprintf(`
|
||
SELECT t.*
|
||
FROM %s t
|
||
JOIN %s m
|
||
ON t.tenant_id = m.tenant_id
|
||
AND t.model_name = m.model_name
|
||
WHERE t.deleted_at IS NULL
|
||
AND t.state = 3
|
||
AND t.retry_count >= m.retry_times
|
||
ORDER BY t.updated_at ASC
|
||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||
limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// HardDeleteByIDGlobal 硬删除任务记录
|
||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
|
||
id,
|
||
)
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 200
|
||
}
|
||
deadline := gtime.New(time.Now().Add(-timeout))
|
||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state IN (0,1) AND updated_at < ? LIMIT ?`, public.TableNameTask),
|
||
deadline, limit,
|
||
)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|