400 lines
13 KiB
Go
400 lines
13 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"customer-server/consts/account"
|
||
"customer-server/consts/public"
|
||
"customer-server/consts/scriptedSpeech"
|
||
"customer-server/dao"
|
||
"customer-server/model/dto"
|
||
"customer-server/model/entity"
|
||
"encoding/json"
|
||
"fmt"
|
||
"slices"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
"gitea.com/red-future/common/http"
|
||
"gitea.com/red-future/common/jaeger"
|
||
"gitea.com/red-future/common/utils"
|
||
gmq "github.com/bjang03/gmq/core/gmq"
|
||
"github.com/bjang03/gmq/mq"
|
||
"github.com/bjang03/gmq/types"
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var SessionToolService = new(sessionToolService)
|
||
|
||
type sessionToolService struct{}
|
||
|
||
func (s *sessionToolService) PushOpeningRemark(ctx context.Context, userId string, accountInfo *dto.AccountVO, headers map[string]string) (content string, err error) {
|
||
content = ""
|
||
var sceneType = scriptedSpeech.SceneTypeOpeningRemark
|
||
var key = fmt.Sprintf("account:%s:%s:%s", accountInfo.AccountCode, account.GetDescByCode(accountInfo.Platform), userId)
|
||
get, err := g.Redis().Get(ctx, key)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if g.IsEmpty(get) {
|
||
// 构建开场白内容
|
||
if len(accountInfo.DatasetIds) > 1 {
|
||
var datasetInfo *dto.RagListDatasetRes
|
||
datasetInfo, err = SessionToolService.GetDatasetInfo(ctx, accountInfo.DatasetIds, headers)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if g.IsEmpty(datasetInfo) {
|
||
err = fmt.Errorf("数据集不存在")
|
||
return
|
||
}
|
||
var datasetDescriptions []string
|
||
for _, dataset := range datasetInfo.List {
|
||
datasetDescriptions = append(datasetDescriptions, dataset.Name)
|
||
}
|
||
content = SessionToolService.BuildMenuContent(accountInfo.Greeting, datasetDescriptions, len(accountInfo.DatasetIds))
|
||
} else {
|
||
content = SessionToolService.BuildMenuContent(accountInfo.Greeting, accountInfo.KeywordOption, len(accountInfo.DatasetIds))
|
||
}
|
||
err = s.pushDelayMsg(ctx, key, sceneType.Code(), sceneType.Desc(), accountInfo.DatasetIds)
|
||
if err != nil {
|
||
return
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (s *sessionToolService) PushDialog(ctx context.Context, userId string, questionContent string, accountInfo *dto.AccountVO, headers map[string]string) (content string, err error) {
|
||
sceneType := scriptedSpeech.SceneTypeDialog
|
||
// 删除延迟消息
|
||
//err = s.DeleteDelayMsg(ctx)
|
||
//if err != nil {
|
||
// return nil, err
|
||
//}
|
||
content = ""
|
||
var key = fmt.Sprintf("account:%s:%s:%s", accountInfo.AccountCode, account.GetDescByCode(accountInfo.Platform), userId)
|
||
get, err := g.Redis().Get(ctx, key)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if !g.IsEmpty(get) {
|
||
// 获取用户对话上下文
|
||
var history []*dto.Message
|
||
history, err = SessionToolService.GetUserHistory(ctx, userId)
|
||
if err != nil {
|
||
err = fmt.Errorf("获取用户对话上下文失败: %w", err)
|
||
return
|
||
}
|
||
|
||
// 获取用户对话记录
|
||
var accountUserDialog *entity.AccountUserDialog
|
||
accountUserDialog, err = dao.AccountUserDialog.Get(ctx, &dto.GetAccountUserDialogReq{
|
||
AccountId: accountInfo.Id,
|
||
UserId: userId,
|
||
})
|
||
if err != nil {
|
||
err = fmt.Errorf("获取用户对话记录失败: %w", err)
|
||
return
|
||
}
|
||
if g.IsEmpty(accountUserDialog.Id) {
|
||
// 保存用户对话记录
|
||
if _, err = dao.AccountUserDialog.Insert(ctx, &dto.AddAccountUserDialogReq{
|
||
AccountId: accountInfo.Id,
|
||
UserId: userId,
|
||
DialogCount: 1,
|
||
}); err != nil {
|
||
err = fmt.Errorf("保存用户对话记录失败: %w", err)
|
||
return
|
||
}
|
||
} else {
|
||
if accountUserDialog.DialogCount >= g.Cfg().MustGet(ctx, "card.triggerCount").Int64() {
|
||
// TODO 替换为实际卡片发送逻辑
|
||
content = "请加一下卡片的联系方式,进行更专业的咨询"
|
||
sceneType = scriptedSpeech.SceneTypeCardSend
|
||
if _, err = SessionToolService.ClearUserHistory(ctx, userId); err != nil {
|
||
err = fmt.Errorf("清除用户对话上下文失败: %w", err)
|
||
return
|
||
}
|
||
} else {
|
||
// 更新用户对话记录
|
||
if _, err = dao.AccountUserDialog.Update(ctx, &dto.UpdateAccountUserDialogReq{
|
||
Id: accountUserDialog.Id,
|
||
DialogCount: 1,
|
||
}); err != nil {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
if sceneType.Code() != scriptedSpeech.SceneTypeCardSend.Code() {
|
||
// 通过HTTP调用rag服务的RAG查询接口
|
||
var ragQuery *dto.RagQueryRes
|
||
ragQuery, err = SessionToolService.GetRagQuery(ctx, questionContent, accountInfo.DatasetIds, history, headers)
|
||
if err != nil {
|
||
err = fmt.Errorf("调用rag服务的RAG查询接口失败: %w", err)
|
||
return
|
||
}
|
||
content = ragQuery.Answer
|
||
|
||
// 保存用户对话上下文
|
||
err = SessionToolService.SaveUserHistory(ctx, userId, []*dto.Message{
|
||
{Role: "user", Content: questionContent},
|
||
{Role: "assistant", Content: content},
|
||
})
|
||
if err != nil {
|
||
err = fmt.Errorf("保存用户对话上下文失败: %w", err)
|
||
return
|
||
}
|
||
}
|
||
|
||
err = s.pushDelayMsg(ctx, key, sceneType.Code(), sceneType.Desc(), accountInfo.DatasetIds)
|
||
if err != nil {
|
||
return
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (s *sessionToolService) pushDelayMsg(ctx context.Context, key string, sceneTypeCode scriptedSpeech.SceneType, sceneTypeDesc string, datasetIds []int64) (err error) {
|
||
err = g.Redis().SetEX(ctx, key, sceneTypeDesc, gconv.Int64(10*time.Second))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 获取追问话术内容
|
||
var msg string
|
||
if len(datasetIds) == 1 {
|
||
scriptedSpeechInfo, err := SessionToolService.GetScriptedSpeechContent(ctx, datasetIds[0], sceneTypeCode)
|
||
if err != nil {
|
||
return fmt.Errorf("获取追问话术内容失败: %w", err)
|
||
}
|
||
if g.IsEmpty(scriptedSpeechInfo) {
|
||
if sceneTypeCode == scriptedSpeech.SceneTypeOpeningRemark.Code() {
|
||
msg = "宝子,刚才给您发的信息您有看到吗?有任何问题都能直接问我,加微信也能更方便沟通~"
|
||
} else if sceneTypeCode == scriptedSpeech.SceneTypeDialog.Code() {
|
||
msg = "看您暂时没回复,是不是还有什么疑问?加微信我详细给您说明~"
|
||
} else if sceneTypeCode == scriptedSpeech.SceneTypeCardSend.Code() {
|
||
msg = "宝子,加上没~要及时加哦,不然卡片容易失效哒✨"
|
||
}
|
||
}
|
||
msg = scriptedSpeechInfo.QuestionContent
|
||
} else {
|
||
msg = "宝子,刚才给您发的信息您有看到吗?有任何问题都能直接问我,加微信也能更方便沟通~"
|
||
}
|
||
var msgMap = map[string]string{
|
||
"key": key,
|
||
"data": msg,
|
||
}
|
||
err = gmq.GetGmq(public.GmqMsgPluginsName).GmqPublishDelay(ctx, &mq.NatsPubDelayMessage{
|
||
PubDelayMessage: types.PubDelayMessage{
|
||
PubMessage: types.PubMessage{
|
||
Topic: public.AccountFollowupTopic,
|
||
Data: msgMap,
|
||
},
|
||
DelaySeconds: 60,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// GetAccountInfo 获取客服账号信息
|
||
func (s *sessionToolService) GetAccountInfo(ctx context.Context, accountCode string) (res *dto.AccountVO, err error) {
|
||
r, err := dao.Account.GetByAccountCode(ctx, &dto.GetByAccountCodeReq{
|
||
AccountCode: accountCode,
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取客服账号信息失败: %w", err)
|
||
}
|
||
err = gconv.Struct(r, &res)
|
||
return
|
||
}
|
||
|
||
// SetUserInfo 设置用户信息
|
||
func (s *sessionToolService) SetUserInfo(ctx context.Context, creator string, tenantId uint64) (headers map[string]string, err error) {
|
||
// 创建完整的用户信息
|
||
userInfo := &beans.User{
|
||
UserName: creator,
|
||
TenantId: tenantId,
|
||
}
|
||
ctx = context.WithValue(ctx, "user", *userInfo)
|
||
// 提取并保存请求头(在连接升级前)
|
||
headers = make(map[string]string)
|
||
// 提取其他headers
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
for k, v := range r.Request.Header {
|
||
if len(v) > 0 {
|
||
headers[k] = v[0]
|
||
}
|
||
}
|
||
}
|
||
// 将完整用户信息序列化为JSON,放到X-User-Info请求头
|
||
userInfoJson, err := gjson.Encode(userInfo)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("用户信息序列化失败: %w", err)
|
||
}
|
||
headers["X-User-Info"] = string(userInfoJson)
|
||
return
|
||
}
|
||
|
||
// GetDatasetInfo 获取数据集信息
|
||
func (s *sessionToolService) GetDatasetInfo(ctx context.Context, datasetIds []int64, headers map[string]string) (res *dto.RagListDatasetRes, err error) {
|
||
// 通过HTTP调用rag服务的关键词查询接口
|
||
res = &dto.RagListDatasetRes{}
|
||
if err = http.Get(ctx, "rag/dataset/list", headers, &res, &dto.RagListDatasetReq{
|
||
Ids: datasetIds,
|
||
}); err != nil {
|
||
return nil, fmt.Errorf("获取数据集信息失败: %w", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
// BuildMenuContent 生成菜单话术内容
|
||
func (s *sessionToolService) BuildMenuContent(greeting string, options []string, datasetCount int) string {
|
||
var sb strings.Builder
|
||
// 问候语
|
||
if datasetCount > 1 {
|
||
greeting = "您好,很高兴为您服务!请问咨询什么方面问题?"
|
||
} else {
|
||
if greeting == "" {
|
||
greeting = "您好,很高兴为您服务!请问有什么可以帮您?"
|
||
}
|
||
}
|
||
|
||
sb.WriteString(greeting)
|
||
sb.WriteByte('\n')
|
||
// 拼接选项 1、xx 2、xx...
|
||
for i, opt := range options {
|
||
sb.WriteString(fmt.Sprintf("%d、%s\n", i+1, opt))
|
||
if i == len(options)-1 {
|
||
sb.WriteString(fmt.Sprintf("%s\n", "💗回复数字就好~"))
|
||
}
|
||
}
|
||
// 固定结尾
|
||
sb.WriteString("🌟也可直接点击下方咨询专业老师~")
|
||
|
||
return sb.String()
|
||
}
|
||
|
||
// GetScriptedSpeechContent 获取话术内容
|
||
func (s *sessionToolService) GetScriptedSpeechContent(ctx context.Context, datasetId int64, sceneType scriptedSpeech.SceneType) (res *dto.ScriptedSpeechVO, err error) {
|
||
r, err := dao.ScriptedSpeech.GetByDatasetIdAndSceneType(ctx, &dto.ListScriptedSpeechReq{
|
||
DatasetId: datasetId,
|
||
SceneType: sceneType,
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = gconv.Struct(r, &res)
|
||
return
|
||
}
|
||
|
||
// GetRagQuery 获取rag查询结果
|
||
func (s *sessionToolService) GetRagQuery(ctx context.Context, questionContent string, datasetIds []int64, history []*dto.Message, headers map[string]string) (res *dto.RagQueryRes, err error) {
|
||
resp := new(dto.RagQueryRes)
|
||
if err = http.Post(ctx, "rag/document/vector/ragQuery", headers, &resp, &dto.RagQueryReq{
|
||
Content: questionContent,
|
||
DatasetIds: datasetIds,
|
||
History: history,
|
||
TopK: 5,
|
||
}); err != nil {
|
||
return
|
||
}
|
||
return resp, nil
|
||
}
|
||
|
||
// SaveUserHistory 保存用户对话历史到Redis
|
||
func (s *sessionToolService) SaveUserHistory(ctx context.Context, userKey string, newMessages []*dto.Message) (err error) {
|
||
key := fmt.Sprintf(public.AccountDialogKeyUserId, userKey)
|
||
|
||
// 1. 先读旧历史
|
||
var oldMessages []*dto.Message
|
||
oldMessages, err = s.GetUserHistory(ctx, key)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 2. 合并
|
||
allMessages := append(oldMessages, newMessages...)
|
||
|
||
// 3. 限制长度(保留最新 N 轮)
|
||
maxMsgCount := 2 * g.Cfg().MustGet(ctx, "history.contextLimit", 5).Int()
|
||
if len(allMessages) > maxMsgCount {
|
||
allMessages = allMessages[len(allMessages)-maxMsgCount:]
|
||
}
|
||
|
||
// 4. 存回Redis
|
||
data, err := json.Marshal(allMessages)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return g.Redis().SetEX(ctx, key, data, gconv.Int64(15*time.Second))
|
||
}
|
||
|
||
// GetUserHistory 从Redis获取用户历史
|
||
func (s *sessionToolService) GetUserHistory(ctx context.Context, key string) ([]*dto.Message, error) {
|
||
data, err := g.Redis().Get(ctx, key)
|
||
if err != nil || data.IsEmpty() {
|
||
return []*dto.Message{}, nil
|
||
}
|
||
|
||
var messages []*dto.Message
|
||
if err = json.Unmarshal(data.Bytes(), &messages); err != nil {
|
||
return []*dto.Message{}, err
|
||
}
|
||
return messages, nil
|
||
}
|
||
|
||
// ClearUserHistory 清空历史(可选)
|
||
func (s *sessionToolService) ClearUserHistory(ctx context.Context, userKey string) (int64, error) {
|
||
key := fmt.Sprintf(public.AccountDialogKeyUserId, userKey)
|
||
return g.Redis().Del(ctx, key)
|
||
}
|
||
|
||
// getDatasetIdsByKeywords 通过关键词查询数据集ID
|
||
func (s *sessionToolService) getDatasetIdsByKeywords(ctx context.Context, content string, headers map[string]string) (res []int64, err error) {
|
||
// 1. 提取关键词
|
||
keywords := s.extractKeywords(content)
|
||
g.Log().Infof(ctx, "提取关键词: %v", keywords)
|
||
|
||
// 通过HTTP调用rag服务的关键词查询接口
|
||
respKeyword := &dto.RAGListKeywordRes{}
|
||
if err = http.Get(ctx, "rag/keyword/listKeyword", headers, &respKeyword, &dto.RAGListKeywordReq{
|
||
Words: keywords,
|
||
}); err != nil {
|
||
jaeger.RecordError(ctx, err, "RAG查询关键词失败")
|
||
g.Log().Errorf(ctx, "RAG查询关键词失败: %v", err)
|
||
return
|
||
}
|
||
var datasetIds []int64
|
||
for _, v := range respKeyword.List {
|
||
if !slices.Contains(datasetIds, v.DatasetId) {
|
||
datasetIds = append(datasetIds, v.DatasetId)
|
||
}
|
||
}
|
||
return datasetIds, nil
|
||
}
|
||
|
||
// extractKeywords 提取关键词
|
||
func (s *sessionToolService) extractKeywords(text string) []string {
|
||
if text == "" {
|
||
return []string{}
|
||
}
|
||
|
||
// 使用gse分词工具提取关键词
|
||
keywords := utils.GseTool.Extract(text, 5)
|
||
|
||
words := make([]string, 0, len(keywords))
|
||
for _, kw := range keywords {
|
||
if kw.Word != "" {
|
||
words = append(words, kw.Word)
|
||
}
|
||
}
|
||
|
||
// 如果没有提取到关键词,使用分词结果
|
||
if len(words) == 0 {
|
||
words = utils.GseTool.Cut(text)
|
||
}
|
||
|
||
return words
|
||
}
|