Files
model-asynch/service/worker.go
WangLiZhao 4e6b98b7d3 feat(stat): 添加模型请求按天统计功能
- 新增统计控制器、服务层与数据访问层,提供按天统计接口
- 在 worker 处理任务时原子累加请求计数(仅实际调用模型时计数)
- 更新数据库表结构,添加 asynch_model_stat 表及索引
- 更新文档说明统计功能的使用方式与统计口径
2026-04-27 10:42:42 +08:00

209 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
mu sync.Mutex
pool *grpool.Pool
closed bool
}
func (w *asyncWorker) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=falseworker 未启动")
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
return
}
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
if limit <= 0 {
limit = 1
}
w.pool = grpool.New(limit)
w.closed = false
go w.pollLoop(ctx)
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
}
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
func (w *asyncWorker) Stop(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
w.pool.Close()
w.closed = true
g.Log().Infof(ctx, "[worker] stopped")
}
}
func (w *asyncWorker) pollLoop(ctx context.Context) {
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
pollInterval, _ := time.ParseDuration(pollIntervalStr)
if pollInterval <= 0 {
pollInterval = time.Second
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
if batchSize <= 0 {
batchSize = 1
}
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
w.Stop(ctx)
return
case <-ticker.C:
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
continue
}
if len(tasks) == 0 {
continue
}
for _, t := range tasks {
task := t // 防止闭包捕获循环变量
w.mu.Lock()
p := w.pool
w.mu.Unlock()
if p == nil || p.IsClosed() {
// 池已关闭,回滚任务
_ = w.rollbackToPending(ctx, task.Id)
continue
}
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
}, func(ctx context.Context, err error) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
}
})
}
}
}
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
return
}
// 2) 分布式并发限制(按 tenant+model
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
var (
data []byte
contentType string
ext string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
// 临时文件不可用:回退重新调用模型
data = nil
}
}
if data == nil {
// 统计:仅在真正请求模型时 +1OSS 重试不计入)
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
data, err = InvokeModel(ctx, m, payload)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
contentType, ext = DetectFileType(data)
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}