feat: 优化RAG检索与聊天模型支持历史对话

实现双路检索并行优化,使用EINO官方模板重构聊天逻辑,增加多轮对话历史记录管理及相关性过滤,并修复数据库唯一索引。
This commit is contained in:
2026-04-09 13:57:46 +08:00
parent 14a429f4ae
commit 2ced0a43e5
9 changed files with 310 additions and 147 deletions

View File

@@ -14,91 +14,122 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
var globalChatModel *qwen.ChatModel
const (
MaxHistoryTurns = 5 // 最大历史轮数
)
var (
globalChatModel *qwen.ChatModel
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
)
func init() {
ctx := context.Background()
// 初始化大模型
if err := initChatModel(ctx); err != nil {
glog.Errorf(ctx, "初始化大模型失败: %v", err)
}
// 初始化 EINO 提示词模板
initRAGPromptTemplate()
return
}
// 初始化通义千问
func initChatModel(ctx context.Context) error {
if globalChatModel != nil {
return nil
}
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
var err error
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: apiKey,
Model: model,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
Timeout: 60 * 1e9,
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
MaxTokens: gconv.PtrInt(1024), // 最长回答
TopP: gconv.PtrFloat32(1.0),
})
if err != nil {
glog.Errorf(ctx, "初始化大模型失败: %v", err)
return err
}
return
globalChatModel = cm
return nil
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
// 1. 构建参考知识
knowledge, sources := buildKnowledgeAndSources(docs)
// 2. 构建提示词
msgs, err := buildPromptMessages(ctx, knowledge, content)
if err != nil {
return
}
// 3. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
return
}
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
var knowledge string
var sources []string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
// 提取 document_id
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
sources = append(sources, gconv.String(docID))
}
}
return knowledge, sources
}
// buildPromptMessages 构建提示词模板
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
promptTpl := prompt.FromMessages(
// 初始化 EINO 官方提示词模板(最关键!)
func initRAGPromptTemplate() {
ragPromptTemplate = prompt.FromMessages(
schema.FString,
// 系统提示(带参考知识)
&schema.Message{
Role: schema.System,
// Content: `你是专业客服助手,语气友好。
//如果参考知识中有相关信息,请优先依据参考知识回答。
//如果没有相关信息,就正常回答,不要说无法回答。
//
//参考知识:
//{knowledge}`,
Content: `你是专业的客服助手,语气友好。
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
Content: `你是专业客服,语气友好简洁
请严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
},
// 用户问题
&schema.Message{
Role: schema.User,
Content: "{question}",
},
)
}
return promptTpl.Format(ctx, map[string]any{
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) {
// 1. 构建参考知识
knowledge := buildKnowledgeAndSources(docs)
// 2. 历史精简
history = limitHistory(history)
// 3. ✅ EINO 官方模板格式化(超级干净)
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
"knowledge": knowledge,
"question": question,
})
if err != nil {
return nil, err
}
// 4. 历史插入到模板消息中间标准EINO用法
if len(history) > 0 {
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
}
// 5. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
return
}
func limitHistory(history []*schema.Message) []*schema.Message {
valid := make([]*schema.Message, 0, len(history))
for _, m := range history {
if m.Role == schema.User || m.Role == schema.Assistant {
valid = append(valid, m)
}
}
keep := 2 * MaxHistoryTurns
if len(valid) > keep {
valid = valid[len(valid)-keep:]
}
return valid
}
// buildKnowledgeAndSources 拼接参考知识
func buildKnowledgeAndSources(docs []*schema.Document) string {
var knowledge string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
return knowledge
}
// streamGenerateAnswer 流式生成

View File

@@ -5,11 +5,14 @@ import (
"errors"
"rag/dao"
"sort"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
@@ -53,43 +56,139 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
}
options = retriever.GetCommonOptions(options, opts...)
// 安全保护:防止 nil 指针 panic
topK := 10
if options.TopK != nil {
topK = *options.TopK
}
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// ==========================================
// 🔥 双路检索:向量 + 全文
// 🔥 优化版grpool 并行双路检索(安全、健壮、无泄漏)
// ==========================================
docsVector, err := r.doRetrieveVector(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
var (
docsVector []*schema.Document
docsFulltext []*schema.Document
errVector error
errFulltext error
// 缓冲通道=2确保无死锁等待
done = make(chan struct{}, 2)
)
// 上下文:超时 + 可取消双保障建议5s超时根据业务调整
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// 封装并行任务函数,消除重复代码
runTask := func(task func() error, errTarget *error) {
defer func() {
// 任务结束必发信号,确保通道不阻塞
done <- struct{}{}
}()
// 捕获 panic + 执行业务逻辑
g.TryCatch(taskCtx, func(ctx context.Context) {
*errTarget = task()
}, func(ctx context.Context, panicErr error) {
*errTarget = panicErr
})
// 任务失败:立即取消另一个任务(快速失败)
if *errTarget != nil {
cancel()
}
}
// ----------------------
// 并行提交两个检索任务
// ----------------------
// 任务1向量检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
return errVector
}, &errVector)
})
// 任务2全文检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
return errFulltext
}, &errFulltext)
})
// ----------------------
// 安全等待所有任务完成
// ----------------------
<-done
<-done
// ----------------------
// 统一错误处理
// ----------------------
// 用 errors.Join 合并所有错误,不丢失信息
if err := errors.Join(errVector, errFulltext); err != nil {
return nil, err
}
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 合并 + 去重
// 合并 + 智能去重(保留最优分数)
docs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序distance 越小越靠前)
// 排序:向量优先,同类型按距离升序
sort.Slice(docs, func(i, j int) bool {
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
//
//// 有类型标记的优先
//if okI && !okJ {
// return true
//}
//if !okI && okJ {
// return false
//}
//
//// 向量永远排前面
//if byI == "vector" && byJ == "fulltext" {
// return true
//}
//if byI == "fulltext" && byJ == "vector" {
// return false
//}
// 同类型按 distance 升序(越小越相似)
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 最多保留 topK
if len(docs) > *options.TopK {
docs = docs[:*options.TopK]
// 在Retrieve方法末尾增加相关性校验
validDocs := make([]*schema.Document, 0)
for i, d := range docs {
// 过滤distance过大的垃圾结果比如distance>0.8的直接丢弃)
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
validDocs = append(validDocs, d)
}
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
return docs, nil
// 如果没有有效结果返回空让LLM回答「暂无相关信息」
if len(validDocs) == 0 {
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
}
// 最多保留 topK
if len(validDocs) > topK {
validDocs = validDocs[:topK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
}
// ==========================================
@@ -105,7 +204,10 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
}
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := *opts.TopK
topK := 10
if opts.TopK != nil {
topK = *opts.TopK
}
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
@@ -144,13 +246,16 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
score := gconv.Float64(row["_rankingScore"])
distance := score
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": 0.1, // 全文结果给高分
"distance": distance,
"retrieve_by": "fulltext",
},
})
@@ -159,18 +264,26 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str
}
// ==========================================
// 合并去重
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
// 先存入向量结果
for _, d := range vecDocs {
idMap[d.ID] = d
}
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
for _, d := range fullDocs {
if _, exists := idMap[d.ID]; !exists {
if existDoc, ok := idMap[d.ID]; ok {
// 标记同时被向量和全文检索到
existDoc.MetaData["retrieve_by"] = "both"
} else {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)