diff --git a/common/eino/chat_model.go b/common/eino/chat_model.go index 4d9e4b7..9f36db0 100644 --- a/common/eino/chat_model.go +++ b/common/eino/chat_model.go @@ -14,91 +14,122 @@ import ( "github.com/gogf/gf/v2/util/gconv" ) -var globalChatModel *qwen.ChatModel +const ( + MaxHistoryTurns = 5 // 最大历史轮数 +) + +var ( + globalChatModel *qwen.ChatModel + ragPromptTemplate prompt.ChatTemplate // EINO 官方模板 +) func init() { ctx := context.Background() + // 初始化大模型 + if err := initChatModel(ctx); err != nil { + glog.Errorf(ctx, "初始化大模型失败: %v", err) + } + // 初始化 EINO 提示词模板 + initRAGPromptTemplate() + + return +} + +// 初始化通义千问 +func initChatModel(ctx context.Context) error { + if globalChatModel != nil { + return nil + } apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String() model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String() - var err error - globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{ + cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{ APIKey: apiKey, Model: model, BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + Timeout: 60 * 1e9, Temperature: gconv.PtrFloat32(0.7), // 客服最佳 MaxTokens: gconv.PtrInt(1024), // 最长回答 TopP: gconv.PtrFloat32(1.0), }) if err != nil { - glog.Errorf(ctx, "初始化大模型失败: %v", err) + return err } - return + globalChatModel = cm + return nil } -// NewChatModel 只处理逻辑,不复用创建模型 -func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) { - // 1. 构建参考知识 - knowledge, sources := buildKnowledgeAndSources(docs) - - // 2. 构建提示词 - msgs, err := buildPromptMessages(ctx, knowledge, content) - if err != nil { - return - } - - // 3. 🔥 直接使用全局单例,不重复创建 - replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs) - - return -} - -// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源 -func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) { - var knowledge string - var sources []string - - for i, doc := range docs { - knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content) - - // 提取 document_id - if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 { - sources = append(sources, gconv.String(docID)) - } - } - return knowledge, sources -} - -// buildPromptMessages 构建提示词模板 -func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) { - promptTpl := prompt.FromMessages( +// 初始化 EINO 官方提示词模板(最关键!) +func initRAGPromptTemplate() { + ragPromptTemplate = prompt.FromMessages( schema.FString, + // 系统提示(带参考知识) &schema.Message{ Role: schema.System, - // Content: `你是专业的客服助手,语气友好。 - //如果参考知识中有相关信息,请优先依据参考知识回答。 - //如果没有相关信息,就正常回答,不要说无法回答。 - // - //参考知识: - //{knowledge}`, - Content: `你是专业的客服助手,语气友好。 - 请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。 - - 参考知识: - {knowledge}`, + Content: `你是专业客服,语气友好简洁。 +请严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。 + +参考知识: +{knowledge}`, }, + // 用户问题 &schema.Message{ Role: schema.User, Content: "{question}", }, ) +} - return promptTpl.Format(ctx, map[string]any{ +// NewChatModel 只处理逻辑,不复用创建模型 +func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) { + // 1. 构建参考知识 + knowledge := buildKnowledgeAndSources(docs) + // 2. 历史精简 + history = limitHistory(history) + // 3. ✅ EINO 官方模板格式化(超级干净) + msgs, err := ragPromptTemplate.Format(ctx, map[string]any{ "knowledge": knowledge, "question": question, }) + if err != nil { + return nil, err + } + // 4. 历史插入到模板消息中间(标准EINO用法) + if len(history) > 0 { + msgs = append(msgs[:1], append(history, msgs[1:]...)...) + } + // 5. 🔥 直接使用全局单例,不重复创建 + replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs) + + return +} + +func limitHistory(history []*schema.Message) []*schema.Message { + valid := make([]*schema.Message, 0, len(history)) + for _, m := range history { + if m.Role == schema.User || m.Role == schema.Assistant { + valid = append(valid, m) + } + } + + keep := 2 * MaxHistoryTurns + if len(valid) > keep { + valid = valid[len(valid)-keep:] + } + return valid +} + +// buildKnowledgeAndSources 拼接参考知识 +func buildKnowledgeAndSources(docs []*schema.Document) string { + var knowledge string + + for i, doc := range docs { + knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content) + + } + return knowledge } // streamGenerateAnswer 流式生成 diff --git a/common/eino/retriever.go b/common/eino/retriever.go index 84b1505..f7750ae 100644 --- a/common/eino/retriever.go +++ b/common/eino/retriever.go @@ -5,11 +5,14 @@ import ( "errors" "rag/dao" "sort" + "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/grpool" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) @@ -53,43 +56,139 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ... } options = retriever.GetCommonOptions(options, opts...) + // 安全保护:防止 nil 指针 panic + topK := 10 + if options.TopK != nil { + topK = *options.TopK + } + ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ Query: query, TopK: *options.TopK, }) // ========================================== - // 🔥 双路检索:向量 + 全文 + // 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏) // ========================================== - docsVector, err := r.doRetrieveVector(ctx, query, options) - if err != nil { - callbacks.OnError(ctx, err) + var ( + docsVector []*schema.Document + docsFulltext []*schema.Document + errVector error + errFulltext error + + // 缓冲通道=2,确保无死锁等待 + done = make(chan struct{}, 2) + ) + + // 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整) + taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + // 封装并行任务函数,消除重复代码 + runTask := func(task func() error, errTarget *error) { + defer func() { + // 任务结束必发信号,确保通道不阻塞 + done <- struct{}{} + }() + + // 捕获 panic + 执行业务逻辑 + g.TryCatch(taskCtx, func(ctx context.Context) { + *errTarget = task() + }, func(ctx context.Context, panicErr error) { + *errTarget = panicErr + }) + + // 任务失败:立即取消另一个任务(快速失败) + if *errTarget != nil { + cancel() + } + } + + // ---------------------- + // 并行提交两个检索任务 + // ---------------------- + // 任务1:向量检索 + grpool.Add(taskCtx, func(ctx context.Context) { + runTask(func() error { + docsVector, errVector = r.doRetrieveVector(ctx, query, options) + return errVector + }, &errVector) + }) + + // 任务2:全文检索 + grpool.Add(taskCtx, func(ctx context.Context) { + runTask(func() error { + docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options) + return errFulltext + }, &errFulltext) + }) + + // ---------------------- + // 安全等待所有任务完成 + // ---------------------- + <-done + <-done + + // ---------------------- + // 统一错误处理 + // ---------------------- + // 用 errors.Join 合并所有错误,不丢失信息 + if err := errors.Join(errVector, errFulltext); err != nil { return nil, err } - docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options) - if err != nil { - callbacks.OnError(ctx, err) - return nil, err - } - - // 合并 + 去重 + // 合并 + 智能去重(保留最优分数) docs := mergeAndDeduplicate(docsVector, docsFulltext) - // 排序(distance 越小越靠前) + // 排序:向量优先,同类型按距离升序 sort.Slice(docs, func(i, j int) bool { + //byI, okI := docs[i].MetaData["retrieve_by"].(string) + //byJ, okJ := docs[j].MetaData["retrieve_by"].(string) + // + //// 有类型标记的优先 + //if okI && !okJ { + // return true + //} + //if !okI && okJ { + // return false + //} + // + //// 向量永远排前面 + //if byI == "vector" && byJ == "fulltext" { + // return true + //} + //if byI == "fulltext" && byJ == "vector" { + // return false + //} + + // 同类型按 distance 升序(越小越相似) d1 := gconv.Float64(docs[i].MetaData["distance"]) d2 := gconv.Float64(docs[j].MetaData["distance"]) return d1 < d2 }) - // 最多保留 topK - if len(docs) > *options.TopK { - docs = docs[:*options.TopK] + // 在Retrieve方法末尾,增加相关性校验 + validDocs := make([]*schema.Document, 0) + for i, d := range docs { + // 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃) + if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 { + validDocs = append(validDocs, d) + } } - callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs}) - return docs, nil + // 如果没有有效结果,返回空,让LLM回答「暂无相关信息」 + if len(validDocs) == 0 { + callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs}) + return validDocs, nil + } + + // 最多保留 topK + if len(validDocs) > topK { + validDocs = validDocs[:topK] + } + + callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs}) + return validDocs, nil } // ========================================== @@ -105,7 +204,10 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, } queryVec := pgvector.NewVector(gconv.Float32s(vectors[0])) - topK := *opts.TopK + topK := 10 + if opts.TopK != nil { + topK = *opts.TopK + } datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK) @@ -144,13 +246,16 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { + score := gconv.Float64(row["_rankingScore"]) + distance := score + docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), - "distance": 0.1, // 全文结果给高分 + "distance": distance, "retrieve_by": "fulltext", }, }) @@ -159,18 +264,26 @@ func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query str } // ========================================== -// 合并去重 +// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记) // ========================================== func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document { idMap := make(map[string]*schema.Document) + + // 先存入向量结果 for _, d := range vecDocs { idMap[d.ID] = d } + + // 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数 for _, d := range fullDocs { - if _, exists := idMap[d.ID]; !exists { + if existDoc, ok := idMap[d.ID]; ok { + // 标记同时被向量和全文检索到 + existDoc.MetaData["retrieve_by"] = "both" + } else { idMap[d.ID] = d } } + merged := make([]*schema.Document, 0, len(idMap)) for _, d := range idMap { merged = append(merged, d) diff --git a/dao/document_chunk.go b/dao/document_chunk.go index 218cf70..577c8f3 100644 --- a/dao/document_chunk.go +++ b/dao/document_chunk.go @@ -64,7 +64,7 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) { sql := ` SELECT id, content, dataset_id, document_id, - vector <-> ? AS distance + vector <=> ? AS distance FROM rag_vector_document_chunk WHERE dataset_id IN (?) AND vector IS NOT NULL @@ -84,8 +84,9 @@ func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64 func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) { // 构建 meilisearch 查询参数 searchParams := &meilisearch.SearchParams{ - Query: query, - Limit: int64(topK), + Query: query, + Limit: int64(topK), + ShowRankingScore: true, } // 构建 datasetIds 过滤条件 diff --git a/model/dto/rag_query.go b/model/dto/rag_query.go index 8682c03..9682aad 100644 --- a/model/dto/rag_query.go +++ b/model/dto/rag_query.go @@ -8,14 +8,18 @@ import ( type RAGQueryReq struct { g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"` - Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"` - DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"` - TopK int `json:"topK" d:"5" dc:"检索topK,默认5"` + Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"` + DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"` + History []*Message `json:"history" dc:"历史对话"` + TopK int `json:"topK" d:"5" dc:"检索topK,默认5"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` } // RAGQueryRes RAG查询响应 type RAGQueryRes struct { - Answer string `json:"answer" dc:"生成的答案"` - DatasetId string `json:"datasetId" dc:"使用的数据集ID"` - Sources []string `json:"sources" dc:"参考来源"` + Answer string `json:"answer" dc:"生成的答案"` } diff --git a/service/document.go b/service/document.go index a78448f..a00fd06 100644 --- a/service/document.go +++ b/service/document.go @@ -12,10 +12,10 @@ import ( "rag/model/dto" "rag/model/entity" "strings" + "time" "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" - "gitea.com/red-future/common/http" "gitea.com/red-future/common/utils" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" @@ -159,12 +159,16 @@ func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentR if err != nil { return } - + user, err := utils.GetUserInfo(ctx) + if err != nil { + return err + } // ====================== // 核心:grpool + g.Try 最佳实践 // ====================== - taskCtx, cancel := context.WithCancel(ctx) - + // 使用带超时的background context,避免HTTP请求完成后context被取消 + taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + taskCtx = context.WithValue(taskCtx, "user", user) // 任务1: SQL 切分文档 grpool.Add(taskCtx, func(ctx context.Context) { g.TryCatch(ctx, func(ctx context.Context) { @@ -655,23 +659,12 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume // getHistoryDataFromHttp 通过 HTTP 接口查询历史数据 func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) { - headers := make(map[string]string) - if r := g.RequestFromCtx(ctx); r != nil { - for k, v := range r.Request.Header { - if len(v) > 0 { - headers[k] = v[0] - } - } - } - // 调用接口获取数据 - d := &dto.ListDocumentChunkRPC{} - if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d, - "datasetId", gconv.String(doc.DatasetId), - "status", 1); err != nil { - return - } - dictData = d.List + res, _, err := dao.DocumentChunk.List(ctx, &dto.ListDocumentChunkReq{ + DatasetId: doc.DatasetId, + Status: gconv.PtrInt8(1), + }) + err = gconv.Struct(res, &dictData) return } diff --git a/service/document_chunk.go b/service/document_chunk.go index 4dd3c11..d9d4187 100644 --- a/service/document_chunk.go +++ b/service/document_chunk.go @@ -8,6 +8,7 @@ import ( "rag/model/dto" "rag/model/entity" + "gitea.com/red-future/common/beans" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" @@ -48,7 +49,10 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty") return } - + ctx = context.WithValue(ctx, "user", &beans.User{ + TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]), + UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]), + }) idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{ BatchSize: 10, }) diff --git a/service/rag_query.go b/service/rag_query.go index f802167..92ffe7b 100644 --- a/service/rag_query.go +++ b/service/rag_query.go @@ -7,7 +7,9 @@ import ( "rag/model/dto" "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/util/gconv" ) var RAGQuery = new(ragQueryService) @@ -39,14 +41,20 @@ func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto return nil, fmt.Errorf("向量检索失败: %w", err) } - replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs) + messages := make([]*schema.Message, 0) + err = gconv.Struct(req.History, &messages) + if err != nil { + glog.Errorf(ctx, "转换历史消息失败: %v", err) + return nil, fmt.Errorf("转换历史消息失败: %w", err) + } + + replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages) if err != nil { glog.Errorf(ctx, "向量检索失败: %v", err) return nil, fmt.Errorf("向量检索失败: %w", err) } return &dto.RAGQueryRes{ - Answer: replyMsg.Content, - Sources: sources, + Answer: replyMsg.Content, }, nil } diff --git a/service/task.go b/service/task.go index b039497..9d16e72 100644 --- a/service/task.go +++ b/service/task.go @@ -7,6 +7,8 @@ import ( "rag/common/task" + "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) @@ -24,17 +26,20 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP g.Log().Errorf(ctx, "查询任务失败: %v", err) return err } - taskVO := make([]dto.TaskVO, 0, total) - err = gconv.Struct(t, taskVO) - if err != nil { - g.Log().Errorf(ctx, "转换任务失败: %v", err) - return err + completed := false + if total != 0 { + taskVO := make([]*dto.TaskVO, 0, total) + err = gconv.Struct(t, &taskVO) + if err != nil { + g.Log().Errorf(ctx, "转换任务失败: %v", err) + return err + } + taskVO = append(taskVO, &dto.TaskVO{ + TaskType: req.TaskType, + Status: req.Status, + }) + completed = IsAllSubTasksCompleted(taskVO) } - taskVO = append(taskVO, dto.TaskVO{ - TaskType: req.TaskType, - Status: req.Status, - }) - completed := IsAllSubTasksCompleted(taskVO) // 1. 查询是否已存在该文档的该类型任务 existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{ @@ -45,44 +50,48 @@ func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskP g.Log().Errorf(ctx, "查询任务失败: %v", err) return err } + err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { + // 2. 如果不存在,则创建新任务 + if g.IsEmpty(existTask) { + createReq := &dto.CreateTaskReq{ + TaskId: req.TaskId, + TaskType: req.TaskType, + Status: req.Status, + Remark: req.Remark, + } + _, err = dao.Task.Insert(ctx, createReq) + } else { + // 3. 如果已存在,则更新任务 + updateReq := &dto.UpdateTaskReq{ + Id: existTask[0].Id, + Status: req.Status, + Remark: req.Remark, + } + _, err = dao.Task.Update(ctx, updateReq) + if err != nil { + g.Log().Errorf(ctx, "更新任务失败: %v", err) + return err + } + } - // 2. 如果不存在,则创建新任务 - if g.IsEmpty(existTask) { - createReq := &dto.CreateTaskReq{ - TaskId: req.TaskId, - TaskType: req.TaskType, - Status: req.Status, - Remark: req.Remark, + if completed { + // 3. 如果已存在,则更新任务 + _, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{ + TaskId: req.TaskId, + Status: task.TaskStatusCompleted, + Remark: "文档解析完成", + }) } - _, err = dao.Task.Insert(ctx, createReq) - } else { - // 3. 如果已存在,则更新任务 - updateReq := &dto.UpdateTaskReq{ - Id: existTask[0].Id, - Status: req.Status, - Remark: req.Remark, - } - _, err = dao.Task.Update(ctx, updateReq) - if err != nil { - g.Log().Errorf(ctx, "更新任务失败: %v", err) - return err - } - } - if completed { - // 3. 如果已存在,则更新任务 - _, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{ - TaskId: req.TaskId, - Status: task.TaskStatusCompleted, - }) - } + return nil + }) return } // IsAllSubTasksCompleted 判断三个子任务是否全部完成 // 参数:传入当前文档的所有子任务列表 -func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool { +func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool { // 必须包含 3 种任务类型 hasKeywords := false hasVector := false diff --git a/update.sql b/update.sql index 69decb3..d3972eb 100644 --- a/update.sql +++ b/update.sql @@ -159,7 +159,7 @@ COMMENT ON COLUMN rag_knowledge_keyword.dataset_id IS '数据集ID'; COMMENT ON COLUMN rag_knowledge_keyword.document_id IS '文档ID'; COMMENT ON COLUMN rag_knowledge_keyword.word IS '关键词'; COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重'; - +CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word ON rag_knowledge_keyword (tenant_id, dataset_id, document_id, word); --------------------pgsql创建rag_knowledge_keyword表语句--------------------------- --------------------pgsql创建rag_knowledge_task表语句---------------------------