refactor(service): 重构服务模块结构并优化模型配置
This commit is contained in:
269
service/task/task_service.go
Normal file
269
service/task/task_service.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/service/queue"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// Create 创建任务
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
startAt := time.Now()
|
||||
taskID := uuid.NewString()
|
||||
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 插入任务记录
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": util.ForwardHeaders(ctx),
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
ModelKey: m.ApiKey,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
EpicycleId: req.EpicycleId,
|
||||
})
|
||||
if err != nil {
|
||||
// 入库失败:回滚闸门占位
|
||||
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = util.GetLocalIP()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
ErrorMsg: "",
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: storedPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
|
||||
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
||||
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
|
||||
// - 目标:尽快把刚创建的任务拉起来执行
|
||||
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
||||
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
||||
// - 不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
|
||||
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
|
||||
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
|
||||
return true
|
||||
}
|
||||
if t == nil {
|
||||
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
|
||||
return true
|
||||
}
|
||||
|
||||
switch t.State {
|
||||
case 0:
|
||||
//RunByTaskID 尝试执行任务
|
||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
|
||||
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
|
||||
}
|
||||
return false
|
||||
case 1:
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
|
||||
return true
|
||||
case 2, 3, 4:
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
|
||||
return true
|
||||
default:
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
|
||||
return true
|
||||
}
|
||||
}
|
||||
// 立即尝试一次
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-pollCtx.Done():
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetResult 获取任务结果
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, errors.New("任务不存在")
|
||||
}
|
||||
return &dto.GetTaskResultRes{
|
||||
OssFile: t.OssFile,
|
||||
State: t.State,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
if req == nil || len(req.TaskIDs) == 0 {
|
||||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||||
}
|
||||
// 1) 先查当前租户下的任务列表
|
||||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||||
now := time.Now()
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.State != 2 {
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: t.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retainSeconds := 86400
|
||||
if m != nil && m.AutoCleanSeconds > 0 {
|
||||
retainSeconds = m.AutoCleanSeconds
|
||||
}
|
||||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||||
|
||||
// 为了本次返回一致性,内存里也更新
|
||||
t.State = 4
|
||||
t.ExpireAt = expireAt
|
||||
}
|
||||
|
||||
// 3) 组装返回
|
||||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, dto.GetTaskBatchItem{
|
||||
TaskID: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
})
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
// List 获取任务列表
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil {
|
||||
if req.PageNum > 0 {
|
||||
pageNum = req.PageNum
|
||||
}
|
||||
if req.PageSize > 0 {
|
||||
pageSize = req.PageSize
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
taskID := ""
|
||||
var state *int
|
||||
if req != nil {
|
||||
modelName = req.ModelName
|
||||
taskID = req.TaskID
|
||||
state = req.State
|
||||
}
|
||||
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListTaskRes{List: list, Total: total}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user