5 Commits

Author SHA1 Message Date
4cc44bf57c fix(task): 修正错误信息注入逻辑 2026-06-18 16:27:11 +08:00
dc06d1bb9a refactor(config): 更新数据库配置和任务处理逻辑 2026-06-18 14:48:42 +08:00
ecaaa5bdbc refactor(files): 优化文件处理和任务服务逻辑 2026-06-18 13:39:40 +08:00
qhd
b21d7a8dbf fix: 修复请求头转发与任务状态流转问题
移除 util.ForwardHeaders,改为从原始请求精确提取 Authorization 或全部请求头;
任务创建时直接设为 Running 状态,避免二次更新与查询;
模型调用使用独立超时上下文,防止外层取消影响回调;
增加 OSS 上传耗时日志,调整数据库连接池参数。
2026-06-18 10:08:36 +08:00
fddaf36f48 refactor(model): 优化模型网关的数据解析和任务处理逻辑 2026-06-17 14:34:48 +08:00
12 changed files with 346 additions and 385 deletions

View File

@@ -1,7 +1,6 @@
package util package util
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"os" "os"
@@ -9,107 +8,75 @@ import (
"strings" "strings"
) )
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定) // DetectFileType 根据返回的二进制内容推断 contentType + 扩展名
func DetectFileType(data []byte) (contentType string, ext string) { func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 { if len(data) == 0 {
return "application/octet-stream", "" return "application/octet-stream", ".bin"
} }
ct := http.DetectContentType(data) ct := http.DetectContentType(data)
// gateway.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 { if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx]) ct = strings.TrimSpace(ct[:idx])
} }
switch ct { switch ct {
case "audio/mpeg": case "audio/mpeg":
return ct, ".mp3" return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav": case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav" return ct, ".wav"
case "audio/mp4", "audio/x-m4a":
return ct, ".m4a"
case "video/mp4": case "video/mp4":
return ct, ".mp4" return ct, ".mp4"
case "video/webm":
return ct, ".webm"
case "image/png": case "image/png":
return ct, ".png" return ct, ".png"
case "image/jpeg": case "image/jpeg":
return ct, ".jpg" return ct, ".jpg"
case "image/gif":
return ct, ".gif"
case "image/webp":
return ct, ".webp"
case "application/pdf": case "application/pdf":
return ct, ".pdf" return ct, ".pdf"
case "text/plain": case "text/plain":
return ct, ".txt" return ct, ".txt"
case "application/json": case "application/json":
return ct, ".json" return ct, ".json"
case "application/zip":
return ct, ".zip"
case "application/octet-stream":
return ct, ".bin"
default: default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 { if parts := strings.Split(ct, "/"); len(parts) == 2 {
sub := parts[1] sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 { if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx]) sub = strings.TrimSpace(sub[:idx])
} }
return ct, "." + sub return ct, "." + sub
} }
return ct, "" return ct, ".bin"
} }
} }
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 // SaveTmpResult 将二进制数据写入临时文件
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) { func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch") dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err return "", fmt.Errorf("创建临时目录失败: %w", err)
} }
if ext == "" { if ext == "" {
ext = ".bin" ext = ".bin"
} }
if ext[0] != '.' { if ext[0] != '.' {
ext = "." + ext ext = "." + ext
} }
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil { if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err return "", fmt.Errorf("写入临时文件失败: %w", err)
} }
return path, nil return path, nil
} }
// SaveTempFileByType
// 根据传入的数据自动判断:
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
// 若是任意结构体/map → 自动转 JSON 保存
// 返回:新临时文件路径、错误
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
// 1. 先清理旧临时文件(统一逻辑)
if oldTmpFile != "" {
_ = os.Remove(oldTmpFile)
}
var tmpPath string
var tmpErr error
// 2. 判断是否是二进制音频([]byte + .mp3
if audioData, ok := data.([]byte); ok {
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
} else {
// 3. 其他类型 → 序列化为 JSON 保存
mappedBytes, err := json.Marshal(data)
if err != nil {
return "", err
}
if len(mappedBytes) == 0 {
return "", nil
}
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
}
if tmpErr != nil || tmpPath == "" {
return "", tmpErr
}
return tmpPath, nil
}
// saveTmpResult 你原有的底层保存文件方法(保留不动)
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
// 你原来实现,比如:
filename := taskID + ext
tmpPath := filepath.Join(os.TempDir(), filename)
err := os.WriteFile(tmpPath, data, 0644)
return tmpPath, err
}

View File

@@ -19,17 +19,20 @@ import (
tgjson "github.com/tidwall/gjson" tgjson "github.com/tidwall/gjson"
) )
// ParseAndValidate 解析并校验结果 // ParseAndValidate 解析模型响应,并返回标准格式
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组 contentStr := gconv.String(raw[entity.ResponseBody])
contentVal, ok := raw[entity.ResponseBody] if strings.TrimSpace(contentStr) == "" {
if !ok { return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
return raw, fmt.Errorf("字段 %s 不存在", entity.ResponseBody)
} }
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" { contentStr = strings.Map(func(r rune) rune {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", entity.ResponseBody) if r < 32 && r != ' ' {
return -1
} }
return r
}, contentStr)
var arr []any var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err) return raw, fmt.Errorf("JSON解析失败: %w", err)
@@ -38,20 +41,14 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
return raw, fmt.Errorf("解析后数组为空") return raw, fmt.Errorf("解析后数组为空")
} }
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields { for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() { for i, r := range arr {
round, _ := r.(map[string]any)
if round != nil && gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
} }
} }
} }
}
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
} }

View File

@@ -28,10 +28,10 @@ database:
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效 timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
model_gateway: model_gateway:
- type: "pgsql" - type: "pgsql"
host: "116.204.74.41" host: "192.168.3.30"
port: "15432" port: "5432"
user: "postgres" user: "postgres"
pass: "Bjang09@686^*^" pass: "123456"
name: "model-gateway" name: "model-gateway"
prefix: "" prefix: ""
role: "master" role: "master"
@@ -39,8 +39,8 @@ database:
dryRun: false dryRun: false
charset: "utf8" charset: "utf8"
timezone: "Asia/Shanghai" timezone: "Asia/Shanghai"
maxIdle: 5 maxIdle: 15
maxOpen: 20 maxOpen: 60
maxLifetime: "30s" maxLifetime: "30s"
maxIdleConnTime: "30s" maxIdleConnTime: "30s"
createdAt: "created_at" createdAt: "created_at"

View File

@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
Where(entity.ModelGatewayModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) { func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := ` sql := `
SELECT DISTINCT ON (model_name) * SELECT DISTINCT ON (model_name) *
FROM asynch_models FROM ` + public.TableNameModel + `
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?) AND (? = '' OR model_name LIKE ?)
` `

View File

@@ -7,7 +7,6 @@ import (
"model-gateway/model/entity" "model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb" "gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
// ClaimByID 按主键抢占,返回抢占后的任务 // ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
// 1) 先查任务
var task entity.ModelGatewayTask var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id). Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One() One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if r.IsEmpty() {
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err = r.Struct(&task); err != nil {
return nil, err
}
// 2) 改为执行中
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
OmitEmpty().
Update()
if err != nil {
return nil, err
}
return &task, nil return &task, nil
} }

View File

@@ -25,6 +25,8 @@ type CreateModelReq struct {
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"` Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"` OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
@@ -61,6 +63,8 @@ type UpdateModelReq struct {
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"` Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"` OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`

View File

@@ -12,7 +12,6 @@ type modelGatewayModelCol struct {
FormJSON string FormJSON string
RequestMapping string RequestMapping string
ResponseMapping string ResponseMapping string
ResponseBody string
RequiredFields string RequiredFields string
IsPrivate string IsPrivate string
IsChatModel string IsChatModel string
@@ -31,6 +30,7 @@ type modelGatewayModelCol struct {
StreamConfig string StreamConfig string
FirstFrame string FirstFrame string
LastFrame string LastFrame string
MaxTokens string
} }
var ModelGatewayModelCol = modelGatewayModelCol{ var ModelGatewayModelCol = modelGatewayModelCol{
@@ -61,6 +61,7 @@ var ModelGatewayModelCol = modelGatewayModelCol{
StreamConfig: "stream_config", StreamConfig: "stream_config",
FirstFrame: "first_frame", FirstFrame: "first_frame",
LastFrame: "last_frame", LastFrame: "last_frame",
MaxTokens: "max_tokens",
} }
type ModelGatewayModel struct { type ModelGatewayModel struct {
@@ -91,9 +92,10 @@ type ModelGatewayModel struct {
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"` FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"` LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
} }
const ( //ResponseMapping 下的字段 const (
ResponseBody = "response_body" //返回主体 ResponseBody = "content" //返回主体(必填)
TotalTokens = "total_tokens" //总token数 TotalTokens = "total_tokens" //总token数
) )

View File

@@ -7,7 +7,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"mime/multipart" "mime/multipart"
"model-gateway/common/util"
"model-gateway/model/entity" "model-gateway/model/entity"
"time" "time"
@@ -43,16 +42,25 @@ func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *Upload
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err := part.Write(data); err != nil { if _, err = part.Write(data); err != nil {
return nil, err return nil, err
} }
contentType := writer.FormDataContentType() //contentType := writer.FormDataContentType()
if err = writer.Close(); err != nil { if err = writer.Close(); err != nil {
return nil, err return nil, err
} }
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers["Content-Type"] = contentType //headers["Content-Type"] = contentType
headers := make(map[string]string)
headers["Content-Type"] = writer.FormDataContentType()
if r := g.RequestFromCtx(ctx); r != nil {
if auth := r.Header.Get("Authorization"); auth != "" {
headers["Authorization"] = auth
}
}
fullURL := "oss/file/uploadFile" fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
@@ -78,15 +86,25 @@ type CallbackPayload struct {
// TriggerCallback 任务的回调 // TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) { func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp struct{} var resp struct{}
payload := CallbackPayload{ payload := CallbackPayload{
TaskId: t.TaskID, TaskId: t.TaskID,
State: t.State, State: t.State,
OssFile: t.ResultFile.OssFile,
FileType: t.ResultFile.FileType,
ErrorMsg: t.ErrorMsg, ErrorMsg: t.ErrorMsg,
} }
if !g.IsEmpty(t.ResultFile) {
payload.OssFile = t.ResultFile.OssFile
payload.FileType = t.ResultFile.FileType
}
jsonData, err := json.Marshal(payload) jsonData, err := json.Marshal(payload)
if err != nil { if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err) g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
@@ -112,7 +130,15 @@ type PromptsCallbackPayload struct {
// TriggerPromptsCallback 任务成功后的提示词回调 // TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) { func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback" callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var resp struct{} var resp struct{}
payload := PromptsCallbackPayload{ payload := PromptsCallbackPayload{
EpicycleId: epicycleId, EpicycleId: epicycleId,
@@ -136,7 +162,15 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epi
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员 // IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func IsSuperAdmin(ctx context.Context) (res bool, err error) { func IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := util.ForwardHeaders(ctx) //headers := util.ForwardHeaders(ctx)
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
var r = make(map[string]bool) var r = make(map[string]bool)
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil { if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err return false, err

View File

@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
ModelName: req.ModelName, ModelName: req.ModelName,
IsChatModel: req.IsChatModel, IsChatModel: req.IsChatModel,
}) })
if err != nil { if err != nil || model == nil {
return nil, err return nil, err
} }
return &dto.GetModelRes{ return &dto.GetModelRes{

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"model-gateway/common/util" "model-gateway/common/util"
"model-gateway/consts/public" "model-gateway/consts/public"
"model-gateway/service/queue"
"time" "time"
"model-gateway/dao" "model-gateway/dao"
@@ -28,12 +27,15 @@ type taskService struct{}
// Create 创建任务 // Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
taskID := uuid.NewString() taskID := uuid.NewString()
startAt := time.Now()
// 1) 检查模型配置,并且获取模型 // 1) 获取用户信息
userInfo, err := utils.GetUserInfo(ctx) userInfo, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2) 检查模型配置
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId, TenantId: userInfo.TenantId,
@@ -48,72 +50,63 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return nil, errors.New("模型不存在或未启用") return nil, errors.New("模型不存在或未启用")
} }
// 2) 排队上限严格控制Redis 原子闸门) // TODO: 排队控制暂时关闭,后续需要时取消注释
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2) // limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
if limit > 0 { // if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds) // ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
if err != nil { // if err != nil {
return nil, err // return nil, err
} // }
if !ok { // if !ok {
return nil, errors.New("任务排队已满,请稍后再试") // return nil, errors.New("任务排队已满,请稍后再试")
} // }
} // }
// 3) 插入任务记录 // 3) 构建任务实体
requestPayload := entity.RequestPayload{ task := &entity.ModelGatewayTask{
Body: req.RequestPayload, ModelName: model.ModelName,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
}
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
ModelName: req.ModelName,
TaskID: taskID, TaskID: taskID,
State: public.TaskStatusPending, State: public.TaskStatusRunning,
BizName: req.BizName, BizName: req.BizName,
CallbackURL: req.CallbackUrl, CallbackURL: req.CallbackUrl,
RequestPayload: &requestPayload, RequestPayload: &entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
},
EpicycleId: req.EpicycleId, EpicycleId: req.EpicycleId,
})
if err != nil { // 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
} }
// 4) 写操作日志(不影响主流程,失败忽略) // 4) 插入任务记录
ip := "" id, err := dao.ModelGatewayTask.Insert(ctx, task)
ua := "" if err != nil {
apiPath := "/task/createTask" // TODO: 恢复排队逻辑后,此处需要回滚排队占位
httpMethod := "POST" // queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
task.Id = id
// 5) 记录操作日志(非关键路径,失败不影响主流程)
ip, ua := "", ""
if r := g.RequestFromCtx(ctx); r != nil { if r := g.RequestFromCtx(ctx); r != nil {
ip = utils.GetLocalIP() ip = utils.GetLocalIP()
ua = r.UserAgent() ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
} }
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{ _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip, IP: ip,
UserAgent: ua, UserAgent: ua,
APIPath: apiPath, APIPath: "/task/createTask",
HttpMethod: httpMethod, HttpMethod: "POST",
BizName: req.BizName, BizName: req.BizName,
ModelName: req.ModelName, ModelName: req.ModelName,
TaskID: taskID, TaskID: taskID,
OpType: "createTask", OpType: "createTask",
Success: 1, Success: 1,
CostMs: time.Since(time.Now()).Milliseconds(), CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: &requestPayload, RequestPayload: task.RequestPayload,
ResponsePayload: gdb.Map{ ResponsePayload: gdb.Map{"taskId": taskID},
"taskId": taskID,
},
}) })
// 5) 获取任务信息 // 6) 异步执行任务
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -19,7 +18,6 @@ import (
"model-gateway/model/dto" "model-gateway/model/dto"
"model-gateway/model/entity" "model-gateway/model/entity"
"model-gateway/service/gateway" "model-gateway/service/gateway"
"model-gateway/service/queue"
"gitea.redpowerfuture.com/red-future/common/beans" "gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
@@ -38,43 +36,28 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
body = task.RequestPayload.Body body = task.RequestPayload.Body
maxRetry = model.RetryTimes maxRetry = model.RetryTimes
startTime = time.Now() startTime = time.Now()
rawBytes []byte
result map[string]any result map[string]any
err error err error
) )
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================ // ============================================
// 1) 分布式并发控制 // 1) 调用模型
// ============================================ // ============================================
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName) for attempt := 0; ; attempt++ {
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency) if attempt > 0 {
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
if err != nil { time.Sleep(time.Duration(attempt) * time.Second)
task.DurationSeconds = int64(time.Since(startTime).Seconds())
w.failTask(ctx, task, startTime, err.Error())
return
} }
if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// ============================================
// 2) 调用模型
// ============================================
switch { switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream: case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body) rawBytes, err = InvokeModel(ctx, model, body)
if streamErr != nil { if err == nil {
w.failTask(ctx, task, startTime, streamErr.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
}
case model.CallMode != nil && *model.CallMode == public.CallModeAsync: case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
if err == nil { if err == nil {
@@ -83,22 +66,22 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
default: default:
result, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
} }
if err != nil {
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") &&
!strings.Contains(err.Error(), "InternalServiceError") {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
// ============================================ g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
// 3) 缓存临时文件
// ============================================
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
} }
// ============================================ // ============================================
// 4) 解析校验 + 响应映射(可重试) // 2) 解析校验 + 响应映射(可重试)
// ============================================ // ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime) result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil { if err != nil {
@@ -108,30 +91,26 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
} }
// ============================================ // ============================================
// 5) 上传 OSS可重试 // 3) 上传 OSS可重试
// ============================================ // ============================================
var oss *gateway.UploadFileResponse var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json") oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil { if err == nil {
break break
} }
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
task.State = public.TaskStatusFailed
task.ErrorMsg = err.Error()
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return return
} }
} }
// ============================================ // ============================================
// 6) 成功收尾 // 4) 成功收尾
// ============================================ // ============================================
task.State = public.TaskStatusSuccess task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds()) task.DurationSeconds = int64(time.Since(startTime).Seconds())
@@ -141,52 +120,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
FileSize: int64(oss.FileSize), FileSize: int64(oss.FileSize),
} }
task.TextResult = result task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil { if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
return return
} }
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
if req.EpicycleId != 0 { if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
} }
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s", g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat) task.TaskID, task.DurationSeconds, oss.FileFormat)
_ = os.Remove(task.TmpFile)
}
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
var data []byte
var err error
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
data, err = InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
}
}
return data, nil
} }
// asyncResult 异步任务结果 // asyncResult 异步任务结果
@@ -241,67 +187,37 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
} }
} }
// callModel 调用模型 + 检测文件类型 + 保存临时文件 // callModel 调用模型 + 提取文本结果
// 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) { func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
var data []byte data, err := InvokeModel(ctx, model, body)
var err error
// 1) 如果已有临时文件且 phase=1直接读取
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
data = nil
}
}
// 2) 没有可用数据,调用模型
if data == nil {
data, err = InvokeModel(ctx, model, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 3) 检测文件类型,保存临时文件
_, ext := util.DetectFileType(data)
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
}
}
// 4) 检测文件类型,提取文本结果
contentType, _ := util.DetectFileType(data) contentType, _ := util.DetectFileType(data)
var textResult string var textResult string
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data) textResult = string(data)
} }
// 5) 非文本内容,返回错误
if textResult == "" { if textResult == "" {
return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType) return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType)
} }
// 6) 解析并返回
return gjson.New(textResult).Map(), nil return gjson.New(textResult).Map(), nil
} }
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
// 1) 响应映射 // 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body) mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil { if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err) return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +225,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue continue
} }
// 2) 存 token 到数据库,防止后续失败丢失 // 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok { if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens]) task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens, ExpendTokens: task.ExpendTokens,
}) })
@@ -326,9 +242,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil { if err == nil {
return parsed, nil return parsed, nil
} }
lastErr = err
case public.BuildTypeStruct: case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, entity.ResponseBody) return util.ParseStructResult(mapped, entity.ResponseBody), nil
return parsed, nil
default: default:
return mapped, nil return mapped, nil
} }
@@ -336,20 +252,20 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err) return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
} }
// 4) 重新调模型(直接调,不走缓存) // 4) 拼接错误信息到请求体,重调模型
task.RetryCount++ task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task) _, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil { if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue continue
} }
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil { if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
@@ -361,11 +277,53 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil return body, nil
} }
// injectErrorMessage 将错误信息插入到最后一个 user 消息之前
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("【上一轮输出错误,请修正】%s", err.Error())
// 找到最后一个 user 的位置
lastUserIdx := -1
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx == -1 {
return payload
}
// 在最后一个 user 之前插入错误消息
errMsgObj := map[string]any{
"role": "user",
"content": []map[string]any{{"type": "text", "text": errMsg}},
}
// 切片插入
messages = append(messages[:lastUserIdx], append([]any{errMsgObj}, messages[lastUserIdx:]...)...)
payload["messages"] = messages
return payload
}
// InvokeModel 调用模型服务,返回二进制结果 // InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数 // 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName) //_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 // 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
@@ -392,13 +350,20 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
baseURL = baseURL + "?" + q.Encode() baseURL = baseURL + "?" + q.Encode()
} }
} }
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) // 改用独立超时ctx隔绝外层截止
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil)
//req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default: default:
bodyBytes, err := json.Marshal(body) bodyBytes, err := json.Marshal(body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
defer reqCancel()
req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
//req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
} }
// 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者 // 5注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
@@ -487,15 +452,11 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
// return mappedResponse, nil // return mappedResponse, nil
// } // }
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 // failTask 任务失败统一处理
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) { func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3 t.State = 3
t.ErrorMsg = errMsg t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds()) t.DurationSeconds = int64(time.Since(startTime).Seconds())
_, err := dao.ModelGatewayTask.Update(ctx, t) _, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
if err != nil { go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
} }

View File

@@ -33,6 +33,7 @@ CREATE TABLE IF NOT EXISTS model_gateway_models (
max_concurrency int4 NOT NULL DEFAULT 10, max_concurrency int4 NOT NULL DEFAULT 10,
timeout_seconds int4 NOT NULL DEFAULT 600, timeout_seconds int4 NOT NULL DEFAULT 600,
retry_times int2 NOT NULL DEFAULT 3, retry_times int2 NOT NULL DEFAULT 3,
auto_clean_seconds int4 NOT NULL DEFAULT 86400,
response_token_field varchar(128) NOT NULL DEFAULT '', response_token_field varchar(128) NOT NULL DEFAULT '',
call_mode int2 NOT NULL DEFAULT 0, call_mode int2 NOT NULL DEFAULT 0,
required_fields jsonb NOT NULL DEFAULT '[]', required_fields jsonb NOT NULL DEFAULT '[]',
@@ -54,7 +55,6 @@ COMMENT ON COLUMN model_gateway_models.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_models.updater IS '更新人'; COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间'; COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)'; COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称'; COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型'; COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称'; COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
@@ -62,11 +62,12 @@ COMMENT ON COLUMN model_gateway_models.base_url IS '模型地址';
COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST'; COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息'; COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥'; COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化0-私有 1-公共'; COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用0-停用 1-启用'; COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型0-否 1-是'; COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员'; COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构'; COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射'; COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射'; COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
@@ -80,7 +81,9 @@ COMMENT ON COLUMN model_gateway_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数'; COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)'; COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数'; COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN model_gateway_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射'; COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表'; COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数0 表示不传'; COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数0 表示不传';