feat(session): 重构会话管理和消息存储功能
This commit is contained in:
@@ -15,17 +15,22 @@ type session struct{}
|
|||||||
|
|
||||||
var Session = new(session)
|
var Session = new(session)
|
||||||
|
|
||||||
// SessionCallback 接收会话回调通知
|
// SessionCallback 会话回调
|
||||||
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
|
||||||
return sessionService.Callback(ctx, req)
|
return sessionService.Callback(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHistoryMessages 获取历史消息
|
// GetHistoryList 获取历史列表(前端列表)
|
||||||
func (c *session) GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (res *dto.GetHistoryMessagesRes, err error) {
|
func (c *session) GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (res *dto.GetHistoryListRes, err error) {
|
||||||
return sessionService.GetHistoryMessages(ctx, req)
|
return sessionService.GetHistoryList(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSession 删除会话
|
// DeleteMessages 批量删除消息
|
||||||
|
func (c *session) DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (res *dto.DeleteMessagesRes, err error) {
|
||||||
|
return sessionService.DeleteMessages(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSession 删除整个会话
|
||||||
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
func (c *session) DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (res *dto.DeleteSessionRes, err error) {
|
||||||
return sessionService.DeleteSession(ctx, req)
|
return sessionService.DeleteSession(ctx, req)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,5 +66,5 @@ type GetPromptTextReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetPromptTextRes struct {
|
type GetPromptTextRes struct {
|
||||||
Messages any `json:"messages" dc:"最终消息数组"`
|
Messages any `json:"messages" dc:"历史消息"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,22 @@ package dto
|
|||||||
|
|
||||||
import "github.com/gogf/gf/v2/frame/g"
|
import "github.com/gogf/gf/v2/frame/g"
|
||||||
|
|
||||||
|
// HistoryRound 一轮对话
|
||||||
|
type HistoryRound struct {
|
||||||
|
Id int64 `json:"id" dc:"记录ID"`
|
||||||
|
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||||
|
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||||
|
User map[string]any `json:"user" dc:"用户消息"`
|
||||||
|
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
||||||
|
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
||||||
|
UpdatedAt string `json:"updatedAt" dc:"更新时间"`
|
||||||
|
}
|
||||||
|
|
||||||
// SessionCallbackReq 会话回调请求
|
// SessionCallbackReq 会话回调请求
|
||||||
type SessionCallbackReq struct {
|
type SessionCallbackReq struct {
|
||||||
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
g.Meta `path:"/callback" method:"post" tags:"会话管理" summary:"会话回调"`
|
||||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
Messages map[string]any `json:"messages" v:"required" dc:"消息数组"`
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
EpicycleId int64 `json:"epicycleId" v:"required" dc:"轮次ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionCallbackRes 会话回调响应
|
// SessionCallbackRes 会话回调响应
|
||||||
@@ -15,36 +26,55 @@ type SessionCallbackRes struct {
|
|||||||
SessionId string `json:"sessionId" dc:"会话ID"`
|
SessionId string `json:"sessionId" dc:"会话ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHistoryMessagesReq 获取历史消息请求
|
// GetHistoryListReq 获取历史列表请求(前端)
|
||||||
|
type GetHistoryListReq struct {
|
||||||
|
g.Meta `path:"/historyList" method:"get" tags:"会话管理" summary:"获取历史列表"`
|
||||||
|
Page int `json:"page" d:"1" dc:"页码"`
|
||||||
|
Size int `json:"size" d:"10" dc:"每页条数"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryListRes 获取历史列表响应
|
||||||
|
type GetHistoryListRes struct {
|
||||||
|
List []HistoryRound `json:"list" dc:"历史列表"`
|
||||||
|
Total int `json:"total" dc:"总数"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHistoryMessagesReq 获取历史消息请求(提示词拼接)
|
||||||
type GetHistoryMessagesReq struct {
|
type GetHistoryMessagesReq struct {
|
||||||
g.Meta `path:"/history" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
g.Meta `path:"/historyMessages" method:"get" tags:"会话管理" summary:"获取历史消息"`
|
||||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
NodeId string `json:"nodeId" dc:"节点ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHistoryMessagesRes 获取历史消息响应
|
// GetHistoryMessagesRes 获取历史消息响应
|
||||||
type GetHistoryMessagesRes struct {
|
type GetHistoryMessagesRes struct {
|
||||||
Messages []HistoryRound `json:"messages" dc:"历史消息列表"`
|
Messages []FlatMessage `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// HistoryRound 一轮对话
|
type FlatMessage struct {
|
||||||
type HistoryRound struct {
|
Role string `json:"role"`
|
||||||
Id int64 `json:"id" dc:"记录ID"`
|
Content string `json:"content"`
|
||||||
User map[string]any `json:"user" dc:"用户消息"`
|
|
||||||
Assistant map[string]any `json:"assistant" dc:"助手回复"`
|
|
||||||
CreatedAt string `json:"createdAt" dc:"创建时间"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSessionReq 删除会话请求
|
// DeleteMessagesReq 批量删除消息请求
|
||||||
type DeleteSessionReq struct {
|
type DeleteMessagesReq struct {
|
||||||
g.Meta `path:"/delete" method:"post" tags:"会话管理" summary:"删除会话"`
|
g.Meta `path:"/deleteMessages" method:"post" tags:"会话管理" summary:"批量删除消息"`
|
||||||
TenantId uint64 `json:"tenantId" dc:"租户ID"`
|
|
||||||
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
NodeId string `json:"nodeId" dc:"节点ID"`
|
MsgIds []int64 `json:"msgIds" v:"required" dc:"消息ID列表"`
|
||||||
MsgIds []int64 `json:"msgIds" dc:"消息ID列表,传则删单条,不传删整个会话"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSessionRes 删除会话响应
|
// DeleteMessagesRes 批量删除消息响应
|
||||||
|
type DeleteMessagesRes struct {
|
||||||
|
Ok bool `json:"ok" dc:"是否成功"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionReq 删除整个会话请求
|
||||||
|
type DeleteSessionReq struct {
|
||||||
|
g.Meta `path:"/deleteSession" method:"post" tags:"会话管理" summary:"删除整个会话"`
|
||||||
|
SessionId string `json:"sessionId" v:"required" dc:"会话ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteSessionRes 删除整个会话响应
|
||||||
type DeleteSessionRes struct {
|
type DeleteSessionRes struct {
|
||||||
Ok bool `json:"ok" dc:"是否成功"`
|
Ok bool `json:"ok" dc:"是否成功"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ type ComposeTask struct {
|
|||||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||||
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
GatewayState int `orm:"gateway_state" json:"gatewayState"`
|
||||||
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
||||||
ResultText map[string]any `orm:"result_text" json:"resultText"`
|
ResultJson map[string]any `orm:"result_json" json:"resultJson"`
|
||||||
Messages map[string]any `orm:"messages" json:"messages"`
|
|
||||||
Status string `orm:"status" json:"status"`
|
Status string `orm:"status" json:"status"`
|
||||||
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
ErrorMessage string `orm:"error_message" json:"errorMessage"`
|
||||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||||
@@ -28,8 +27,7 @@ type composeTaskCol struct {
|
|||||||
CallbackUrl string
|
CallbackUrl string
|
||||||
GatewayState string
|
GatewayState string
|
||||||
RequestPayload string
|
RequestPayload string
|
||||||
ResultText string
|
ResultJson string
|
||||||
Messages string
|
|
||||||
Status string
|
Status string
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
OssFile string
|
OssFile string
|
||||||
@@ -45,8 +43,7 @@ var ComposeTaskCol = composeTaskCol{
|
|||||||
CallbackUrl: "callback_url",
|
CallbackUrl: "callback_url",
|
||||||
GatewayState: "gateway_state",
|
GatewayState: "gateway_state",
|
||||||
RequestPayload: "request_payload",
|
RequestPayload: "request_payload",
|
||||||
ResultText: "result_text",
|
ResultJson: "result_json",
|
||||||
Messages: "messages",
|
|
||||||
Status: "status",
|
Status: "status",
|
||||||
ErrorMessage: "error_message",
|
ErrorMessage: "error_message",
|
||||||
OssFile: "oss_file",
|
OssFile: "oss_file",
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
|||||||
req := SendCallbackReq{
|
req := SendCallbackReq{
|
||||||
TaskId: composeTask.TaskId,
|
TaskId: composeTask.TaskId,
|
||||||
Status: composeTask.Status,
|
Status: composeTask.Status,
|
||||||
Messages: composeTask.Messages,
|
Messages: composeTask.ResultJson,
|
||||||
ErrorMsg: composeTask.ErrorMessage,
|
ErrorMsg: composeTask.ErrorMessage,
|
||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
GatewayState: req.State,
|
GatewayState: req.State,
|
||||||
OssFile: req.OssFile,
|
OssFile: req.OssFile,
|
||||||
FileType: req.FileType,
|
FileType: req.FileType,
|
||||||
ResultText: req.Messages,
|
ResultJson: req.Messages,
|
||||||
})
|
})
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusFailed
|
composeTask.Status = public.ComposeStatusFailed
|
||||||
@@ -181,11 +181,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
Status: public.ComposeStatusSuccess,
|
Status: public.ComposeStatusSuccess,
|
||||||
Messages: messages,
|
|
||||||
GatewayState: req.State,
|
GatewayState: req.State,
|
||||||
OssFile: req.OssFile,
|
OssFile: req.OssFile,
|
||||||
FileType: req.FileType,
|
FileType: req.FileType,
|
||||||
ResultText: req.Messages,
|
ResultJson: messages,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -214,7 +213,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
// 6) 回调业务方
|
// 6) 回调业务方
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusSuccess
|
composeTask.Status = public.ComposeStatusSuccess
|
||||||
composeTask.Messages = messages
|
composeTask.ResultJson = messages
|
||||||
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
_ = gateway.SendCallback(ctx, composeTask, epicycleId)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -232,7 +231,7 @@ func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes,
|
|||||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := parseMessagesForResponse(record.Messages)
|
messages := parseMessagesForResponse(record.ResultJson)
|
||||||
|
|
||||||
return &dto.GetComposeTaskRes{
|
return &dto.GetComposeTaskRes{
|
||||||
TaskId: record.TaskId,
|
TaskId: record.TaskId,
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -51,33 +52,34 @@ func SaveToRedis(ctx context.Context, tenantID uint64, sessionID string, round *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSingleMessage 删除 Redis 中单条消息(按消息ID)
|
// DeleteRedisMessages 批量删除 Redis 中多条消息(按消息ID列表)
|
||||||
func DeleteSingleMessage(ctx context.Context, tenantID uint64, sessionID string, msgID int64) error {
|
func DeleteRedisMessages(ctx context.Context, tenantID uint64, sessionID string, msgIDs []int64) error {
|
||||||
key := formatRedisKey(tenantID, sessionID)
|
key := formatRedisKey(tenantID, sessionID)
|
||||||
|
|
||||||
cursor := "0"
|
for _, msgID := range msgIDs {
|
||||||
for {
|
cursor := "0"
|
||||||
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
for {
|
||||||
if err != nil {
|
result, err := g.Redis().Do(ctx, "ZSCAN", key, cursor, "MATCH", fmt.Sprintf("*\"id\":%d*", msgID), "COUNT", 10)
|
||||||
return fmt.Errorf("ZSCAN失败: %w", err)
|
if err != nil {
|
||||||
}
|
g.Log().Warningf(ctx, "[会话Redis] ZSCAN失败 msgID=%d err=%v", msgID, err)
|
||||||
|
break
|
||||||
parts := result.Strings()
|
|
||||||
if len(parts) < 2 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
cursor = parts[0]
|
|
||||||
members := parts[1:]
|
|
||||||
|
|
||||||
for _, member := range members {
|
|
||||||
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[会话Redis] ZREM单条失败 key=%s err=%v", key, err)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if cursor == "0" {
|
parts := result.Strings()
|
||||||
break
|
if len(parts) < 2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = parts[0]
|
||||||
|
for _, member := range parts[1:] {
|
||||||
|
if _, err := g.Redis().Do(ctx, "ZREM", key, member); err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[会话Redis] ZREM失败 err=%v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cursor == "0" {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,8 +97,8 @@ func DeleteSessionHistory(ctx context.Context, tenantID uint64, sessionID string
|
|||||||
// 读操作
|
// 读操作
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|
||||||
// GetFromRedis 从 Redis ZSET 获取会话历史
|
// GetFromRedis 从 Redis ZSET 获取会话历史,返回 HistoryRound 切片
|
||||||
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]dto.HistoryRound, error) {
|
||||||
key := formatRedisKey(tenantID, sessionID)
|
key := formatRedisKey(tenantID, sessionID)
|
||||||
maxRounds := util.GetMaxRounds(ctx)
|
maxRounds := util.GetMaxRounds(ctx)
|
||||||
|
|
||||||
@@ -106,64 +108,46 @@ func GetFromRedis(ctx context.Context, tenantID uint64, sessionID string) ([]map
|
|||||||
}
|
}
|
||||||
|
|
||||||
if result == nil || result.IsNil() {
|
if result == nil || result.IsNil() {
|
||||||
return []map[string]any{}, nil
|
return []dto.HistoryRound{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseRedisRounds(ctx, result.Strings()), nil
|
return parseRounds(result.Strings()), nil
|
||||||
}
|
|
||||||
|
|
||||||
// GetSessionHistoryForInference 获取扁平消息数组(给推理用)
|
|
||||||
func GetSessionHistoryForInference(ctx context.Context, tenantID uint64, sessionID string) ([]map[string]any, error) {
|
|
||||||
rounds, err := GetFromRedis(ctx, tenantID, sessionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("获取历史会话失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rounds) == 0 {
|
|
||||||
return []map[string]any{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return flattenRounds(rounds), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 解析
|
// 解析
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|
||||||
func parseRedisRounds(ctx context.Context, members []string) []map[string]any {
|
// parseRounds 解析 Redis ZSET members 为 HistoryRound 切片
|
||||||
rounds := make([]map[string]any, 0, len(members))
|
func parseRounds(members []string) []dto.HistoryRound {
|
||||||
|
rounds := make([]dto.HistoryRound, 0, len(members))
|
||||||
for _, member := range members {
|
for _, member := range members {
|
||||||
var data map[string]any
|
var round dto.HistoryRound
|
||||||
if err := json.Unmarshal([]byte(member), &data); err != nil {
|
if err := json.Unmarshal([]byte(member), &round); err != nil {
|
||||||
g.Log().Warningf(ctx, "[会话Redis] 解析数据失败 err=%v", err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rounds = append(rounds, data)
|
if round.User != nil || round.Assistant != nil {
|
||||||
|
rounds = append(rounds, round)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return rounds
|
return rounds
|
||||||
}
|
}
|
||||||
|
|
||||||
func flattenRounds(rounds []map[string]any) []map[string]any {
|
func flattenRounds(rounds []dto.HistoryRound) []dto.FlatMessage {
|
||||||
var messages []map[string]any
|
var messages []dto.FlatMessage
|
||||||
for i := len(rounds) - 1; i >= 0; i-- {
|
for i := len(rounds) - 1; i >= 0; i-- {
|
||||||
if user, ok := rounds[i]["user"].(map[string]any); ok && len(user) > 0 {
|
if rounds[i].User != nil && gconv.String(rounds[i].User["content"]) != "" {
|
||||||
messages = append(messages, user)
|
messages = append(messages, dto.FlatMessage{
|
||||||
|
Role: gconv.String(rounds[i].User["role"]),
|
||||||
|
Content: gconv.String(rounds[i].User["content"]),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
if assistant, ok := rounds[i]["assistant"].(map[string]any); ok && len(assistant) > 0 {
|
if rounds[i].Assistant != nil && gconv.String(rounds[i].Assistant["content"]) != "" {
|
||||||
messages = append(messages, assistant)
|
messages = append(messages, dto.FlatMessage{
|
||||||
|
Role: gconv.String(rounds[i].Assistant["role"]),
|
||||||
|
Content: gconv.String(rounds[i].Assistant["content"]),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendFieldToMessages(data map[string]any, field string, messages *[]map[string]any) {
|
|
||||||
msgs, ok := data[field].([]any)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, m := range msgs {
|
|
||||||
if msg, ok := m.(map[string]any); ok {
|
|
||||||
*messages = append(*messages, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -15,9 +15,14 @@ import (
|
|||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 回调存储
|
||||||
|
// ============================================
|
||||||
|
|
||||||
// Callback 会话回调
|
// Callback 会话回调
|
||||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||||
req.Messages["role"] = "assistant"
|
req.Messages["role"] = "assistant"
|
||||||
|
|
||||||
// 1) 更新 DB
|
// 1) 更新 DB
|
||||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||||
@@ -36,22 +41,42 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
|
|||||||
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 写入 Redis
|
// 3) entity → HistoryRound → 写入 Redis
|
||||||
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, &dto.HistoryRound{
|
round := entityToHistoryRound(session)
|
||||||
Id: session.Id,
|
round.Assistant = req.Messages
|
||||||
User: session.RequestContent,
|
if err = SaveToRedis(ctx, session.TenantId, session.SessionId, round); err != nil {
|
||||||
Assistant: req.Messages,
|
|
||||||
CreatedAt: gconv.String(session.CreatedAt),
|
|
||||||
}); err != nil {
|
|
||||||
return nil, fmt.Errorf("redis存储失败: %w", err)
|
return nil, fmt.Errorf("redis存储失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 返回
|
|
||||||
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d", session.SessionId, session.Id)
|
||||||
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
return &dto.SessionCallbackRes{Status: true, SessionId: session.SessionId}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHistoryMessages 获取历史消息
|
// ============================================
|
||||||
|
// 场景1:前端历史列表(按 creator)
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// GetHistoryList 获取历史列表
|
||||||
|
func GetHistoryList(ctx context.Context, req *dto.GetHistoryListReq) (*dto.GetHistoryListRes, error) {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sessions, total, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
}, req.Page, req.Size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("DB获取历史列表失败: %w", err)
|
||||||
|
}
|
||||||
|
rounds := sessionsToHistoryRounds(sessions)
|
||||||
|
return &dto.GetHistoryListRes{List: rounds, Total: total}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================
|
||||||
|
// 场景2:提示词拼接(按 sessionId + nodeId)
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
// GetHistoryMessages 获取历史消息(Redis → DB → 异步回种)
|
||||||
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*dto.GetHistoryMessagesRes, error) {
|
||||||
user, err := utils.GetUserInfo(ctx)
|
user, err := utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -59,10 +84,9 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 1) Redis
|
// 1) Redis
|
||||||
redisRounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId)
|
if rounds, err := GetFromRedis(ctx, user.TenantId, req.SessionId); err == nil && len(rounds) > 0 {
|
||||||
if err == nil && len(redisRounds) > 0 {
|
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(rounds))
|
||||||
g.Log().Debugf(ctx, "[历史消息] Redis命中 sessionId=%s count=%d", req.SessionId, len(redisRounds))
|
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||||
return &dto.GetHistoryMessagesRes{Messages: parseHistoryRounds(redisRounds)}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) DB
|
// 2) DB
|
||||||
@@ -70,129 +94,108 @@ func GetHistoryMessages(ctx context.Context, req *dto.GetHistoryMessagesReq) (*d
|
|||||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
SessionId: req.SessionId,
|
SessionId: req.SessionId,
|
||||||
|
NodeId: req.NodeId,
|
||||||
}, 1, maxRounds)
|
}, 1, maxRounds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||||
}
|
}
|
||||||
if len(sessions) == 0 {
|
if len(sessions) == 0 {
|
||||||
return &dto.GetHistoryMessagesRes{Messages: []dto.HistoryRound{}}, nil
|
return &dto.GetHistoryMessagesRes{Messages: []dto.FlatMessage{}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 转换 + 异步回种
|
// 3) 转换 + 异步回种
|
||||||
rounds := sessionsToHistoryRounds(sessions)
|
rounds := sessionsToHistoryRounds(sessions)
|
||||||
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, sessions)
|
go asyncCacheToRedis(context.WithoutCancel(ctx), user.TenantId, req.SessionId, rounds)
|
||||||
|
|
||||||
return &dto.GetHistoryMessagesRes{Messages: rounds}, nil
|
return &dto.GetHistoryMessagesRes{Messages: flattenRounds(rounds)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseHistoryRounds Redis 数据转为 HistoryRound
|
// ============================================
|
||||||
func parseHistoryRounds(redisRounds []map[string]any) []dto.HistoryRound {
|
// 删除
|
||||||
rounds := make([]dto.HistoryRound, 0, len(redisRounds))
|
// ============================================
|
||||||
for _, r := range redisRounds {
|
|
||||||
round := dto.HistoryRound{
|
// DeleteMessages 批量删除消息
|
||||||
Id: gconv.Int64(r["id"]),
|
func DeleteMessages(ctx context.Context, req *dto.DeleteMessagesReq) (*dto.DeleteMessagesRes, error) {
|
||||||
CreatedAt: gconv.String(r["createdAt"]),
|
if len(req.MsgIds) == 0 {
|
||||||
}
|
return &dto.DeleteMessagesRes{Ok: false}, fmt.Errorf("msgIds不能为空")
|
||||||
if user, ok := r["user"].(map[string]any); ok {
|
|
||||||
round.User = user
|
|
||||||
}
|
|
||||||
if assistant, ok := r["assistant"].(map[string]any); ok {
|
|
||||||
round.Assistant = assistant
|
|
||||||
}
|
|
||||||
rounds = append(rounds, round)
|
|
||||||
}
|
}
|
||||||
return rounds
|
|
||||||
}
|
|
||||||
|
|
||||||
// sessionsToHistoryRounds DB 数据转为 HistoryRound
|
// 1) 删 DB
|
||||||
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
for _, id := range req.MsgIds {
|
||||||
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
_, _ = dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
|
||||||
for _, s := range sessions {
|
SQLBaseDO: beans.SQLBaseDO{Id: id},
|
||||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
})
|
||||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
|
||||||
|
|
||||||
round := dto.HistoryRound{
|
|
||||||
Id: s.Id,
|
|
||||||
CreatedAt: gconv.String(s.CreatedAt),
|
|
||||||
}
|
|
||||||
if len(reqMsgs) > 0 {
|
|
||||||
round.User = reqMsgs[0]
|
|
||||||
}
|
|
||||||
if len(respMsgs) > 0 {
|
|
||||||
if respMsgs[0]["role"] == nil {
|
|
||||||
respMsgs[0]["role"] = "assistant"
|
|
||||||
}
|
|
||||||
round.Assistant = respMsgs[0]
|
|
||||||
}
|
|
||||||
rounds = append(rounds, round)
|
|
||||||
}
|
}
|
||||||
return rounds
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 2) 删 Redis
|
||||||
|
_ = DeleteRedisMessages(ctx, user.TenantId, req.SessionId, req.MsgIds)
|
||||||
|
|
||||||
|
return &dto.DeleteMessagesRes{Ok: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSession 删除会话
|
// DeleteSession 删除整个会话
|
||||||
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteSessionRes, error) {
|
||||||
hasMsgID := len(req.MsgIds) > 0 && req.MsgIds[0] > 0
|
// 1) 删 DB
|
||||||
|
if _, err := dao.ComposeSession.Delete(ctx, &entity.ComposeSession{
|
||||||
deleteReq := &entity.ComposeSession{
|
|
||||||
SessionId: req.SessionId,
|
SessionId: req.SessionId,
|
||||||
NodeId: req.NodeId,
|
}); err != nil {
|
||||||
}
|
|
||||||
if hasMsgID {
|
|
||||||
deleteReq.Id = req.MsgIds[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := dao.ComposeSession.Delete(ctx, deleteReq); err != nil {
|
|
||||||
return nil, fmt.Errorf("DB删除失败: %w", err)
|
return nil, fmt.Errorf("DB删除失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasMsgID {
|
user, err := utils.GetUserInfo(ctx)
|
||||||
if err := DeleteSingleMessage(ctx, req.TenantId, req.SessionId, req.MsgIds[0]); err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[删除会话] Redis删除单条失败 msgID=%d err=%v", req.MsgIds[0], err)
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
// 2) 删 Redis
|
||||||
if err := DeleteSessionHistory(ctx, req.TenantId, req.SessionId); err != nil {
|
if err := DeleteSessionHistory(ctx, user.TenantId, req.SessionId); err != nil {
|
||||||
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
g.Log().Warningf(ctx, "[删除会话] Redis删除失败 sessionId=%s err=%v", req.SessionId, err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &dto.DeleteSessionRes{Ok: true}, nil
|
return &dto.DeleteSessionRes{Ok: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
// 内部方法
|
// 转换方法(entity ↔ dto,集中管理)
|
||||||
// ============================================
|
// ============================================
|
||||||
|
|
||||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
// entityToHistoryRound entity → HistoryRound
|
||||||
var messages []map[string]any
|
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||||
for i := len(sessions) - 1; i >= 0; i-- {
|
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
||||||
appendRoleMessages(sessions[i].RequestContent, "user", &messages)
|
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
||||||
appendRoleMessages(sessions[i].ResponseContent, "assistant", &messages)
|
|
||||||
|
round := &dto.HistoryRound{
|
||||||
|
Id: s.Id,
|
||||||
|
SessionId: s.SessionId,
|
||||||
|
NodeId: s.NodeId,
|
||||||
|
CreatedAt: gconv.String(s.CreatedAt),
|
||||||
|
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||||
}
|
}
|
||||||
return messages
|
if len(reqMsgs) > 0 {
|
||||||
|
round.User = reqMsgs[0]
|
||||||
|
}
|
||||||
|
if len(respMsgs) > 0 {
|
||||||
|
round.Assistant = respMsgs[0]
|
||||||
|
}
|
||||||
|
return round
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendRoleMessages(content any, defaultRole string, messages *[]map[string]any) {
|
// sessionsToHistoryRounds 批量转换
|
||||||
msgs := util.ConvertToMessages(content)
|
func sessionsToHistoryRounds(sessions []*entity.ComposeSession) []dto.HistoryRound {
|
||||||
for _, m := range msgs {
|
rounds := make([]dto.HistoryRound, 0, len(sessions))
|
||||||
if m["role"] == nil || gconv.String(m["role"]) == "" {
|
|
||||||
m["role"] = defaultRole
|
|
||||||
}
|
|
||||||
*messages = append(*messages, m)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// asyncCacheToRedis 异步缓存会话数据到 Redis
|
|
||||||
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, sessions []*entity.ComposeSession) {
|
|
||||||
for _, s := range sessions {
|
for _, s := range sessions {
|
||||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
rounds = append(rounds, *entityToHistoryRound(s))
|
||||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
}
|
||||||
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
return rounds
|
||||||
_ = SaveToRedis(ctx, tenantID, sessionID, &dto.HistoryRound{
|
}
|
||||||
Id: s.Id,
|
|
||||||
User: s.RequestContent,
|
// asyncCacheToRedis 异步缓存到 Redis
|
||||||
Assistant: s.ResponseContent,
|
func asyncCacheToRedis(ctx context.Context, tenantID uint64, sessionID string, rounds []dto.HistoryRound) {
|
||||||
CreatedAt: gconv.String(s.CreatedAt),
|
for i := range rounds {
|
||||||
})
|
if rounds[i].User != nil || rounds[i].Assistant != nil {
|
||||||
|
_ = SaveToRedis(ctx, tenantID, sessionID, &rounds[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user