113 lines
3.3 KiB
Go
113 lines
3.3 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"prompts-core/dao"
|
|
"prompts-core/model/dto"
|
|
"prompts-core/model/entity"
|
|
|
|
"gitea.com/red-future/common/beans"
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
)
|
|
|
|
var Session = &sessionService{}
|
|
|
|
type sessionService struct{}
|
|
|
|
func (s *sessionService) SessionCallback(ctx context.Context, req *dto.SessionCallbackReq) (res *beans.ResponseEmpty, err error) {
|
|
// 1. 解析AI返回的文本
|
|
result, err := parseOutput(req.Text)
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] 解析模型输出失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
return nil, err
|
|
}
|
|
|
|
// 2. 更新数据库
|
|
result["role"] = "assistant"
|
|
_, err = dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
|
ResponseContent: result,
|
|
})
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] 更新数据库失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
return nil, err
|
|
}
|
|
|
|
// 3. 获取当前轮次完整数据
|
|
session, err := dao.ComposeSession.GetById(ctx, req.EpicycleId)
|
|
if err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] 获取会话数据失败 epicycleId=%d err=%v", req.EpicycleId, err)
|
|
return nil, err
|
|
}
|
|
|
|
// 4. 转换 json 并存入 Redis
|
|
requestMessages := convertToMessages(session.RequestContent)
|
|
responseMessages := convertToMessages(session.ResponseContent)
|
|
|
|
if err = s.saveToRedis(ctx, session.SessionId, requestMessages, responseMessages); err != nil {
|
|
g.Log().Errorf(ctx, "[会话回调] Redis存储失败 sessionId=%s id=%d err=%v",
|
|
session.SessionId, session.Id, err)
|
|
return nil, err
|
|
}
|
|
|
|
g.Log().Infof(ctx, "[会话回调] 存储成功 sessionId=%s id=%d requestLen=%d responseLen=%d",
|
|
session.SessionId, session.Id, len(requestMessages), len(responseMessages))
|
|
return &beans.ResponseEmpty{}, nil
|
|
}
|
|
|
|
// GetHistoryMessages 获取历史信息
|
|
func (s *sessionService) GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
|
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
|
|
|
// 1. 先从 Redis 拿
|
|
redisHistory, err := s.GetSessionHistoryForInference(ctx, sessionId)
|
|
if err == nil && len(redisHistory) > 0 {
|
|
return redisHistory, nil
|
|
}
|
|
|
|
// 2. Redis 没有 → fallback DB
|
|
sessions, err := dao.ComposeSession.GetListBySessionId(ctx, sessionId, maxRounds)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
|
}
|
|
|
|
var messages []map[string]any
|
|
|
|
for _, session := range sessions {
|
|
// request
|
|
reqMsgs := convertToMessages(session.RequestContent)
|
|
for _, m := range reqMsgs {
|
|
role := gconv.String(m["role"])
|
|
if role == "user" || role == "assistant" {
|
|
messages = append(messages, m)
|
|
}
|
|
}
|
|
|
|
// response
|
|
respMsgs := convertToMessages(session.ResponseContent)
|
|
for _, m := range respMsgs {
|
|
if m["role"] == nil {
|
|
m["role"] = "assistant"
|
|
}
|
|
messages = append(messages, m)
|
|
}
|
|
}
|
|
|
|
// 3. 回写 Redis
|
|
for _, session := range sessions {
|
|
reqMsgs := convertToMessages(session.RequestContent)
|
|
respMsgs := convertToMessages(session.ResponseContent)
|
|
for i := range respMsgs {
|
|
if respMsgs[i]["role"] == nil {
|
|
respMsgs[i]["role"] = "assistant"
|
|
}
|
|
}
|
|
if len(reqMsgs) > 0 || len(respMsgs) > 0 {
|
|
_ = s.saveToRedis(ctx, session.SessionId, reqMsgs, respMsgs)
|
|
}
|
|
}
|
|
return messages, nil
|
|
}
|