package prompt import ( "context" "fmt" "prompts-core/service/gateway" "strings" "prompts-core/dao" "prompts-core/model/entity" "github.com/gogf/gf/v2/util/gconv" ) // IR 统一 Prompt 中间表示 type IR struct { System []Segment `json:"system"` History []Segment `json:"history"` User []Segment `json:"user"` } // Segment 消息片段 type Segment struct { Type string `json:"type"` Content string `json:"content"` Role string `json:"role,omitempty"` } // ProviderProtocol 协议编译配置(从 DB JSONB 字段解析) type ProviderProtocol struct { TargetField string `json:"target_field"` MergeOrder []string `json:"merge_order"` RoleMapping map[string]string `json:"role_mapping"` ContentMapping ContentMapping `json:"content_mapping"` RequestTemplate map[string]any `json:"request_template"` SystemPromptTemplate string `json:"system_prompt_template"` Capabilities map[string]any `json:"capabilities"` } // ContentMapping 内容字段映射 type ContentMapping struct { Type string `json:"type"` Field string `json:"field"` } // NewPromptIR 创建空 PromptIR func NewPromptIR() *IR { return &IR{ System: make([]Segment, 0), History: make([]Segment, 0), User: make([]Segment, 0), } } // String 返回 PromptIR 的完整内容字符串(用于 token 计算) func (ir *IR) String() string { var builder strings.Builder for _, seg := range ir.System { builder.WriteString("System: ") builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.History { builder.WriteString(seg.Role) builder.WriteString(": ") builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.User { builder.WriteString("User: ") builder.WriteString(seg.Content) builder.WriteString("\n") } return builder.String() } // GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算) func (ir *IR) GetTotalContent() string { var builder strings.Builder for _, seg := range ir.System { builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.History { builder.WriteString(seg.Content) builder.WriteString("\n") } for _, seg := range ir.User { builder.WriteString(seg.Content) builder.WriteString("\n") } return builder.String() } // AddSystem 添加系统提示 func (ir *IR) AddSystem(content string) *IR { if content != "" { ir.System = append(ir.System, Segment{Type: "text", Content: content}) } return ir } // AddUser 添加用户消息 func (ir *IR) AddUser(content string) *IR { if content != "" { ir.User = append(ir.User, Segment{Type: "text", Content: content}) } return ir } // AddHistory 添加历史消息 func (ir *IR) AddHistory(role, content string) *IR { if content != "" { ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) } return ir } // ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认) func (ir *IR) ToMessages() []map[string]any { var messages []map[string]any for _, seg := range ir.System { messages = append(messages, map[string]any{ "role": "system", "content": seg.Content, }) } for _, seg := range ir.History { messages = append(messages, map[string]any{ "role": seg.Role, "content": seg.Content, }) } for _, seg := range ir.User { messages = append(messages, map[string]any{ "role": "user", "content": seg.Content, }) } return messages } // GetProtocolByProvider 根据 provider_name 获取协议配置 func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderProtocol, error) { entity, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ ProviderName: providerName, Status: 1, }) if err != nil || entity == nil { return nil, err } return parseProtocol(entity), nil } // parseProtocol 将 DB entity 转为编译用协议配置 func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { return &ProviderProtocol{ TargetField: e.TargetField, SystemPromptTemplate: e.SystemPromptTemplate, MergeOrder: e.MergeOrder, RoleMapping: gconv.MapStrStr(e.RoleMapping), ContentMapping: ContentMapping{ Type: gconv.String(e.ContentMapping["type"]), Field: gconv.String(e.ContentMapping["field"]), }, RequestTemplate: e.RequestTemplate, Capabilities: e.Capabilities, } } // Compile 将 PromptIR 按协议配置编译为 Provider Request func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) { if ir == nil || p == nil { return nil, fmt.Errorf("ir and protocol are required") } messages := mergeByOrder(ir, p.MergeOrder) messages = mapRoles(messages, p.RoleMapping) messages = mapContent(messages, p.ContentMapping) return buildRequest(messages, p, chatModel), nil } // mergeByOrder 按协议配置顺序拼接消息 func mergeByOrder(ir *IR, order []string) []map[string]any { roleMap := map[string][]Segment{ "system": ir.System, "history": ir.History, "user": ir.User, } var messages []map[string]any for _, part := range order { for _, seg := range roleMap[part] { msg := map[string]any{"content": seg.Content} if part == "history" { msg["role"] = seg.Role } else { msg["role"] = part } messages = append(messages, msg) } } return messages } // mapRoles 角色映射 func mapRoles(messages []map[string]any, mapping map[string]string) []map[string]any { if len(mapping) == 0 { return messages } for i, msg := range messages { role, ok := msg["role"].(string) if !ok { continue } if mapped, exists := mapping[role]; exists { messages[i]["role"] = mapped } } return messages } func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { if cm.Field == "" || cm.Field == "content" { return messages } for i, msg := range messages { if content, ok := msg["content"]; ok { delete(msg, "content") switch cm.Type { case "parts": messages[i]["parts"] = []map[string]any{{cm.Field: content}} default: messages[i][cm.Field] = content } } } return messages } // buildRequest 按 target_field 和 request_template 构建请求体 func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gateway.AsynchModel) map[string]any { if len(p.RequestTemplate) > 0 { return renderTemplate(p, messages, chatModel) } return map[string]any{ p.TargetField: messages, } } // renderTemplate 模板渲染 func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { result := make(map[string]any, len(p.RequestTemplate)+1) for k, v := range p.RequestTemplate { result[k] = v } if chatModel != nil { result["model"] = chatModel.ModelName } result["messages"] = messages if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 { result["max_tokens"] = maxTokens } return result }