refactor(service): 重构模型网关服务结构
This commit is contained in:
@@ -24,7 +24,6 @@ import (
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -34,11 +33,13 @@ type asyncWorker struct {
|
||||
}
|
||||
|
||||
// handleOne 执行一次完整的任务
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
|
||||
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
|
||||
maxRetry := model.RetryTimes // 重试次数
|
||||
startTime := time.Now()
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
|
||||
var (
|
||||
body = task.RequestPayload.Body // 核心请求参数
|
||||
maxRetry = model.RetryTimes // 重试次数
|
||||
startTime = time.Now()
|
||||
modelMessages = map[string]any{}
|
||||
)
|
||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||
|
||||
// 1) 分布式并发控制
|
||||
@@ -51,8 +52,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Id: task.Id,
|
||||
},
|
||||
State: public.TaskStatusPending,
|
||||
})
|
||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||||
_ = w.rollbackToPending(ctx, task.Id)
|
||||
return
|
||||
}
|
||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||
@@ -65,24 +71,24 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
modelMessages, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
modelMessages, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
|
||||
modelMessages, err = util.PullTaskResult(ctx, modelMessages, model.QueryConfig, model.HeadMsg)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
default:
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
modelMessages, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
@@ -90,20 +96,20 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
}
|
||||
|
||||
// 3) 保存临时文件
|
||||
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
|
||||
tmpPath, err := util.SaveTempFileByType(task.TaskID, modelMessages, task.TmpFile)
|
||||
if err == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_, err = dao.Task.Update(ctx, task)
|
||||
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
|
||||
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
|
||||
modelMessages, err = w.parseAndRetry(ctx, modelMessages, task, model, req, maxRetry, startTime)
|
||||
if err != nil {
|
||||
task.TextResult = body
|
||||
task.TextResult = modelMessages
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -123,9 +129,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
if attempt == maxRetry {
|
||||
task.State = 3
|
||||
task.ErrorMsg = err.Error()
|
||||
task.FinishedAt = gtime.Now()
|
||||
task.Phase = 1
|
||||
_, err = dao.Task.Update(ctx, task)
|
||||
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
}
|
||||
@@ -137,12 +142,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
// 6) 成功回调
|
||||
task.State = 2
|
||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
task.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||||
task.FileType = oss.FileFormat
|
||||
task.TextResult = body
|
||||
task.FileSize = int64(oss.FileSize)
|
||||
|
||||
if _, err = dao.Task.Update(ctx, task); err != nil {
|
||||
task.ResultFile = &entity.ResultFile{
|
||||
OssFile: oss.FileAddressPrefix + oss.FileURL,
|
||||
FileType: oss.FileFormat,
|
||||
FileSize: int64(oss.FileSize),
|
||||
}
|
||||
task.TextResult = modelMessages
|
||||
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
return
|
||||
}
|
||||
@@ -161,7 +167,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
||||
}
|
||||
|
||||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
|
||||
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
@@ -173,8 +179,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
data, err = InvokeModel(ctx, model, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,7 +187,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_, err = dao.Task.Update(ctx, task)
|
||||
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||
}
|
||||
@@ -201,7 +206,7 @@ type asyncResult struct {
|
||||
// asyncTaskChan 全局异步任务等待通道
|
||||
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
|
||||
|
||||
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
|
||||
// 1. 提交异步任务
|
||||
body, err := w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
@@ -246,7 +251,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
|
||||
|
||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||
// 返回: 解析后的响应体, error
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
@@ -261,8 +266,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
||||
|
||||
// 2) 没有可用数据,调用模型
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
data, err = InvokeModel(ctx, model, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -273,7 +277,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_, err = dao.Task.Update(ctx, task)
|
||||
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||
}
|
||||
@@ -297,7 +301,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
|
||||
}
|
||||
|
||||
// parseAndRetry 解析模型返回结果,并重试
|
||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
@@ -316,7 +320,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
||||
// 2) 先存 token 到数据库,防止后续失败丢失
|
||||
if _, ok := mapped[model.ResponseTokenField]; ok {
|
||||
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
|
||||
_, err = dao.Task.Update(ctx, &entity.AsynchTask{
|
||||
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||
ExpendTokens: task.ExpendTokens,
|
||||
})
|
||||
@@ -344,9 +348,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
||||
}
|
||||
|
||||
// 4) 重新调模型(直接调,不走缓存)
|
||||
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
|
||||
reqBody := util.GetModelBody(task.RequestPayload)
|
||||
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
|
||||
task.RetryCount++
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
|
||||
|
||||
if callErr != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
||||
continue
|
||||
@@ -354,7 +359,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
||||
|
||||
// 5) 解析原始响应,覆盖 body 进入下一轮
|
||||
var rawResp map[string]any
|
||||
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
||||
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||
continue
|
||||
}
|
||||
@@ -366,18 +371,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
|
||||
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
||||
// 1) 记录模型调用次数
|
||||
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
|
||||
|
||||
// 2)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
|
||||
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||||
|
||||
// 2)构建请求 URL 和超时
|
||||
// 3)构建请求 URL 和超时
|
||||
baseURL := strings.TrimRight(model.BaseURL, "/")
|
||||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||
client := &http.Client{Timeout: timeout}
|
||||
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
|
||||
|
||||
// 3)构建 HTTP 请求
|
||||
// 4)构建 HTTP 请求
|
||||
var req *http.Request
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
@@ -401,31 +409,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||
// 5)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if modelKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+modelKey)
|
||||
if model.ApiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+model.ApiKey)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// 5)发送请求
|
||||
// 6)发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 6)读取响应体
|
||||
// 7)读取响应体
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 7)检查 HTTP 状态码
|
||||
// 8)检查 HTTP 状态码
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
@@ -488,7 +496,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
|
||||
// }
|
||||
|
||||
// uploadOSS 从临时文件上传 OSS
|
||||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.ModelGatewayTask) (*gateway.UploadFileResponse, error) {
|
||||
data, err := os.ReadFile(t.TmpFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取临时文件失败: %w", err)
|
||||
@@ -498,19 +506,14 @@ func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gat
|
||||
}
|
||||
|
||||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
|
||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
_, err := dao.Task.Update(ctx, t)
|
||||
_, err := dao.ModelGatewayTask.Update(ctx, t)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
|
||||
}
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
}
|
||||
|
||||
// rollbackToPending 恢复任务状态为 PENDING
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user