feat: 优化RAG检索与聊天模型支持历史对话
实现双路检索并行优化,使用EINO官方模板重构聊天逻辑,增加多轮对话历史记录管理及相关性过滤,并修复数据库唯一索引。
This commit is contained in:
@@ -14,91 +14,122 @@ import (
|
|||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var globalChatModel *qwen.ChatModel
|
const (
|
||||||
|
MaxHistoryTurns = 5 // 最大历史轮数
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
globalChatModel *qwen.ChatModel
|
||||||
|
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
// 初始化大模型
|
||||||
|
if err := initChatModel(ctx); err != nil {
|
||||||
|
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||||
|
}
|
||||||
|
// 初始化 EINO 提示词模板
|
||||||
|
initRAGPromptTemplate()
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化通义千问
|
||||||
|
func initChatModel(ctx context.Context) error {
|
||||||
|
if globalChatModel != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||||
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||||
|
|
||||||
var err error
|
cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||||
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
Model: model,
|
Model: model,
|
||||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
Timeout: 60 * 1e9,
|
||||||
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||||
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||||
TopP: gconv.PtrFloat32(1.0),
|
TopP: gconv.PtrFloat32(1.0),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
globalChatModel = cm
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatModel 只处理逻辑,不复用创建模型
|
// 初始化 EINO 官方提示词模板(最关键!)
|
||||||
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
|
func initRAGPromptTemplate() {
|
||||||
// 1. 构建参考知识
|
ragPromptTemplate = prompt.FromMessages(
|
||||||
knowledge, sources := buildKnowledgeAndSources(docs)
|
|
||||||
|
|
||||||
// 2. 构建提示词
|
|
||||||
msgs, err := buildPromptMessages(ctx, knowledge, content)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 🔥 直接使用全局单例,不重复创建
|
|
||||||
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
|
|
||||||
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
|
|
||||||
var knowledge string
|
|
||||||
var sources []string
|
|
||||||
|
|
||||||
for i, doc := range docs {
|
|
||||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
|
||||||
|
|
||||||
// 提取 document_id
|
|
||||||
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
|
|
||||||
sources = append(sources, gconv.String(docID))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return knowledge, sources
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPromptMessages 构建提示词模板
|
|
||||||
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
|
|
||||||
promptTpl := prompt.FromMessages(
|
|
||||||
schema.FString,
|
schema.FString,
|
||||||
|
// 系统提示(带参考知识)
|
||||||
&schema.Message{
|
&schema.Message{
|
||||||
Role: schema.System,
|
Role: schema.System,
|
||||||
// Content: `你是专业的客服助手,语气友好。
|
Content: `你是专业客服,语气友好简洁。
|
||||||
//如果参考知识中有相关信息,请优先依据参考知识回答。
|
请严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
|
||||||
//如果没有相关信息,就正常回答,不要说无法回答。
|
|
||||||
//
|
参考知识:
|
||||||
//参考知识:
|
{knowledge}`,
|
||||||
//{knowledge}`,
|
|
||||||
Content: `你是专业的客服助手,语气友好。
|
|
||||||
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
|
|
||||||
|
|
||||||
参考知识:
|
|
||||||
{knowledge}`,
|
|
||||||
},
|
},
|
||||||
|
// 用户问题
|
||||||
&schema.Message{
|
&schema.Message{
|
||||||
Role: schema.User,
|
Role: schema.User,
|
||||||
Content: "{question}",
|
Content: "{question}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return promptTpl.Format(ctx, map[string]any{
|
// NewChatModel 只处理逻辑,不复用创建模型
|
||||||
|
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) {
|
||||||
|
// 1. 构建参考知识
|
||||||
|
knowledge := buildKnowledgeAndSources(docs)
|
||||||
|
// 2. 历史精简
|
||||||
|
history = limitHistory(history)
|
||||||
|
// 3. ✅ EINO 官方模板格式化(超级干净)
|
||||||
|
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
|
||||||
"knowledge": knowledge,
|
"knowledge": knowledge,
|
||||||
"question": question,
|
"question": question,
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 4. 历史插入到模板消息中间(标准EINO用法)
|
||||||
|
if len(history) > 0 {
|
||||||
|
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
|
||||||
|
}
|
||||||
|
// 5. 🔥 直接使用全局单例,不重复创建
|
||||||
|
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func limitHistory(history []*schema.Message) []*schema.Message {
|
||||||
|
valid := make([]*schema.Message, 0, len(history))
|
||||||
|
for _, m := range history {
|
||||||
|
if m.Role == schema.User || m.Role == schema.Assistant {
|
||||||
|
valid = append(valid, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
keep := 2 * MaxHistoryTurns
|
||||||
|
if len(valid) > keep {
|
||||||
|
valid = valid[len(valid)-keep:]
|
||||||
|
}
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildKnowledgeAndSources 拼接参考知识
|
||||||
|
func buildKnowledgeAndSources(docs []*schema.Document) string {
|
||||||
|
var knowledge string
|
||||||
|
|
||||||
|
for i, doc := range docs {
|
||||||
|
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||||
|
|
||||||
|
}
|
||||||
|
return knowledge
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamGenerateAnswer 流式生成
|
// streamGenerateAnswer 流式生成
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"rag/dao"
|
"rag/dao"
|
||||||
"sort"
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/callbacks"
|
"github.com/cloudwego/eino/callbacks"
|
||||||
"github.com/cloudwego/eino/components/embedding"
|
"github.com/cloudwego/eino/components/embedding"
|
||||||
"github.com/cloudwego/eino/components/retriever"
|
"github.com/cloudwego/eino/components/retriever"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/os/grpool"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
"github.com/pgvector/pgvector-go"
|
"github.com/pgvector/pgvector-go"
|
||||||
)
|
)
|
||||||
@@ -53,43 +56,139 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
|
|||||||
}
|
}
|
||||||
options = retriever.GetCommonOptions(options, opts...)
|
options = retriever.GetCommonOptions(options, opts...)
|
||||||
|
|
||||||
|
// 安全保护:防止 nil 指针 panic
|
||||||
|
topK := 10
|
||||||
|
if options.TopK != nil {
|
||||||
|
topK = *options.TopK
|
||||||
|
}
|
||||||
|
|
||||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||||
Query: query,
|
Query: query,
|
||||||
TopK: *options.TopK,
|
TopK: *options.TopK,
|
||||||
})
|
})
|
||||||
|
|
||||||
// ==========================================
|
// ==========================================
|
||||||
// 🔥 双路检索:向量 + 全文
|
// 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏)
|
||||||
// ==========================================
|
// ==========================================
|
||||||
docsVector, err := r.doRetrieveVector(ctx, query, options)
|
var (
|
||||||
if err != nil {
|
docsVector []*schema.Document
|
||||||
callbacks.OnError(ctx, err)
|
docsFulltext []*schema.Document
|
||||||
|
errVector error
|
||||||
|
errFulltext error
|
||||||
|
|
||||||
|
// 缓冲通道=2,确保无死锁等待
|
||||||
|
done = make(chan struct{}, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整)
|
||||||
|
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 封装并行任务函数,消除重复代码
|
||||||
|
runTask := func(task func() error, errTarget *error) {
|
||||||
|
defer func() {
|
||||||
|
// 任务结束必发信号,确保通道不阻塞
|
||||||
|
done <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 捕获 panic + 执行业务逻辑
|
||||||
|
g.TryCatch(taskCtx, func(ctx context.Context) {
|
||||||
|
*errTarget = task()
|
||||||
|
}, func(ctx context.Context, panicErr error) {
|
||||||
|
*errTarget = panicErr
|
||||||
|
})
|
||||||
|
|
||||||
|
// 任务失败:立即取消另一个任务(快速失败)
|
||||||
|
if *errTarget != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------
|
||||||
|
// 并行提交两个检索任务
|
||||||
|
// ----------------------
|
||||||
|
// 任务1:向量检索
|
||||||
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
|
runTask(func() error {
|
||||||
|
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
|
||||||
|
return errVector
|
||||||
|
}, &errVector)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 任务2:全文检索
|
||||||
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
|
runTask(func() error {
|
||||||
|
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
|
||||||
|
return errFulltext
|
||||||
|
}, &errFulltext)
|
||||||
|
})
|
||||||
|
|
||||||
|
// ----------------------
|
||||||
|
// 安全等待所有任务完成
|
||||||
|
// ----------------------
|
||||||
|
<-done
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// ----------------------
|
||||||
|
// 统一错误处理
|
||||||
|
// ----------------------
|
||||||
|
// 用 errors.Join 合并所有错误,不丢失信息
|
||||||
|
if err := errors.Join(errVector, errFulltext); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
|
// 合并 + 智能去重(保留最优分数)
|
||||||
if err != nil {
|
|
||||||
callbacks.OnError(ctx, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 合并 + 去重
|
|
||||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||||
|
|
||||||
// 排序(distance 越小越靠前)
|
// 排序:向量优先,同类型按距离升序
|
||||||
sort.Slice(docs, func(i, j int) bool {
|
sort.Slice(docs, func(i, j int) bool {
|
||||||
|
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
|
||||||
|
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
|
||||||
|
//
|
||||||
|
//// 有类型标记的优先
|
||||||
|
//if okI && !okJ {
|
||||||
|
// return true
|
||||||
|
//}
|
||||||
|
//if !okI && okJ {
|
||||||
|
// return false
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//// 向量永远排前面
|
||||||
|
//if byI == "vector" && byJ == "fulltext" {
|
||||||
|
// return true
|
||||||
|
//}
|
||||||
|
//if byI == "fulltext" && byJ == "vector" {
|
||||||
|
// return false
|
||||||
|
//}
|
||||||
|
|
||||||
|
// 同类型按 distance 升序(越小越相似)
|
||||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||||
return d1 < d2
|
return d1 < d2
|
||||||
})
|
})
|
||||||
|
|
||||||
// 最多保留 topK
|
// 在Retrieve方法末尾,增加相关性校验
|
||||||
if len(docs) > *options.TopK {
|
validDocs := make([]*schema.Document, 0)
|
||||||
docs = docs[:*options.TopK]
|
for i, d := range docs {
|
||||||
|
// 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃)
|
||||||
|
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
|
||||||
|
validDocs = append(validDocs, d)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
|
// 如果没有有效结果,返回空,让LLM回答「暂无相关信息」
|
||||||
return docs, nil
|
if len(validDocs) == 0 {
|
||||||
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||||
|
return validDocs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最多保留 topK
|
||||||
|
if len(validDocs) > topK {
|
||||||
|
validDocs = validDocs[:topK]
|
||||||
|
}
|
||||||
|
|
||||||
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||||
|
return validDocs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==========================================
|
// ==========================================
|
||||||
@@ -105,7 +204,10 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||||||
topK := *opts.TopK
|
topK := 10
|
||||||
|
if opts.TopK != nil {
|
||||||
|
topK = *opts.TopK
|
||||||
|
}
|
||||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||||
|
|
||||||
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||||
@@ -144,13 +246,16 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str
|
|||||||
|
|
||||||
docs := make([]*schema.Document, 0, len(rows))
|
docs := make([]*schema.Document, 0, len(rows))
|
||||||
for _, row := range rows {
|
for _, row := range rows {
|
||||||
|
score := gconv.Float64(row["_rankingScore"])
|
||||||
|
distance := score
|
||||||
|
|
||||||
docs = append(docs, &schema.Document{
|
docs = append(docs, &schema.Document{
|
||||||
ID: gconv.String(row["id"]),
|
ID: gconv.String(row["id"]),
|
||||||
Content: gconv.String(row["content"]),
|
Content: gconv.String(row["content"]),
|
||||||
MetaData: map[string]any{
|
MetaData: map[string]any{
|
||||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||||
"document_id": gconv.Int64(row["document_id"]),
|
"document_id": gconv.Int64(row["document_id"]),
|
||||||
"distance": 0.1, // 全文结果给高分
|
"distance": distance,
|
||||||
"retrieve_by": "fulltext",
|
"retrieve_by": "fulltext",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -159,18 +264,26 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ==========================================
|
// ==========================================
|
||||||
// 合并去重
|
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
|
||||||
// ==========================================
|
// ==========================================
|
||||||
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||||||
idMap := make(map[string]*schema.Document)
|
idMap := make(map[string]*schema.Document)
|
||||||
|
|
||||||
|
// 先存入向量结果
|
||||||
for _, d := range vecDocs {
|
for _, d := range vecDocs {
|
||||||
idMap[d.ID] = d
|
idMap[d.ID] = d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
|
||||||
for _, d := range fullDocs {
|
for _, d := range fullDocs {
|
||||||
if _, exists := idMap[d.ID]; !exists {
|
if existDoc, ok := idMap[d.ID]; ok {
|
||||||
|
// 标记同时被向量和全文检索到
|
||||||
|
existDoc.MetaData["retrieve_by"] = "both"
|
||||||
|
} else {
|
||||||
idMap[d.ID] = d
|
idMap[d.ID] = d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
merged := make([]*schema.Document, 0, len(idMap))
|
merged := make([]*schema.Document, 0, len(idMap))
|
||||||
for _, d := range idMap {
|
for _, d := range idMap {
|
||||||
merged = append(merged, d)
|
merged = append(merged, d)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR
|
|||||||
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||||
sql := `
|
sql := `
|
||||||
SELECT id, content, dataset_id, document_id,
|
SELECT id, content, dataset_id, document_id,
|
||||||
vector <-> ? AS distance
|
vector <=> ? AS distance
|
||||||
FROM rag_vector_document_chunk
|
FROM rag_vector_document_chunk
|
||||||
WHERE dataset_id IN (?)
|
WHERE dataset_id IN (?)
|
||||||
AND vector IS NOT NULL
|
AND vector IS NOT NULL
|
||||||
@@ -84,8 +84,9 @@ func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64
|
|||||||
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||||||
// 构建 meilisearch 查询参数
|
// 构建 meilisearch 查询参数
|
||||||
searchParams := &meilisearch.SearchParams{
|
searchParams := &meilisearch.SearchParams{
|
||||||
Query: query,
|
Query: query,
|
||||||
Limit: int64(topK),
|
Limit: int64(topK),
|
||||||
|
ShowRankingScore: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建 datasetIds 过滤条件
|
// 构建 datasetIds 过滤条件
|
||||||
|
|||||||
@@ -8,14 +8,18 @@ import (
|
|||||||
type RAGQueryReq struct {
|
type RAGQueryReq struct {
|
||||||
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
|
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
|
||||||
|
|
||||||
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
|
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
|
||||||
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
|
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
|
||||||
TopK int `json:"topK" d:"5" dc:"检索topK,默认5"`
|
History []*Message `json:"history" dc:"历史对话"`
|
||||||
|
TopK int `json:"topK" d:"5" dc:"检索topK,默认5"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RAGQueryRes RAG查询响应
|
// RAGQueryRes RAG查询响应
|
||||||
type RAGQueryRes struct {
|
type RAGQueryRes struct {
|
||||||
Answer string `json:"answer" dc:"生成的答案"`
|
Answer string `json:"answer" dc:"生成的答案"`
|
||||||
DatasetId string `json:"datasetId" dc:"使用的数据集ID"`
|
|
||||||
Sources []string `json:"sources" dc:"参考来源"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,10 @@ import (
|
|||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||||
"gitea.com/red-future/common/http"
|
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
gmq "github.com/bjang03/gmq/core/gmq"
|
gmq "github.com/bjang03/gmq/core/gmq"
|
||||||
"github.com/bjang03/gmq/mq"
|
"github.com/bjang03/gmq/mq"
|
||||||
@@ -159,12 +159,16 @@ func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentR
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
// ======================
|
// ======================
|
||||||
// 核心:grpool + g.Try 最佳实践
|
// 核心:grpool + g.Try 最佳实践
|
||||||
// ======================
|
// ======================
|
||||||
taskCtx, cancel := context.WithCancel(ctx)
|
// 使用带超时的background context,避免HTTP请求完成后context被取消
|
||||||
|
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||||
|
taskCtx = context.WithValue(taskCtx, "user", user)
|
||||||
// 任务1: SQL 切分文档
|
// 任务1: SQL 切分文档
|
||||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
g.TryCatch(ctx, func(ctx context.Context) {
|
g.TryCatch(ctx, func(ctx context.Context) {
|
||||||
@@ -655,23 +659,12 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
|
|||||||
|
|
||||||
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
|
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
|
||||||
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
|
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
|
||||||
headers := make(map[string]string)
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
for k, v := range r.Request.Header {
|
|
||||||
if len(v) > 0 {
|
|
||||||
headers[k] = v[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 调用接口获取数据
|
// 调用接口获取数据
|
||||||
d := &dto.ListDocumentChunkRPC{}
|
res, _, err := dao.DocumentChunk.List(ctx, &dto.ListDocumentChunkReq{
|
||||||
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
|
DatasetId: doc.DatasetId,
|
||||||
"datasetId", gconv.String(doc.DatasetId),
|
Status: gconv.PtrInt8(1),
|
||||||
"status", 1); err != nil {
|
})
|
||||||
return
|
err = gconv.Struct(res, &dictData)
|
||||||
}
|
|
||||||
dictData = d.List
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
"github.com/cloudwego/eino/components/indexer"
|
"github.com/cloudwego/eino/components/indexer"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
@@ -48,7 +49,10 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
|
|||||||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||||
|
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]),
|
||||||
|
UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]),
|
||||||
|
})
|
||||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||||
BatchSize: 10,
|
BatchSize: 10,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ import (
|
|||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/components/retriever"
|
"github.com/cloudwego/eino/components/retriever"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/gogf/gf/v2/os/glog"
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var RAGQuery = new(ragQueryService)
|
var RAGQuery = new(ragQueryService)
|
||||||
@@ -39,14 +41,20 @@ func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto
|
|||||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
|
messages := make([]*schema.Message, 0)
|
||||||
|
err = gconv.Struct(req.History, &messages)
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "转换历史消息失败: %v", err)
|
||||||
|
return nil, fmt.Errorf("转换历史消息失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &dto.RAGQueryRes{
|
return &dto.RAGQueryRes{
|
||||||
Answer: replyMsg.Content,
|
Answer: replyMsg.Content,
|
||||||
Sources: sources,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
|
|
||||||
"rag/common/task"
|
"rag/common/task"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
@@ -24,17 +26,20 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
|
|||||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
taskVO := make([]dto.TaskVO, 0, total)
|
completed := false
|
||||||
err = gconv.Struct(t, taskVO)
|
if total != 0 {
|
||||||
if err != nil {
|
taskVO := make([]*dto.TaskVO, 0, total)
|
||||||
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
err = gconv.Struct(t, &taskVO)
|
||||||
return err
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
taskVO = append(taskVO, &dto.TaskVO{
|
||||||
|
TaskType: req.TaskType,
|
||||||
|
Status: req.Status,
|
||||||
|
})
|
||||||
|
completed = IsAllSubTasksCompleted(taskVO)
|
||||||
}
|
}
|
||||||
taskVO = append(taskVO, dto.TaskVO{
|
|
||||||
TaskType: req.TaskType,
|
|
||||||
Status: req.Status,
|
|
||||||
})
|
|
||||||
completed := IsAllSubTasksCompleted(taskVO)
|
|
||||||
|
|
||||||
// 1. 查询是否已存在该文档的该类型任务
|
// 1. 查询是否已存在该文档的该类型任务
|
||||||
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||||
@@ -45,44 +50,48 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP
|
|||||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||||
|
// 2. 如果不存在,则创建新任务
|
||||||
|
if g.IsEmpty(existTask) {
|
||||||
|
createReq := &dto.CreateTaskReq{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
TaskType: req.TaskType,
|
||||||
|
Status: req.Status,
|
||||||
|
Remark: req.Remark,
|
||||||
|
}
|
||||||
|
_, err = dao.Task.Insert(ctx, createReq)
|
||||||
|
} else {
|
||||||
|
// 3. 如果已存在,则更新任务
|
||||||
|
updateReq := &dto.UpdateTaskReq{
|
||||||
|
Id: existTask[0].Id,
|
||||||
|
Status: req.Status,
|
||||||
|
Remark: req.Remark,
|
||||||
|
}
|
||||||
|
_, err = dao.Task.Update(ctx, updateReq)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 如果不存在,则创建新任务
|
if completed {
|
||||||
if g.IsEmpty(existTask) {
|
// 3. 如果已存在,则更新任务
|
||||||
createReq := &dto.CreateTaskReq{
|
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
TaskType: req.TaskType,
|
Status: task.TaskStatusCompleted,
|
||||||
Status: req.Status,
|
Remark: "文档解析完成",
|
||||||
Remark: req.Remark,
|
})
|
||||||
}
|
}
|
||||||
_, err = dao.Task.Insert(ctx, createReq)
|
|
||||||
} else {
|
|
||||||
// 3. 如果已存在,则更新任务
|
|
||||||
updateReq := &dto.UpdateTaskReq{
|
|
||||||
Id: existTask[0].Id,
|
|
||||||
Status: req.Status,
|
|
||||||
Remark: req.Remark,
|
|
||||||
}
|
|
||||||
_, err = dao.Task.Update(ctx, updateReq)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if completed {
|
return nil
|
||||||
// 3. 如果已存在,则更新任务
|
})
|
||||||
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
|
||||||
TaskId: req.TaskId,
|
|
||||||
Status: task.TaskStatusCompleted,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||||
// 参数:传入当前文档的所有子任务列表
|
// 参数:传入当前文档的所有子任务列表
|
||||||
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
|
func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
|
||||||
// 必须包含 3 种任务类型
|
// 必须包含 3 种任务类型
|
||||||
hasKeywords := false
|
hasKeywords := false
|
||||||
hasVector := false
|
hasVector := false
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ COMMENT ON COLUMN rag_knowledge_keyword.dataset_id IS '数据集ID';
|
|||||||
COMMENT ON COLUMN rag_knowledge_keyword.document_id IS '文档ID';
|
COMMENT ON COLUMN rag_knowledge_keyword.document_id IS '文档ID';
|
||||||
COMMENT ON COLUMN rag_knowledge_keyword.word IS '关键词';
|
COMMENT ON COLUMN rag_knowledge_keyword.word IS '关键词';
|
||||||
COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
|
COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
|
||||||
|
CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word ON rag_knowledge_keyword (tenant_id, dataset_id, document_id, word);
|
||||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||||
|
|
||||||
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user