From 3fa2896fc3c3b2efb35669cc09e185a5cde4154d Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 3 Jun 2026 13:30:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(util):=20=E9=87=8D=E6=9E=84=E6=98=A0?= =?UTF-8?q?=E5=B0=84=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=BC=82=E6=AD=A5=E4=BB=BB=E5=8A=A1=E8=BD=AE=E8=AF=A2?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/mapping.go | 36 --------------- controller/prompt_compose_controller.go | 3 +- controller/prompt_session_controller.go | 5 -- dao/model_dao.go | 13 +----- service/gateway/gateway_http_service.go | 54 ++++++++++++++++++++++ service/prompt/prompt_build_service.go | 13 +++--- service/prompt/prompt_compose_service.go | 20 ++++---- service/prompt/prompt_ir_service.go | 7 +-- service/prompt/prompt_user_form_batches.go | 4 +- 9 files changed, 80 insertions(+), 75 deletions(-) diff --git a/common/util/mapping.go b/common/util/mapping.go index 5616af3..60ee815 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -1,49 +1,13 @@ package util import ( - "fmt" "net/url" - "prompts-core/model/entity" "strings" "github.com/gogf/gf/v2/encoding/gjson" - "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) -// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 -// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段 -func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { - // 1) 获取校验配置,并取值 - requestMapping := model.RequestMapping - contentStr, ok := raw[model.ResponseBody].(string) - if !ok || contentStr == "" { - return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody) - } - - // 2) 解析 content 为 JSON 数组 - var rounds []map[string]any - if err := gjson.DecodeTo(contentStr, &rounds); err != nil { - return fmt.Errorf("解析 content JSON 数组失败: %w", err) - } - if len(rounds) == 0 { - return fmt.Errorf("content 数组为空") - } - - // 3) 逐条校验:只检查默认值为空的必填字段是否存在 - for i, round := range rounds { - for path, defaultValue := range requestMapping { - if !g.IsEmpty(defaultValue) { - continue - } - if gjson.New(round).Get(path).IsNil() { - return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path) - } - } - } - return nil -} - // ReverseMap 映射 payload 到 mapping func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { jsonObj := gjson.New("{}") diff --git a/controller/prompt_compose_controller.go b/controller/prompt_compose_controller.go index 1996b94..d23fb98 100644 --- a/controller/prompt_compose_controller.go +++ b/controller/prompt_compose_controller.go @@ -6,6 +6,7 @@ import ( "prompts-core/dao" "prompts-core/model/dto" "prompts-core/model/entity" + "prompts-core/service/gateway" promptService "prompts-core/service/prompt" @@ -42,7 +43,7 @@ func (c *prompt) Text(ctx context.Context, req *dto.TextReq) (res *dto.TextRes, if err != nil { return } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, ModelName: composeTask.ModelName, }) diff --git a/controller/prompt_session_controller.go b/controller/prompt_session_controller.go index bb5dae4..84c3a94 100644 --- a/controller/prompt_session_controller.go +++ b/controller/prompt_session_controller.go @@ -16,8 +16,3 @@ var Session = new(session) func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { return sessionService.Callback(ctx, req) } - -//TODO:后期历史相关服务可能拆分(三个接口) -// 1. 添加历史会话 -// 2. 获取历史会话 -// 3. 更新历史信息 diff --git a/dao/model_dao.go b/dao/model_dao.go index c3aa59a..7aabdc7 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -16,6 +16,7 @@ type modelDao struct{} func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). OmitEmpty(). + Where(entity.AsynchModelCol.Id, req.Id). Where(entity.AsynchModelCol.Creator, req.Creator). Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). Where(entity.AsynchModelCol.ModelName, req.ModelName). @@ -26,15 +27,3 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s err = r.Struct(&m) return } - -// GetsByModelName 批量获取模型 -func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) { - err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). - OmitEmpty(). - Where(entity.AsynchModelCol.Creator, creator). - WhereIn(entity.AsynchModelCol.ModelName, modelNames). - Fields(fields). - Scan(&list) - - return -} diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index bdc755c..f8fbc2b 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -7,6 +7,7 @@ import ( "prompts-core/common/util" "prompts-core/model/entity" + "gitea.com/red-future/common/beans" commonHttp "gitea.com/red-future/common/http" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" @@ -37,6 +38,59 @@ func CreateGatewayTask(ctx context.Context, payload map[string]any) (string, err return req.TaskId, nil } +type GetModelConfigResp struct { + Model *AsynchModel `json:"model"` +} + +type AsynchModel struct { + beans.SQLBaseDO `orm:",inline"` + ModelName string `orm:"model_name" json:"modelName"` + ModelType int `orm:"model_type" json:"modelType"` + BaseURL string `orm:"base_url" json:"baseUrl"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + HeadMsg map[string]any `orm:"head_msg" json:"headMsg"` + Form []map[string]any `orm:"form_json" json:"form"` + RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` + ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` + ResponseBody string `orm:"response_body" json:"responseBody"` + ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` + IsPrivate *int `orm:"is_private" json:"isPrivate"` + IsChatModel int `orm:"is_chat_model" json:"isChatModel"` + IsAsync *int `orm:"is_async" json:"isAsync"` + IsStream *int `orm:"is_stream" json:"isStream"` + ApiKey string `orm:"api_key" json:"apiKey"` + Enabled *int `orm:"enabled" json:"enabled"` + MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` + TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` + RetryTimes int `orm:"retry_times" json:"retryTimes"` + AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` + IsOwner *int `json:"isOwner" orm:"is_owner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` + ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` + QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` + StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` + FirstFrame string `orm:"first_frame" json:"firstFrame"` + LastFrame string `orm:"last_frame" json:"lastFrame"` + CallbackUrl string `orm:"callback_url" json:"callbackUrl"` +} + +// GetModelConfig 获取模型配置 +func GetModelConfig(ctx context.Context, req *AsynchModel) (model *AsynchModel, err error) { + fmt.Println("req参数", req) + fullURL := fmt.Sprintf("model-gateway/model/getModel?creator=%s&modelName=%s&isChatModel=%d", + req.Creator, req.ModelName, req.IsChatModel) + headers := util.ForwardHeaders(ctx) + var resp GetModelConfigResp + if err = commonHttp.Get(ctx, fullURL, headers, &resp, nil); err != nil { + return nil, fmt.Errorf("获取模型配置失败: %w", err) + } + if resp.Model == nil { + return nil, fmt.Errorf("模型不存在: creator=%s modelName=%s isChatModel=%d", req.Creator, req.ModelName, req.IsChatModel) + } + return resp.Model, nil +} + // GetTaskResultRes 任务结果响应 type GetTaskResultRes struct { OssFile string `json:"ossFile" dc:"结果文件OSS地址"` diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index d86a4ec..a73816d 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "prompts-core/consts/public" + "prompts-core/service/gateway" "strings" "prompts-core/common/util" @@ -29,7 +30,7 @@ type UserPromptPayload struct { } // buildInferenceRequest 构建推理请求 -func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) { +func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (map[string]any, error) { //1) 处理表单分批 processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel) if err != nil { @@ -47,7 +48,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha } // buildPromptTypeRequest 构建提示词类型请求(BuildType=1) -func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { +func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) { //1) 构建系统提示词 systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches) ir.AddSystem(systemPrompt) @@ -69,13 +70,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai } // buildNodeTypeRequest 构建节点类型请求(BuildType=2) -func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) { +func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { ir.AddUser(NodeBuild(ctx, req)) return compileToProviderRequest(ctx, ir, chatModel) } // compileToProviderRequest 编译为 Provider 请求 -func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) { +func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel) (map[string]any, error) { protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName) if err != nil { return nil, fmt.Errorf("获取协议配置失败: %w", err) @@ -97,7 +98,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti } // promptBuildWithRounds 构建系统提示词 -func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string { +func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, batches int) string { providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: chatModel.OperatorName, Status: 1, @@ -144,7 +145,7 @@ func buildUserFormContent(userForm []map[string]any) string { } // checkOverallContent 检查整体内容是否超出窗口 -func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool { +func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool { fullContent := ir.String() return util.CountToken(fullContent, model.TokenConfig) } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 999c128..9076471 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -42,14 +42,14 @@ func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.Com } // GetModelMessage 获取模型信息 -func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.AsynchModel, *entity.AsynchModel, error) { +func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*gateway.AsynchModel, *gateway.AsynchModel, error) { userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) } - chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + chatModel, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, - IsChatModel: new(1), + IsChatModel: 1, }) if err != nil { return nil, nil, err @@ -57,8 +57,8 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. if chatModel == nil { return nil, nil, errors.New("当前没有对话模型,请添加") } - aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + aiModels, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{TenantId: userInfo.TenantId, Creator: userInfo.UserName}, ModelName: req.ModelName, }) if err != nil { @@ -72,7 +72,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. } // validateUserForm 校验用户表单 -func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) error { +func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) error { if len(req.UserForm) == 0 { return nil } @@ -90,7 +90,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er } // handlePromptBuild 处理提示词构建(BuildType=1) -func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { +func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) { // 获取历史会话 history, err := session.GetHistoryMessages(ctx, req.SessionId) if err != nil { @@ -123,7 +123,7 @@ func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatMod } // handleNodeBuild 处理节点构建(BuildType=2) -func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { +func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) { taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) if err != nil { return nil, fmt.Errorf("调用推理模型失败: %w", err) @@ -148,7 +148,7 @@ func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel } // callInferenceModel 调用推理模型 -func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) { +func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel, history []map[string]any) (string, int64, error) { taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history) if err != nil { return "", 0, fmt.Errorf("构建推理请求失败: %w", err) @@ -186,7 +186,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { if err != nil { return fmt.Errorf("查询任务失败: %w", err) } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator}, ModelName: composeTask.ModelName, }) diff --git a/service/prompt/prompt_ir_service.go b/service/prompt/prompt_ir_service.go index d9d554e..7bdac56 100644 --- a/service/prompt/prompt_ir_service.go +++ b/service/prompt/prompt_ir_service.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "prompts-core/common/util" + "prompts-core/service/gateway" "strings" "prompts-core/dao" @@ -178,7 +179,7 @@ func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { } // Compile 将 PromptIR 按协议配置编译为 Provider Request -func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) { +func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } @@ -262,7 +263,7 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { } // buildRequest 按 target_field 和 request_template 构建请求体 -func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any { +func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any { if len(p.RequestTemplate) > 0 { return renderTemplate(p.RequestTemplate, messages, chatModel) } @@ -273,7 +274,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent } // renderTemplate 简单的 {{key}} 模板替换 -func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any { +func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { b, _ := json.Marshal(tmpl) str := string(b) diff --git a/service/prompt/prompt_user_form_batches.go b/service/prompt/prompt_user_form_batches.go index d5feb17..740d896 100644 --- a/service/prompt/prompt_user_form_batches.go +++ b/service/prompt/prompt_user_form_batches.go @@ -3,17 +3,17 @@ package prompt import ( "context" "fmt" + "prompts-core/service/gateway" "strings" "github.com/gogf/gf/v2/frame/g" "prompts-core/common/util" "prompts-core/model/dto" - "prompts-core/model/entity" ) // ProcessUserFormBatches 处理 UserForm 分批(按 token 大小拼接内容) -func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel) (*dto.ComposeMessagesReq, int, error) { +func ProcessUserFormBatches(ctx context.Context, req *dto.ComposeMessagesReq, model *gateway.AsynchModel) (*dto.ComposeMessagesReq, int, error) { if model.TokenConfig == nil || len(req.UserForm) == 0 { return req, 1, nil }