495 lines
16 KiB
Go
495 lines
16 KiB
Go
package task
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
"unicode/utf8"
|
||
|
||
"model-gateway/common/util"
|
||
"model-gateway/consts/public"
|
||
"model-gateway/dao"
|
||
"model-gateway/model/dto"
|
||
"model-gateway/model/entity"
|
||
"model-gateway/service/gateway"
|
||
"model-gateway/service/queue"
|
||
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var AsyncWorker = &asyncWorker{}
|
||
|
||
type asyncWorker struct {
|
||
}
|
||
|
||
// handleOne 执行一次完整的任务
|
||
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
|
||
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
|
||
maxRetry := model.RetryTimes // 重试次数
|
||
startTime := time.Now()
|
||
|
||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||
|
||
// 1) 分布式并发控制
|
||
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
||
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||
if err != nil {
|
||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
if !acquired {
|
||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||
_ = w.rollbackToPending(ctx, task.Id)
|
||
return
|
||
}
|
||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||
|
||
// 2) 调用模型
|
||
switch {
|
||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||
rawBytes, err := w.callModelStream(ctx, task, model, body)
|
||
if err != nil {
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||
if err != nil {
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||
body, err = w.callModel(ctx, task, model, body)
|
||
if err != nil {
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
|
||
if err != nil {
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
default:
|
||
body, err = w.callModel(ctx, task, model, body)
|
||
if err != nil {
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
}
|
||
|
||
// 3) 保存临时文件
|
||
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
|
||
if err == nil && tmpPath != "" {
|
||
task.TmpFile = tmpPath
|
||
task.Phase = 1
|
||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||
}
|
||
|
||
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
|
||
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
|
||
if err != nil {
|
||
task.TextResult = body
|
||
w.failTask(ctx, task, startTime, err.Error())
|
||
return
|
||
}
|
||
|
||
// 5) 上传 OSS(可重试)
|
||
var oss *gateway.UploadFileResponse
|
||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||
if attempt > 0 {
|
||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||
}
|
||
oss, err = w.uploadOSS(ctx, task)
|
||
if err == nil {
|
||
break
|
||
}
|
||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
||
task.TaskID, attempt, maxRetry, err)
|
||
if attempt == maxRetry {
|
||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
|
||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 6) 成功回调
|
||
task.State = 2
|
||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||
task.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||
task.FileType = oss.FileFormat
|
||
task.TextResult = body
|
||
task.FileSize = int64(oss.FileSize)
|
||
|
||
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
|
||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||
return
|
||
}
|
||
|
||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
|
||
if req.EpicycleId != 0 {
|
||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
|
||
}
|
||
|
||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
|
||
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
|
||
|
||
// 7) 删除临时文件
|
||
_ = os.Remove(task.TmpFile)
|
||
}
|
||
|
||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
|
||
var data []byte
|
||
var err error
|
||
|
||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||
data, err = os.ReadFile(task.TmpFile)
|
||
if err != nil || len(data) == 0 {
|
||
data = nil
|
||
}
|
||
}
|
||
|
||
if data == nil {
|
||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
|
||
if tmpErr == nil && tmpPath != "" {
|
||
task.TmpFile = tmpPath
|
||
task.Phase = 1
|
||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||
}
|
||
}
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// asyncResult 异步任务结果
|
||
type asyncResult struct {
|
||
result map[string]any
|
||
err error
|
||
}
|
||
|
||
// asyncTaskChan 全局异步任务等待通道
|
||
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
|
||
|
||
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||
// 1. 提交异步任务
|
||
body, err := w.callModel(ctx, task, model, body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 2. 拿到 task_id
|
||
taskID := gjson.New(body).Get(model.ResponseBody).String()
|
||
|
||
// 3. 创建等待通道
|
||
ch := make(chan asyncResult, 1)
|
||
asyncTaskChan.Store(taskID, ch)
|
||
defer func() {
|
||
asyncTaskChan.Delete(taskID)
|
||
close(ch)
|
||
}()
|
||
|
||
// 4. 阻塞等待回调或超时
|
||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
|
||
|
||
select {
|
||
case res, ok := <-ch:
|
||
if !ok {
|
||
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
|
||
}
|
||
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
|
||
return res.result, res.err
|
||
case <-ctx.Done():
|
||
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
|
||
}
|
||
}
|
||
|
||
// NotifyAsyncResult 回调接口调用此方法通知结果
|
||
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
|
||
if ch, ok := asyncTaskChan.Load(taskID); ok {
|
||
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
|
||
}
|
||
}
|
||
|
||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||
// 返回: 解析后的响应体, error
|
||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||
var data []byte
|
||
var err error
|
||
|
||
// 1) 如果已有临时文件且 phase=1,直接读取
|
||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||
data, err = os.ReadFile(task.TmpFile)
|
||
if err != nil || len(data) == 0 {
|
||
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
|
||
data = nil
|
||
}
|
||
}
|
||
|
||
// 2) 没有可用数据,调用模型
|
||
if data == nil {
|
||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 3) 检测文件类型,保存临时文件
|
||
_, ext := util.DetectFileType(data)
|
||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
|
||
if tmpErr == nil && tmpPath != "" {
|
||
task.TmpFile = tmpPath
|
||
task.Phase = 1
|
||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||
}
|
||
}
|
||
|
||
// 4) 检测文件类型,提取文本结果
|
||
contentType, _ := util.DetectFileType(data)
|
||
var textResult string
|
||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||
textResult = string(data)
|
||
}
|
||
|
||
// 5) 非文本内容,返回错误
|
||
if textResult == "" {
|
||
return nil, fmt.Errorf("模型返回非文本内容,contentType=%s", contentType)
|
||
}
|
||
|
||
// 6) 解析并返回
|
||
return gjson.New(textResult).Map(), nil
|
||
}
|
||
|
||
// parseAndRetry 解析模型返回结果,并重试
|
||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||
if attempt > 0 {
|
||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||
}
|
||
|
||
// 1) 响应映射
|
||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||
if err != nil {
|
||
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||
if attempt == maxRetry {
|
||
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 2) 先存 token 到数据库,防止后续失败丢失
|
||
if tokens, ok := mapped[model.ResponseTokenField]; ok {
|
||
task.ExpendTokens = gconv.Int64(tokens)
|
||
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
|
||
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
|
||
})
|
||
}
|
||
|
||
// 3) 解析 + 校验
|
||
var parsed map[string]any
|
||
switch req.BuildType {
|
||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||
parsed, err = util.ParseAndValidate(mapped, model)
|
||
if err == nil {
|
||
return parsed, nil
|
||
}
|
||
case public.BuildTypeStruct:
|
||
parsed = util.ParseStructResult(mapped, model.ResponseBody)
|
||
return parsed, nil
|
||
default:
|
||
return mapped, nil
|
||
}
|
||
|
||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||
|
||
if attempt == maxRetry {
|
||
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
|
||
}
|
||
|
||
// 4) 重新调模型(直接调,不走缓存)
|
||
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
|
||
reqBody := util.GetModelBody(task.RequestPayload)
|
||
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
|
||
if callErr != nil {
|
||
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
||
continue
|
||
}
|
||
|
||
// 5) 解析原始响应,覆盖 body 进入下一轮
|
||
var rawResp map[string]any
|
||
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||
continue
|
||
}
|
||
body = rawResp
|
||
}
|
||
|
||
return body, nil
|
||
}
|
||
|
||
// InvokeModel 调用模型服务,返回二进制结果
|
||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
|
||
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
|
||
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||
|
||
// 2)构建请求 URL 和超时
|
||
baseURL := strings.TrimRight(model.BaseURL, "/")
|
||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||
client := &http.Client{Timeout: timeout}
|
||
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
|
||
|
||
// 3)构建 HTTP 请求
|
||
var req *http.Request
|
||
switch method {
|
||
case http.MethodGet:
|
||
q, err := util.BodyToQuery(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(q) > 0 {
|
||
if strings.Contains(baseURL, "?") {
|
||
baseURL = baseURL + "&" + q.Encode()
|
||
} else {
|
||
baseURL = baseURL + "?" + q.Encode()
|
||
}
|
||
}
|
||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||
default:
|
||
bodyBytes, err := json.Marshal(body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||
}
|
||
|
||
// 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||
req.Header.Set(hk, hv)
|
||
}
|
||
if modelKey != "" {
|
||
req.Header.Set("Authorization", "Bearer "+modelKey)
|
||
}
|
||
if method != http.MethodGet {
|
||
req.Header.Set("Content-Type", "application/json")
|
||
}
|
||
|
||
// 5)发送请求
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 6)读取响应体
|
||
b, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 7)检查 HTTP 状态码
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
msg := string(b)
|
||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||
}
|
||
return b, nil
|
||
}
|
||
|
||
// // InvokeModel 调用模型服务,返回二进制结果
|
||
//
|
||
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||
// if m == nil || m.BaseURL == "" {
|
||
// return nil, fmt.Errorf("模型配置不完整")
|
||
// }
|
||
// // 请求参数映射
|
||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||
// if err != nil {
|
||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||
// }
|
||
// // 合并请求头
|
||
// headers := util.ForwardHeaders(ctx)
|
||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||
// headers[hk] = hv
|
||
// }
|
||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||
// headers[hk] = hv
|
||
// }
|
||
//
|
||
// // 设置超时
|
||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||
// if timeout <= 0 {
|
||
// timeout = 600 * time.Second
|
||
// }
|
||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||
// defer cancel()
|
||
//
|
||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||
// if method == "" {
|
||
// method = http.MethodPost
|
||
// }
|
||
//
|
||
// var respBytes []byte
|
||
//
|
||
// switch method {
|
||
// case http.MethodGet:
|
||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||
// default:
|
||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||
// }
|
||
// if err != nil {
|
||
// return nil, err
|
||
// }
|
||
// // 响应参数映射
|
||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||
// if err != nil {
|
||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||
// return respBytes, nil
|
||
// }
|
||
// return mappedResponse, nil
|
||
// }
|
||
|
||
// uploadOSS 从临时文件上传 OSS
|
||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||
data, err := os.ReadFile(t.TmpFile)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取临时文件失败: %w", err)
|
||
}
|
||
_, ext := util.DetectFileType(data)
|
||
return gateway.UploadByTask(ctx, data, ext)
|
||
}
|
||
|
||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
|
||
t.State = 3
|
||
t.ErrorMsg = errMsg
|
||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||
}
|
||
|
||
// rollbackToPending 恢复任务状态为 PENDING
|
||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||
}
|