first commit

This commit is contained in:
2026-04-23 13:53:09 +08:00
commit 9de47fa5b8
34 changed files with 2764 additions and 0 deletions

224
dao/task_dao.go Normal file
View File

@@ -0,0 +1,224 @@
package dao
import (
"context"
"fmt"
"time"
"model-asynch/consts/public"
"model-asynch/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
var Task = &taskDao{}
type taskDao struct{}
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
if err != nil {
return 0, err
}
return r.LastInsertId()
}
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&t)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
data := gdb.Map{
entity.AsynchTaskCol.State: 4,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 1,
entity.AsynchTaskCol.StartedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 2,
entity.AsynchTaskCol.OssFile: ossFile,
entity.AsynchTaskCol.FileType: fileType,
entity.AsynchTaskCol.FileSize: fileSize,
entity.AsynchTaskCol.ErrorMsg: "",
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 3,
entity.AsynchTaskCol.ErrorMsg: errorMsg,
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.ModelName, modelName).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Count()
return int64(n), err
}
// ClaimPending 抢占 pending 任务state=0并在同一事务中更新为 runningstate=1
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
FROM %s
WHERE deleted_at IS NULL AND state = 0
ORDER BY created_at ASC
LIMIT %d
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
batchSize,
)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
tasks = nil
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
// 更新为 running
now := time.Now()
for _, t := range tasks {
// tx.Model 不走 gfdb Hook这里手动更新必要字段
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
if err != nil {
return err
}
}
return nil
})
return
}
// ListExpiredSuccess 获取已成功且过期的任务
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.State, 2).
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// ListTimeoutTasks 获取超时的排队/执行中任务
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
deadline := gtime.New(time.Now().Add(-timeout))
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// DebugPing 用于启动时检测数据库连通性(可选)
func (d *taskDao) DebugPing(ctx context.Context) error {
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
return err
}