130 lines
3.1 KiB
Go
130 lines
3.1 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"time"
|
||
|
||
"model-asynch/dao"
|
||
"model-asynch/model/dto"
|
||
"model-asynch/model/entity"
|
||
|
||
"github.com/gogf/gf/v2/os/gtime"
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
var Task = &taskService{}
|
||
|
||
type taskService struct{}
|
||
|
||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||
// 固化 token/user 等信息
|
||
ctx = asyncCtx(ctx)
|
||
|
||
// 1) 检查模型配置
|
||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if m == nil || m.Enabled != 1 {
|
||
return nil, errors.New("模型不存在或未启用")
|
||
}
|
||
|
||
// 2) 排队上限(近似控制)
|
||
if m.QueueLimit > 0 {
|
||
cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if cnt >= int64(m.QueueLimit) {
|
||
return nil, errors.New("任务排队已满,请稍后再试")
|
||
}
|
||
}
|
||
|
||
taskID := uuid.NewString()
|
||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||
storedPayload := map[string]any{
|
||
"payload": req.RequestPayload,
|
||
"headers": forwardHeaders(ctx),
|
||
}
|
||
|
||
t := &entity.AsynchTask{
|
||
ModelName: req.ModelName,
|
||
TaskID: taskID,
|
||
State: 0,
|
||
InputRef: req.InputRef,
|
||
RequestPayload: storedPayload,
|
||
}
|
||
_, err = dao.Task.Insert(ctx, t)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||
}
|
||
|
||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if t == nil {
|
||
return nil, errors.New("任务不存在")
|
||
}
|
||
return &dto.GetTaskResultRes{
|
||
OssFile: t.OssFile,
|
||
State: t.State,
|
||
}, nil
|
||
}
|
||
|
||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||
if req == nil || len(req.TaskIDs) == 0 {
|
||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||
}
|
||
// 1) 先查当前租户下的任务列表
|
||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||
now := time.Now()
|
||
for _, t := range list {
|
||
if t == nil {
|
||
continue
|
||
}
|
||
if t.State != 2 {
|
||
continue
|
||
}
|
||
// 按模型配置决定保留时间
|
||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
retainSeconds := 86400
|
||
if m != nil && m.AutoCleanSeconds > 0 {
|
||
retainSeconds = m.AutoCleanSeconds
|
||
}
|
||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||
|
||
// 为了本次返回一致性,内存里也更新
|
||
t.State = 4
|
||
t.ExpireAt = expireAt
|
||
}
|
||
|
||
// 3) 组装返回
|
||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||
for _, t := range list {
|
||
if t == nil {
|
||
continue
|
||
}
|
||
items = append(items, dto.GetTaskBatchItem{
|
||
TaskID: t.TaskID,
|
||
State: t.State,
|
||
OssFile: t.OssFile,
|
||
})
|
||
}
|
||
return &dto.GetTaskBatchRes{List: items}, nil
|
||
}
|