feat: 支持多租户多模型对话及文档去重优化
This commit is contained in:
@@ -3,6 +3,8 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/consts/model"
|
||||
"rag/dao"
|
||||
"sort"
|
||||
"time"
|
||||
@@ -29,21 +31,25 @@ type PGVectorRetriever struct {
|
||||
topK int
|
||||
index string
|
||||
dslInfo map[string]any
|
||||
reranker *DashScopeReranker // 通义精排
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
if config.Embedder == nil {
|
||||
return nil, errors.New("embedder is required")
|
||||
}
|
||||
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
|
||||
if config.DefaultTopK <= 0 {
|
||||
config.DefaultTopK = 5
|
||||
}
|
||||
|
||||
e, err := GetTenantEmbedderByType(ctx, configType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PGVectorRetriever{
|
||||
embedder: config.Embedder,
|
||||
embedder: e,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
dslInfo: config.DSLInfo,
|
||||
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -138,48 +144,37 @@ func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...
|
||||
}
|
||||
|
||||
// 合并 + 智能去重(保留最优分数)
|
||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
|
||||
// 排序:向量优先,同类型按距离升序
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
|
||||
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
|
||||
//
|
||||
//// 有类型标记的优先
|
||||
//if okI && !okJ {
|
||||
// return true
|
||||
//}
|
||||
//if !okI && okJ {
|
||||
// return false
|
||||
//}
|
||||
//
|
||||
//// 向量永远排前面
|
||||
//if byI == "vector" && byJ == "fulltext" {
|
||||
// return true
|
||||
//}
|
||||
//if byI == "fulltext" && byJ == "vector" {
|
||||
// return false
|
||||
//}
|
||||
|
||||
// 同类型按 distance 升序(越小越相似)
|
||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
|
||||
// 在Retrieve方法末尾,增加相关性校验
|
||||
validDocs := make([]*schema.Document, 0)
|
||||
for i, d := range docs {
|
||||
// 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃)
|
||||
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
|
||||
validDocs = append(validDocs, d)
|
||||
// =========================
|
||||
// 🔥 Cross-Encoder 精排
|
||||
// =========================
|
||||
var finalDocs []*schema.Document
|
||||
if r.reranker != nil {
|
||||
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rerank failed: %w", err)
|
||||
}
|
||||
finalDocs = ranked
|
||||
} else {
|
||||
sort.Slice(mergedDocs, func(i, j int) bool {
|
||||
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
finalDocs = mergedDocs
|
||||
}
|
||||
|
||||
// 如果没有有效结果,返回空,让LLM回答「暂无相关信息」
|
||||
if len(validDocs) == 0 {
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||
return validDocs, nil
|
||||
// =========================
|
||||
// 过滤无效文档
|
||||
// =========================
|
||||
const maxDistance = 0.8
|
||||
validDocs := make([]*schema.Document, 0, len(finalDocs))
|
||||
for _, doc := range finalDocs {
|
||||
dist := gconv.Float64(doc.MetaData["distance"])
|
||||
if dist <= maxDistance {
|
||||
validDocs = append(validDocs, doc)
|
||||
}
|
||||
}
|
||||
|
||||
// 最多保留 topK
|
||||
@@ -208,9 +203,15 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
|
||||
if opts.TopK != nil {
|
||||
topK = *opts.TopK
|
||||
}
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
var datasetIds, documentIds []int64
|
||||
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
|
||||
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
}
|
||||
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
|
||||
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
|
||||
}
|
||||
|
||||
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -236,10 +237,17 @@ func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string,
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
var datasetIds, documentIds []int64
|
||||
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
|
||||
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
}
|
||||
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
|
||||
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
|
||||
}
|
||||
|
||||
// 调用你已有的 Meilisearch DAO
|
||||
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user