Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 445ee02c5a |
@@ -1,6 +1,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -8,75 +9,107 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名
|
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return "application/octet-stream", ".bin"
|
return "application/octet-stream", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
ct := http.DetectContentType(data)
|
ct := http.DetectContentType(data)
|
||||||
|
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||||
ct = strings.TrimSpace(ct[:idx])
|
ct = strings.TrimSpace(ct[:idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
switch ct {
|
switch ct {
|
||||||
case "audio/mpeg":
|
case "audio/mpeg":
|
||||||
return ct, ".mp3"
|
return ct, ".mp3"
|
||||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||||
return ct, ".wav"
|
return ct, ".wav"
|
||||||
case "audio/mp4", "audio/x-m4a":
|
|
||||||
return ct, ".m4a"
|
|
||||||
case "video/mp4":
|
case "video/mp4":
|
||||||
return ct, ".mp4"
|
return ct, ".mp4"
|
||||||
case "video/webm":
|
|
||||||
return ct, ".webm"
|
|
||||||
case "image/png":
|
case "image/png":
|
||||||
return ct, ".png"
|
return ct, ".png"
|
||||||
case "image/jpeg":
|
case "image/jpeg":
|
||||||
return ct, ".jpg"
|
return ct, ".jpg"
|
||||||
case "image/gif":
|
|
||||||
return ct, ".gif"
|
|
||||||
case "image/webp":
|
|
||||||
return ct, ".webp"
|
|
||||||
case "application/pdf":
|
case "application/pdf":
|
||||||
return ct, ".pdf"
|
return ct, ".pdf"
|
||||||
case "text/plain":
|
case "text/plain":
|
||||||
return ct, ".txt"
|
return ct, ".txt"
|
||||||
case "application/json":
|
case "application/json":
|
||||||
return ct, ".json"
|
return ct, ".json"
|
||||||
case "application/zip":
|
|
||||||
return ct, ".zip"
|
|
||||||
case "application/octet-stream":
|
|
||||||
return ct, ".bin"
|
|
||||||
default:
|
default:
|
||||||
|
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||||
sub := parts[1]
|
sub := parts[1]
|
||||||
|
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||||
sub = strings.TrimSpace(sub[:idx])
|
sub = strings.TrimSpace(sub[:idx])
|
||||||
}
|
}
|
||||||
return ct, "." + sub
|
return ct, "." + sub
|
||||||
}
|
}
|
||||||
return ct, ".bin"
|
return ct, ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTmpResult 将二进制数据写入临时文件
|
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||||
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
return "", fmt.Errorf("创建临时目录失败: %w", err)
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ext == "" {
|
if ext == "" {
|
||||||
ext = ".bin"
|
ext = ".bin"
|
||||||
}
|
}
|
||||||
if ext[0] != '.' {
|
if ext[0] != '.' {
|
||||||
ext = "." + ext
|
ext = "." + ext
|
||||||
}
|
}
|
||||||
|
|
||||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||||
return "", fmt.Errorf("写入临时文件失败: %w", err)
|
return "", err
|
||||||
}
|
}
|
||||||
return path, nil
|
return path, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveTempFileByType
|
||||||
|
// 根据传入的数据自动判断:
|
||||||
|
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
|
||||||
|
// 若是任意结构体/map → 自动转 JSON 保存
|
||||||
|
// 返回:新临时文件路径、错误
|
||||||
|
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
|
||||||
|
// 1. 先清理旧临时文件(统一逻辑)
|
||||||
|
if oldTmpFile != "" {
|
||||||
|
_ = os.Remove(oldTmpFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tmpPath string
|
||||||
|
var tmpErr error
|
||||||
|
|
||||||
|
// 2. 判断是否是二进制音频([]byte + .mp3)
|
||||||
|
if audioData, ok := data.([]byte); ok {
|
||||||
|
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
|
||||||
|
} else {
|
||||||
|
// 3. 其他类型 → 序列化为 JSON 保存
|
||||||
|
mappedBytes, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(mappedBytes) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmpErr != nil || tmpPath == "" {
|
||||||
|
return "", tmpErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return tmpPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTmpResult 你原有的底层保存文件方法(保留不动)
|
||||||
|
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||||
|
// 你原来实现,比如:
|
||||||
|
filename := taskID + ext
|
||||||
|
tmpPath := filepath.Join(os.TempDir(), filename)
|
||||||
|
err := os.WriteFile(tmpPath, data, 0644)
|
||||||
|
return tmpPath, err
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,20 +19,17 @@ import (
|
|||||||
tgjson "github.com/tidwall/gjson"
|
tgjson "github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseAndValidate 解析模型响应,并返回标准格式
|
// ParseAndValidate 解析并校验结果
|
||||||
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
|
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
|
||||||
contentStr := gconv.String(raw[entity.ResponseBody])
|
// 1) 解析 content 字符串为 rounds 数组
|
||||||
if strings.TrimSpace(contentStr) == "" {
|
contentVal, ok := raw[entity.ResponseBody]
|
||||||
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
|
if !ok {
|
||||||
|
return raw, fmt.Errorf("字段 %s 不存在", entity.ResponseBody)
|
||||||
|
}
|
||||||
|
contentStr, ok := contentVal.(string)
|
||||||
|
if !ok || strings.TrimSpace(contentStr) == "" {
|
||||||
|
return raw, fmt.Errorf("字段 %s 为空或不是字符串", entity.ResponseBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
contentStr = strings.Map(func(r rune) rune {
|
|
||||||
if r < 32 && r != ' ' {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}, contentStr)
|
|
||||||
|
|
||||||
var arr []any
|
var arr []any
|
||||||
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
||||||
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
||||||
@@ -41,11 +38,17 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
|
|||||||
return raw, fmt.Errorf("解析后数组为空")
|
return raw, fmt.Errorf("解析后数组为空")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range model.RequiredFields {
|
// 2) 校验必填字段
|
||||||
|
if len(model.RequiredFields) > 0 {
|
||||||
for i, r := range arr {
|
for i, r := range arr {
|
||||||
round, _ := r.(map[string]any)
|
round, ok := r.(map[string]any)
|
||||||
if round != nil && gjson.New(round).Get(field).IsNil() {
|
if !ok {
|
||||||
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
continue
|
||||||
|
}
|
||||||
|
for _, field := range model.RequiredFields {
|
||||||
|
if gjson.New(round).Get(field).IsNil() {
|
||||||
|
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
10
config.yml
10
config.yml
@@ -28,10 +28,10 @@ database:
|
|||||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||||
model_gateway:
|
model_gateway:
|
||||||
- type: "pgsql"
|
- type: "pgsql"
|
||||||
host: "192.168.3.30"
|
host: "116.204.74.41"
|
||||||
port: "5432"
|
port: "15432"
|
||||||
user: "postgres"
|
user: "postgres"
|
||||||
pass: "123456"
|
pass: "Bjang09@686^*^"
|
||||||
name: "model-gateway"
|
name: "model-gateway"
|
||||||
prefix: ""
|
prefix: ""
|
||||||
role: "master"
|
role: "master"
|
||||||
@@ -39,8 +39,8 @@ database:
|
|||||||
dryRun: false
|
dryRun: false
|
||||||
charset: "utf8"
|
charset: "utf8"
|
||||||
timezone: "Asia/Shanghai"
|
timezone: "Asia/Shanghai"
|
||||||
maxIdle: 15
|
maxIdle: 5
|
||||||
maxOpen: 60
|
maxOpen: 20
|
||||||
maxLifetime: "30s"
|
maxLifetime: "30s"
|
||||||
maxIdleConnTime: "30s"
|
maxIdleConnTime: "30s"
|
||||||
createdAt: "created_at"
|
createdAt: "created_at"
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
|
|||||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||||
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
|
|
||||||
Fields(fields).One()
|
Fields(fields).One()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -123,7 +122,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
|
|||||||
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
||||||
sql := `
|
sql := `
|
||||||
SELECT DISTINCT ON (model_name) *
|
SELECT DISTINCT ON (model_name) *
|
||||||
FROM ` + public.TableNameModel + `
|
FROM asynch_models
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
AND (? = '' OR model_name LIKE ?)
|
AND (? = '' OR model_name LIKE ?)
|
||||||
`
|
`
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||||
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -127,32 +128,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
|
|||||||
|
|
||||||
// ClaimByID 按主键抢占,返回抢占后的任务
|
// ClaimByID 按主键抢占,返回抢占后的任务
|
||||||
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
||||||
// 1) 先查任务
|
|
||||||
var task entity.ModelGatewayTask
|
var task entity.ModelGatewayTask
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
r, err := tx.Model(public.TableNameTask).
|
||||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||||
One()
|
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
||||||
|
Limit(1).
|
||||||
|
LockUpdate().
|
||||||
|
One()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r.IsEmpty() {
|
||||||
|
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
||||||
|
}
|
||||||
|
if err := r.Struct(&task); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Model(public.TableNameTask).
|
||||||
|
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
||||||
|
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||||
|
OmitEmpty().
|
||||||
|
Update()
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if r.IsEmpty() {
|
|
||||||
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
|
||||||
}
|
|
||||||
if err = r.Struct(&task); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2) 改为执行中
|
|
||||||
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
|
||||||
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
|
||||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
|
||||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
|
|
||||||
OmitEmpty().
|
|
||||||
Update()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &task, nil
|
return &task, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,36 +9,34 @@ import (
|
|||||||
|
|
||||||
// CreateModelReq 添加模型配置
|
// CreateModelReq 添加模型配置
|
||||||
type CreateModelReq struct {
|
type CreateModelReq struct {
|
||||||
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
||||||
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
|
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
|
||||||
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
|
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
|
||||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
|
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
|
||||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||||
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
||||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
||||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
||||||
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
||||||
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
||||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
||||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
||||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||||
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
||||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
||||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
||||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
||||||
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
||||||
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||||
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
|
||||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateModelRes struct {
|
type CreateModelRes struct {
|
||||||
@@ -46,37 +44,35 @@ type CreateModelRes struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UpdateModelReq struct {
|
type UpdateModelReq struct {
|
||||||
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
||||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
|
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
|
||||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
|
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
|
||||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
|
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
|
||||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST"`
|
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST"`
|
||||||
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
||||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
||||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
||||||
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
||||||
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
||||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
||||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
||||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||||
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
||||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
||||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
||||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
||||||
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
||||||
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
|
||||||
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
|
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
|
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
|
||||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
|
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
|
|
||||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateModelRes struct {
|
type UpdateModelRes struct {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type modelGatewayModelCol struct {
|
|||||||
FormJSON string
|
FormJSON string
|
||||||
RequestMapping string
|
RequestMapping string
|
||||||
ResponseMapping string
|
ResponseMapping string
|
||||||
|
ResponseBody string
|
||||||
RequiredFields string
|
RequiredFields string
|
||||||
IsPrivate string
|
IsPrivate string
|
||||||
IsChatModel string
|
IsChatModel string
|
||||||
@@ -30,7 +31,6 @@ type modelGatewayModelCol struct {
|
|||||||
StreamConfig string
|
StreamConfig string
|
||||||
FirstFrame string
|
FirstFrame string
|
||||||
LastFrame string
|
LastFrame string
|
||||||
MaxTokens string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModelGatewayModelCol = modelGatewayModelCol{
|
var ModelGatewayModelCol = modelGatewayModelCol{
|
||||||
@@ -61,7 +61,6 @@ var ModelGatewayModelCol = modelGatewayModelCol{
|
|||||||
StreamConfig: "stream_config",
|
StreamConfig: "stream_config",
|
||||||
FirstFrame: "first_frame",
|
FirstFrame: "first_frame",
|
||||||
LastFrame: "last_frame",
|
LastFrame: "last_frame",
|
||||||
MaxTokens: "max_tokens",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelGatewayModel struct {
|
type ModelGatewayModel struct {
|
||||||
@@ -92,10 +91,9 @@ type ModelGatewayModel struct {
|
|||||||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||||
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||||
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||||
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const ( //ResponseMapping 下的字段
|
||||||
ResponseBody = "content" //返回主体(必填)
|
ResponseBody = "response_body" //返回主体
|
||||||
TotalTokens = "total_tokens" //总token数
|
TotalTokens = "total_tokens" //总token数
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
"model-gateway/common/util"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -42,25 +43,16 @@ func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *Upload
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err = part.Write(data); err != nil {
|
if _, err := part.Write(data); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
//contentType := writer.FormDataContentType()
|
contentType := writer.FormDataContentType()
|
||||||
if err = writer.Close(); err != nil {
|
if err = writer.Close(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
//headers["Content-Type"] = contentType
|
headers["Content-Type"] = contentType
|
||||||
|
|
||||||
headers := make(map[string]string)
|
|
||||||
headers["Content-Type"] = writer.FormDataContentType()
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
|
||||||
headers["Authorization"] = auth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fullURL := "oss/file/uploadFile"
|
fullURL := "oss/file/uploadFile"
|
||||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||||
|
|
||||||
@@ -86,25 +78,15 @@ type CallbackPayload struct {
|
|||||||
|
|
||||||
// TriggerCallback 任务的回调
|
// TriggerCallback 任务的回调
|
||||||
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
|
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
|
||||||
//headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
headers := make(map[string]string)
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
for k, v := range r.Request.Header {
|
|
||||||
if len(v) > 0 {
|
|
||||||
headers[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var resp struct{}
|
var resp struct{}
|
||||||
payload := CallbackPayload{
|
payload := CallbackPayload{
|
||||||
TaskId: t.TaskID,
|
TaskId: t.TaskID,
|
||||||
State: t.State,
|
State: t.State,
|
||||||
|
OssFile: t.ResultFile.OssFile,
|
||||||
|
FileType: t.ResultFile.FileType,
|
||||||
ErrorMsg: t.ErrorMsg,
|
ErrorMsg: t.ErrorMsg,
|
||||||
}
|
}
|
||||||
if !g.IsEmpty(t.ResultFile) {
|
|
||||||
payload.OssFile = t.ResultFile.OssFile
|
|
||||||
payload.FileType = t.ResultFile.FileType
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(payload)
|
jsonData, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||||
@@ -130,15 +112,7 @@ type PromptsCallbackPayload struct {
|
|||||||
// TriggerPromptsCallback 任务成功后的提示词回调
|
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||||
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
|
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
|
||||||
callbackURL := "prompts-core/session/callback"
|
callbackURL := "prompts-core/session/callback"
|
||||||
//headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
headers := make(map[string]string)
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
for k, v := range r.Request.Header {
|
|
||||||
if len(v) > 0 {
|
|
||||||
headers[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var resp struct{}
|
var resp struct{}
|
||||||
payload := PromptsCallbackPayload{
|
payload := PromptsCallbackPayload{
|
||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
@@ -162,15 +136,7 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epi
|
|||||||
|
|
||||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||||
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||||
//headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
headers := make(map[string]string)
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
for k, v := range r.Request.Header {
|
|
||||||
if len(v) > 0 {
|
|
||||||
headers[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var r = make(map[string]bool)
|
var r = make(map[string]bool)
|
||||||
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
|
|||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
IsChatModel: req.IsChatModel,
|
IsChatModel: req.IsChatModel,
|
||||||
})
|
})
|
||||||
if err != nil || model == nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &dto.GetModelRes{
|
return &dto.GetModelRes{
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"model-gateway/common/util"
|
"model-gateway/common/util"
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
|
"model-gateway/service/queue"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
@@ -27,15 +28,12 @@ type taskService struct{}
|
|||||||
// Create 创建任务
|
// Create 创建任务
|
||||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||||
taskID := uuid.NewString()
|
taskID := uuid.NewString()
|
||||||
startAt := time.Now()
|
|
||||||
|
|
||||||
// 1) 获取用户信息
|
// 1) 检查模型配置,并且获取模型
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
userInfo, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 检查模型配置
|
|
||||||
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
SQLBaseDO: beans.SQLBaseDO{
|
||||||
TenantId: userInfo.TenantId,
|
TenantId: userInfo.TenantId,
|
||||||
@@ -50,63 +48,72 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
return nil, errors.New("模型不存在或未启用")
|
return nil, errors.New("模型不存在或未启用")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 排队控制暂时关闭,后续需要时取消注释
|
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||||
// limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
||||||
// if limit > 0 {
|
if limit > 0 {
|
||||||
// ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// return nil, err
|
return nil, err
|
||||||
// }
|
}
|
||||||
// if !ok {
|
if !ok {
|
||||||
// return nil, errors.New("任务排队已满,请稍后再试")
|
return nil, errors.New("任务排队已满,请稍后再试")
|
||||||
// }
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
// 3) 构建任务实体
|
|
||||||
task := &entity.ModelGatewayTask{
|
|
||||||
ModelName: model.ModelName,
|
|
||||||
TaskID: taskID,
|
|
||||||
State: public.TaskStatusRunning,
|
|
||||||
BizName: req.BizName,
|
|
||||||
CallbackURL: req.CallbackUrl,
|
|
||||||
RequestPayload: &entity.RequestPayload{
|
|
||||||
Body: req.RequestPayload,
|
|
||||||
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
|
||||||
},
|
|
||||||
EpicycleId: req.EpicycleId,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 插入任务记录
|
// 3) 插入任务记录
|
||||||
id, err := dao.ModelGatewayTask.Insert(ctx, task)
|
requestPayload := entity.RequestPayload{
|
||||||
if err != nil {
|
Body: req.RequestPayload,
|
||||||
// TODO: 恢复排队逻辑后,此处需要回滚排队占位
|
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
||||||
// queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
}
|
||||||
|
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
TaskID: taskID,
|
||||||
|
State: public.TaskStatusPending,
|
||||||
|
BizName: req.BizName,
|
||||||
|
CallbackURL: req.CallbackUrl,
|
||||||
|
RequestPayload: &requestPayload,
|
||||||
|
EpicycleId: req.EpicycleId,
|
||||||
|
})
|
||||||
|
if err != nil { // 入库失败:回滚闸门占位
|
||||||
|
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
task.Id = id
|
|
||||||
|
|
||||||
// 5) 记录操作日志(非关键路径,失败不影响主流程)
|
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||||
ip, ua := "", ""
|
ip := ""
|
||||||
|
ua := ""
|
||||||
|
apiPath := "/task/createTask"
|
||||||
|
httpMethod := "POST"
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
ip = utils.GetLocalIP()
|
ip = utils.GetLocalIP()
|
||||||
ua = r.UserAgent()
|
ua = r.UserAgent()
|
||||||
|
apiPath = r.URL.Path
|
||||||
|
httpMethod = r.Method
|
||||||
}
|
}
|
||||||
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
|
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
UserAgent: ua,
|
UserAgent: ua,
|
||||||
APIPath: "/task/createTask",
|
APIPath: apiPath,
|
||||||
HttpMethod: "POST",
|
HttpMethod: httpMethod,
|
||||||
BizName: req.BizName,
|
BizName: req.BizName,
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
OpType: "createTask",
|
OpType: "createTask",
|
||||||
Success: 1,
|
Success: 1,
|
||||||
CostMs: time.Since(startAt).Milliseconds(),
|
CostMs: time.Since(time.Now()).Milliseconds(),
|
||||||
RequestPayload: task.RequestPayload,
|
RequestPayload: &requestPayload,
|
||||||
ResponsePayload: gdb.Map{"taskId": taskID},
|
ResponsePayload: gdb.Map{
|
||||||
|
"taskId": taskID,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
// 6) 异步执行任务
|
// 5) 获取任务信息
|
||||||
|
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5) 创建成功后立即异步尝试执行当前任务
|
||||||
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
||||||
|
|
||||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,6 +19,7 @@ import (
|
|||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
"model-gateway/service/gateway"
|
"model-gateway/service/gateway"
|
||||||
|
"model-gateway/service/queue"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
@@ -36,52 +38,67 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
body = task.RequestPayload.Body
|
body = task.RequestPayload.Body
|
||||||
maxRetry = model.RetryTimes
|
maxRetry = model.RetryTimes
|
||||||
startTime = time.Now()
|
startTime = time.Now()
|
||||||
rawBytes []byte
|
|
||||||
result map[string]any
|
result map[string]any
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||||
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
|
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 1) 调用模型
|
// 1) 分布式并发控制
|
||||||
// ============================================
|
// ============================================
|
||||||
for attempt := 0; ; attempt++ {
|
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
||||||
if attempt > 0 {
|
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
||||||
g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
|
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||||
time.Sleep(time.Duration(attempt) * time.Second)
|
if err != nil {
|
||||||
}
|
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
|
w.failTask(ctx, task, startTime, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !acquired {
|
||||||
|
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||||
|
State: public.TaskStatusPending,
|
||||||
|
})
|
||||||
|
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||||
|
|
||||||
switch {
|
// ============================================
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
// 2) 调用模型
|
||||||
rawBytes, err = InvokeModel(ctx, model, body)
|
// ============================================
|
||||||
if err == nil {
|
switch {
|
||||||
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||||||
}
|
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
if streamErr != nil {
|
||||||
result, err = w.callModel(ctx, task, model, body)
|
w.failTask(ctx, task, startTime, streamErr.Error())
|
||||||
if err == nil {
|
|
||||||
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
result, err = w.callModel(ctx, task, model, body)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(err.Error(), "Timeout") &&
|
|
||||||
!strings.Contains(err.Error(), "InternalServiceError") {
|
|
||||||
w.failTask(ctx, task, startTime, err.Error())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||||
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||||
|
result, err = w.callModel(ctx, task, model, body)
|
||||||
|
if err == nil {
|
||||||
|
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
result, err = w.callModel(ctx, task, model, body)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
w.failTask(ctx, task, startTime, err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 2) 解析校验 + 响应映射(可重试)
|
// 3) 缓存临时文件
|
||||||
|
// ============================================
|
||||||
|
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
|
||||||
|
task.TmpFile = tmpPath
|
||||||
|
task.Phase = 1
|
||||||
|
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 4) 解析校验 + 响应映射(可重试)
|
||||||
// ============================================
|
// ============================================
|
||||||
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
|
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -91,26 +108,30 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 3) 上传 OSS(可重试)
|
// 5) 上传 OSS(可重试)
|
||||||
// ============================================
|
// ============================================
|
||||||
var oss *gateway.UploadFileResponse
|
var oss *gateway.UploadFileResponse
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||||
}
|
}
|
||||||
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
|
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
|
task.State = public.TaskStatusFailed
|
||||||
|
task.ErrorMsg = err.Error()
|
||||||
|
task.Phase = 1
|
||||||
|
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 4) 成功收尾
|
// 6) 成功收尾
|
||||||
// ============================================
|
// ============================================
|
||||||
task.State = public.TaskStatusSuccess
|
task.State = public.TaskStatusSuccess
|
||||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
@@ -120,19 +141,52 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
FileSize: int64(oss.FileSize),
|
FileSize: int64(oss.FileSize),
|
||||||
}
|
}
|
||||||
task.TextResult = result
|
task.TextResult = result
|
||||||
|
|
||||||
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
||||||
g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
|
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
|
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||||
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
|
||||||
if req.EpicycleId != 0 {
|
if req.EpicycleId != 0 {
|
||||||
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
|
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
|
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
|
||||||
task.TaskID, task.DurationSeconds, oss.FileFormat)
|
task.TaskID, task.DurationSeconds, oss.FileFormat)
|
||||||
|
|
||||||
|
_ = os.Remove(task.TmpFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||||
|
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
||||||
|
var data []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||||
|
data, err = os.ReadFile(task.TmpFile)
|
||||||
|
if err != nil || len(data) == 0 {
|
||||||
|
data = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if data == nil {
|
||||||
|
data, err = InvokeModel(ctx, model, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
|
||||||
|
if tmpErr == nil && tmpPath != "" {
|
||||||
|
task.TmpFile = tmpPath
|
||||||
|
task.Phase = 1
|
||||||
|
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// asyncResult 异步任务结果
|
// asyncResult 异步任务结果
|
||||||
@@ -187,37 +241,67 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// callModel 调用模型 + 提取文本结果
|
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||||
|
// 返回: 解析后的响应体, error
|
||||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
|
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
|
||||||
data, err := InvokeModel(ctx, model, body)
|
var data []byte
|
||||||
if err != nil {
|
var err error
|
||||||
return nil, err
|
|
||||||
|
// 1) 如果已有临时文件且 phase=1,直接读取
|
||||||
|
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||||
|
data, err = os.ReadFile(task.TmpFile)
|
||||||
|
if err != nil || len(data) == 0 {
|
||||||
|
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
|
||||||
|
data = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 2) 没有可用数据,调用模型
|
||||||
|
if data == nil {
|
||||||
|
data, err = InvokeModel(ctx, model, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 检测文件类型,保存临时文件
|
||||||
|
_, ext := util.DetectFileType(data)
|
||||||
|
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
|
||||||
|
if tmpErr == nil && tmpPath != "" {
|
||||||
|
task.TmpFile = tmpPath
|
||||||
|
task.Phase = 1
|
||||||
|
_, err = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 检测文件类型,提取文本结果
|
||||||
contentType, _ := util.DetectFileType(data)
|
contentType, _ := util.DetectFileType(data)
|
||||||
var textResult string
|
var textResult string
|
||||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||||
textResult = string(data)
|
textResult = string(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5) 非文本内容,返回错误
|
||||||
if textResult == "" {
|
if textResult == "" {
|
||||||
return nil, fmt.Errorf("模型返回非文本内容,contentType=%s", contentType)
|
return nil, fmt.Errorf("模型返回非文本内容,contentType=%s", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 6) 解析并返回
|
||||||
return gjson.New(textResult).Map(), nil
|
return gjson.New(textResult).Map(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAndRetry 解析模型返回结果,并重试
|
// parseAndRetry 解析模型返回结果,并重试
|
||||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||||
var lastErr error
|
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1) 响应映射
|
// 1) 响应映射
|
||||||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
|
||||||
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
||||||
@@ -225,10 +309,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 存 token
|
// 2) 先存 token 到数据库,防止后续失败丢失
|
||||||
if _, ok := mapped[entity.TotalTokens]; ok {
|
if _, ok := mapped[entity.TotalTokens]; ok {
|
||||||
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
|
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||||
ExpendTokens: task.ExpendTokens,
|
ExpendTokens: task.ExpendTokens,
|
||||||
})
|
})
|
||||||
@@ -242,9 +326,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
lastErr = err
|
|
||||||
case public.BuildTypeStruct:
|
case public.BuildTypeStruct:
|
||||||
return util.ParseStructResult(mapped, entity.ResponseBody), nil
|
parsed = util.ParseStructResult(mapped, entity.ResponseBody)
|
||||||
|
return parsed, nil
|
||||||
default:
|
default:
|
||||||
return mapped, nil
|
return mapped, nil
|
||||||
}
|
}
|
||||||
@@ -252,20 +336,20 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
|
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
|
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 拼接错误信息到请求体,重调模型
|
// 4) 重新调模型(直接调,不走缓存)
|
||||||
task.RetryCount++
|
task.RetryCount++
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
|
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
|
||||||
|
|
||||||
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
|
|
||||||
rawData, callErr := InvokeModel(ctx, model, body)
|
|
||||||
if callErr != nil {
|
if callErr != nil {
|
||||||
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5) 解析原始响应,覆盖 body 进入下一轮
|
||||||
var rawResp map[string]any
|
var rawResp map[string]any
|
||||||
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
||||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||||
@@ -277,53 +361,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// injectErrorMessage 将错误信息插入到最后一个 user 消息之前
|
|
||||||
func injectErrorMessage(payload map[string]any, err error) map[string]any {
|
|
||||||
if err == nil {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
messages, _ := payload["messages"].([]any)
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
errMsg := fmt.Sprintf("【上一轮输出错误,请修正】%s", err.Error())
|
|
||||||
|
|
||||||
// 找到最后一个 user 的位置
|
|
||||||
lastUserIdx := -1
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
msg, ok := messages[i].(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if gconv.String(msg["role"]) == "user" {
|
|
||||||
lastUserIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastUserIdx == -1 {
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
// 在最后一个 user 之前插入错误消息
|
|
||||||
errMsgObj := map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": []map[string]any{{"type": "text", "text": errMsg}},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 切片插入
|
|
||||||
messages = append(messages[:lastUserIdx], append([]any{errMsgObj}, messages[lastUserIdx:]...)...)
|
|
||||||
payload["messages"] = messages
|
|
||||||
return payload
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvokeModel 调用模型服务,返回二进制结果
|
// InvokeModel 调用模型服务,返回二进制结果
|
||||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||||
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
||||||
// 1) 记录模型调用次数
|
// 1) 记录模型调用次数
|
||||||
//_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
|
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
|
||||||
|
|
||||||
// 2)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
// 2)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||||
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
|
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
|
||||||
@@ -350,20 +392,13 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
|
|||||||
baseURL = baseURL + "?" + q.Encode()
|
baseURL = baseURL + "?" + q.Encode()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 改用独立超时ctx,隔绝外层截止
|
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||||
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
defer reqCancel()
|
|
||||||
req, err = http.NewRequestWithContext(reqCtx, http.MethodGet, baseURL, nil)
|
|
||||||
//req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
|
||||||
default:
|
default:
|
||||||
bodyBytes, err := json.Marshal(body)
|
bodyBytes, err := json.Marshal(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
reqCtx, reqCancel := context.WithTimeout(context.Background(), timeout)
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||||||
defer reqCancel()
|
|
||||||
req, err = http.NewRequestWithContext(reqCtx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
|
||||||
//req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
// 5)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||||
@@ -452,11 +487,15 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
|
|||||||
// return mappedResponse, nil
|
// return mappedResponse, nil
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// failTask 任务失败统一处理
|
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = errMsg
|
t.ErrorMsg = errMsg
|
||||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
|
_, err := dao.ModelGatewayTask.Update(ctx, t)
|
||||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
|
||||||
|
}
|
||||||
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ CREATE TABLE IF NOT EXISTS model_gateway_models (
|
|||||||
max_concurrency int4 NOT NULL DEFAULT 10,
|
max_concurrency int4 NOT NULL DEFAULT 10,
|
||||||
timeout_seconds int4 NOT NULL DEFAULT 600,
|
timeout_seconds int4 NOT NULL DEFAULT 600,
|
||||||
retry_times int2 NOT NULL DEFAULT 3,
|
retry_times int2 NOT NULL DEFAULT 3,
|
||||||
auto_clean_seconds int4 NOT NULL DEFAULT 86400,
|
|
||||||
response_token_field varchar(128) NOT NULL DEFAULT '',
|
response_token_field varchar(128) NOT NULL DEFAULT '',
|
||||||
call_mode int2 NOT NULL DEFAULT 0,
|
call_mode int2 NOT NULL DEFAULT 0,
|
||||||
required_fields jsonb NOT NULL DEFAULT '[]',
|
required_fields jsonb NOT NULL DEFAULT '[]',
|
||||||
@@ -55,6 +54,7 @@ COMMENT ON COLUMN model_gateway_models.created_at IS '创建时间';
|
|||||||
COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
|
COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
|
||||||
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
|
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
|
||||||
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
|
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
|
||||||
|
|
||||||
COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
|
COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
|
||||||
COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
|
COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
|
||||||
COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
|
COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
|
||||||
@@ -62,12 +62,11 @@ COMMENT ON COLUMN model_gateway_models.base_url IS '模型地址';
|
|||||||
COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
|
COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
|
||||||
COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
|
COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
|
||||||
COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
|
COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
|
||||||
|
|
||||||
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化:0-私有 1-公共';
|
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化:0-私有 1-公共';
|
||||||
COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用:0-停用 1-启用';
|
COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用:0-停用 1-启用';
|
||||||
COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型:0-否 1-是';
|
COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型:0-否 1-是';
|
||||||
COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
|
COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
|
||||||
|
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式:0-同步 1-异步 2-流式';
|
||||||
COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
|
COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
|
||||||
COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
|
COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
|
||||||
COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
|
COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
|
||||||
@@ -81,9 +80,7 @@ COMMENT ON COLUMN model_gateway_models.last_frame IS '尾帧图片参数';
|
|||||||
COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
|
COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
|
||||||
COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
|
COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
|
||||||
COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
|
COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
|
||||||
COMMENT ON COLUMN model_gateway_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
|
|
||||||
COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
|
COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
|
||||||
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式:0-同步 1-异步 2-流式';
|
|
||||||
COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
|
COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
|
||||||
COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数,0 表示不传';
|
COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数,0 表示不传';
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user