first commit

This commit is contained in:
2026-04-23 13:53:09 +08:00
commit 9de47fa5b8
34 changed files with 2764 additions and 0 deletions

177
service/worker.go Normal file
View File

@@ -0,0 +1,177 @@
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
mu sync.Mutex
pool *grpool.Pool
closed bool
}
func (w *asyncWorker) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=falseworker 未启动")
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
return
}
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
if limit <= 0 {
limit = 1
}
w.pool = grpool.New(limit)
w.closed = false
go w.pollLoop(ctx)
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
}
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
func (w *asyncWorker) Stop(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
w.pool.Close()
w.closed = true
g.Log().Infof(ctx, "[worker] stopped")
}
}
func (w *asyncWorker) pollLoop(ctx context.Context) {
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
pollInterval, _ := time.ParseDuration(pollIntervalStr)
if pollInterval <= 0 {
pollInterval = time.Second
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
if batchSize <= 0 {
batchSize = 1
}
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
w.Stop(ctx)
return
case <-ticker.C:
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
continue
}
if len(tasks) == 0 {
continue
}
for _, t := range tasks {
task := t // 防止闭包捕获循环变量
w.mu.Lock()
p := w.pool
w.mu.Unlock()
if p == nil || p.IsClosed() {
// 池已关闭,回滚任务
_ = w.rollbackToPending(ctx, task.Id)
continue
}
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
}, func(ctx context.Context, err error) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
}
})
}
}
}
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
return
}
// 2) 分布式并发限制(按 tenant+model
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
data, err := InvokeModel(ctx, m, payload)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 4) 存储 OSS/MinIO
contentType, ext := DetectFileType(data)
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}