- 新增操作日志表(asynch_op_log)及对应DAO,记录任务创建等操作的审计信息 - 新增任务分页查询接口(ListTask)及对应DTO、Service和DAO方法 - 优化模型调用失败重试逻辑:支持配置重试排队策略(插队到队首或队尾) - 新增临时文件存储机制,当模型调用成功但OSS上传失败时,下次仅重试OSS上传 - 模型配置新增retry_queue_max_seconds字段,控制失败重试排队策略 - 更新数据库表结构(asynch_models、asynch_task、新增asynch_op_log)及同步更新SQL - 配置文件调整:超时单位改为秒,更新服务地址和轮询间隔 - 修复模型列表查询支持按名称模糊搜索
251 lines
7.0 KiB
Go
251 lines
7.0 KiB
Go
package dao
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"model-asynch/consts/public"
|
||
"model-asynch/model/entity"
|
||
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/os/gtime"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var Task = &taskDao{}
|
||
|
||
type taskDao struct{}
|
||
|
||
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return r.LastInsertId()
|
||
}
|
||
|
||
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||
One()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if r.IsEmpty() {
|
||
return nil, nil
|
||
}
|
||
err = r.Struct(&t)
|
||
return
|
||
}
|
||
|
||
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
|
||
if len(taskIDs) == 0 {
|
||
return nil, nil
|
||
}
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
||
All()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
|
||
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
|
||
data := gdb.Map{
|
||
entity.AsynchTaskCol.State: 4,
|
||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||
entity.AsynchTaskCol.Updater: "",
|
||
}
|
||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.Id, id).
|
||
Where(entity.AsynchTaskCol.State, 2).
|
||
Data(data).
|
||
Update()
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
|
||
now := gtime.Now()
|
||
data := gdb.Map{
|
||
entity.AsynchTaskCol.State: 1,
|
||
entity.AsynchTaskCol.StartedAt: now,
|
||
entity.AsynchTaskCol.Updater: "",
|
||
}
|
||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.Id, id).
|
||
Data(data).
|
||
Update()
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
||
now := gtime.Now()
|
||
data := gdb.Map{
|
||
entity.AsynchTaskCol.State: 2,
|
||
entity.AsynchTaskCol.OssFile: ossFile,
|
||
entity.AsynchTaskCol.FileType: fileType,
|
||
entity.AsynchTaskCol.FileSize: fileSize,
|
||
entity.AsynchTaskCol.ErrorMsg: "",
|
||
entity.AsynchTaskCol.FinishedAt: now,
|
||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||
entity.AsynchTaskCol.Updater: "",
|
||
}
|
||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.Id, id).
|
||
Data(data).
|
||
Update()
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
|
||
now := gtime.Now()
|
||
data := gdb.Map{
|
||
entity.AsynchTaskCol.State: 3,
|
||
entity.AsynchTaskCol.ErrorMsg: errorMsg,
|
||
entity.AsynchTaskCol.FinishedAt: now,
|
||
entity.AsynchTaskCol.Updater: "",
|
||
}
|
||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.Id, id).
|
||
Data(data).
|
||
Update()
|
||
return err
|
||
}
|
||
|
||
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||
Delete()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return r.RowsAffected()
|
||
}
|
||
|
||
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
|
||
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
|
||
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.ModelName, modelName).
|
||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||
Count()
|
||
return int64(n), err
|
||
}
|
||
|
||
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
||
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
||
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||
if modelNameLike != "" {
|
||
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
|
||
}
|
||
if taskIDLike != "" {
|
||
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
|
||
}
|
||
if state != nil {
|
||
m = m.Where(entity.AsynchTaskCol.State, *state)
|
||
}
|
||
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
|
||
if pageNum > 0 && pageSize > 0 {
|
||
m = m.Page(pageNum, pageSize)
|
||
}
|
||
r, totalInt, err := m.AllAndCount(false)
|
||
if err != nil {
|
||
return nil, 0, err
|
||
}
|
||
total = gconv.Int64(totalInt)
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1)
|
||
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
|
||
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||
if batchSize <= 0 {
|
||
batchSize = 1
|
||
}
|
||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||
sql := fmt.Sprintf(
|
||
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
|
||
FROM %s
|
||
WHERE deleted_at IS NULL AND state = 0
|
||
ORDER BY created_at ASC
|
||
LIMIT %d
|
||
FOR UPDATE SKIP LOCKED`,
|
||
public.TableNameTask,
|
||
batchSize,
|
||
)
|
||
r, err := tx.GetAll(sql)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if r.IsEmpty() {
|
||
tasks = nil
|
||
return nil
|
||
}
|
||
if err := r.Structs(&tasks); err != nil {
|
||
return err
|
||
}
|
||
// 更新为 running
|
||
now := time.Now()
|
||
for _, t := range tasks {
|
||
// tx.Model 不走 gfdb Hook,这里手动更新必要字段
|
||
_, err = tx.Exec(
|
||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||
now, now, t.Id,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
})
|
||
return
|
||
}
|
||
|
||
// ListExpiredSuccess 获取已成功且过期的任务
|
||
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 100
|
||
}
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
Where(entity.AsynchTaskCol.State, 2).
|
||
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
|
||
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
|
||
Limit(limit).
|
||
All()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// ListTimeoutTasks 获取超时的排队/执行中任务
|
||
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
|
||
if limit <= 0 {
|
||
limit = 100
|
||
}
|
||
deadline := gtime.New(time.Now().Add(-timeout))
|
||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
|
||
Limit(limit).
|
||
All()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
err = r.Structs(&list)
|
||
return
|
||
}
|
||
|
||
// DebugPing 用于启动时检测数据库连通性(可选)
|
||
func (d *taskDao) DebugPing(ctx context.Context) error {
|
||
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
|
||
return err
|
||
}
|