refactor(prompt): 重构提示词构建服务与数据模型
This commit is contained in:
@@ -20,11 +20,27 @@ type PromptIR struct {
|
||||
|
||||
// Segment 消息片段
|
||||
type Segment struct {
|
||||
Type string `json:"type"` // text/image
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||
type ProviderProtocol struct {
|
||||
TargetField string `json:"target_field"`
|
||||
MergeOrder []string `json:"merge_order"`
|
||||
RoleMapping map[string]string `json:"role_mapping"`
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
SystemPromptTemplate string `json:"system_prompt_template"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
type ContentMapping struct {
|
||||
Type string `json:"type"`
|
||||
Field string `json:"field"`
|
||||
}
|
||||
|
||||
// NewPromptIR 创建空 PromptIR
|
||||
func NewPromptIR() *PromptIR {
|
||||
return &PromptIR{
|
||||
@@ -34,6 +50,54 @@ func NewPromptIR() *PromptIR {
|
||||
}
|
||||
}
|
||||
|
||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||
func (ir *PromptIR) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString("System: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Role)
|
||||
builder.WriteString(": ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString("User: ")
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||
func (ir *PromptIR) GetTotalContent() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.History {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
for _, seg := range ir.User {
|
||||
builder.WriteString(seg.Content)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// AddSystem 添加系统提示
|
||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
if content != "" {
|
||||
@@ -62,7 +126,6 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
// 1. 系统消息
|
||||
for _, seg := range ir.System {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
@@ -70,7 +133,6 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
})
|
||||
}
|
||||
|
||||
// 2. 历史消息
|
||||
for _, seg := range ir.History {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": seg.Role,
|
||||
@@ -78,13 +140,13 @@ func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
})
|
||||
}
|
||||
|
||||
// 3. 用户消息
|
||||
for _, seg := range ir.User {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -97,74 +159,35 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
||||
if err != nil || entity == nil {
|
||||
return nil, err
|
||||
}
|
||||
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
|
||||
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
|
||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
||||
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
|
||||
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
|
||||
fmt.Println("entity打印", entity)
|
||||
return parseProtocol(entity), nil
|
||||
}
|
||||
|
||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
p := &ProviderProtocol{
|
||||
TargetField: e.TargetField,
|
||||
TargetField: e.TargetField,
|
||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||
}
|
||||
|
||||
// MergeOrder: any → []string
|
||||
if e.MergeOrder != nil {
|
||||
b, _ := json.Marshal(e.MergeOrder)
|
||||
json.Unmarshal(b, &p.MergeOrder)
|
||||
}
|
||||
|
||||
// RoleMapping: any → map[string]string
|
||||
if e.RoleMapping != nil {
|
||||
b, _ := json.Marshal(e.RoleMapping)
|
||||
json.Unmarshal(b, &p.RoleMapping)
|
||||
}
|
||||
|
||||
// ContentMapping: any → ContentMapping
|
||||
if e.ContentMapping != nil {
|
||||
b, _ := json.Marshal(e.ContentMapping)
|
||||
json.Unmarshal(b, &p.ContentMapping)
|
||||
}
|
||||
|
||||
// RequestTemplate: any → map[string]any
|
||||
if e.RequestTemplate != nil {
|
||||
b, _ := json.Marshal(e.RequestTemplate)
|
||||
json.Unmarshal(b, &p.RequestTemplate)
|
||||
}
|
||||
fmt.Printf("parseProtocol: %+v\n", p)
|
||||
// 使用通用解析方法处理各个字段
|
||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||
return p
|
||||
}
|
||||
|
||||
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
|
||||
type ProviderProtocol struct {
|
||||
TargetField string `json:"target_field"`
|
||||
MergeOrder []string `json:"merge_order"`
|
||||
RoleMapping map[string]string `json:"role_mapping"`
|
||||
ContentMapping ContentMapping `json:"content_mapping"`
|
||||
RequestTemplate map[string]any `json:"request_template"`
|
||||
}
|
||||
|
||||
// ContentMapping 内容字段映射
|
||||
type ContentMapping struct {
|
||||
Type string `json:"type"` // direct/parts
|
||||
Field string `json:"field"` // content/text
|
||||
}
|
||||
|
||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
|
||||
if ir == nil || p == nil {
|
||||
return nil, fmt.Errorf("ir and protocol are required")
|
||||
}
|
||||
// 1. 按 merge_order 拼接消息
|
||||
|
||||
messages := mergeByOrder(ir, p.MergeOrder)
|
||||
// 2. 角色映射
|
||||
messages = mapRoles(messages, p.RoleMapping)
|
||||
// 3. 内容字段映射
|
||||
messages = mapContent(messages, p.ContentMapping)
|
||||
// 4. 按 target_field + request_template 构建请求体
|
||||
|
||||
return buildRequest(messages, p, chatModel), nil
|
||||
}
|
||||
|
||||
@@ -197,6 +220,7 @@ func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -205,15 +229,18 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
||||
if len(mapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
role, ok := msg["role"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if mapped, exists := mapping[role]; exists {
|
||||
messages[i]["role"] = mapped
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -225,15 +252,14 @@ func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
// Gemini 格式: {"parts": [{"text": "..."}]}
|
||||
msg["parts"] = []map[string]any{
|
||||
{cm.Field: content},
|
||||
}
|
||||
default:
|
||||
// direct: {"content": "..."}
|
||||
msg[cm.Field] = content
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -242,6 +268,7 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *ent
|
||||
if len(p.RequestTemplate) > 0 {
|
||||
return renderTemplate(p.RequestTemplate, messages, chatModel)
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
p.TargetField: messages,
|
||||
}
|
||||
@@ -252,13 +279,13 @@ func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *e
|
||||
b, _ := json.Marshal(tmpl)
|
||||
str := string(b)
|
||||
|
||||
// 替换 {{model}}
|
||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
||||
// 替换 {{messages}}
|
||||
|
||||
msgBytes, _ := json.Marshal(messages)
|
||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||
|
||||
var result map[string]any
|
||||
json.Unmarshal([]byte(str), &result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user