refactor(task): 重构异步任务处理流程

This commit is contained in:
2026-05-27 09:36:26 +08:00
parent 2548ffc7ac
commit d74559ae74
10 changed files with 162 additions and 212 deletions

View File

@@ -227,3 +227,24 @@ func MergeConsult(req map[string]any, messages map[string]any, extendMapping map
} }
return result return result
} }
// GetUserMessage 获取用户消息
func GetUserMessage(taskReq map[string]any) map[string]any {
// 先取 requestPayload
rp, ok := taskReq["requestPayload"].(map[string]any)
if !ok {
return nil
}
// 再取 messages
messages, ok := rp["messages"].([]any)
if !ok {
return nil
}
for _, msg := range messages {
m, ok := msg.(map[string]any)
if ok && m["role"] == "user" {
return m
}
}
return nil
}

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"prompts-core/model/dto" "prompts-core/model/dto"
promptService "prompts-core/service/prompt" sessionService "prompts-core/service/session"
) )
type session struct{} type session struct{}
@@ -14,5 +14,10 @@ var Session = new(session)
// SessionCallback 会话回调 // SessionCallback 会话回调
func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) { func (c *session) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *dto.SessionCallbackRes, err error) {
return promptService.SessionCallback(ctx, req) return sessionService.Callback(ctx, req)
} }
//TODO:后期历史相关服务可能拆分(三个接口)
// 1. 添加历史会话
// 2. 获取历史会话
// 3. 更新历史信息

View File

@@ -22,6 +22,7 @@ type ConsultItem struct {
} }
type ComposeMessagesRes struct { type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"` TaskId string `json:"taskId" dc:"任务ID"`
EpicycleId int64 `json:"epicycle_id" dc:"轮次ID"`
} }
// MultiRoundResult 多轮返回结果 // MultiRoundResult 多轮返回结果
@@ -36,7 +37,7 @@ type CallbackReq struct {
State int `json:"state" dc:"网关任务状态"` State int `json:"state" dc:"网关任务状态"`
OssFile string `json:"oss_file" dc:"结果文件地址"` OssFile string `json:"oss_file" dc:"结果文件地址"`
FileType string `json:"file_type" dc:"结果文件类型"` FileType string `json:"file_type" dc:"结果文件类型"`
Text string `json:"text" dc:"文本结果"` Messages map[string]any `json:"messages" dc:"消息数组"`
ErrorMsg string `json:"error_msg" dc:"错误信息"` ErrorMsg string `json:"error_msg" dc:"错误信息"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
} }

View File

@@ -4,9 +4,11 @@ import "github.com/gogf/gf/v2/frame/g"
type SessionCallbackReq struct { type SessionCallbackReq struct {
g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"` g.Meta `path:"/sessionCallback" method:"post" tags:"提示词处理"`
Text string `json:"text" dc:"文本结果"` Messages map[string]any `json:"messages" dc:"消息数组"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
} }
type SessionCallbackRes struct { type SessionCallbackRes struct {
Status bool `json:"status" dc:"状态"`
SessionId string `json:"sessionId" dc:"会话ID"`
} }

View File

@@ -5,8 +5,8 @@ import "gitea.com/red-future/common/beans"
type ComposeSession struct { type ComposeSession struct {
beans.SQLBaseDO `orm:",inline"` beans.SQLBaseDO `orm:",inline"`
SessionId string `orm:"session_id" json:"sessionId"` SessionId string `orm:"session_id" json:"sessionId"`
RequestContent any `orm:"request_content" json:"requestContent"` RequestContent map[string]any `orm:"request_content" json:"requestContent"`
ResponseContent any `orm:"response_content" json:"responseContent"` ResponseContent map[string]any `orm:"response_content" json:"responseContent"`
Remark string `orm:"remark" json:"remark"` Remark string `orm:"remark" json:"remark"`
} }

View File

@@ -11,7 +11,7 @@ type ComposeTask struct {
CallbackUrl string `orm:"callback_url" json:"callbackUrl"` CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
GatewayState int `orm:"gateway_state" json:"gatewayState"` GatewayState int `orm:"gateway_state" json:"gatewayState"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"` RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
ResultText string `orm:"result_text" json:"resultText"` ResultText map[string]any `orm:"result_text" json:"resultText"`
Messages map[string]any `orm:"messages" json:"messages"` Messages map[string]any `orm:"messages" json:"messages"`
Status string `orm:"status" json:"status"` Status string `orm:"status" json:"status"`
ErrorMessage string `orm:"error_message" json:"errorMessage"` ErrorMessage string `orm:"error_message" json:"errorMessage"`

View File

@@ -25,6 +25,7 @@ type UserPromptPayload struct {
Consult []dto.ConsultItem `json:"consult"` Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"` UserFilesText map[string]string `json:"userFilesText"`
Skills string `json:"skills"` Skills string `json:"skills"`
BuildType int `json:"buildType"`
} }
// buildInferenceRequest 构建推理请求 // buildInferenceRequest 构建推理请求
@@ -33,9 +34,7 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
if err != nil { if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err) return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
} }
ir := NewPromptIR() ir := NewPromptIR()
switch req.BuildType { switch req.BuildType {
case public.BuildTypePrompt: case public.BuildTypePrompt:
return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches) return buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, history, ir, totalBatches)
@@ -65,11 +64,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig) availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow) return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
} }
// 记录历史会话
_, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: ir.User,
})
return compileToProviderRequest(ctx, ir, chatModel) return compileToProviderRequest(ctx, ir, chatModel)
} }
@@ -168,6 +162,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
Consult: req.Consult, Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult), UserFilesText: ExtractFileTexts(ctx, req.Consult),
Skills: SkillMdContent(ctx, req.SkillName), Skills: SkillMdContent(ctx, req.SkillName),
BuildType: req.BuildType,
} }
return gjson.New(payload).String() return gjson.New(payload).String()
} }

View File

@@ -5,10 +5,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans" "gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils" "gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util" "prompts-core/common/util"
@@ -44,17 +44,27 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err) return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
} }
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
chatModel, err := getChatModel(ctx, userInfo.UserName) SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName) if aiModel == nil {
if err != nil { return nil, nil, errors.New("需要构建的模型不存在")
return nil, nil, err
} }
return chatModel, aiModel, nil return chatModel, aiModel, nil
} }
@@ -73,51 +83,24 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试", return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow) exceedTokens, availableWindow)
} }
return nil return nil
} }
// handlePromptBuild 处理提示词构建BuildType=1 // handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) { func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话 // 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId) history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil { if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err) g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil history = nil
} }
// 调用推理模型 // 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history) taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil { if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err) return nil, fmt.Errorf("调用推理模型失败: %w", err)
} }
// 保存任务记录 // 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil { _, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID, TaskId: taskID,
ModelName: req.ModelName, ModelName: req.ModelName,
SkillName: req.SkillName, SkillName: req.SkillName,
@@ -126,77 +109,70 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage
RequestPayload: util.MustMarshalToMap(req), RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending, Status: public.ComposeStatusPending,
}) })
return err
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
} }
// getChatModel 获取聊天模型 // handleNodeBuild 处理节点构建BuildType=2
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) { func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
SQLBaseDO: beans.SQLBaseDO{Creator: userName}, if err != nil {
IsChatModel: new(1), return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err) return nil, fmt.Errorf("保存任务记录失败: %w", err)
} }
return &dto.ComposeMessagesRes{
if chatModel == nil { TaskId: taskID,
return nil, errors.New("当前没有对话模型,请添加") EpicycleId: id,
} }, nil
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
} }
// callInferenceModel 调用推理模型 // callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) { func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history) taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil { if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err) return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
} }
taskID, err := gateway.CreateGatewayTask(ctx, taskReq) taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil { if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err) return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
} }
if taskID == "" { if taskID == "" {
return "", errors.New("网关未返回taskId") return "", 0, errors.New("网关未返回taskId")
} }
return taskID, nil return taskID, id, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) map[string]any {
if data == nil {
data = make(map[string]any)
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{data},
}
} }
// Callback 回调处理 // Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error { func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d", g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text)) req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务 // 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{ composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId, TaskId: req.TaskId,
@@ -220,7 +196,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State, GatewayState: req.State,
OssFile: req.OssFile, OssFile: req.OssFile,
FileType: req.FileType, FileType: req.FileType,
ResultText: req.Text, ResultText: req.Messages,
}) })
// 用更新后的值发送回调 // 用更新后的值发送回调
if composeTask.CallbackUrl != "" { if composeTask.CallbackUrl != "" {
@@ -241,11 +217,11 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
var messages map[string]any var messages map[string]any
switch composeTask.BuildType { switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析 case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Text) messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析 case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Text) messages = ParseNodeResult(req.Messages)
default: default:
messages = gjson.New(req.Text).Map() messages = req.Messages
} }
// 2. 处理附加字段 // 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping) messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
@@ -257,7 +233,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State, GatewayState: req.State,
OssFile: req.OssFile, OssFile: req.OssFile,
FileType: req.FileType, FileType: req.FileType,
ResultText: req.Text, ResultText: req.Messages,
}) })
if err != nil { if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err) g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
@@ -278,18 +254,12 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
} }
// ParsePromptResult 解析提示词构建结果 // ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw string) map[string]any { func ParsePromptResult(raw map[string]any) map[string]any {
var wrapper map[string]any contentStr, ok := raw["content"].(string)
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
contentStr, ok := wrapper["content"].(string)
if !ok || contentStr == "" { if !ok || contentStr == "" {
return createDefaultResult(wrapper) return raw
} }
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil { if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return map[string]any{ return map[string]any{
"total_rounds": len(roundsArray), "total_rounds": len(roundsArray),
@@ -297,7 +267,6 @@ func ParsePromptResult(raw string) map[string]any {
} }
} }
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil { if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return map[string]any{ return map[string]any{
"total_rounds": 1, "total_rounds": 1,
@@ -305,7 +274,7 @@ func ParsePromptResult(raw string) map[string]any {
} }
} }
return createDefaultResult(map[string]any{"content": contentStr}) return map[string]any{"content": contentStr}
} }
func tryParseAsMapArray(jsonStr string) []map[string]any { func tryParseAsMapArray(jsonStr string) []map[string]any {
@@ -330,22 +299,20 @@ func tryParseAsMap(jsonStr string) map[string]any {
return obj return obj
} }
// ParseNodeResult 解析节点构建结果 func ParseNodeResult(raw map[string]any) map[string]any {
func ParseNodeResult(raw string) map[string]any { contentStr, ok := raw["content"].(string)
var result map[string]any if ok && contentStr != "" {
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
var inner map[string]any var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil { if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{inner},
}
} }
} }
return map[string]any{ return map[string]any{
"total_rounds": 1, "total_rounds": 1,
"rounds": []map[string]any{result}, "rounds": []map[string]any{raw},
} }
} }

View File

@@ -1,9 +1,10 @@
package prompt package session
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"prompts-core/model/entity"
"time" "time"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
@@ -13,37 +14,33 @@ const (
redisKeyPrefix = "chat:session:%s" redisKeyPrefix = "chat:session:%s"
) )
// saveToRedis 保存会话数据到Redis // formatRedisKey 格式化Redis
func saveToRedis(ctx context.Context, sessionId string, requestMessages []map[string]any, responseMessages []map[string]any) error { func formatRedisKey(sessionId string) string {
key := formatRedisKey(sessionId) return fmt.Sprintf(redisKeyPrefix, sessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": sessionId,
"requestContent": requestMessages,
"responseContent": responseMessages,
"timestamp": time.Now().Unix(),
} }
// saveToRedis 保存会话数据到Redis
func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
key := formatRedisKey(session.SessionId)
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
expireSeconds := g.Cfg().MustGet(ctx, "session.expireTime", 1800).Int64()
data := map[string]any{
"sessionId": session.SessionId,
"requestContent": session.RequestContent,
"responseContent": session.ResponseContent,
"timestamp": time.Now().Unix(),
}
b, err := json.Marshal(data) b, err := json.Marshal(data)
if err != nil { if err != nil {
return fmt.Errorf("序列化会话数据失败: %w", err) return fmt.Errorf("序列化会话数据失败: %w", err)
} }
if err := executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil { if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
return err return err
} }
return nil return nil
} }
// formatRedisKey 格式化Redis键
func formatRedisKey(sessionId string) string {
return fmt.Sprintf(redisKeyPrefix, sessionId)
}
// executeRedisCommands 执行Redis命令 // executeRedisCommands 执行Redis命令
func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error { func executeRedisCommands(ctx context.Context, key string, value string, maxRounds int, expireSeconds int64) error {
if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil { if _, err := g.Redis().Do(ctx, "LPUSH", key, value); err != nil {

View File

@@ -1,4 +1,4 @@
package prompt package session
import ( import (
"context" "context"
@@ -14,74 +14,36 @@ import (
"prompts-core/model/entity" "prompts-core/model/entity"
) )
// SessionCallback 会话回调 // Callback 会话回调
func SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
result, err := util.ParseOutput(req.Text) req.Messages["role"] = "assistant"
if err != nil {
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
result["role"] = "assistant"
if err = updateSessionResponse(ctx, req.EpicycleId, result); err != nil {
return nil, err
}
session, err := getSessionById(ctx, req.EpicycleId)
if err != nil {
return nil, err
}
if err := saveSessionToRedis(ctx, session); err != nil {
return nil, err
}
requestMessages := util.ConvertToMessages(session.RequestContent)
responseMessages := util.ConvertToMessages(session.ResponseContent)
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
return &dto.SessionCallbackRes{}, nil
}
// updateSessionResponse 更新会话响应
func updateSessionResponse(ctx context.Context, epicycleId int64, response any) error {
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
ResponseContent: response, ResponseContent: req.Messages,
}) })
if err != nil { if err != nil {
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", epicycleId, err) g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
return fmt.Errorf("更新数据库失败: %w", err) return nil, fmt.Errorf("更新数据库失败: %w", err)
} }
return nil
}
// getSessionById 根据ID获取会话
func getSessionById(ctx context.Context, epicycleId int64) (*entity.ComposeSession, error) {
session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{ session, err := dao.ComposeSession.Get(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: epicycleId}, SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
}) })
if session == nil {
return nil, fmt.Errorf("会话不存在: epicycleId=%d", req.EpicycleId)
}
if err != nil { if err != nil {
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", epicycleId, err) g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
return nil, fmt.Errorf("获取会话数据失败: %w", err) return nil, fmt.Errorf("获取会话数据失败: %w", err)
} }
return session, nil if err = saveToRedis(ctx, session); err != nil {
return nil, fmt.Errorf("redis存储失败: %w", err)
} }
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
// saveSessionToRedis 保存会话到Redis session.SessionId, session.Id, len(session.RequestContent), len(session.ResponseContent))
func saveSessionToRedis(ctx context.Context, session *entity.ComposeSession) error { return &dto.SessionCallbackRes{
requestMessages := util.ConvertToMessages(session.RequestContent) Status: true,
responseMessages := util.ConvertToMessages(session.ResponseContent) SessionId: session.SessionId,
}, nil
if err := saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
session.SessionId, session.Id, err)
return fmt.Errorf("Redis存储失败: %w", err)
}
return nil
} }
// GetHistoryMessages 获取历史信息 // GetHistoryMessages 获取历史信息
@@ -159,7 +121,7 @@ func cacheSessionsToRedis(ctx context.Context, sessions []*entity.ComposeSession
} }
if len(reqMsgs) > 0 || len(respMsgs) > 0 { if len(reqMsgs) > 0 || len(respMsgs) > 0 {
_ = saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs) _ = saveToRedis(ctx, session)
} }
} }
} }