1 Commits

9 changed files with 212 additions and 137 deletions

View File

@@ -19,17 +19,20 @@ import (
tgjson "github.com/tidwall/gjson"
)
// ParseAndValidate 解析并校验结果
// ParseAndValidate 解析模型响应,并返回标准格式
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[entity.ResponseBody]
if !ok {
return raw, fmt.Errorf("字段 %s 不存在", entity.ResponseBody)
}
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", entity.ResponseBody)
contentStr := gconv.String(raw[entity.ResponseBody])
if strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
}
contentStr = strings.Map(func(r rune) rune {
if r < 32 && r != ' ' {
return -1
}
return r
}, contentStr)
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err)
@@ -38,17 +41,11 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
return raw, fmt.Errorf("解析后数组为空")
}
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for _, field := range model.RequiredFields {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
round, _ := r.(map[string]any)
if round != nil && gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}

View File

@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
Fields(fields).One()
if err != nil {
return nil, err
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
FROM ` + public.TableNameModel + `
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
`

View File

@@ -7,7 +7,6 @@ import (
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
// ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
// 1) 先查任务
var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err = r.Struct(&task); err != nil {
return nil, err
}
// 2) 改为执行中
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
OmitEmpty().
Update()
if err != nil {
return nil, err
}
return &task, nil
}

View File

@@ -9,34 +9,36 @@ import (
// CreateModelReq 添加模型配置
type CreateModelReq struct {
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数默认10"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数默认3"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间默认86400"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type CreateModelRes struct {
@@ -44,35 +46,37 @@ type CreateModelRes struct {
}
type UpdateModelReq struct {
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type UpdateModelRes struct {

View File

@@ -12,7 +12,6 @@ type modelGatewayModelCol struct {
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
RequiredFields string
IsPrivate string
IsChatModel string
@@ -31,6 +30,7 @@ type modelGatewayModelCol struct {
StreamConfig string
FirstFrame string
LastFrame string
MaxTokens string
}
var ModelGatewayModelCol = modelGatewayModelCol{
@@ -61,6 +61,7 @@ var ModelGatewayModelCol = modelGatewayModelCol{
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
MaxTokens: "max_tokens",
}
type ModelGatewayModel struct {
@@ -91,9 +92,10 @@ type ModelGatewayModel struct {
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
}
const ( //ResponseMapping 下的字段
ResponseBody = "response_body" //返回主体
TotalTokens = "total_tokens" //总token数
const (
ResponseBody = "content" //返回主体(必填)
TotalTokens = "total_tokens" //总token数
)

View File

@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
ModelName: req.ModelName,
IsChatModel: req.IsChatModel,
})
if err != nil {
if err != nil || model == nil {
return nil, err
}
return &dto.GetModelRes{

View File

@@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
},
})
// 5) 获取任务信息
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
// 5) 抢占任务:改为执行中
rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
State: public.TaskStatusRunning,
})
if err != nil {
return nil, err
}
if rows == 0 {
return nil, fmt.Errorf("任务不存在: id=%d", id)
}
// 6) 查询任务信息
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
})
if err != nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
// 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
// ============================================
// 2) 调用模型
// ============================================
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error())
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
err = streamErr
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") {
w.failTask(ctx, task, startTime, err.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue
}
// 2) 存 token 到数据库,防止后续失败丢失
// 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil {
return parsed, nil
}
lastErr = err
case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, entity.ResponseBody)
return parsed, nil
return util.ParseStructResult(mapped, entity.ResponseBody), nil
default:
return mapped, nil
}
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
}
// 4) 重新调模型(直接调,不走缓存)
// 4) 拼接错误信息到请求体,重调模型
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil {
if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil
}
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {

View File

@@ -33,6 +33,7 @@ CREATE TABLE IF NOT EXISTS model_gateway_models (
max_concurrency int4 NOT NULL DEFAULT 10,
timeout_seconds int4 NOT NULL DEFAULT 600,
retry_times int2 NOT NULL DEFAULT 3,
auto_clean_seconds int4 NOT NULL DEFAULT 86400,
response_token_field varchar(128) NOT NULL DEFAULT '',
call_mode int2 NOT NULL DEFAULT 0,
required_fields jsonb NOT NULL DEFAULT '[]',
@@ -54,7 +55,6 @@ COMMENT ON COLUMN model_gateway_models.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
@@ -62,11 +62,12 @@ COMMENT ON COLUMN model_gateway_models.base_url IS '模型地址';
COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
@@ -80,7 +81,9 @@ COMMENT ON COLUMN model_gateway_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN model_gateway_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数0 表示不传';