refactor(prompts-core): 重构代码结构和优化工具函数

This commit is contained in:
2026-06-10 14:51:25 +08:00
parent 1c1db7e30c
commit b69e7386e2
10 changed files with 164 additions and 432 deletions

View File

@@ -13,11 +13,10 @@ import (
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt)
@@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
// 提取 userForm 中的 prompt 作为自定义提示词
var customPrompt string
for _, item := range req.UserForm {
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
customPrompt = gconv.String(prompt)
break
}
}
// 用户消息
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
@@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
}, nil
}
// promptBuildWithRounds 构建提示词
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
@@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
outputJSON, //【输出结构】 %s
@@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
@@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
return b.String()
}
// buildUserFormText 构建用户表单内容字符串
func buildUserFormText(form []map[string]any) string {
if len(form) == 0 {
return ""
@@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string {
var builder strings.Builder
for _, item := range form {
for k, v := range item {
builder.WriteString(fmt.Sprintf("%s\n", k))
switch val := v.(type) {
case []any:
// 数组类型:逐条列出
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, elem := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
if m, ok := elem.(map[string]any); ok {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
builder.WriteString("\n")
} else {
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
}
}
case []map[string]any:
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, m := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
builder.WriteString(fmt.Sprint(elem))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf("%s%v\n", k, v))
builder.WriteString(fmt.Sprintf(" %v\n", v))
}
}
}
@@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
return fmt.Sprintf(promptTpl,
gjson.New(req.Form).MustToJsonString(),
gjson.New(req.UserForm).MustToJsonString(),
)
}