Files
model-gateway/service/worker.go

285 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"model-gateway/model/dto"
"model-gateway/service/gateway"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
"model-gateway/dao"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
if req.BatchSize <= 0 {
req.BatchSize = 10
}
if req.Goroutines <= 0 {
req.Goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
if err != nil {
return nil, err
}
if len(tasks) == 0 {
return nil, errors.New("no task to run")
}
pool := grpool.New(req.Goroutines)
defer pool.Close()
claimed := len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
_ = dao.Task.UpdateFailedGlobal(ctx, task)
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return &dto.RunWorkRes{
Claimed: claimed,
}, nil
}
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, req)
return nil
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
payload := util.ParseStoredPayload(t.RequestPayload)
maxRetry := 0 // 后面从 model 取
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
// 1) 获取模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil {
w.failTask(ctx, t, "模型不存在或未启用")
return
}
maxRetry = model.RetryTimes
// 2) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
if !acquired {
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() { _ = releaseSemaphore(ctx, semKey) }()
// 3) request_payload 校验
if payload == nil {
w.failTask(ctx, t, "request_payload 为空")
return
}
// 4) 调用模型(不重试,失败直接回调)
textResult, err := w.callModel(ctx, t, model, payload)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
// 5) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
oss, err = w.uploadOSS(ctx, t)
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 6) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
err = util.ValidatePromptResult(textResult, model.RequestMapping)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
}
}
// 7) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
t.TextResult = textResult
t.FileSize = int64(oss.FileSize)
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
return
}
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
_ = os.Remove(t.TmpFile)
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = os.ReadFile(t.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
contentType, ext = util.DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
return gjson.New(textResult).Map(), nil
}
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
_ = dao.Task.UpdateFailedGlobal(ctx, t)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值
func GetExpendTokens(responseTokenField string, result map[string]any) int {
val := gjson.New(result).Get(responseTokenField)
if val.IsNil() {
return 0
}
return val.Int()
}