From 9410199fbe33e2099580b9f22d30a296a3259c6c Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Tue, 9 Jun 2026 14:00:01 +0800 Subject: [PATCH] =?UTF-8?q?feat(session):=20=E9=87=8D=E6=9E=84=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E7=AE=A1=E7=90=86=E5=92=8CRedis=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/config.go | 21 +- config.yml | 6 +- controller/prompt_session_controller.go | 19 +- dao/compose_session_dao.go | 11 +- model/dto/prompt_session_dto.go | 38 +++- service/prompt/prompt_compose_service.go | 23 +- .../session/prompt_session_redis_service.go | 187 ++++++++------- service/session/prompt_session_service.go | 215 +++++++++++------- 8 files changed, 324 insertions(+), 196 deletions(-) diff --git a/common/util/config.go b/common/util/config.go index 5a7be25..42940ce 100644 --- a/common/util/config.go +++ b/common/util/config.go @@ -2,7 +2,6 @@ package util import ( "context" - "strings" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" @@ -13,16 +12,6 @@ func GetServerName(ctx context.Context) string { return g.Cfg().MustGet(ctx, "server.name", "").String() } -// GetServerPort 从配置获取服务端口 -func GetServerPort(ctx context.Context) string { - address := g.Cfg().MustGet(ctx, "server.address", ":8080").String() - // address 格式如 ":3009",去掉冒号 - if strings.HasPrefix(address, ":") { - return address[1:] - } - return "8080" -} - // GetModelPrompt 获取请求模型的提示词 func GetModelPrompt(ctx context.Context, modelType int) string { key := "modelPrompts.types." + gconv.String(modelType) @@ -33,3 +22,13 @@ func GetModelPrompt(ctx context.Context, modelType int) string { func GetBuildPrompt(ctx context.Context) string { return g.Cfg().MustGet(ctx, "nodePrompts", "").String() } + +// GetMaxRounds 获取最大轮数配置 +func GetMaxRounds(ctx context.Context) int { + return g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() +} + +// GetExpireMinutes 获取过期时间配置 +func GetExpireMinutes(ctx context.Context) int { + return g.Cfg().MustGet(ctx, "session.expireMinutes", 30).Int() +} diff --git a/config.yml b/config.yml index 5ca72ba..a00b2ed 100644 --- a/config.yml +++ b/config.yml @@ -50,14 +50,14 @@ database: redis: default: - address: 116.204.74.41:6379 + address: 192.168.3.30:6379 db: 0 consul: - address: 116.204.74.41:8500 + address: 192.168.3.30:8500 jaeger: - addr: 116.204.74.41:4318 + addr: 192.168.3.30:4318 task: waitTimeoutSeconds: 600 # /composeMessages 同步等待最终结果的最长时间(秒) diff --git a/controller/prompt_session_controller.go b/controller/prompt_session_controller.go index 84c3a94..2dfc535 100644 --- a/controller/prompt_session_controller.go +++ b/controller/prompt_session_controller.go @@ -1,18 +1,31 @@ +// ============================================ +// controller/session.go +// ============================================ + package controller import ( "context" - "prompts-core/model/dto" + "prompts-core/model/dto" sessionService "prompts-core/service/session" ) type session struct{} -// Session 提示词会话控制器 var Session = new(session) -// SessionCallback 会话回调 +// SessionCallback 接收会话回调通知 func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { return sessionService.Callback(ctx, req) } + +// GetHistoryMessages 获取历史消息 +func (c *session) GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (res *dto.GetHistoryMessagesRes, err error) { + return sessionService.GetHistoryMessages(ctx, req) +} + +// DeleteSession 删除会话 +func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) { + return sessionService.DeleteSession(ctx, req) +} diff --git a/dao/compose_session_dao.go b/dao/compose_session_dao.go index d58aaba..b06831c 100644 --- a/dao/compose_session_dao.go +++ b/dao/compose_session_dao.go @@ -6,7 +6,6 @@ import ( "prompts-core/model/entity" "gitea.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/util/gconv" ) var ComposeSession = &composeSessionDao{} @@ -15,13 +14,8 @@ type composeSessionDao struct{} // Insert 插入 func (d *composeSessionDao) Insert(ctx context.Context, req *entity.ComposeSession) (id int64, err error) { - var m = new(entity.ComposeSession) - err = gconv.Struct(req, &m) - if err != nil { - return - } r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). - Insert(m) + Insert(req) if err != nil { return } @@ -54,7 +48,6 @@ func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession OmitEmpty() model.Where(entity.ComposeSessionCol.Creator, req.Creator) model.Where(entity.ComposeSessionCol.SessionId, req.SessionId) - model.Where(entity.ComposeSessionCol.NodeId, req.NodeId) model.OrderDesc(entity.ComposeSessionCol.CreatedAt) model.Page(page, size) r, total, err := model.AllAndCount(false) @@ -70,6 +63,7 @@ func (d *composeSessionDao) Get(ctx context.Context, req *entity.ComposeSession, r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). OmitEmpty(). Where(entity.ComposeSessionCol.Id, req.Id). + Where(entity.ComposeSessionCol.Creator, req.Creator). Where(entity.ComposeSessionCol.SessionId, req.SessionId). Fields(fields).One() if err != nil { @@ -87,6 +81,7 @@ func (d *composeSessionDao) Delete(ctx context.Context, req *entity.ComposeSessi r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameComposeSession). OmitEmpty(). Where(entity.ComposeSessionCol.Id, req.Id). + Where(entity.ComposeSessionCol.Creator, req.Creator). Where(entity.ComposeSessionCol.SessionId, req.SessionId). Delete() if err != nil { diff --git a/model/dto/prompt_session_dto.go b/model/dto/prompt_session_dto.go index 4de558f..694bbe9 100644 --- a/model/dto/prompt_session_dto.go +++ b/model/dto/prompt_session_dto.go @@ -2,13 +2,49 @@ package dto import "github.com/gogf/gf/v2/frame/g" +// SessionCallbackReq 会话回调请求 type SessionCallbackReq struct { - g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"` + g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"` Messages map[string]any `json:"messages" dc:"消息数组"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` } +// SessionCallbackRes 会话回调响应 type SessionCallbackRes struct { Status bool `json:"status" dc:"状态"` SessionId string `json:"sessionId" dc:"会话ID"` } + +// GetHistoryMessagesReq 获取历史消息请求 +type GetHistoryMessagesReq struct { + g.Meta `path:"/history" method:"get" tags:"会话管理" summary:"获取历史消息"` + SessionId string `json:"sessionId" v:"required" dc:"会话ID"` + NodeId string `json:"nodeId" dc:"节点ID"` +} + +// GetHistoryMessagesRes 获取历史消息响应 +type GetHistoryMessagesRes struct { + Messages []HistoryRound `json:"messages" dc:"历史消息列表"` +} + +// HistoryRound 一轮对话 +type HistoryRound struct { + Id int64 `json:"id" dc:"记录ID"` + User map[string]any `json:"user" dc:"用户消息"` + Assistant map[string]any `json:"assistant" dc:"助手回复"` + CreatedAt string `json:"createdAt" dc:"创建时间"` +} + +// DeleteSessionReq 删除会话请求 +type DeleteSessionReq struct { + g.Meta `path:"/delete" method:"post" tags:"会话管理" summary:"删除会话"` + TenantId uint64 `json:"tenantId" dc:"租户ID"` + SessionId string `json:"sessionId" v:"required" dc:"会话ID"` + NodeId string `json:"nodeId" dc:"节点ID"` + MsgIds []int64 `json:"msgIds" dc:"消息ID列表,传则删单条,不传删整个会话"` +} + +// DeleteSessionRes 删除会话响应 +type DeleteSessionRes struct { + Ok bool `json:"ok" dc:"是否成功"` +} diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 929a4cc..a064eea 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -198,13 +198,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas buildType := gconv.Int(payload["buildType"]) if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" { // 4) 获取历史内容并拼接 - history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId) - for _, msg := range history { - role := gconv.String(msg["role"]) - if role != "user" && role != "assistant" { - continue - } - } + _, _ = session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ + SessionId: sessionId, + NodeId: nodeId, + }) // 5) 存储提示词结果作为历史请求 if userMsg := util.ExtractUserText(messages); userMsg != nil { epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ @@ -261,11 +258,15 @@ func parseMessagesForResponse(messages any) any { } func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) { - // 1) 获取基础数据 - // 4) 模拟历史拼接 - history, _ := session.GetHistoryMessages(ctx, "88888888", "node1") + history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ + SessionId: "88888888", + NodeId: "node1", + }) + if err != nil { + return nil, err + } return &dto.GetPromptTextRes{ - Messages: history, + Messages: history.Messages, }, nil } diff --git a/service/session/prompt_session_redis_service.go b/service/session/prompt_session_redis_service.go index bec6553..28a5310 100644 --- a/service/session/prompt_session_redis_service.go +++ b/service/session/prompt_session_redis_service.go @@ -4,134 +4,165 @@ import ( "context" "encoding/json" "fmt" - "prompts-core/model/entity" + "prompts-core/common/util" + "prompts-core/model/dto" "time" "github.com/gogf/gf/v2/frame/g" ) const ( - redisKeyPrefix = "chat:session:%s" + // RedisKeySessionHistory 会话历史缓存 key: session:history:{tenantId}:{sessionId} + RedisKeySessionHistory = "session:history:%d:%s" ) -// formatRedisKey 格式化Redis键 -func formatRedisKey(sessionId string) string { - return fmt.Sprintf(redisKeyPrefix, sessionId) +// formatRedisKey 格式化 Redis key +func formatRedisKey(tenantID uint64, sessionID string) string { + return fmt.Sprintf(RedisKeySessionHistory, tenantID, sessionID) } -// saveToRedis 保存会话数据到Redis -func saveToRedis(ctx context.Context, session *entity.ComposeSession) error { - key := formatRedisKey(session.SessionId) - maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() - expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64() - data := map[string]any{ - "sessionId": session.SessionId, - "requestContent": session.RequestContent, - "responseContent": session.ResponseContent, - "timestamp": time.Now().Unix(), - } - b, err := json.Marshal(data) +// ============================================ +// 写操作 +// ============================================ + +// SaveToRedis 保存一轮对话到 Redis ZSET +func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *dto.HistoryRound) error { + key := formatRedisKey(tenantID, sessionID) + maxRounds := util.GetMaxRounds(ctx) + expireSeconds := int64(util.GetExpireMinutes(ctx) * 60) + + b, err := json.Marshal(round) if err != nil { return fmt.Errorf("序列化会话数据失败: %w", err) } - if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { - return err + + score := float64(time.Now().UnixMilli()) + + if _, err = g.Redis().Do(ctx, "ZADD", key, score, string(b)); err != nil { + return fmt.Errorf("ZADD失败: %w", err) } + if _, err = g.Redis().Do(ctx, "ZREMRANGEBYRANK", key, 0, -(maxRounds + 1)); err != nil { + return fmt.Errorf("裁剪失败: %w", err) + } + if _, err = g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { + return fmt.Errorf("设置过期失败: %w", err) + } + return nil } -// executeRedisCommands 执行Redis命令 -func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error { - if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil { - return fmt.Errorf("写入Redis失败: %w", err) +// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID) +func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error { + key := formatRedisKey(tenantID, sessionID) + + cursor := "0" + for { + result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10) + if err != nil { + return fmt.Errorf("ZSCAN失败: %w", err) + } + + parts := result.Strings() + if len(parts) < 2 { + break + } + + cursor = parts[0] + members := parts[1:] + + for _, member := range members { + if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil { + g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err) + } + } + + if cursor == "0" { + break + } } - if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil { - return fmt.Errorf("裁剪Redis列表失败: %w", err) - } - if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil { - return fmt.Errorf("设置过期时间失败: %w", err) - } return nil } -// getFromRedis 从Redis获取会话历史 -func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) { - key := formatRedisKey(sessionId) - result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1) +// DeleteSessionHistory 删除整个会话的 Redis 缓存 +func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string) error { + key := formatRedisKey(tenantID, sessionID) + _, err := g.Redis().Do(ctx, "DEL", key) + return err +} + +// ============================================ +// 读操作 +// ============================================ + +// GetFromRedis 从 Redis ZSET 获取会话历史 +func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { + key := formatRedisKey(tenantID, sessionID) + maxRounds := util.GetMaxRounds(ctx) + + result, err := g.Redis().Do(ctx, "ZREVRANGE", key, 0, maxRounds-1) if err != nil { - return nil, fmt.Errorf("从Redis获取数据失败: %w", err) + return nil, fmt.Errorf("ZREVRANGE失败: %w", err) } if result == nil || result.IsNil() { return []map[string]any{}, nil } - sessions := parseRedisSessions(ctx, result.Strings()) - - reverseSlice(sessions) - - return sessions, nil + return parseRedisRounds(ctx, result.Strings()), nil } -// parseRedisSessions 解析Redis会话数据 -func parseRedisSessions(ctx context.Context, values []string) []map[string]any { - var sessions []map[string]any - - for _, str := range values { - var data map[string]any - if err := json.Unmarshal([]byte(str), &data); err != nil { - g.Log().Warningf(ctx, "[会话] 解析Redis数据失败 err=%v", err) - continue - } - sessions = append(sessions, data) - } - - return sessions -} - -// reverseSlice 反转切片 -func reverseSlice(s []map[string]any) { - for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { - s[i], s[j] = s[j], s[i] - } -} - -// GetSessionHistoryForInference 获取历史会话,返回扁平消息数组(给推理用) -func GetSessionHistoryForInference(ctx context.Context, sessionId string) ([]map[string]any, error) { - historyData, err := getFromRedis(ctx, sessionId) +// GetSessionHistoryForInference 获取扁平消息数组(给推理用) +func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) { + rounds, err := GetFromRedis(ctx, tenantID, sessionID) if err != nil { return nil, fmt.Errorf("获取历史会话失败: %w", err) } - if len(historyData) == 0 { + if len(rounds) == 0 { return []map[string]any{}, nil } - return flattenHistoryMessages(historyData), nil + return flattenRounds(rounds), nil } -// flattenHistoryMessages 扁平化历史消息 -func flattenHistoryMessages(historyData []map[string]any) []map[string]any { - var messages []map[string]any +// ============================================ +// 解析 +// ============================================ - for _, round := range historyData { - appendMessagesFromField(round, "requestContent", &messages) - appendMessagesFromField(round, "responseContent", &messages) +func parseRedisRounds(ctx context.Context, members []string) []map[string]any { + rounds := make([]map[string]any, 0, len(members)) + for _, member := range members { + var data map[string]any + if err := json.Unmarshal([]byte(member), &data); err != nil { + g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err) + continue + } + rounds = append(rounds, data) } + return rounds +} +func flattenRounds(rounds []map[string]any) []map[string]any { + var messages []map[string]any + for i := len(rounds) - 1; i >= 0; i-- { + if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 { + messages = append(messages, user) + } + if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 { + messages = append(messages, assistant) + } + } return messages } -// appendMessagesFromField 从指定字段追加消息 -func appendMessagesFromField(data map[string]any, field string, messages *[]map[string]any) { - msgs, ok := data[field].([]interface{}) +func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) { + msgs, ok := data[field].([]any) if !ok { return } - for _, m := range msgs { - if msg, ok := m.(map[string]interface{}); ok { + if msg, ok := m.(map[string]any); ok { *messages = append(*messages, msg) } } diff --git a/service/session/prompt_session_service.go b/service/session/prompt_session_service.go index 4e2c471..7414936 100644 --- a/service/session/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -5,6 +5,7 @@ import ( "fmt" "gitea.com/red-future/common/beans" + "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" @@ -17,6 +18,7 @@ import ( // Callback 会话回调 func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { req.Messages["role"] = "assistant" + // 1) 更新 DB _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, ResponseContent: req.Messages, @@ -25,121 +27,172 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, fmt.Errorf("更新数据库失败: %w", err) } + + // 2) 查询完整记录 session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{ SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, }) - if session == nil { + if err != nil || session == nil { return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId) } - if err != nil { - g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err) - return nil, fmt.Errorf("获取会话数据失败: %w", err) - } - if err = saveToRedis(ctx, session); err != nil { + + // 3) 写入 Redis + if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{ + Id: session.Id, + User: session.RequestContent, + Assistant: req.Messages, + CreatedAt: gconv.String(session.CreatedAt), + }); err != nil { return nil, fmt.Errorf("redis存储失败: %w", err) } - g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", - session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent)) - return &dto.SessionCallbackRes{ - Status: true, - SessionId: session.SessionId, - }, nil + + // 4) 返回 + g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id) + return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil } -// GetHistoryMessages 获取历史信息 -func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) { - // 1) 获取最大轮次 - maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() - - // 2) 从 Redis 获取历史记录 - redisHistory, err := GetSessionHistoryForInference(ctx, sessionId) - if err == nil && len(redisHistory) > 0 { - return redisHistory, nil +// GetHistoryMessages 获取历史消息 +func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err } - // 3) Redis 没有,从数据库查最新 maxRounds 条 + // 1) Redis + redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId) + if err == nil && len(redisRounds) > 0 { + g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds)) + return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil + } + + // 2) DB + maxRounds := util.GetMaxRounds(ctx) sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ - SessionId: sessionId, - NodeId: nodeId, + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + SessionId: req.SessionId, }, 1, maxRounds) if err != nil { return nil, fmt.Errorf("DB获取历史失败: %w", err) } - // 4) 为空返回报错 if len(sessions) == 0 { - return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId) - } - // 5) 提取为统一格式 - messages := extractMessagesFromSessions(sessions) - - // 6) 缓存 Redis 半小时 - //_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute) - - return messages, nil -} - -// getHistoryFromDatabase 从数据库获取历史记录 -func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int) ([]map[string]any, error) { - sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{ - SessionId: sessionId, - }, 1, maxRounds) - if err != nil { - return nil, fmt.Errorf("DB获取历史失败: %w", err) + return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil } - messages := extractMessagesFromSessions(sessions) + // 3) 转换 + 异步回种 + rounds := sessionsToHistoryRounds(sessions) + go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions) - cacheSessionsToRedis(ctx, sessions) - - return messages, nil + return &dto.GetHistoryMessagesRes{Messages: rounds}, nil } -// extractMessagesFromSessions 从会话列表中提取消息 +// parseHistoryRounds Redis 数据转为 HistoryRound +func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound { + rounds := make([]dto.HistoryRound, 0, len(redisRounds)) + for _, r := range redisRounds { + round := dto.HistoryRound{ + Id: gconv.Int64(r["id"]), + CreatedAt: gconv.String(r["createdAt"]), + } + if user, ok := r["user"].(map[string]any); ok { + round.User = user + } + if assistant, ok := r["assistant"].(map[string]any); ok { + round.Assistant = assistant + } + rounds = append(rounds, round) + } + return rounds +} + +// sessionsToHistoryRounds DB 数据转为 HistoryRound +func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound { + rounds := make([]dto.HistoryRound, 0, len(sessions)) + for _, s := range sessions { + reqMsgs := util.ConvertToMessages(s.RequestContent) + respMsgs := util.ConvertToMessages(s.ResponseContent) + + round := dto.HistoryRound{ + Id: s.Id, + CreatedAt: gconv.String(s.CreatedAt), + } + if len(reqMsgs) > 0 { + round.User = reqMsgs[0] + } + if len(respMsgs) > 0 { + if respMsgs[0]["role"] == nil { + respMsgs[0]["role"] = "assistant" + } + round.Assistant = respMsgs[0] + } + rounds = append(rounds, round) + } + return rounds +} + +// DeleteSession 删除会话 +func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) { + hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0 + + deleteReq := &entity.ComposeSession{ + SessionId: req.SessionId, + NodeId: req.NodeId, + } + if hasMsgID { + deleteReq.Id = req.MsgIds[0] + } + + if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil { + return nil, fmt.Errorf("DB删除失败: %w", err) + } + + if hasMsgID { + if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil { + g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err) + } + } else { + if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil { + g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err) + } + } + + return &dto.DeleteSessionRes{Ok: true}, nil +} + +// ============================================ +// 内部方法 +// ============================================ + func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any { var messages []map[string]any - for _, session := range sessions { - appendRequestMessages(session.RequestContent, &messages) - appendResponseMessages(session.ResponseContent, &messages) + for i := len(sessions) - 1; i >= 0; i-- { + appendRoleMessages(sessions[i].RequestContent, "user", &messages) + appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages) } return messages } -// appendRequestMessages 追加请求消息 -func appendRequestMessages(requestContent any, messages *[]map[string]any) { - reqMsgs := util.ConvertToMessages(requestContent) - for _, m := range reqMsgs { - role := gconv.String(m["role"]) - if role == "user" || role == "assistant" { - *messages = append(*messages, m) - } - } -} - -// appendResponseMessages 追加响应消息 -func appendResponseMessages(responseContent any, messages *[]map[string]any) { - respMsgs := util.ConvertToMessages(responseContent) - for _, m := range respMsgs { - if m["role"] == nil { - m["role"] = "assistant" +func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) { + msgs := util.ConvertToMessages(content) + for _, m := range msgs { + if m["role"] == nil || gconv.String(m["role"]) == "" { + m["role"] = defaultRole } *messages = append(*messages, m) } } -// cacheSessionsToRedis 将会话缓存到Redis -func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession) { - for _, session := range sessions { - reqMsgs := util.ConvertToMessages(session.RequestContent) - respMsgs := util.ConvertToMessages(session.ResponseContent) - - for i := range respMsgs { - if respMsgs[i]["role"] == nil { - respMsgs[i]["role"] = "assistant" - } - } - +// asyncCacheToRedis 异步缓存会话数据到 Redis +func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) { + for _, s := range sessions { + reqMsgs := util.ConvertToMessages(s.RequestContent) + respMsgs := util.ConvertToMessages(s.ResponseContent) if len(reqMsgs) > 0 || len(respMsgs) > 0 { - _ = saveToRedis(ctx, session) + _ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{ + Id: s.Id, + User: s.RequestContent, + Assistant: s.ResponseContent, + CreatedAt: gconv.String(s.CreatedAt), + }) } } }