159 lines
5.6 KiB
Go
159 lines
5.6 KiB
Go
package prompt
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"prompts-core/service/gateway"
|
||
"strings"
|
||
|
||
"prompts-core/common/util"
|
||
"prompts-core/dao"
|
||
"prompts-core/model/dto"
|
||
"prompts-core/model/entity"
|
||
|
||
"gitea.com/red-future/common/utils"
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
)
|
||
|
||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||
//1) 构建系统提示词
|
||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||
ir.AddSystem(systemPrompt)
|
||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||
ir.AddUser(userPrompt)
|
||
//2) 检查整体内容是否超出窗口
|
||
if !checkOverallContent(ir, aiModel) {
|
||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||
}
|
||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||
}
|
||
|
||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||
ir.AddUser(NodeBuild(ctx, req))
|
||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||
}
|
||
|
||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||
ir.AddSystem(customPrompt)
|
||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||
}
|
||
|
||
// compileToProviderRequest 编译为 Provider 请求
|
||
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||
if err != nil || protocol == nil {
|
||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||
}
|
||
// 如果传了自定义提示词,替换掉协议模板
|
||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||
protocol.SystemPromptTemplate = customPrompt[0] +
|
||
"【核心铁律】" +
|
||
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||
}
|
||
providerReq, err := Compile(ir, protocol, chatModel)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("编译请求失败: %w", err)
|
||
}
|
||
return map[string]any{
|
||
"modelName": chatModel.ModelName,
|
||
"bizName": util.GetServerName(ctx),
|
||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||
"requestPayload": providerReq,
|
||
"buildType": req.BuildType,
|
||
}, nil
|
||
}
|
||
|
||
// promptBuildWithRounds 构建提示词
|
||
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||
ProviderName: chatModel.OperatorName,
|
||
Status: 1,
|
||
})
|
||
if err != nil || providerProtocol == nil {
|
||
return ""
|
||
}
|
||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||
|
||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||
outputJSON, //【输出结构】 %s
|
||
)
|
||
}
|
||
|
||
// checkOverallContent 检查整体内容是否超出窗口
|
||
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||
fullContent := ir.String()
|
||
return util.CountToken(fullContent, model.TokenConfig)
|
||
}
|
||
|
||
// buildUserPrompt 构建用户提示词
|
||
func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt string) string {
|
||
var b strings.Builder
|
||
b.WriteString(fmt.Sprintf("目标模型:%s\n", req.ModelName))
|
||
if prompt != "" {
|
||
b.WriteString(fmt.Sprintf("系统提示词:%s\n", prompt))
|
||
}
|
||
if skills := SkillMdContent(ctx, req.SkillName); skills != "" {
|
||
b.WriteString(fmt.Sprintf("技能内容:\n%s\n", skills))
|
||
}
|
||
if formText := buildUserFormText(req.Form); formText != "" {
|
||
b.WriteString(fmt.Sprintf("系统参数:\n%s\n", formText))
|
||
}
|
||
if userFormText := buildUserFormText(req.UserForm); userFormText != "" {
|
||
b.WriteString(fmt.Sprintf("用户需求:\n%s\n", userFormText))
|
||
}
|
||
if len(req.Consult) > 0 {
|
||
b.WriteString(fmt.Sprintf("参考附件:%s\n", gjson.New(req.Consult).String()))
|
||
}
|
||
if fileTexts := ExtractFileTexts(ctx, req.Consult); fileTexts != "" {
|
||
b.WriteString(fmt.Sprintf("附件内容:\n%s\n", fileTexts))
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func buildUserFormText(form []map[string]any) string {
|
||
if len(form) == 0 {
|
||
return ""
|
||
}
|
||
var builder strings.Builder
|
||
for _, item := range form {
|
||
for k, v := range item {
|
||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||
switch val := v.(type) {
|
||
case []any:
|
||
for i, elem := range val {
|
||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||
if m, ok := elem.(map[string]any); ok {
|
||
for mk, mv := range m {
|
||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||
}
|
||
} else {
|
||
builder.WriteString(fmt.Sprint(elem))
|
||
}
|
||
builder.WriteString("\n")
|
||
}
|
||
default:
|
||
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||
}
|
||
}
|
||
}
|
||
return strings.TrimSpace(builder.String())
|
||
}
|
||
|
||
// NodeBuild 节点构建
|
||
func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||
promptTpl := util.GetBuildPrompt(ctx)
|
||
if promptTpl == "" {
|
||
return ""
|
||
}
|
||
return fmt.Sprintf(promptTpl,
|
||
gjson.New(req.Form).MustToJsonString(),
|
||
gjson.New(req.UserForm).MustToJsonString(),
|
||
)
|
||
}
|