Files
prompts-core/service/prompt/prompt_build_service.go

205 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"context"
"errors"
"fmt"
"prompts-core/consts/public"
"strings"
"prompts-core/common/util"
"prompts-core/dao"
"prompts-core/model/dto"
"prompts-core/model/entity"
"github.com/gogf/gf/v2/util/gconv"
)
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
ir := NewPromptIR()
switch req.BuildType {
case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
case public.BuildTypeNode:
return buildNodeTypeRequest(ctx, req, chatModel, ir)
default:
return nil, errors.New("不支持的构建类型")
}
}
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
continue
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, aiModel) {
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
}
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel)
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *entity.AsynchModel) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil {
return nil, fmt.Errorf("获取协议配置失败: %w", err)
}
if protocol == nil {
return nil, errors.New("协议配置不存在")
}
providerReq, err := Compile(ir, protocol, chatModel)
if err != nil {
return nil, fmt.Errorf("编译请求失败: %w", err)
}
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
%s
技能名称: %s
用户文件: %v
`, req.ModelName, formInfo, req.SkillName, req.UserFiles)
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
req.ModelName, // %s 目标模型名称
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
totalRounds, // %d 数组长度(多轮输出要求)
totalRounds, // %d 数组长度(结构铁律)
totalRounds, // %d 数组长度(多轮输出要求第二个)
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
totalRounds, // %d 数组长度(最后一行)
)
}
// buildUserFormContent 构建用户表单内容字符串
func buildUserFormContent(userForm []map[string]any) string {
var builder strings.Builder
for _, item := range userForm {
builder.WriteString(fmt.Sprintf("%v\n", item))
}
return builder.String()
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *entity.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
// buildUserPrompt 构建用户提示词
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
userFormForPayload := prepareUserFormPayload(req.UserForm)
payload := map[string]any{
"model": req.ModelName,
"promptInfo": prompt,
"form": req.Form,
"userForm": userFormForPayload,
"userFiles": req.UserFiles,
"userFilesText": FetchFileTexts(ctx, req.UserFiles),
"skills": SkillMdContent(ctx, req.SkillName),
}
return util.MustMarshal(payload)
}
// prepareUserFormPayload 准备用户表单载荷
func prepareUserFormPayload(userForm []map[string]any) any {
if len(userForm) == 0 {
return nil
}
if _, ok := userForm[0]["batch_index"]; ok {
return userForm
}
return mergeUserFormTexts(userForm)
}
// mergeUserFormTexts 合并 UserForm 中的所有文本内容
func mergeUserFormTexts(userForm []map[string]any) string {
var builder strings.Builder
for i, item := range userForm {
text := getItemText(item)
if i > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(text)
}
return builder.String()
}
// NodeBuild 节点构建
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
promptTpl := util.GetBuildPrompt(ctx)
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
}