refactor(service): 重构服务模块结构并优化模型配置

This commit is contained in:
2026-05-29 17:54:19 +08:00
parent d74559ae74
commit 55eb436639
7 changed files with 204 additions and 53 deletions

View File

@@ -20,7 +20,7 @@ import (
type UserPromptPayload struct {
Model string `json:"model"`
PromptInfo string `json:"promptInfo"`
Form map[string]any `json:"form"`
Form any `json:"form"`
UserForm any `json:"userForm"`
Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"`
@@ -30,6 +30,7 @@ type UserPromptPayload struct {
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
@@ -47,9 +48,10 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
//2) 构建历史对话
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
@@ -57,7 +59,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, aiModel) {
@@ -70,7 +71,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
@@ -90,33 +90,33 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"bizName": util.GetServerName(ctx),
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
// promptBuildWithRounds 构建系统提示词
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
ProviderName: chatModel.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
formContent := buildUserFormContent(req.Form)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormContent)
`, formContent, userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
@@ -129,11 +129,8 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, mod
req.ModelName, // %s 目标模型名称
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
totalRounds, // %d 数组长度(多轮输出要求)
totalRounds, // %d 数组长度(结构铁律)
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
totalRounds, // %d 数组长度(最后一行)
)
}
@@ -157,7 +154,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
payload := UserPromptPayload{
Model: req.ModelName,
PromptInfo: prompt,
Form: req.Form,
Form: prepareUserFormPayload(req.Form),
UserForm: prepareUserFormPayload(req.UserForm),
Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult),

View File

@@ -21,13 +21,16 @@ import (
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
//1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
//2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
//3) 处理不同类型
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
@@ -54,7 +57,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
@@ -62,10 +65,10 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
return nil, nil, err
}
if aiModel == nil {
if aiModels == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
return chatModel, aiModels, nil
}
// validateUserForm 校验用户表单
@@ -150,12 +153,15 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
id := int64(0)
if req.SessionId != "" {
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {