Files
rag/common/eino/retriever.go

118 lines
2.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"errors"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
if config.Embedder == nil {
return nil, errors.New("embedder is required")
}
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
return &PGVectorRetriever{
embedder: config.Embedder,
topK: config.DefaultTopK,
index: config.DefaultIndex,
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
// 1. 处理公共 Option官方标准写法
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
// 2. 回调(官方标准)
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// 3. 执行检索
docs, err := r.doRetrieve(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 4. 完成回调
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
Docs: docs,
})
return docs, nil
}
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
// 1. 生成向量
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
}
if len(vectors) == 0 {
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(vectors[0])
topK := *opts.TopK
// 2. PG 向量相似度检索 SQL
sql := `
SELECT id, content, dataset_id, document_id,
vector <-> ? AS distance
FROM document_chunk
ORDER BY distance ASC
LIMIT ?
`
// 3. 查询
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
if err != nil {
return nil, err
}
// 4. 转为 Eino Document
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
Metadata: map[string]any{
"dataset_id": row["dataset_id"],
"document_id": row["document_id"],
"distance": row["distance"],
},
})
}
return docs, nil
}