Files
model-asynch/service/worker.go
WangLiZhao f6c70a451e feat: 新增操作日志、任务分页查询与模型失败重试优化
- 新增操作日志表(asynch_op_log)及对应DAO,记录任务创建等操作的审计信息
- 新增任务分页查询接口(ListTask)及对应DTO、Service和DAO方法
- 优化模型调用失败重试逻辑:支持配置重试排队策略(插队到队首或队尾)
- 新增临时文件存储机制,当模型调用成功但OSS上传失败时,下次仅重试OSS上传
- 模型配置新增retry_queue_max_seconds字段,控制失败重试排队策略
- 更新数据库表结构(asynch_models、asynch_task、新增asynch_op_log)及同步更新SQL
- 配置文件调整:超时单位改为秒,更新服务地址和轮询间隔
- 修复模型列表查询支持按名称模糊搜索
2026-04-25 10:42:21 +08:00

206 lines
5.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
mu sync.Mutex
pool *grpool.Pool
closed bool
}
func (w *asyncWorker) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=falseworker 未启动")
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
return
}
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
if limit <= 0 {
limit = 1
}
w.pool = grpool.New(limit)
w.closed = false
go w.pollLoop(ctx)
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
}
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
func (w *asyncWorker) Stop(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
w.pool.Close()
w.closed = true
g.Log().Infof(ctx, "[worker] stopped")
}
}
func (w *asyncWorker) pollLoop(ctx context.Context) {
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
pollInterval, _ := time.ParseDuration(pollIntervalStr)
if pollInterval <= 0 {
pollInterval = time.Second
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
if batchSize <= 0 {
batchSize = 1
}
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
w.Stop(ctx)
return
case <-ticker.C:
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
continue
}
if len(tasks) == 0 {
continue
}
for _, t := range tasks {
task := t // 防止闭包捕获循环变量
w.mu.Lock()
p := w.pool
w.mu.Unlock()
if p == nil || p.IsClosed() {
// 池已关闭,回滚任务
_ = w.rollbackToPending(ctx, task.Id)
continue
}
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
}, func(ctx context.Context, err error) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
}
})
}
}
}
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
return
}
// 2) 分布式并发限制(按 tenant+model
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
var (
data []byte
contentType string
ext string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
// 临时文件不可用:回退重新调用模型
data = nil
}
}
if data == nil {
data, err = InvokeModel(ctx, m, payload)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
contentType, ext = DetectFileType(data)
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}