prompts-core
This commit is contained in:
362
service/compose_parser.go
Normal file
362
service/compose_parser.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"prompts-core/model/dto"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 类型定义
|
||||
// ============================================
|
||||
|
||||
// modelOutput 推理模型的标准输出格式
|
||||
type modelOutput struct {
|
||||
Messages []dto.Message `json:"messages"`
|
||||
System any `json:"system"`
|
||||
User any `json:"user"`
|
||||
}
|
||||
|
||||
// gatewayResponse 模型网关的标准响应格式
|
||||
type gatewayResponse struct {
|
||||
Choices []choice `json:"choices"`
|
||||
}
|
||||
|
||||
type choice struct {
|
||||
Message message `json:"message"`
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 核心解析函数
|
||||
// ============================================
|
||||
|
||||
// ParseModelResponse 解析推理模型的文本响应,返回消息列表
|
||||
// 支持三种格式:
|
||||
// 1. 标准 messages 格式: {"messages": [...]}
|
||||
// 2. 简化 system/user 格式: {"system": "...", "user": "..."}
|
||||
// 3. 网关包装格式: {"choices": [{"message": {"content": "..."}}]}
|
||||
func ParseModelResponse(text string) ([]dto.Message, error) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil, errors.New("模型响应为空")
|
||||
}
|
||||
|
||||
// 1. 尝试解包网关响应
|
||||
if content := unwrapGatewayResponse(text); content != "" {
|
||||
text = content
|
||||
}
|
||||
|
||||
// 2. 解析为标准格式
|
||||
output, err := parseAsModelOutput(text)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 优先使用 messages 字段
|
||||
if len(output.Messages) > 0 {
|
||||
messages := normalizeMessageContents(output.Messages)
|
||||
if err := validateMessageList(messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// 4. 兼容 system/user 格式
|
||||
return buildMessagesFromSystemUser(output)
|
||||
}
|
||||
|
||||
// ParseStoredMessages 从数据库存储的数据中解析消息列表
|
||||
func ParseStoredMessages(data any) []dto.Message {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 统一序列化
|
||||
jsonBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试直接解析
|
||||
var messages []dto.Message
|
||||
if err := json.Unmarshal(jsonBytes, &messages); err == nil {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 尝试解析为 JSON 字符串再解析
|
||||
var jsonStr string
|
||||
if err := json.Unmarshal(jsonBytes, &jsonStr); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &messages); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 内部解析函数
|
||||
// ============================================
|
||||
|
||||
// parseAsModelOutput 将文本解析为 modelOutput 结构
|
||||
func parseAsModelOutput(text string) (*modelOutput, error) {
|
||||
// 清理可能的 Markdown 代码块标记
|
||||
text = cleanMarkdownCodeBlock(text)
|
||||
|
||||
var output modelOutput
|
||||
if err := json.Unmarshal([]byte(text), &output); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &output, nil
|
||||
}
|
||||
|
||||
// unwrapGatewayResponse 解包网关的标准响应格式
|
||||
func unwrapGatewayResponse(text string) string {
|
||||
// 快速检查是否可能是网关响应
|
||||
if !strings.Contains(text, `"choices"`) {
|
||||
return ""
|
||||
}
|
||||
|
||||
var resp gatewayResponse
|
||||
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||||
return content
|
||||
}
|
||||
|
||||
// buildMessagesFromSystemUser 从 system/user 字段构建消息列表
|
||||
func buildMessagesFromSystemUser(output *modelOutput) ([]dto.Message, error) {
|
||||
messages := make([]dto.Message, 0, 2)
|
||||
|
||||
// 添加 user 消息
|
||||
if !isEmptyValue(output.User) {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: "user",
|
||||
Content: normalizeContent(output.User),
|
||||
})
|
||||
}
|
||||
|
||||
// 添加 system 消息
|
||||
if !isEmptyValue(output.System) {
|
||||
messages = append(messages, dto.Message{
|
||||
Role: "system",
|
||||
Content: normalizeContent(output.System),
|
||||
})
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return nil, errors.New("未解析到有效的 system 或 user 内容")
|
||||
}
|
||||
|
||||
if err := validateMessageList(messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 内容规范化
|
||||
// ============================================
|
||||
|
||||
// normalizeMessageContents 规范化消息列表中的所有内容
|
||||
func normalizeMessageContents(messages []dto.Message) []dto.Message {
|
||||
for i := range messages {
|
||||
messages[i].Content = normalizeContent(messages[i].Content)
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// normalizeContent 规范化单个消息内容
|
||||
// - 如果是 JSON 字符串,尝试解析为对象/数组
|
||||
// - 否则保持原样
|
||||
func normalizeContent(content any) any {
|
||||
switch v := content.(type) {
|
||||
case string:
|
||||
return tryUnmarshalJSON(v)
|
||||
default:
|
||||
return content
|
||||
}
|
||||
}
|
||||
|
||||
// tryUnmarshalJSON 尝试将 JSON 字符串解析为结构化对象
|
||||
func tryUnmarshalJSON(s string) any {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
|
||||
// 只处理看起来像 JSON 的内容
|
||||
if !looksLikeJSON(s) {
|
||||
return s
|
||||
}
|
||||
|
||||
var result any
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil || result == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// looksLikeJSON 判断字符串是否可能是 JSON
|
||||
func looksLikeJSON(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
return strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[")
|
||||
}
|
||||
|
||||
// cleanMarkdownCodeBlock 清理 Markdown 代码块标记
|
||||
func cleanMarkdownCodeBlock(text string) string {
|
||||
// 去除可能的 ```json 和 ``` 标记
|
||||
text = strings.TrimPrefix(text, "```json")
|
||||
text = strings.TrimPrefix(text, "```JSON")
|
||||
text = strings.TrimPrefix(text, "```")
|
||||
text = strings.TrimSuffix(text, "```")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 验证
|
||||
// ============================================
|
||||
|
||||
// validateMessageList 验证消息列表的合法性
|
||||
func validateMessageList(messages []dto.Message) error {
|
||||
if len(messages) == 0 {
|
||||
return errors.New("消息列表不能为空")
|
||||
}
|
||||
|
||||
hasUser := false
|
||||
for i, msg := range messages {
|
||||
if err := validateMessage(msg); err != nil {
|
||||
return fmt.Errorf("消息[%d]验证失败: %w", i, err)
|
||||
}
|
||||
if msg.Role == "user" {
|
||||
hasUser = true
|
||||
}
|
||||
}
|
||||
|
||||
// 至少需要一条 user 消息
|
||||
if !hasUser {
|
||||
return errors.New("消息列表必须包含至少一条 user 角色消息")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateMessage 验证单条消息的合法性
|
||||
func validateMessage(msg dto.Message) error {
|
||||
role := strings.TrimSpace(msg.Role)
|
||||
if role == "" {
|
||||
return errors.New("role 不能为空")
|
||||
}
|
||||
|
||||
if !isValidRole(role) {
|
||||
return fmt.Errorf("role 值非法: %s (仅允许 system/user/assistant)", role)
|
||||
}
|
||||
|
||||
// user 角色的 content 不能为空
|
||||
if role == "user" && isEmptyValue(msg.Content) {
|
||||
return errors.New("user 角色的 content 不能为空")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidRole 判断角色是否合法
|
||||
func isValidRole(role string) bool {
|
||||
switch role {
|
||||
case "system", "user", "assistant":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// HasUserMessage 判断消息列表中是否包含非空的 user 消息
|
||||
func HasUserMessage(messages []dto.Message) bool {
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "user" && !isEmptyValue(msg.Content) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasSystemMessage 判断消息列表中是否包含非空的 system 消息
|
||||
func HasSystemMessage(messages []dto.Message) bool {
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" && !isEmptyValue(msg.Content) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractUserContent 提取消息列表中第一个 user 角色的内容
|
||||
func ExtractUserContent(messages []dto.Message) any {
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "user" {
|
||||
return msg.Content
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtractSystemContent 提取消息列表中第一个 system 角色的内容
|
||||
func ExtractSystemContent(messages []dto.Message) any {
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
return msg.Content
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 测试辅助函数 (可选)
|
||||
// ============================================
|
||||
|
||||
// MockModelResponse 创建模拟的模型响应用于测试
|
||||
func MockModelResponse(systemContent, userContent string) string {
|
||||
output := modelOutput{
|
||||
Messages: []dto.Message{
|
||||
{Role: "system", Content: systemContent},
|
||||
{Role: "user", Content: userContent},
|
||||
},
|
||||
}
|
||||
|
||||
bytes, _ := json.Marshal(output)
|
||||
return string(bytes)
|
||||
}
|
||||
|
||||
// MockGatewayResponse 创建模拟的网关响应用于测试
|
||||
func MockGatewayResponse(innerJSON string) string {
|
||||
resp := gatewayResponse{
|
||||
Choices: []choice{
|
||||
{
|
||||
Message: message{
|
||||
Content: innerJSON,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bytes, _ := json.Marshal(resp)
|
||||
return string(bytes)
|
||||
}
|
||||
Reference in New Issue
Block a user