first commit

This commit is contained in:
2026-04-23 13:53:09 +08:00
commit 9de47fa5b8
34 changed files with 2764 additions and 0 deletions

129
service/task_service.go Normal file
View File

@@ -0,0 +1,129 @@
package service
import (
"context"
"errors"
"time"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, err
}
if m == nil || m.Enabled != 1 {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限(近似控制)
if m.QueueLimit > 0 {
cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName)
if err != nil {
return nil, err
}
if cnt >= int64(m.QueueLimit) {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
taskID := uuid.NewString()
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
InputRef: req.InputRef,
RequestPayload: storedPayload,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
return nil, err
}
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}