363 lines
8.4 KiB
Go
363 lines
8.4 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"prompts-core/model/dto"
|
|
)
|
|
|
|
// ============================================
|
|
// 类型定义
|
|
// ============================================
|
|
|
|
// modelOutput 推理模型的标准输出格式
|
|
type modelOutput struct {
|
|
Messages []dto.Message `json:"messages"`
|
|
System any `json:"system"`
|
|
User any `json:"user"`
|
|
}
|
|
|
|
// gatewayResponse 模型网关的标准响应格式
|
|
type gatewayResponse struct {
|
|
Choices []choice `json:"choices"`
|
|
}
|
|
|
|
type choice struct {
|
|
Message message `json:"message"`
|
|
}
|
|
|
|
type message struct {
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// ============================================
|
|
// 核心解析函数
|
|
// ============================================
|
|
|
|
// ParseModelResponse 解析推理模型的文本响应,返回消息列表
|
|
// 支持三种格式:
|
|
// 1. 标准 messages 格式: {"messages": [...]}
|
|
// 2. 简化 system/user 格式: {"system": "...", "user": "..."}
|
|
// 3. 网关包装格式: {"choices": [{"message": {"content": "..."}}]}
|
|
func ParseModelResponse(text string) ([]dto.Message, error) {
|
|
text = strings.TrimSpace(text)
|
|
if text == "" {
|
|
return nil, errors.New("模型响应为空")
|
|
}
|
|
|
|
// 1. 尝试解包网关响应
|
|
if content := unwrapGatewayResponse(text); content != "" {
|
|
text = content
|
|
}
|
|
|
|
// 2. 解析为标准格式
|
|
output, err := parseAsModelOutput(text)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
|
}
|
|
|
|
// 3. 优先使用 messages 字段
|
|
if len(output.Messages) > 0 {
|
|
messages := normalizeMessageContents(output.Messages)
|
|
if err := validateMessageList(messages); err != nil {
|
|
return nil, err
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
// 4. 兼容 system/user 格式
|
|
return buildMessagesFromSystemUser(output)
|
|
}
|
|
|
|
// ParseStoredMessages 从数据库存储的数据中解析消息列表
|
|
func ParseStoredMessages(data any) []dto.Message {
|
|
if data == nil {
|
|
return nil
|
|
}
|
|
|
|
// 统一序列化
|
|
jsonBytes, err := json.Marshal(data)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// 尝试直接解析
|
|
var messages []dto.Message
|
|
if err := json.Unmarshal(jsonBytes, &messages); err == nil {
|
|
return messages
|
|
}
|
|
|
|
// 尝试解析为 JSON 字符串再解析
|
|
var jsonStr string
|
|
if err := json.Unmarshal(jsonBytes, &jsonStr); err != nil {
|
|
return nil
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(jsonStr), &messages); err != nil {
|
|
return nil
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// ============================================
|
|
// 内部解析函数
|
|
// ============================================
|
|
|
|
// parseAsModelOutput 将文本解析为 modelOutput 结构
|
|
func parseAsModelOutput(text string) (*modelOutput, error) {
|
|
// 清理可能的 Markdown 代码块标记
|
|
text = cleanMarkdownCodeBlock(text)
|
|
|
|
var output modelOutput
|
|
if err := json.Unmarshal([]byte(text), &output); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &output, nil
|
|
}
|
|
|
|
// unwrapGatewayResponse 解包网关的标准响应格式
|
|
func unwrapGatewayResponse(text string) string {
|
|
// 快速检查是否可能是网关响应
|
|
if !strings.Contains(text, `"choices"`) {
|
|
return ""
|
|
}
|
|
|
|
var resp gatewayResponse
|
|
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
|
return ""
|
|
}
|
|
|
|
if len(resp.Choices) == 0 {
|
|
return ""
|
|
}
|
|
|
|
content := strings.TrimSpace(resp.Choices[0].Message.Content)
|
|
return content
|
|
}
|
|
|
|
// buildMessagesFromSystemUser 从 system/user 字段构建消息列表
|
|
func buildMessagesFromSystemUser(output *modelOutput) ([]dto.Message, error) {
|
|
messages := make([]dto.Message, 0, 2)
|
|
|
|
// 添加 user 消息
|
|
if !isEmptyValue(output.User) {
|
|
messages = append(messages, dto.Message{
|
|
Role: "user",
|
|
Content: normalizeContent(output.User),
|
|
})
|
|
}
|
|
|
|
// 添加 system 消息
|
|
if !isEmptyValue(output.System) {
|
|
messages = append(messages, dto.Message{
|
|
Role: "system",
|
|
Content: normalizeContent(output.System),
|
|
})
|
|
}
|
|
|
|
if len(messages) == 0 {
|
|
return nil, errors.New("未解析到有效的 system 或 user 内容")
|
|
}
|
|
|
|
if err := validateMessageList(messages); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return messages, nil
|
|
}
|
|
|
|
// ============================================
|
|
// 内容规范化
|
|
// ============================================
|
|
|
|
// normalizeMessageContents 规范化消息列表中的所有内容
|
|
func normalizeMessageContents(messages []dto.Message) []dto.Message {
|
|
for i := range messages {
|
|
messages[i].Content = normalizeContent(messages[i].Content)
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// normalizeContent 规范化单个消息内容
|
|
// - 如果是 JSON 字符串,尝试解析为对象/数组
|
|
// - 否则保持原样
|
|
func normalizeContent(content any) any {
|
|
switch v := content.(type) {
|
|
case string:
|
|
return tryUnmarshalJSON(v)
|
|
default:
|
|
return content
|
|
}
|
|
}
|
|
|
|
// tryUnmarshalJSON 尝试将 JSON 字符串解析为结构化对象
|
|
func tryUnmarshalJSON(s string) any {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
return s
|
|
}
|
|
|
|
// 只处理看起来像 JSON 的内容
|
|
if !looksLikeJSON(s) {
|
|
return s
|
|
}
|
|
|
|
var result any
|
|
if err := json.Unmarshal([]byte(s), &result); err != nil || result == nil {
|
|
return s
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// looksLikeJSON 判断字符串是否可能是 JSON
|
|
func looksLikeJSON(s string) bool {
|
|
s = strings.TrimSpace(s)
|
|
return strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[")
|
|
}
|
|
|
|
// cleanMarkdownCodeBlock 清理 Markdown 代码块标记
|
|
func cleanMarkdownCodeBlock(text string) string {
|
|
// 去除可能的 ```json 和 ``` 标记
|
|
text = strings.TrimPrefix(text, "```json")
|
|
text = strings.TrimPrefix(text, "```JSON")
|
|
text = strings.TrimPrefix(text, "```")
|
|
text = strings.TrimSuffix(text, "```")
|
|
return strings.TrimSpace(text)
|
|
}
|
|
|
|
// ============================================
|
|
// 验证
|
|
// ============================================
|
|
|
|
// validateMessageList 验证消息列表的合法性
|
|
func validateMessageList(messages []dto.Message) error {
|
|
if len(messages) == 0 {
|
|
return errors.New("消息列表不能为空")
|
|
}
|
|
|
|
hasUser := false
|
|
for i, msg := range messages {
|
|
if err := validateMessage(msg); err != nil {
|
|
return fmt.Errorf("消息[%d]验证失败: %w", i, err)
|
|
}
|
|
if msg.Role == "user" {
|
|
hasUser = true
|
|
}
|
|
}
|
|
|
|
// 至少需要一条 user 消息
|
|
if !hasUser {
|
|
return errors.New("消息列表必须包含至少一条 user 角色消息")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateMessage 验证单条消息的合法性
|
|
func validateMessage(msg dto.Message) error {
|
|
role := strings.TrimSpace(msg.Role)
|
|
if role == "" {
|
|
return errors.New("role 不能为空")
|
|
}
|
|
|
|
if !isValidRole(role) {
|
|
return fmt.Errorf("role 值非法: %s (仅允许 system/user/assistant)", role)
|
|
}
|
|
|
|
// user 角色的 content 不能为空
|
|
if role == "user" && isEmptyValue(msg.Content) {
|
|
return errors.New("user 角色的 content 不能为空")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isValidRole 判断角色是否合法
|
|
func isValidRole(role string) bool {
|
|
switch role {
|
|
case "system", "user", "assistant":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// HasUserMessage 判断消息列表中是否包含非空的 user 消息
|
|
func HasUserMessage(messages []dto.Message) bool {
|
|
for _, msg := range messages {
|
|
if msg.Role == "user" && !isEmptyValue(msg.Content) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HasSystemMessage 判断消息列表中是否包含非空的 system 消息
|
|
func HasSystemMessage(messages []dto.Message) bool {
|
|
for _, msg := range messages {
|
|
if msg.Role == "system" && !isEmptyValue(msg.Content) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ExtractUserContent 提取消息列表中第一个 user 角色的内容
|
|
func ExtractUserContent(messages []dto.Message) any {
|
|
for _, msg := range messages {
|
|
if msg.Role == "user" {
|
|
return msg.Content
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ExtractSystemContent 提取消息列表中第一个 system 角色的内容
|
|
func ExtractSystemContent(messages []dto.Message) any {
|
|
for _, msg := range messages {
|
|
if msg.Role == "system" {
|
|
return msg.Content
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ============================================
|
|
// 测试辅助函数 (可选)
|
|
// ============================================
|
|
|
|
// MockModelResponse 创建模拟的模型响应用于测试
|
|
func MockModelResponse(systemContent, userContent string) string {
|
|
output := modelOutput{
|
|
Messages: []dto.Message{
|
|
{Role: "system", Content: systemContent},
|
|
{Role: "user", Content: userContent},
|
|
},
|
|
}
|
|
|
|
bytes, _ := json.Marshal(output)
|
|
return string(bytes)
|
|
}
|
|
|
|
// MockGatewayResponse 创建模拟的网关响应用于测试
|
|
func MockGatewayResponse(innerJSON string) string {
|
|
resp := gatewayResponse{
|
|
Choices: []choice{
|
|
{
|
|
Message: message{
|
|
Content: innerJSON,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
bytes, _ := json.Marshal(resp)
|
|
return string(bytes)
|
|
}
|