diff --git a/common/util/json.go b/common/util/json.go index e7ef599..83d2615 100644 --- a/common/util/json.go +++ b/common/util/json.go @@ -227,3 +227,24 @@ func MergeConsult(req map[string]any, messages map[string]any, extendMapping map } return result } + +// GetUserMessage 获取用户消息 +func GetUserMessage(taskReq map[string]any) map[string]any { + // 先取 requestPayload + rp, ok := taskReq["requestPayload"].(map[string]any) + if !ok { + return nil + } + // 再取 messages + messages, ok := rp["messages"].([]any) + if !ok { + return nil + } + for _, msg := range messages { + m, ok := msg.(map[string]any) + if ok && m["role"] == "user" { + return m + } + } + return nil +} diff --git a/controller/prompt_session_controller.go b/controller/prompt_session_controller.go index 4bc7613..bb5dae4 100644 --- a/controller/prompt_session_controller.go +++ b/controller/prompt_session_controller.go @@ -4,7 +4,7 @@ import ( "context" "prompts-core/model/dto" - promptService "prompts-core/service/prompt" + sessionService "prompts-core/service/session" ) type session struct{} @@ -14,5 +14,10 @@ var Session = new(session) // SessionCallback 会话回调 func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { - return promptService.SessionCallback(ctx, req) + return sessionService.Callback(ctx, req) } + +//TODO:后期历史相关服务可能拆分(三个接口) +// 1. 添加历史会话 +// 2. 获取历史会话 +// 3. 更新历史信息 diff --git a/model/dto/prompt_compose_dto.go b/model/dto/prompt_compose_dto.go index 9a6ac31..c656c99 100644 --- a/model/dto/prompt_compose_dto.go +++ b/model/dto/prompt_compose_dto.go @@ -21,7 +21,8 @@ type ConsultItem struct { Url string `json:"url" dc:"附件地址"` } type ComposeMessagesRes struct { - TaskId string `json:"taskId" dc:"任务ID"` + TaskId string `json:"taskId" dc:"任务ID"` + EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"` } // MultiRoundResult 多轮返回结果 @@ -32,13 +33,13 @@ type MultiRoundResult struct { type CallbackReq struct { g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"` - TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` - State int `json:"state" dc:"网关任务状态"` - OssFile string `json:"oss_file" dc:"结果文件地址"` - FileType string `json:"file_type" dc:"结果文件类型"` - Text string `json:"text" dc:"文本结果"` - ErrorMsg string `json:"error_msg" dc:"错误信息"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` + State int `json:"state" dc:"网关任务状态"` + OssFile string `json:"oss_file" dc:"结果文件地址"` + FileType string `json:"file_type" dc:"结果文件类型"` + Messages map[string]any `json:"messages" dc:"消息数组"` + ErrorMsg string `json:"error_msg" dc:"错误信息"` + EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` } type CallbackRes struct { diff --git a/model/dto/prompt_session_dto.go b/model/dto/prompt_session_dto.go index c0974c6..4de558f 100644 --- a/model/dto/prompt_session_dto.go +++ b/model/dto/prompt_session_dto.go @@ -4,9 +4,11 @@ import "github.com/gogf/gf/v2/frame/g" type SessionCallbackReq struct { g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"` - Text string `json:"text" dc:"文本结果"` - EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` + Messages map[string]any `json:"messages" dc:"消息数组"` + EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` } type SessionCallbackRes struct { + Status bool `json:"status" dc:"状态"` + SessionId string `json:"sessionId" dc:"会话ID"` } diff --git a/model/entity/prompts_compose_session.go b/model/entity/prompts_compose_session.go index 9dcb38a..bead3e3 100644 --- a/model/entity/prompts_compose_session.go +++ b/model/entity/prompts_compose_session.go @@ -4,10 +4,10 @@ import "gitea.com/red-future/common/beans" type ComposeSession struct { beans.SQLBaseDO `orm:",inline"` - SessionId string `orm:"session_id" json:"sessionId"` - RequestContent any `orm:"request_content" json:"requestContent"` - ResponseContent any `orm:"response_content" json:"responseContent"` - Remark string `orm:"remark" json:"remark"` + SessionId string `orm:"session_id" json:"sessionId"` + RequestContent map[string]any `orm:"request_content" json:"requestContent"` + ResponseContent map[string]any `orm:"response_content" json:"responseContent"` + Remark string `orm:"remark" json:"remark"` } type composeSessionCol struct { diff --git a/model/entity/prompts_compose_task.go b/model/entity/prompts_compose_task.go index a185c75..55318cc 100644 --- a/model/entity/prompts_compose_task.go +++ b/model/entity/prompts_compose_task.go @@ -11,7 +11,7 @@ type ComposeTask struct { CallbackUrl string `orm:"callback_url" json:"callbackUrl"` GatewayState int `orm:"gateway_state" json:"gatewayState"` RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"` - ResultText string `orm:"result_text" json:"resultText"` + ResultText map[string]any `orm:"result_text" json:"resultText"` Messages map[string]any `orm:"messages" json:"messages"` Status string `orm:"status" json:"status"` ErrorMessage string `orm:"error_message" json:"errorMessage"` diff --git a/service/prompt/prompt_build_service.go b/service/prompt/prompt_build_service.go index 9aa5ccd..8ed8ee6 100644 --- a/service/prompt/prompt_build_service.go +++ b/service/prompt/prompt_build_service.go @@ -25,6 +25,7 @@ type UserPromptPayload struct { Consult []dto.ConsultItem `json:"consult"` UserFilesText map[string]string `json:"userFilesText"` Skills string `json:"skills"` + BuildType int `json:"buildType"` } // buildInferenceRequest 构建推理请求 @@ -33,9 +34,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha if err != nil { return nil, fmt.Errorf("处理用户表单分批失败: %w", err) } - ir := NewPromptIR() - switch req.BuildType { case public.BuildTypePrompt: return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches) @@ -65,11 +64,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) } - // 记录历史会话 - _, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ - SessionId: req.SessionId, - RequestContent: ir.User, - }) return compileToProviderRequest(ctx, ir, chatModel) } @@ -168,6 +162,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st Consult: req.Consult, UserFilesText: ExtractFileTexts(ctx, req.Consult), Skills: SkillMdContent(ctx, req.SkillName), + BuildType: req.BuildType, } return gjson.New(payload).String() } diff --git a/service/prompt/prompt_compose_service.go b/service/prompt/prompt_compose_service.go index 9bf0b88..eebec25 100644 --- a/service/prompt/prompt_compose_service.go +++ b/service/prompt/prompt_compose_service.go @@ -5,10 +5,10 @@ import ( "encoding/json" "errors" "fmt" + "prompts-core/service/session" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" - "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "prompts-core/common/util" @@ -44,17 +44,27 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity. if err != nil { return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) } - - chatModel, err := getChatModel(ctx, userInfo.UserName) + chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + IsChatModel: new(1), + }) + if err != nil { + return nil, nil, err + } + if chatModel == nil { + return nil, nil, errors.New("当前没有对话模型,请添加") + } + aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName}, + ModelName: req.ModelName, + }) if err != nil { return nil, nil, err } - aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName) - if err != nil { - return nil, nil, err + if aiModel == nil { + return nil, nil, errors.New("需要构建的模型不存在") } - return chatModel, aiModel, nil } @@ -73,51 +83,24 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens,可用窗口 %d tokens,请精简后重试", exceedTokens, availableWindow) } - return nil } // handlePromptBuild 处理提示词构建(BuildType=1) func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { // 获取历史会话 - history, err := GetHistoryMessages(ctx, req.SessionId) + history, err := session.GetHistoryMessages(ctx, req.SessionId) if err != nil { g.Log().Errorf(ctx, "获取历史会话失败: %v,将不使用历史会话", err) history = nil } // 调用推理模型 - taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history) + taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history) if err != nil { return nil, fmt.Errorf("调用推理模型失败: %w", err) } // 保存任务记录 - if err = saveComposeTask(ctx, taskID, req); err != nil { - return nil, fmt.Errorf("保存任务记录失败: %w", err) - } - return &dto.ComposeMessagesRes{ - TaskId: taskID, - }, nil -} - -// handleNodeBuild 处理节点构建(BuildType=2) -func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { - taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) - if err != nil { - return nil, fmt.Errorf("调用推理模型失败: %w", err) - } - - if err := saveComposeTask(ctx, taskID, req); err != nil { - return nil, fmt.Errorf("保存任务记录失败: %w", err) - } - - return &dto.ComposeMessagesRes{ - TaskId: taskID, - }, nil -} - -// saveComposeTask 保存组合任务记录 -func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error { - _, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ + _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ TaskId: taskID, ModelName: req.ModelName, SkillName: req.SkillName, @@ -126,77 +109,70 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage RequestPayload: util.MustMarshalToMap(req), Status: public.ComposeStatusPending, }) - return err + + if err != nil { + return nil, fmt.Errorf("保存任务记录失败: %w", err) + } + return &dto.ComposeMessagesRes{ + TaskId: taskID, + EpicycleId: id, + }, nil } -// getChatModel 获取聊天模型 -func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) { - chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userName}, - IsChatModel: new(1), +// handleNodeBuild 处理节点构建(BuildType=2) +func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { + taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil) + if err != nil { + return nil, fmt.Errorf("调用推理模型失败: %w", err) + } + // 保存任务记录 + _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{ + TaskId: taskID, + ModelName: req.ModelName, + SkillName: req.SkillName, + BuildType: req.BuildType, + CallbackUrl: req.CallbackUrl, + RequestPayload: util.MustMarshalToMap(req), + Status: public.ComposeStatusPending, }) if err != nil { - return nil, fmt.Errorf("查询聊天模型失败: %w", err) + return nil, fmt.Errorf("保存任务记录失败: %w", err) } - - if chatModel == nil { - return nil, errors.New("当前没有对话模型,请添加") - } - - return chatModel, nil -} - -// getAIModel 获取AI模型 -func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) { - aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Creator: userName}, - ModelName: modelName, - }) - if err != nil { - return nil, fmt.Errorf("查询AI模型失败: %w", err) - } - - if aiModel == nil { - return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName) - } - - return aiModel, nil + return &dto.ComposeMessagesRes{ + TaskId: taskID, + EpicycleId: id, + }, nil } // callInferenceModel 调用推理模型 -func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) { - taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history) +func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) { + taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history) if err != nil { - return "", fmt.Errorf("构建推理请求失败: %w", err) + return "", 0, fmt.Errorf("构建推理请求失败: %w", err) + } + id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{ + SessionId: req.SessionId, + RequestContent: util.GetUserMessage(taskReq), + }) + if err != nil { + return "", 0, fmt.Errorf("保存历史会话失败: %w", err) } - taskID, err := gateway.CreateGatewayTask(ctx, taskReq) if err != nil { - return "", fmt.Errorf("创建网关任务失败: %w", err) + return "", 0, fmt.Errorf("创建网关任务失败: %w", err) } if taskID == "" { - return "", errors.New("网关未返回taskId") + return "", 0, errors.New("网关未返回taskId") } - return taskID, nil -} - -// createDefaultResult 创建默认结果 -func createDefaultResult(data map[string]any) map[string]any { - if data == nil { - data = make(map[string]any) - } - return map[string]any{ - "total_rounds": 1, - "rounds": []map[string]any{data}, - } + return taskID, id, nil } // Callback 回调处理 func Callback(ctx context.Context, req *dto.CallbackReq) error { g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", - req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) + req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages)) // 查询任务 composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ TaskId: req.TaskId, @@ -220,7 +196,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { GatewayState: req.State, OssFile: req.OssFile, FileType: req.FileType, - ResultText: req.Text, + ResultText: req.Messages, }) // 用更新后的值发送回调 if composeTask.CallbackUrl != "" { @@ -241,11 +217,11 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { var messages map[string]any switch composeTask.BuildType { case public.BuildTypePrompt: // 提示词构建解析 - messages = ParsePromptResult(req.Text) + messages = ParsePromptResult(req.Messages) case public.BuildTypeNode: // 节点构建解析 - messages = ParseNodeResult(req.Text) + messages = ParseNodeResult(req.Messages) default: - messages = gjson.New(req.Text).Map() + messages = req.Messages } // 2. 处理附加字段 messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping) @@ -257,7 +233,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { GatewayState: req.State, OssFile: req.OssFile, FileType: req.FileType, - ResultText: req.Text, + ResultText: req.Messages, }) if err != nil { g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err) @@ -278,18 +254,12 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error { } // ParsePromptResult 解析提示词构建结果 -func ParsePromptResult(raw string) map[string]any { - var wrapper map[string]any - if err := json.Unmarshal([]byte(raw), &wrapper); err != nil { - return createDefaultResult(map[string]any{"raw": raw}) - } - - contentStr, ok := wrapper["content"].(string) +func ParsePromptResult(raw map[string]any) map[string]any { + contentStr, ok := raw["content"].(string) if !ok || contentStr == "" { - return createDefaultResult(wrapper) + return raw } - // 先尝试解析为数组 if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { return map[string]any{ "total_rounds": len(roundsArray), @@ -297,7 +267,6 @@ func ParsePromptResult(raw string) map[string]any { } } - // 再尝试解析为单个对象 if singleRound := tryParseAsMap(contentStr); singleRound != nil { return map[string]any{ "total_rounds": 1, @@ -305,7 +274,7 @@ func ParsePromptResult(raw string) map[string]any { } } - return createDefaultResult(map[string]any{"content": contentStr}) + return map[string]any{"content": contentStr} } func tryParseAsMapArray(jsonStr string) []map[string]any { @@ -330,22 +299,20 @@ func tryParseAsMap(jsonStr string) map[string]any { return obj } -// ParseNodeResult 解析节点构建结果 -func ParseNodeResult(raw string) map[string]any { - var result map[string]any - if err := json.Unmarshal([]byte(raw), &result); err != nil { - return createDefaultResult(map[string]any{"raw": raw}) - } - - if contentStr, ok := result["content"].(string); ok && contentStr != "" { +func ParseNodeResult(raw map[string]any) map[string]any { + contentStr, ok := raw["content"].(string) + if ok && contentStr != "" { var inner map[string]any if err := json.Unmarshal([]byte(contentStr), &inner); err == nil { - result = inner + return map[string]any{ + "total_rounds": 1, + "rounds": []map[string]any{inner}, + } } } return map[string]any{ "total_rounds": 1, - "rounds": []map[string]any{result}, + "rounds": []map[string]any{raw}, } } diff --git a/service/prompt/prompt_session_redis_service.go b/service/session/prompt_session_redis_service.go similarity index 89% rename from service/prompt/prompt_session_redis_service.go rename to service/session/prompt_session_redis_service.go index 4024628..0599e42 100644 --- a/service/prompt/prompt_session_redis_service.go +++ b/service/session/prompt_session_redis_service.go @@ -1,9 +1,10 @@ -package prompt +package session import ( "context" "encoding/json" "fmt" + "prompts-core/model/entity" "time" "github.com/gogf/gf/v2/frame/g" @@ -13,37 +14,33 @@ const ( redisKeyPrefix = "chat:session:%s" ) -// saveToRedis 保存会话数据到Redis -func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { - key := formatRedisKey(sessionId) +// formatRedisKey 格式化Redis键 +func formatRedisKey(sessionId string) string { + return fmt.Sprintf(redisKeyPrefix, sessionId) +} +// saveToRedis 保存会话数据到Redis +func saveToRedis(ctx context.Context, session *entity.ComposeSession) error { + key := formatRedisKey(session.SessionId) maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int() expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64() - data := map[string]any{ - "sessionId": sessionId, - "requestContent": requestMessages, - "responseContent": responseMessages, + "sessionId": session.SessionId, + "requestContent": session.RequestContent, + "responseContent": session.ResponseContent, "timestamp": time.Now().Unix(), } - b, err := json.Marshal(data) if err != nil { return fmt.Errorf("序列化会话数据失败: %w", err) } - if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { + if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { return err } - return nil } -// formatRedisKey 格式化Redis键 -func formatRedisKey(sessionId string) string { - return fmt.Sprintf(redisKeyPrefix, sessionId) -} - // executeRedisCommands 执行Redis命令 func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error { if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil { diff --git a/service/prompt/prompt_session_service.go b/service/session/prompt_session_service.go similarity index 59% rename from service/prompt/prompt_session_service.go rename to service/session/prompt_session_service.go index c72c6f3..9933457 100644 --- a/service/prompt/prompt_session_service.go +++ b/service/session/prompt_session_service.go @@ -1,4 +1,4 @@ -package prompt +package session import ( "context" @@ -14,74 +14,36 @@ import ( "prompts-core/model/entity" ) -// SessionCallback 会话回调 -func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { - result, err := util.ParseOutput(req.Text) - if err != nil { - g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err) - return nil, fmt.Errorf("解析模型输出失败: %w", err) - } - - result["role"] = "assistant" - if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil { - return nil, err - } - - session, err := getSessionById(ctx, req.EpicycleId) - if err != nil { - return nil, err - } - - if err := saveSessionToRedis(ctx, session); err != nil { - return nil, err - } - - requestMessages := util.ConvertToMessages(session.RequestContent) - responseMessages := util.ConvertToMessages(session.ResponseContent) - - g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", - session.SessionId, session.Id, len(requestMessages), len(responseMessages)) - - return &dto.SessionCallbackRes{}, nil -} - -// updateSessionResponse 更新会话响应 -func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error { +// Callback 会话回调 +func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { + req.Messages["role"] = "assistant" _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ - SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, - ResponseContent: response, + SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, + ResponseContent: req.Messages, }) if err != nil { - g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err) - return fmt.Errorf("更新数据库失败: %w", err) + g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err) + return nil, fmt.Errorf("更新数据库失败: %w", err) } - return nil -} - -// getSessionById 根据ID获取会话 -func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) { session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{ - SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, + SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, }) + if session == nil { + return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId) + } if err != nil { - g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err) + g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err) return nil, fmt.Errorf("获取会话数据失败: %w", err) } - return session, nil -} - -// saveSessionToRedis 保存会话到Redis -func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error { - requestMessages := util.ConvertToMessages(session.RequestContent) - responseMessages := util.ConvertToMessages(session.ResponseContent) - - if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil { - g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v", - session.SessionId, session.Id, err) - return fmt.Errorf("Redis存储失败: %w", err) + if err = saveToRedis(ctx, session); err != nil { + return nil, fmt.Errorf("redis存储失败: %w", err) } - - return nil + g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d", + session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent)) + return &dto.SessionCallbackRes{ + Status: true, + SessionId: session.SessionId, + }, nil } // GetHistoryMessages 获取历史信息 @@ -159,7 +121,7 @@ func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession } if len(reqMsgs) > 0 || len(respMsgs) > 0 { - _ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) + _ = saveToRedis(ctx, session) } } }