Files
rag/common/eino/retriever.go

301 lines
8.0 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"sort"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
DSLInfo map[string]any
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
dslInfo map[string]any
reranker *DashScopeReranker // 通义精排
}
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
e, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return nil, err
}
return &PGVectorRetriever{
embedder: e,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
DSLInfo: r.dslInfo,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
// 安全保护:防止 nil 指针 panic
topK := 10
if options.TopK != nil {
topK = *options.TopK
}
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// ==========================================
// 🔥 优化版grpool 并行双路检索(安全、健壮、无泄漏)
// ==========================================
var (
docsVector []*schema.Document
docsFulltext []*schema.Document
errVector error
errFulltext error
// 缓冲通道=2确保无死锁等待
done = make(chan struct{}, 2)
)
// 上下文:超时 + 可取消双保障建议5s超时根据业务调整
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// 封装并行任务函数,消除重复代码
runTask := func(task func() error, errTarget *error) {
defer func() {
// 任务结束必发信号,确保通道不阻塞
done <- struct{}{}
}()
// 捕获 panic + 执行业务逻辑
g.TryCatch(taskCtx, func(ctx context.Context) {
*errTarget = task()
}, func(ctx context.Context, panicErr error) {
*errTarget = panicErr
})
// 任务失败:立即取消另一个任务(快速失败)
if *errTarget != nil {
cancel()
}
}
// ----------------------
// 并行提交两个检索任务
// ----------------------
// 任务1向量检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
return errVector
}, &errVector)
})
// 任务2全文检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
return errFulltext
}, &errFulltext)
})
// ----------------------
// 安全等待所有任务完成
// ----------------------
<-done
<-done
// ----------------------
// 统一错误处理
// ----------------------
// 用 errors.Join 合并所有错误,不丢失信息
if err := errors.Join(errVector, errFulltext); err != nil {
return nil, err
}
// 合并 + 智能去重(保留最优分数)
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
// =========================
// 🔥 Cross-Encoder 精排
// =========================
var finalDocs []*schema.Document
if r.reranker != nil {
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
if err != nil {
return nil, fmt.Errorf("rerank failed: %w", err)
}
finalDocs = ranked
} else {
sort.Slice(mergedDocs, func(i, j int) bool {
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
return d1 < d2
})
finalDocs = mergedDocs
}
// =========================
// 过滤无效文档
// =========================
const maxDistance = 0.8
validDocs := make([]*schema.Document, 0, len(finalDocs))
for _, doc := range finalDocs {
dist := gconv.Float64(doc.MetaData["distance"])
if dist <= maxDistance {
validDocs = append(validDocs, doc)
}
}
// 最多保留 topK
if len(validDocs) > topK {
validDocs = validDocs[:topK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
}
// ==========================================
// 1. 向量检索PG
// ==========================================
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := 10
if opts.TopK != nil {
topK = *opts.TopK
}
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": gconv.Float64(row["distance"]),
"retrieve_by": "vector",
},
})
}
return docs, nil
}
// ==========================================
// 2. 全文检索Meilisearch🔥 新增
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
score := gconv.Float64(row["_rankingScore"])
distance := score
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": distance,
"retrieve_by": "fulltext",
},
})
}
return docs, nil
}
// ==========================================
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
// 先存入向量结果
for _, d := range vecDocs {
idMap[d.ID] = d
}
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
for _, d := range fullDocs {
if existDoc, ok := idMap[d.ID]; ok {
// 标记同时被向量和全文检索到
existDoc.MetaData["retrieve_by"] = "both"
} else {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)
}
return merged
}