146 lines
4.5 KiB
Go
146 lines
4.5 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"gitea.com/red-future/common/beans"
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
|
|
"prompts-core/common/util"
|
|
"prompts-core/dao"
|
|
"prompts-core/model/dto"
|
|
"prompts-core/model/entity"
|
|
)
|
|
|
|
// Callback 会话回调
|
|
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
|
req.Messages["role"] = "assistant"
|
|
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
|
ResponseContent: req.Messages,
|
|
})
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
return nil, fmt.Errorf("更新数据库失败: %w", err)
|
|
}
|
|
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
|
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
|
})
|
|
if session == nil {
|
|
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
|
}
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
return nil, fmt.Errorf("获取会话数据失败: %w", err)
|
|
}
|
|
if err = saveToRedis(ctx, session); err != nil {
|
|
return nil, fmt.Errorf("redis存储失败: %w", err)
|
|
}
|
|
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
|
session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
|
|
return &dto.SessionCallbackRes{
|
|
Status: true,
|
|
SessionId: session.SessionId,
|
|
}, nil
|
|
}
|
|
|
|
// GetHistoryMessages 获取历史信息
|
|
func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) {
|
|
// 1) 获取最大轮次
|
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|
|
|
// 2) 从 Redis 获取历史记录
|
|
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
|
if err == nil && len(redisHistory) > 0 {
|
|
return redisHistory, nil
|
|
}
|
|
|
|
// 3) Redis 没有,从数据库查最新 maxRounds 条
|
|
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
|
SessionId: sessionId,
|
|
NodeId: nodeId,
|
|
}, 1, maxRounds)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
|
}
|
|
// 4) 为空返回报错
|
|
if len(sessions) == 0 {
|
|
return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId)
|
|
}
|
|
// 5) 提取为统一格式
|
|
messages := extractMessagesFromSessions(sessions)
|
|
|
|
// 6) 缓存 Redis 半小时
|
|
//_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute)
|
|
|
|
return messages, nil
|
|
}
|
|
|
|
// getHistoryFromDatabase 从数据库获取历史记录
|
|
func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) {
|
|
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
|
SessionId: sessionId,
|
|
}, 1, maxRounds)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
|
}
|
|
|
|
messages := extractMessagesFromSessions(sessions)
|
|
|
|
cacheSessionsToRedis(ctx, sessions)
|
|
|
|
return messages, nil
|
|
}
|
|
|
|
// extractMessagesFromSessions 从会话列表中提取消息
|
|
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
|
var messages []map[string]any
|
|
for _, session := range sessions {
|
|
appendRequestMessages(session.RequestContent, &messages)
|
|
appendResponseMessages(session.ResponseContent, &messages)
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// appendRequestMessages 追加请求消息
|
|
func appendRequestMessages(requestContent any, messages *[]map[string]any) {
|
|
reqMsgs := util.ConvertToMessages(requestContent)
|
|
for _, m := range reqMsgs {
|
|
role := gconv.String(m["role"])
|
|
if role == "user" || role == "assistant" {
|
|
*messages = append(*messages, m)
|
|
}
|
|
}
|
|
}
|
|
|
|
// appendResponseMessages 追加响应消息
|
|
func appendResponseMessages(responseContent any, messages *[]map[string]any) {
|
|
respMsgs := util.ConvertToMessages(responseContent)
|
|
for _, m := range respMsgs {
|
|
if m["role"] == nil {
|
|
m["role"] = "assistant"
|
|
}
|
|
*messages = append(*messages, m)
|
|
}
|
|
}
|
|
|
|
// cacheSessionsToRedis 将会话缓存到Redis
|
|
func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) {
|
|
for _, session := range sessions {
|
|
reqMsgs := util.ConvertToMessages(session.RequestContent)
|
|
respMsgs := util.ConvertToMessages(session.ResponseContent)
|
|
|
|
for i := range respMsgs {
|
|
if respMsgs[i]["role"] == nil {
|
|
respMsgs[i]["role"] = "assistant"
|
|
}
|
|
}
|
|
|
|
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
|
_ = saveToRedis(ctx, session)
|
|
}
|
|
}
|
|
}
|