285 lines
8.4 KiB
Go
285 lines
8.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"model-gateway/common/util"
|
||
"model-gateway/model/dto"
|
||
"model-gateway/service/gateway"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
"unicode/utf8"
|
||
|
||
"model-gateway/dao"
|
||
"model-gateway/model/entity"
|
||
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/os/grpool"
|
||
)
|
||
|
||
var AsyncWorker = &asyncWorker{}
|
||
|
||
type asyncWorker struct {
|
||
}
|
||
|
||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||
// - batchSize: 本次抢占数量
|
||
// - goroutines: 本次并发数(协程池大小)
|
||
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||
if req.BatchSize <= 0 {
|
||
req.BatchSize = 10
|
||
}
|
||
if req.Goroutines <= 0 {
|
||
req.Goroutines = 1
|
||
}
|
||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(tasks) == 0 {
|
||
return nil, errors.New("no task to run")
|
||
}
|
||
pool := grpool.New(req.Goroutines)
|
||
defer pool.Close()
|
||
claimed := len(tasks)
|
||
done := make(chan struct{}, claimed)
|
||
for _, t := range tasks {
|
||
task := t
|
||
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
||
w.handleOne(ctx, task, &dto.CreateTaskReq{EpicycleId: 0})
|
||
done <- struct{}{}
|
||
}, func(ctx context.Context, e error) {
|
||
if e != nil {
|
||
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
|
||
_ = dao.Task.UpdateFailedGlobal(ctx, task)
|
||
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||
}
|
||
done <- struct{}{}
|
||
})
|
||
}
|
||
for i := 0; i < claimed; i++ {
|
||
<-done
|
||
}
|
||
return &dto.RunWorkRes{
|
||
Claimed: claimed,
|
||
}, nil
|
||
}
|
||
|
||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||
// - 只定向抢占当前 taskId 对应的 pending 任务
|
||
// - 若任务已被其它 worker 抢走/已不在 pending,则直接返回
|
||
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, req *dto.CreateTaskReq) error {
|
||
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if task == nil {
|
||
return nil
|
||
}
|
||
w.handleOne(ctx, task, req)
|
||
return nil
|
||
}
|
||
|
||
// handleOne 执行一次完整的任务
|
||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *dto.CreateTaskReq) {
|
||
payload := util.ParseStoredPayload(t.RequestPayload)
|
||
maxRetry := 0 // 后面从 model 取
|
||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", t.TaskID, t.ModelName)
|
||
|
||
// 1) 获取模型配置
|
||
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||
if err != nil || model == nil {
|
||
w.failTask(ctx, t, "模型不存在或未启用")
|
||
return
|
||
}
|
||
maxRetry = model.RetryTimes
|
||
|
||
// 2) 分布式并发控制
|
||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
||
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
|
||
if err != nil {
|
||
w.failTask(ctx, t, err.Error())
|
||
return
|
||
}
|
||
if !acquired {
|
||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", t.TaskID)
|
||
_ = w.rollbackToPending(ctx, t.Id)
|
||
return
|
||
}
|
||
defer func() { _ = releaseSemaphore(ctx, semKey) }()
|
||
|
||
// 3) request_payload 校验
|
||
if payload == nil {
|
||
w.failTask(ctx, t, "request_payload 为空")
|
||
return
|
||
}
|
||
|
||
// 4) 调用模型(不重试,失败直接回调)
|
||
textResult, err := w.callModel(ctx, t, model, payload)
|
||
if err != nil {
|
||
w.failTask(ctx, t, err.Error())
|
||
return
|
||
}
|
||
|
||
// 5) 上传 OSS(可重试)
|
||
var oss *gateway.UploadFileResponse
|
||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||
if attempt > 0 {
|
||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||
}
|
||
oss, err = w.uploadOSS(ctx, t)
|
||
if err == nil {
|
||
break
|
||
}
|
||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
||
t.TaskID, attempt, maxRetry, err)
|
||
if attempt == maxRetry {
|
||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||
w.failTask(ctx, t, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 6) 解析校验(可重试,失败重新调模型)
|
||
if req.BuildType == 1 {
|
||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||
if attempt > 0 {
|
||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||
}
|
||
err = util.ValidatePromptResult(textResult, model.RequestMapping)
|
||
if err == nil {
|
||
break
|
||
}
|
||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||
t.TaskID, attempt, maxRetry, err)
|
||
if attempt == maxRetry {
|
||
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||
return
|
||
}
|
||
// 重新调模型
|
||
newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||
if modelErr != nil {
|
||
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||
t.TaskID, attempt, maxRetry, modelErr)
|
||
continue
|
||
}
|
||
textResult = newResult
|
||
}
|
||
}
|
||
|
||
// 7) 成功回调
|
||
t.State = 2
|
||
t.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||
t.FileType = oss.FileFormat
|
||
t.TextResult = textResult
|
||
t.FileSize = int64(oss.FileSize)
|
||
t.ExpendTokens = int64(GetExpendTokens(model.ResponseTokenField, textResult))
|
||
|
||
if err = dao.Task.UpdateSuccessGlobal(ctx, t); err != nil {
|
||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", t.TaskID, err)
|
||
return
|
||
}
|
||
|
||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||
if req.EpicycleId != 0 {
|
||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
|
||
}
|
||
|
||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
|
||
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
|
||
_ = os.Remove(t.TmpFile)
|
||
}
|
||
|
||
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||
var data []byte
|
||
var contentType, ext, textResult string
|
||
var err error
|
||
|
||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||
data, err = os.ReadFile(t.TmpFile)
|
||
if err != nil || len(data) == 0 {
|
||
data = nil
|
||
}
|
||
}
|
||
|
||
if data == nil {
|
||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
|
||
if tmpErr == nil && tmpPath != "" {
|
||
t.TmpFile = tmpPath
|
||
t.Phase = 1
|
||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||
}
|
||
}
|
||
|
||
contentType, ext = util.DetectFileType(data)
|
||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||
textResult = string(data)
|
||
}
|
||
|
||
return gjson.New(textResult).Map(), nil
|
||
}
|
||
|
||
// uploadOSS 从临时文件上传 OSS
|
||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||
data, err := os.ReadFile(t.TmpFile)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取临时文件失败: %w", err)
|
||
}
|
||
_, ext := util.DetectFileType(data)
|
||
return gateway.UploadByTask(ctx, data, ext)
|
||
}
|
||
|
||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg string) {
|
||
t.State = 3
|
||
t.ErrorMsg = errMsg
|
||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||
}
|
||
|
||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||
return "", err
|
||
}
|
||
if ext == "" {
|
||
ext = ".bin"
|
||
}
|
||
if ext[0] != '.' {
|
||
ext = "." + ext
|
||
}
|
||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||
return "", err
|
||
}
|
||
return path, nil
|
||
}
|
||
|
||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||
}
|
||
|
||
// GetExpendTokens 根据映射路径从 result 中提取消耗 token 值
|
||
func GetExpendTokens(responseTokenField string, result map[string]any) int {
|
||
val := gjson.New(result).Get(responseTokenField)
|
||
if val.IsNil() {
|
||
return 0
|
||
}
|
||
return val.Int()
|
||
}
|