Files
prompts-core/service/compose_parser.go
2026-05-12 13:59:15 +08:00

363 lines
8.4 KiB
Go

package service
import (
"encoding/json"
"errors"
"fmt"
"strings"
"prompts-core/model/dto"
)
// ============================================
// 类型定义
// ============================================
// modelOutput 推理模型的标准输出格式
type modelOutput struct {
Messages []dto.Message `json:"messages"`
System any `json:"system"`
User any `json:"user"`
}
// gatewayResponse 模型网关的标准响应格式
type gatewayResponse struct {
Choices []choice `json:"choices"`
}
type choice struct {
Message message `json:"message"`
}
type message struct {
Content string `json:"content"`
}
// ============================================
// 核心解析函数
// ============================================
// ParseModelResponse 解析推理模型的文本响应,返回消息列表
// 支持三种格式:
// 1. 标准 messages 格式: {"messages": [...]}
// 2. 简化 system/user 格式: {"system": "...", "user": "..."}
// 3. 网关包装格式: {"choices": [{"message": {"content": "..."}}]}
func ParseModelResponse(text string) ([]dto.Message, error) {
text = strings.TrimSpace(text)
if text == "" {
return nil, errors.New("模型响应为空")
}
// 1. 尝试解包网关响应
if content := unwrapGatewayResponse(text); content != "" {
text = content
}
// 2. 解析为标准格式
output, err := parseAsModelOutput(text)
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
// 3. 优先使用 messages 字段
if len(output.Messages) > 0 {
messages := normalizeMessageContents(output.Messages)
if err := validateMessageList(messages); err != nil {
return nil, err
}
return messages, nil
}
// 4. 兼容 system/user 格式
return buildMessagesFromSystemUser(output)
}
// ParseStoredMessages 从数据库存储的数据中解析消息列表
func ParseStoredMessages(data any) []dto.Message {
if data == nil {
return nil
}
// 统一序列化
jsonBytes, err := json.Marshal(data)
if err != nil {
return nil
}
// 尝试直接解析
var messages []dto.Message
if err := json.Unmarshal(jsonBytes, &messages); err == nil {
return messages
}
// 尝试解析为 JSON 字符串再解析
var jsonStr string
if err := json.Unmarshal(jsonBytes, &jsonStr); err != nil {
return nil
}
if err := json.Unmarshal([]byte(jsonStr), &messages); err != nil {
return nil
}
return messages
}
// ============================================
// 内部解析函数
// ============================================
// parseAsModelOutput 将文本解析为 modelOutput 结构
func parseAsModelOutput(text string) (*modelOutput, error) {
// 清理可能的 Markdown 代码块标记
text = cleanMarkdownCodeBlock(text)
var output modelOutput
if err := json.Unmarshal([]byte(text), &output); err != nil {
return nil, err
}
return &output, nil
}
// unwrapGatewayResponse 解包网关的标准响应格式
func unwrapGatewayResponse(text string) string {
// 快速检查是否可能是网关响应
if !strings.Contains(text, `"choices"`) {
return ""
}
var resp gatewayResponse
if err := json.Unmarshal([]byte(text), &resp); err != nil {
return ""
}
if len(resp.Choices) == 0 {
return ""
}
content := strings.TrimSpace(resp.Choices[0].Message.Content)
return content
}
// buildMessagesFromSystemUser 从 system/user 字段构建消息列表
func buildMessagesFromSystemUser(output *modelOutput) ([]dto.Message, error) {
messages := make([]dto.Message, 0, 2)
// 添加 user 消息
if !isEmptyValue(output.User) {
messages = append(messages, dto.Message{
Role: "user",
Content: normalizeContent(output.User),
})
}
// 添加 system 消息
if !isEmptyValue(output.System) {
messages = append(messages, dto.Message{
Role: "system",
Content: normalizeContent(output.System),
})
}
if len(messages) == 0 {
return nil, errors.New("未解析到有效的 system 或 user 内容")
}
if err := validateMessageList(messages); err != nil {
return nil, err
}
return messages, nil
}
// ============================================
// 内容规范化
// ============================================
// normalizeMessageContents 规范化消息列表中的所有内容
func normalizeMessageContents(messages []dto.Message) []dto.Message {
for i := range messages {
messages[i].Content = normalizeContent(messages[i].Content)
}
return messages
}
// normalizeContent 规范化单个消息内容
// - 如果是 JSON 字符串,尝试解析为对象/数组
// - 否则保持原样
func normalizeContent(content any) any {
switch v := content.(type) {
case string:
return tryUnmarshalJSON(v)
default:
return content
}
}
// tryUnmarshalJSON 尝试将 JSON 字符串解析为结构化对象
func tryUnmarshalJSON(s string) any {
s = strings.TrimSpace(s)
if s == "" {
return s
}
// 只处理看起来像 JSON 的内容
if !looksLikeJSON(s) {
return s
}
var result any
if err := json.Unmarshal([]byte(s), &result); err != nil || result == nil {
return s
}
return result
}
// looksLikeJSON 判断字符串是否可能是 JSON
func looksLikeJSON(s string) bool {
s = strings.TrimSpace(s)
return strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[")
}
// cleanMarkdownCodeBlock 清理 Markdown 代码块标记
func cleanMarkdownCodeBlock(text string) string {
// 去除可能的 ```json 和 ``` 标记
text = strings.TrimPrefix(text, "```json")
text = strings.TrimPrefix(text, "```JSON")
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
return strings.TrimSpace(text)
}
// ============================================
// 验证
// ============================================
// validateMessageList 验证消息列表的合法性
func validateMessageList(messages []dto.Message) error {
if len(messages) == 0 {
return errors.New("消息列表不能为空")
}
hasUser := false
for i, msg := range messages {
if err := validateMessage(msg); err != nil {
return fmt.Errorf("消息[%d]验证失败: %w", i, err)
}
if msg.Role == "user" {
hasUser = true
}
}
// 至少需要一条 user 消息
if !hasUser {
return errors.New("消息列表必须包含至少一条 user 角色消息")
}
return nil
}
// validateMessage 验证单条消息的合法性
func validateMessage(msg dto.Message) error {
role := strings.TrimSpace(msg.Role)
if role == "" {
return errors.New("role 不能为空")
}
if !isValidRole(role) {
return fmt.Errorf("role 值非法: %s (仅允许 system/user/assistant)", role)
}
// user 角色的 content 不能为空
if role == "user" && isEmptyValue(msg.Content) {
return errors.New("user 角色的 content 不能为空")
}
return nil
}
// isValidRole 判断角色是否合法
func isValidRole(role string) bool {
switch role {
case "system", "user", "assistant":
return true
default:
return false
}
}
// HasUserMessage 判断消息列表中是否包含非空的 user 消息
func HasUserMessage(messages []dto.Message) bool {
for _, msg := range messages {
if msg.Role == "user" && !isEmptyValue(msg.Content) {
return true
}
}
return false
}
// HasSystemMessage 判断消息列表中是否包含非空的 system 消息
func HasSystemMessage(messages []dto.Message) bool {
for _, msg := range messages {
if msg.Role == "system" && !isEmptyValue(msg.Content) {
return true
}
}
return false
}
// ExtractUserContent 提取消息列表中第一个 user 角色的内容
func ExtractUserContent(messages []dto.Message) any {
for _, msg := range messages {
if msg.Role == "user" {
return msg.Content
}
}
return nil
}
// ExtractSystemContent 提取消息列表中第一个 system 角色的内容
func ExtractSystemContent(messages []dto.Message) any {
for _, msg := range messages {
if msg.Role == "system" {
return msg.Content
}
}
return nil
}
// ============================================
// 测试辅助函数 (可选)
// ============================================
// MockModelResponse 创建模拟的模型响应用于测试
func MockModelResponse(systemContent, userContent string) string {
output := modelOutput{
Messages: []dto.Message{
{Role: "system", Content: systemContent},
{Role: "user", Content: userContent},
},
}
bytes, _ := json.Marshal(output)
return string(bytes)
}
// MockGatewayResponse 创建模拟的网关响应用于测试
func MockGatewayResponse(innerJSON string) string {
resp := gatewayResponse{
Choices: []choice{
{
Message: message{
Content: innerJSON,
},
},
},
}
bytes, _ := json.Marshal(resp)
return string(bytes)
}