package eino import ( "context" "errors" "rag/dao" "sort" "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/grpool" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) type PGVectorRetrieverConfig struct { Embedder embedding.Embedder DefaultTopK int DefaultIndex string DSLInfo map[string]any } type PGVectorRetriever struct { embedder embedding.Embedder topK int index string dslInfo map[string]any } func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) { if config.Embedder == nil { return nil, errors.New("embedder is required") } if config.DefaultTopK <= 0 { config.DefaultTopK = 5 } return &PGVectorRetriever{ embedder: config.Embedder, topK: config.DefaultTopK, index: config.DefaultIndex, dslInfo: config.DSLInfo, }, nil } func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { options := &retriever.Options{ Index: &r.index, TopK: &r.topK, DSLInfo: r.dslInfo, Embedding: r.embedder, } options = retriever.GetCommonOptions(options, opts...) // 安全保护:防止 nil 指针 panic topK := 10 if options.TopK != nil { topK = *options.TopK } ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ Query: query, TopK: *options.TopK, }) // ========================================== // 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏) // ========================================== var ( docsVector []*schema.Document docsFulltext []*schema.Document errVector error errFulltext error // 缓冲通道=2,确保无死锁等待 done = make(chan struct{}, 2) ) // 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整) taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() // 封装并行任务函数,消除重复代码 runTask := func(task func() error, errTarget *error) { defer func() { // 任务结束必发信号,确保通道不阻塞 done <- struct{}{} }() // 捕获 panic + 执行业务逻辑 g.TryCatch(taskCtx, func(ctx context.Context) { *errTarget = task() }, func(ctx context.Context, panicErr error) { *errTarget = panicErr }) // 任务失败:立即取消另一个任务(快速失败) if *errTarget != nil { cancel() } } // ---------------------- // 并行提交两个检索任务 // ---------------------- // 任务1:向量检索 grpool.Add(taskCtx, func(ctx context.Context) { runTask(func() error { docsVector, errVector = r.doRetrieveVector(ctx, query, options) return errVector }, &errVector) }) // 任务2:全文检索 grpool.Add(taskCtx, func(ctx context.Context) { runTask(func() error { docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options) return errFulltext }, &errFulltext) }) // ---------------------- // 安全等待所有任务完成 // ---------------------- <-done <-done // ---------------------- // 统一错误处理 // ---------------------- // 用 errors.Join 合并所有错误,不丢失信息 if err := errors.Join(errVector, errFulltext); err != nil { return nil, err } // 合并 + 智能去重(保留最优分数) docs := mergeAndDeduplicate(docsVector, docsFulltext) // 排序:向量优先,同类型按距离升序 sort.Slice(docs, func(i, j int) bool { //byI, okI := docs[i].MetaData["retrieve_by"].(string) //byJ, okJ := docs[j].MetaData["retrieve_by"].(string) // //// 有类型标记的优先 //if okI && !okJ { // return true //} //if !okI && okJ { // return false //} // //// 向量永远排前面 //if byI == "vector" && byJ == "fulltext" { // return true //} //if byI == "fulltext" && byJ == "vector" { // return false //} // 同类型按 distance 升序(越小越相似) d1 := gconv.Float64(docs[i].MetaData["distance"]) d2 := gconv.Float64(docs[j].MetaData["distance"]) return d1 < d2 }) // 在Retrieve方法末尾,增加相关性校验 validDocs := make([]*schema.Document, 0) for i, d := range docs { // 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃) if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 { validDocs = append(validDocs, d) } } // 如果没有有效结果,返回空,让LLM回答「暂无相关信息」 if len(validDocs) == 0 { callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs}) return validDocs, nil } // 最多保留 topK if len(validDocs) > topK { validDocs = validDocs[:topK] } callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs}) return validDocs, nil } // ========================================== // 1. 向量检索(PG) // ========================================== func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query}) if err != nil { return nil, err } if len(vectors) == 0 { return nil, errors.New("empty query vector") } queryVec := pgvector.NewVector(gconv.Float32s(vectors[0])) topK := 10 if opts.TopK != nil { topK = *opts.TopK } datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": gconv.Float64(row["distance"]), "retrieve_by": "vector", }, }) } return docs, nil } // ========================================== // 2. 全文检索(Meilisearch)🔥 新增 // ========================================== func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { topK := *opts.TopK datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) // 调用你已有的 Meilisearch DAO rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK) if err != nil { return nil, err } docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { score := gconv.Float64(row["_rankingScore"]) distance := score docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), MetaData: map[string]any{ "dataset_id": gconv.Int64(row["dataset_id"]), "document_id": gconv.Int64(row["document_id"]), "distance": distance, "retrieve_by": "fulltext", }, }) } return docs, nil } // ========================================== // 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记) // ========================================== func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document { idMap := make(map[string]*schema.Document) // 先存入向量结果 for _, d := range vecDocs { idMap[d.ID] = d } // 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数 for _, d := range fullDocs { if existDoc, ok := idMap[d.ID]; ok { // 标记同时被向量和全文检索到 existDoc.MetaData["retrieve_by"] = "both" } else { idMap[d.ID] = d } } merged := make([]*schema.Document, 0, len(idMap)) for _, d := range idMap { merged = append(merged, d) } return merged }