Files
model-asynch/service/task_service.go
2026-04-23 13:53:09 +08:00

130 lines
3.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"time"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, err
}
if m == nil || m.Enabled != 1 {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限(近似控制)
if m.QueueLimit > 0 {
cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName)
if err != nil {
return nil, err
}
if cnt >= int64(m.QueueLimit) {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
taskID := uuid.NewString()
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
InputRef: req.InputRef,
RequestPayload: storedPayload,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
return nil, err
}
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}