180 lines
4.7 KiB
Go
180 lines
4.7 KiB
Go
package eino
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"rag/dao"
|
||
"sort"
|
||
|
||
"github.com/cloudwego/eino/callbacks"
|
||
"github.com/cloudwego/eino/components/embedding"
|
||
"github.com/cloudwego/eino/components/retriever"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
"github.com/pgvector/pgvector-go"
|
||
)
|
||
|
||
type PGVectorRetrieverConfig struct {
|
||
Embedder embedding.Embedder
|
||
DefaultTopK int
|
||
DefaultIndex string
|
||
DSLInfo map[string]any
|
||
}
|
||
|
||
type PGVectorRetriever struct {
|
||
embedder embedding.Embedder
|
||
topK int
|
||
index string
|
||
dslInfo map[string]any
|
||
}
|
||
|
||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||
if config.Embedder == nil {
|
||
return nil, errors.New("embedder is required")
|
||
}
|
||
if config.DefaultTopK <= 0 {
|
||
config.DefaultTopK = 5
|
||
}
|
||
|
||
return &PGVectorRetriever{
|
||
embedder: config.Embedder,
|
||
topK: config.DefaultTopK,
|
||
index: config.DefaultIndex,
|
||
dslInfo: config.DSLInfo,
|
||
}, nil
|
||
}
|
||
|
||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||
options := &retriever.Options{
|
||
Index: &r.index,
|
||
TopK: &r.topK,
|
||
DSLInfo: r.dslInfo,
|
||
Embedding: r.embedder,
|
||
}
|
||
options = retriever.GetCommonOptions(options, opts...)
|
||
|
||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||
Query: query,
|
||
TopK: *options.TopK,
|
||
})
|
||
|
||
// ==========================================
|
||
// 🔥 双路检索:向量 + 全文
|
||
// ==========================================
|
||
docsVector, err := r.doRetrieveVector(ctx, query, options)
|
||
if err != nil {
|
||
callbacks.OnError(ctx, err)
|
||
return nil, err
|
||
}
|
||
|
||
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
|
||
if err != nil {
|
||
callbacks.OnError(ctx, err)
|
||
return nil, err
|
||
}
|
||
|
||
// 合并 + 去重
|
||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||
|
||
// 排序(distance 越小越靠前)
|
||
sort.Slice(docs, func(i, j int) bool {
|
||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||
return d1 < d2
|
||
})
|
||
|
||
// 最多保留 topK
|
||
if len(docs) > *options.TopK {
|
||
docs = docs[:*options.TopK]
|
||
}
|
||
|
||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
|
||
return docs, nil
|
||
}
|
||
|
||
// ==========================================
|
||
// 1. 向量检索(PG)
|
||
// ==========================================
|
||
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(vectors) == 0 {
|
||
return nil, errors.New("empty query vector")
|
||
}
|
||
|
||
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||
topK := *opts.TopK
|
||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||
|
||
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
docs := make([]*schema.Document, 0, len(rows))
|
||
for _, row := range rows {
|
||
docs = append(docs, &schema.Document{
|
||
ID: gconv.String(row["id"]),
|
||
Content: gconv.String(row["content"]),
|
||
MetaData: map[string]any{
|
||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||
"document_id": gconv.Int64(row["document_id"]),
|
||
"distance": gconv.Float64(row["distance"]),
|
||
"retrieve_by": "vector",
|
||
},
|
||
})
|
||
}
|
||
return docs, nil
|
||
}
|
||
|
||
// ==========================================
|
||
// 2. 全文检索(Meilisearch)🔥 新增
|
||
// ==========================================
|
||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||
topK := *opts.TopK
|
||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||
|
||
// 调用你已有的 Meilisearch DAO
|
||
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
docs := make([]*schema.Document, 0, len(rows))
|
||
for _, row := range rows {
|
||
docs = append(docs, &schema.Document{
|
||
ID: gconv.String(row["id"]),
|
||
Content: gconv.String(row["content"]),
|
||
MetaData: map[string]any{
|
||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||
"document_id": gconv.Int64(row["document_id"]),
|
||
"distance": 0.1, // 全文结果给高分
|
||
"retrieve_by": "fulltext",
|
||
},
|
||
})
|
||
}
|
||
return docs, nil
|
||
}
|
||
|
||
// ==========================================
|
||
// 合并去重
|
||
// ==========================================
|
||
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||
idMap := make(map[string]*schema.Document)
|
||
for _, d := range vecDocs {
|
||
idMap[d.ID] = d
|
||
}
|
||
for _, d := range fullDocs {
|
||
if _, exists := idMap[d.ID]; !exists {
|
||
idMap[d.ID] = d
|
||
}
|
||
}
|
||
merged := make([]*schema.Document, 0, len(idMap))
|
||
for _, d := range idMap {
|
||
merged = append(merged, d)
|
||
}
|
||
return merged
|
||
}
|