feat: 支持多数据库配置与PGVector检索
This commit is contained in:
@@ -44,13 +44,13 @@ func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, op
|
||||
// 回调
|
||||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||
|
||||
ids, err := i.bulkStore(ctx, docs, commonOpts)
|
||||
rows, err = i.bulkStore(ctx, docs, commonOpts)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(ids)})
|
||||
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
|
||||
return
|
||||
}
|
||||
|
||||
117
common/eino/retriever.go
Normal file
117
common/eino/retriever.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
type PGVectorRetrieverConfig struct {
|
||||
Embedder embedding.Embedder
|
||||
DefaultTopK int
|
||||
DefaultIndex string
|
||||
}
|
||||
|
||||
type PGVectorRetriever struct {
|
||||
embedder embedding.Embedder
|
||||
topK int
|
||||
index string
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
if config.Embedder == nil {
|
||||
return nil, errors.New("embedder is required")
|
||||
}
|
||||
if config.DefaultTopK <= 0 {
|
||||
config.DefaultTopK = 5
|
||||
}
|
||||
|
||||
return &PGVectorRetriever{
|
||||
embedder: config.Embedder,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 处理公共 Option(官方标准写法)
|
||||
options := &retriever.Options{
|
||||
Index: &r.index,
|
||||
TopK: &r.topK,
|
||||
Embedding: r.embedder,
|
||||
}
|
||||
options = retriever.GetCommonOptions(options, opts...)
|
||||
|
||||
// 2. 回调(官方标准)
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: query,
|
||||
TopK: *options.TopK,
|
||||
})
|
||||
|
||||
// 3. 执行检索
|
||||
docs, err := r.doRetrieve(ctx, query, options)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 完成回调
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
||||
Docs: docs,
|
||||
})
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 生成向量
|
||||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil, errors.New("empty query vector")
|
||||
}
|
||||
|
||||
queryVec := pgvector.NewVector(vectors[0])
|
||||
topK := *opts.TopK
|
||||
|
||||
// 2. PG 向量相似度检索 SQL
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <-> ? AS distance
|
||||
FROM document_chunk
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
// 3. 查询
|
||||
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 转为 Eino Document
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
Metadata: map[string]any{
|
||||
"dataset_id": row["dataset_id"],
|
||||
"document_id": row["document_id"],
|
||||
"distance": row["distance"],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
Reference in New Issue
Block a user