package eino import ( "context" "errors" "fmt" "rag/consts/model" "rag/dao" "sort" "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) type PGVectorRetrieverConfig struct { Embedder embedding.Embedder DefaultTopK int DefaultIndex string DSLInfo map[string]any } type PGVectorRetriever struct { embedder embedding.Embedder topK int index string dslInfo map[string]any reranker *DashScopeReranker // 通义精排 } func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) { if config.DefaultTopK <= 0 { config.DefaultTopK = 5 } e, err := GetTenantEmbedderByType(ctx, configType) if err != nil { return nil, err } return &PGVectorRetriever{ embedder: e, topK: config.DefaultTopK, index: config.DefaultIndex, dslInfo: config.DSLInfo, //reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排 }, nil } func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { options := &retriever.Options{ Index: &r.index, TopK: &r.topK, DSLInfo: r.dslInfo, Embedding: r.embedder, } options = retriever.GetCommonOptions(options, opts...) // 安全保护:防止 nil 指针 panic topK := 10 if options.TopK != nil { topK = *options.TopK } ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ Query: query, TopK: *options.TopK, }) // ========================================== // 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏) // ========================================== var ( docsVector []*schema.Document docsFulltext []*schema.Document errVector error errFulltext error // 缓冲通道=2,确保无死锁等待 done = make(chan struct{}, 2) ) // 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整) taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() // 封装并行任务函数,消除重复代码 runTask := func(task func() error, errTarget *error) { defer func() { // 任务结束必发信号,确保通道不阻塞 done <- struct{}{} }() // 捕获 panic + 执行业务逻辑 g.TryCatch(taskCtx, func(ctx context.Context) { *errTarget = task() }, func(ctx context.Context, panicErr error) { *errTarget = panicErr }) // 任务失败:立即取消另一个任务(快速失败) if *errTarget != nil { cancel() } } // ---------------------- // 并行提交两个检索任务 // ---------------------- // 任务1:向量检索 grpool.Add(taskCtx, func(ctx context.Context) { runTask(func() error { docsVector, errVector = r.doRetrieveVector(ctx, query, options) return errVector }, &errVector) }) // 任务2:全文检索 grpool.Add(taskCtx, func(ctx context.Context) { runTask(func() error { docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options) return errFulltext }, &errFulltext) }) // ---------------------- // 安全等待所有任务完成 // ---------------------- <-done <-done // ---------------------- // 统一错误处理 // ---------------------- // 用 errors.Join 合并所有错误,不丢失信息 if err := errors.Join(errVector, errFulltext); err != nil { return nil, err } // 合并 + 智能去重(保留最优分数) mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext) // ========================= // 🔥 Cross-Encoder 精排 // ========================= var finalDocs []*schema.Document if r.reranker != nil { ranked, err := r.reranker.Rerank(ctx, query, mergedDocs) if err != nil { return nil, fmt.Errorf("rerank failed: %w", err) } finalDocs = ranked } else { sort.Slice(mergedDocs, func(i, j int) bool { d1 := gconv.Float64(mergedDocs[i].MetaData["distance"]) d2 := gconv.Float64(mergedDocs[j].MetaData["distance"]) return d1 < d2 }) finalDocs = mergedDocs } // ========================= // 过滤无效文档 // ========================= const maxDistance = 0.8 validDocs := make([]*schema.Document, 0, len(finalDocs)) for _, doc := range finalDocs { dist := gconv.Float64(doc.MetaData["distance"]) if dist <= maxDistance { validDocs = append(validDocs, doc) } } // 最多保留 topK if len(validDocs) > topK { validDocs = validDocs[:topK] } callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs}) return validDocs, nil } // ========================================== // 1. 向量检索(PG) // ========================================== func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query}) if err != nil { return nil, err } if len(vectors) == 0 { return nil, errors.New("empty query vector") } queryVec := pgvector.NewVector(gconv.Float32s(vectors[0])) topK := 10 if opts.TopK != nil { topK = *opts.TopK } var datasetIds, documentIds []int64 if g.IsEmpty(opts.DSLInfo["dataset_ids"]) { datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"]) } if g.IsEmpty(opts.DSLInfo["document_ids"]) { documentIds = gconv.Int64s(opts.DSLInfo["document_ids"]) } rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": gconv.Float64(row["distance"]), "retrieve_by": "vector", }, }) } return docs, nil } // ========================================== // 2. 全文检索(Meilisearch)🔥 新增 // ========================================== func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { topK := *opts.TopK var datasetIds, documentIds []int64 if g.IsEmpty(opts.DSLInfo["dataset_ids"]) { datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"]) } if g.IsEmpty(opts.DSLInfo["document_ids"]) { documentIds = gconv.Int64s(opts.DSLInfo["document_ids"]) } // 调用你已有的 Meilisearch DAO rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { score := gconv.Float64(row["_rankingScore"]) distance := score docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": distance, "retrieve_by": "fulltext", }, }) } return docs, nil } // ========================================== // 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记) // ========================================== func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document { idMap := make(map[string]*schema.Document) // 先存入向量结果 for _, d := range vecDocs { idMap[d.ID] = d } // 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数 for _, d := range fullDocs { if existDoc, ok := idMap[d.ID]; ok { // 标记同时被向量和全文检索到 existDoc.MetaData["retrieve_by"] = "both" } else { idMap[d.ID] = d } } merged := make([]*schema.Document, 0, len(idMap)) for _, d := range idMap { merged = append(merged, d) } return merged }