diff --git a/common/util/json.go b/common/util/json.go deleted file mode 100644 index 2da838b..0000000 --- a/common/util/json.go +++ /dev/null @@ -1,28 +0,0 @@ -package util - -import ( - "encoding/json" - - "github.com/gogf/gf/v2/container/gvar" -) - -func ParseJSONField(field any) any { - var v *gvar.Var - switch val := field.(type) { - case *gvar.Var: - v = val - default: - return field - } - - if v == nil || v.IsNil() || v.IsEmpty() { - return nil - } - - str := v.String() - var result any - if json.Unmarshal([]byte(str), &result) == nil { - return result - } - return str -} diff --git a/model/dto/model_dto.go b/model/dto/model_dto.go index 89ba260..693ed5f 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_dto.go @@ -1,6 +1,8 @@ package dto import ( + "model-gateway/model/entity" + "gitea.com/red-future/common/beans" "github.com/gogf/gf/v2/frame/g" ) @@ -8,31 +10,33 @@ import ( // CreateModelReq 添加模型配置 type CreateModelReq struct { g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"` - ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"` - ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"` - BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"` - HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"` - HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"` - IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` - Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"` - IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` - IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` - OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` - TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` - ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"` - Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"` - RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` - ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` - ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"` - ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` - MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"` - QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"` - TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"` - ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"` - RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"` - RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"` - AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"` - Remark string `p:"remark" json:"remark" dc:"备注说明"` + ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"` + ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"` + BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"` + HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"` + HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"` + IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` + Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"` + IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` + IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` + ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` + QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` + ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"` + Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"` + RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` + ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` + ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` + MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"` + QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"` + TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"` + ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"` + RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"` + RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"` + AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"` + Remark string `p:"remark" json:"remark" dc:"备注说明"` } type CreateModelRes struct { @@ -41,32 +45,34 @@ type CreateModelRes struct { type UpdateModelReq struct { g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"` - ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"` - ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` - ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"` - BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"` - HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"` - HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"` - ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"` - Form any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"` - RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"` - ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"` - ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"` - ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` - Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` - IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` - IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` - IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` - OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` - TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` - MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"` - QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"` - TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"` - ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"` - RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"` - RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"` - AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"` - Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"` + ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"` + ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` + ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"` + BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"` + HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"` + HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"` + ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"` + Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"` + RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"` + ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"` + ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` + Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` + IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` + IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` + IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` + ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` + QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` + MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"` + QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"` + TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"` + ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"` + RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"` + RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"` + AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"` + Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"` } type UpdateModelRes struct { @@ -85,13 +91,14 @@ type DeleteModelRes struct { // GetModelReq 获取模型配置详情 type GetModelReq struct { - g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"` - ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` - Creator string `p:"creator" json:"creator" dc:"创建人"` + g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"` + ID int64 `p:"id" json:"id,string" dc:"配置ID"` + Creator string `p:"creator" json:"creator" dc:"创建人"` + ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` } type GetModelRes struct { - Model any `json:"model" dc:"模型配置详情"` + Model *entity.AsynchModel `json:"model" dc:"模型配置详情"` } // ListModelReq 配置列表 diff --git a/model/dto/task_dto.go b/model/dto/task_dto.go index 57f189f..45f7ba7 100644 --- a/model/dto/task_dto.go +++ b/model/dto/task_dto.go @@ -20,7 +20,7 @@ type CreateTaskRes struct { // GetTaskResultReq 获取结果(只返回 oss 地址) type GetTaskResultReq struct { g.Meta `path:"/getTaskResult" method:"get" tags:"任务管理" summary:"获取任务结果" dc:"根据任务ID获取结果(只返回OSS地址)"` - TaskID string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"` + TaskID string `p:"taskId" json:"taskId" v:"required#taskwId不能为空" dc:"任务ID"` } type GetTaskResultRes struct { diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index b1d655e..ab59c7d 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -30,6 +30,8 @@ type asynchModelCol struct { IsOwner string OperatorName string TokenConfig string + ExtendMapping string + QueryConfig string } var AsynchModelCol = asynchModelCol{ @@ -60,35 +62,39 @@ var AsynchModelCol = asynchModelCol{ IsOwner: "is_owner", OperatorName: "operator_name", TokenConfig: "token_config", + ExtendMapping: "extend_mapping", + QueryConfig: "query_config", } // AsynchModel 异步模型配置 type AsynchModel struct { beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - ModelType int `orm:"model_type" json:"modelType"` - BaseURL string `orm:"base_url" json:"baseUrl"` - HttpMethod string `orm:"http_method" json:"httpMethod"` - HeadMsg string `orm:"head_msg" json:"headMsg"` - Form any `orm:"form_json" json:"form"` - RequestMapping any `orm:"request_mapping" json:"requestMapping"` - ResponseMapping any `orm:"response_mapping" json:"responseMapping"` - ResponseBody any `orm:"response_body" json:"responseBody"` - ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` - Prompt string `orm:"prompt" json:"prompt"` - IsPrivate *int `orm:"is_private" json:"isPrivate"` - IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` - ApiKey string `orm:"api_key" json:"apiKey"` - Enabled *int `orm:"enabled" json:"enabled"` - MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` - QueueLimit int `orm:"queue_limit" json:"queueLimit"` - TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` - ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` - RetryTimes int `orm:"retry_times" json:"retryTimes"` - RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` - AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` - Remark string `orm:"remark" json:"remark"` - IsOwner *int `json:"isOwner" orm:"is_owner"` - OperatorName string `orm:"operator_name" json:"operatorName"` - TokenConfig any `orm:"token_config" json:"tokenConfig"` + ModelName string `orm:"model_name" json:"modelName"` + ModelType int `orm:"model_type" json:"modelType"` + BaseURL string `orm:"base_url" json:"baseUrl"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + HeadMsg string `orm:"head_msg" json:"headMsg"` + Form map[string]any `orm:"form_json" json:"form"` + RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` + ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` + ResponseBody map[string]any `orm:"response_body" json:"responseBody"` + ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` + Prompt string `orm:"prompt" json:"prompt"` + IsPrivate *int `orm:"is_private" json:"isPrivate"` + IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` + ApiKey string `orm:"api_key" json:"apiKey"` + Enabled *int `orm:"enabled" json:"enabled"` + MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` + QueueLimit int `orm:"queue_limit" json:"queueLimit"` + TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` + ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` + RetryTimes int `orm:"retry_times" json:"retryTimes"` + RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` + AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` + Remark string `orm:"remark" json:"remark"` + IsOwner *int `json:"isOwner" orm:"is_owner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` + ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` + QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` } diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index d5c71ee..0c9e24f 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -63,18 +63,27 @@ func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileEx return resp.FileURL, nil } -// TriggerCallback 任务成功后的回调: -// - JSON body 参数:task_id/state/oss_file/file_type/text(可选) +// CallbackPayload 回调请求体 +type CallbackPayload struct { + TaskId string `json:"task_id"` + State int `json:"state"` + OssFile string `json:"oss_file"` + FileType string `json:"file_type"` + Text string `json:"text"` + ErrorMsg string `json:"error_msg"` +} + +// TriggerCallback 任务成功后的回调 func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { headers := util.ForwardHeaders(ctx) - var req struct{} - payload := map[string]interface{}{ - "task_id": t.TaskID, - "state": t.State, - "oss_file": t.OssFile, - "file_type": t.FileType, - "text": t.TextResult, - "error_msg": t.ErrorMsg, + var resp struct{} + payload := CallbackPayload{ + TaskId: t.TaskID, + State: t.State, + OssFile: t.OssFile, + FileType: t.FileType, + Text: t.TextResult, + ErrorMsg: t.ErrorMsg, } jsonData, err := json.Marshal(payload) if err != nil { @@ -84,7 +93,7 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(headers), len(jsonData)) - err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData) + err = commonHttp.Post(ctx, t.CallbackURL, headers, &resp, jsonData) if err != nil { g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err) return @@ -92,15 +101,20 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData)) } +// PromptsCallbackPayload 提示词回调请求体 +type PromptsCallbackPayload struct { + EpicycleId int64 `json:"epicycleId"` + Text string `json:"text"` +} + // TriggerPromptsCallback 任务成功后的提示词回调 -// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息) func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { callbackURL := "prompts-core/session/sessionCallback" headers := util.ForwardHeaders(ctx) - var req struct{} - payload := map[string]interface{}{ - "epicycleId": epicycleId, - "text": t.TextResult, + var resp struct{} + payload := PromptsCallbackPayload{ + EpicycleId: epicycleId, + Text: t.TextResult, } jsonData, err := json.Marshal(payload) if err != nil { @@ -110,7 +124,7 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleI g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节", t.EpicycleId, callbackURL, len(headers), len(jsonData)) - err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData) + err = commonHttp.Post(ctx, callbackURL, headers, &resp, jsonData) if err != nil { g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err) return diff --git a/service/model_invoker.go b/service/model_invoker.go index 6b1d7d9..8d3ef64 100644 --- a/service/model_invoker.go +++ b/service/model_invoker.go @@ -6,13 +6,12 @@ import ( "encoding/json" "fmt" "io" + "model-gateway/model/entity" "net/http" "net/url" "strings" "time" - "model-gateway/model/entity" - "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" "github.com/tidwall/gjson" @@ -196,6 +195,59 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK return mappedResponse, nil } +//// InvokeModel 调用模型服务,返回二进制结果 +//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { +// if m == nil || m.BaseURL == "" { +// return nil, fmt.Errorf("模型配置不完整") +// } +// // 请求参数映射 +// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) +// if err != nil { +// return nil, fmt.Errorf("请求参数映射失败: %w", err) +// } +// // 合并请求头 +// headers := util.ForwardHeaders(ctx) +// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { +// headers[hk] = hv +// } +// for hk, hv := range parseHeadMsgHeaders(modelKey) { +// headers[hk] = hv +// } +// +// // 设置超时 +// timeout := time.Duration(m.TimeoutSeconds) * time.Second +// if timeout <= 0 { +// timeout = 600 * time.Second +// } +// ctx, cancel := context.WithTimeout(ctx, timeout) +// defer cancel() +// +// invokeUrl := strings.TrimRight(m.BaseURL, "/") +// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) +// if method == "" { +// method = http.MethodPost +// } +// +// var respBytes []byte +// +// switch method { +// case http.MethodGet: +// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) +// default: +// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) +// } +// if err != nil { +// return nil, err +// } +// // 响应参数映射 +// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) +// if err != nil { +// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) +// return respBytes, nil +// } +// return mappedResponse, nil +//} + // ============================================ // 映射相关函数 // ============================================ diff --git a/service/model_service.go b/service/model_service.go index efcba22..7334e13 100644 --- a/service/model_service.go +++ b/service/model_service.go @@ -3,7 +3,6 @@ package service import ( "context" "errors" - "model-gateway/common/util" "model-gateway/consts/public" "model-gateway/dao" "model-gateway/model/dto" @@ -87,6 +86,8 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, + ExtendMapping: req.ExtendMapping, + QueryConfig: req.QueryConfig, }) if err != nil { return nil, err @@ -150,6 +151,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, + ExtendMapping: req.ExtendMapping, + QueryConfig: req.QueryConfig, }) if err != nil { return err @@ -195,6 +198,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, + ExtendMapping: req.ExtendMapping, + QueryConfig: req.QueryConfig, }) return err } @@ -225,6 +230,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, + ExtendMapping: req.ExtendMapping, + QueryConfig: req.QueryConfig, }) return err } @@ -237,17 +244,20 @@ func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) erro } func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } model, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + SQLBaseDO: beans.SQLBaseDO{ + Id: req.ID, + Creator: user.UserName, + }, + ModelName: req.ModelName, }) if err != nil { return nil, err } - model.Form = util.ParseJSONField(model.Form) - model.RequestMapping = util.ParseJSONField(model.RequestMapping) - model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) - model.ResponseBody = util.ParseJSONField(model.ResponseBody) - model.TokenConfig = util.ParseJSONField(model.TokenConfig) return &dto.GetModelRes{ Model: model, }, nil @@ -277,14 +287,6 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dt return } - // 处理列表中每条记录的 JSONB 字段 - for _, m := range models { - m.Form = util.ParseJSONField(m.Form) - m.RequestMapping = util.ParseJSONField(m.RequestMapping) - m.ResponseMapping = util.ParseJSONField(m.ResponseMapping) - m.ResponseBody = util.ParseJSONField(m.ResponseBody) - m.TokenConfig = util.ParseJSONField(m.TokenConfig) - } return &dto.ListModelRes{ List: models, Total: total, @@ -381,11 +383,6 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR if model == nil { return nil, nil } - model.Form = util.ParseJSONField(model.Form) - model.RequestMapping = util.ParseJSONField(model.RequestMapping) - model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) - model.ResponseBody = util.ParseJSONField(model.ResponseBody) - model.TokenConfig = util.ParseJSONField(model.TokenConfig) return &dto.GetIsChatModelRes{ Model: model, }, nil diff --git a/service/worker.go b/service/worker.go index 51779ba..43f0214 100644 --- a/service/worker.go +++ b/service/worker.go @@ -262,8 +262,8 @@ func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { } // GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值 -func GetExpendTokens(tokenMapping string, textResult string) int { - value := gjson.Get(textResult, tokenMapping) +func GetExpendTokens(responseTokenField string, textResult string) int { + value := gjson.Get(textResult, responseTokenField) if value.Exists() { return int(value.Int()) } diff --git a/update.sql b/update.sql index c916a95..d83b2de 100644 --- a/update.sql +++ b/update.sql @@ -51,6 +51,8 @@ CREATE TABLE IF NOT EXISTS asynch_models ( "reserve_ratio": 0.2, "min_reserve": 512, }'::jsonb -- Token配置 + extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, + query_config JSONB NOT NULL DEFAULT '{}'::jsonb; ); CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL; CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name); @@ -94,6 +96,8 @@ COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的 COMMENT ON COLUMN asynch_models.remark IS '备注'; COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射'; COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称'; +COMMENT ON COLUMN asynch_models.extend_mapping IS '附加映射(请求时候额外字段)'; +COMMENT ON COLUMN asynch_models.query_config IS '查询结果配置(通过task_id查询结果相关配置)'; COMMENT ON COLUMN asynch_models.token_config IS '{ "zh_ratio": 1.0, // 中文字符→token系数 "en_ratio": 1.3, // 英文单词→token系数