Files
prompts-core/service/session_redis_service.go
2026-05-12 13:59:15 +08:00

182 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/gogf/gf/v2/frame/g"
)
// Message 消息结构content 支持 string 或 []string
type Message struct {
Role string `json:"role"` // user / assistant / system
Content any `json:"content"` // 内容string 或 []string
Type string `json:"type,omitempty"` // text / file可选扩展
}
// GetContentString 获取 Content 的字符串形式
func (m Message) GetContentString() string {
switch v := m.Content.(type) {
case string:
return v
case []interface{}:
var parts []string
for _, item := range v {
if s, ok := item.(string); ok {
parts = append(parts, s)
}
}
return strings.Join(parts, "\n")
default:
b, _ := json.Marshal(m.Content)
return string(b)
}
}
// SessionRoundData Redis存储的单轮会话数据
type SessionRoundData struct {
SessionId string `json:"sessionId"` // 会话ID
RequestContent []Message `json:"requestContent"` // 用户请求会话
ResponseContent []Message `json:"responseContent"` // AI回调会话
Timestamp int64 `json:"timestamp"` // 存入时间戳
}
// GetSessionHistory 获取多轮会话历史(供推理时使用)
func (s *sessionService) GetSessionHistory(ctx context.Context, sessionId string) ([]SessionRoundData, error) {
return s.getFromRedis(ctx, sessionId)
}
// BuildMessages 根据Redis历史构建完整的Messages数组
func (s *sessionService) BuildMessages(ctx context.Context, sessionId string, currentMessages []Message) ([]Message, error) {
// 获取历史会话
history, err := s.getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
var allMessages []Message
// 按时间顺序拼接历史消息
for _, round := range history {
allMessages = append(allMessages, round.RequestContent...)
allMessages = append(allMessages, round.ResponseContent...)
}
// 添加当前轮次的请求消息
allMessages = append(allMessages, currentMessages...)
return allMessages, nil
}
// ==================== Redis 操作 ====================
// saveToRedis 保存会话数据到Redis
// sessionId: 会话ID作为key
// 最大10轮超出替换最早的过期时间30分钟
func (s *sessionService) saveToRedis(ctx context.Context, sessionId string, requestMessages []Message, responseMessages []Message) error {
key := fmt.Sprintf("chat:session:%s", sessionId)
// 从配置读取,提供默认值
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
expireTime := time.Duration(expireSeconds) * time.Second
// 构造存储数据
data := SessionRoundData{
SessionId: sessionId,
RequestContent: requestMessages,
ResponseContent: responseMessages,
Timestamp: time.Now().Unix(),
}
// 序列化
b, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err)
}
// 写入 RedisLPUSH 添加到最前面,新的在前)
_, err = g.Redis().Do(ctx, "LPUSH", key, string(b))
if err != nil {
return fmt.Errorf("写入Redis失败: %w", err)
}
// 裁剪到最新10轮保留前10条
_, err = g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1)
if err != nil {
return fmt.Errorf("裁剪Redis列表失败: %w", err)
}
// 重置过期时间
_, err = g.Redis().Do(ctx, "EXPIRE", key, int64(expireTime.Seconds()))
if err != nil {
return fmt.Errorf("设置过期时间失败: %w", err)
}
return nil
}
// getFromRedis 从Redis获取会话历史
func (s *sessionService) getFromRedis(ctx context.Context, sessionId string) ([]SessionRoundData, error) {
key := fmt.Sprintf("chat:session:%s", sessionId)
// 获取列表中所有数据最多10条
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
if err != nil {
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
}
if result == nil || result.IsNil() {
return []SessionRoundData{}, nil
}
// 解析数据
var sessions []SessionRoundData
// 将结果转换为字符串数组
values := result.Strings()
for _, str := range values {
var data SessionRoundData
if err := json.Unmarshal([]byte(str), &data); err != nil {
g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err)
continue
}
sessions = append(sessions, data)
}
// 反转顺序Redis存储最新在前使用时按时间正序
for i, j := 0, len(sessions)-1; i < j; i, j = i+1, j-1 {
sessions[i], sessions[j] = sessions[j], sessions[i]
}
return sessions, nil
}
// GetSessionHistoryForInference 获取历史会话直接返回Message数组给推理用
func (s *sessionService) GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]Message, error) {
// 从Redis获取历史会话数据
historyData, err := s.getFromRedis(ctx, sessionId)
if err != nil {
return nil, fmt.Errorf("获取历史会话失败: %w", err)
}
// 如果没有任何历史数据,返回空
if len(historyData) == 0 {
return []Message{}, nil
}
// 把SessionRoundData转换成扁平的Message数组
var messages []Message
for _, round := range historyData {
// 先加用户的请求
messages = append(messages, round.RequestContent...)
// 再加AI的回答
messages = append(messages, round.ResponseContent...)
}
return messages, nil
}