Files
prompts-core/service/prompt/prompt_ir_service.go

265 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"strings"
"prompts-core/dao"
"prompts-core/model/entity"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
}
// Segment 消息片段
type Segment struct {
Type string `json:"type"` // text/image
Content string `json:"content"`
Role string `json:"role,omitempty"`
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
}
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
return ir
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
return ir
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
return ir
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
var messages []map[string]any
// 1. 系统消息
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
// 2. 历史消息
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
// 3. 用户消息
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
return messages
}
// GetProtocolByProvider 根据 provider_name 获取协议配置
func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) {
entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: providerName,
Status: 1,
})
if err != nil || entity == nil {
return nil, err
}
entity.MergeOrder = util.ParseJSONField(entity.MergeOrder)
entity.RoleMapping = util.ParseJSONField(entity.RoleMapping)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
entity.RequestTemplate = util.ParseJSONField(entity.RequestTemplate)
entity.ContentMapping = util.ParseJSONField(entity.ContentMapping)
return parseProtocol(entity), nil
}
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
TargetField: e.TargetField,
}
// MergeOrder: any → []string
if e.MergeOrder != nil {
b, _ := json.Marshal(e.MergeOrder)
json.Unmarshal(b, &p.MergeOrder)
}
// RoleMapping: any → map[string]string
if e.RoleMapping != nil {
b, _ := json.Marshal(e.RoleMapping)
json.Unmarshal(b, &p.RoleMapping)
}
// ContentMapping: any → ContentMapping
if e.ContentMapping != nil {
b, _ := json.Marshal(e.ContentMapping)
json.Unmarshal(b, &p.ContentMapping)
}
// RequestTemplate: any → map[string]any
if e.RequestTemplate != nil {
b, _ := json.Marshal(e.RequestTemplate)
json.Unmarshal(b, &p.RequestTemplate)
}
fmt.Printf("parseProtocol: %+v\n", p)
return p
}
// ProviderProtocol 协议编译配置(从 DB JSONB 字段解析)
type ProviderProtocol struct {
TargetField string `json:"target_field"`
MergeOrder []string `json:"merge_order"`
RoleMapping map[string]string `json:"role_mapping"`
ContentMapping ContentMapping `json:"content_mapping"`
RequestTemplate map[string]any `json:"request_template"`
}
// ContentMapping 内容字段映射
type ContentMapping struct {
Type string `json:"type"` // direct/parts
Field string `json:"field"` // content/text
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *entity.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
// 1. 按 merge_order 拼接消息
messages := mergeByOrder(ir, p.MergeOrder)
// 2. 角色映射
messages = mapRoles(messages, p.RoleMapping)
// 3. 内容字段映射
messages = mapContent(messages, p.ContentMapping)
// 4. 按 target_field + request_template 构建请求体
return buildRequest(messages, p, chatModel), nil
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
}
}
}
return messages
}
// mapRoles 角色映射
func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any {
if len(mapping) == 0 {
return messages
}
for i, msg := range messages {
role, ok := msg["role"].(string)
if !ok {
continue
}
if mapped, exists := mapping[role]; exists {
messages[i]["role"] = mapped
}
}
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
switch cm.Type {
case "parts":
// Gemini 格式: {"parts": [{"text": "..."}]}
msg["parts"] = []map[string]any{
{cm.Field: content},
}
default:
// direct: {"content": "..."}
msg[cm.Field] = content
}
}
return messages
}
// buildRequest 按 target_field 和 request_template 构建请求体
func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *entity.AsynchModel) map[string]any {
if len(p.RequestTemplate) > 0 {
return renderTemplate(p.RequestTemplate, messages, chatModel)
}
return map[string]any{
p.TargetField: messages,
}
}
// renderTemplate 简单的 {{key}} 模板替换
func renderTemplate(tmpl map[string]any, messages []map[string]any, chatModel *entity.AsynchModel) map[string]any {
b, _ := json.Marshal(tmpl)
str := string(b)
// 替换 {{model}}
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
// 替换 {{messages}}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
json.Unmarshal([]byte(str), &result)
return result
}