118 lines
2.6 KiB
Go
118 lines
2.6 KiB
Go
package eino
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
|
||
"github.com/cloudwego/eino/callbacks"
|
||
"github.com/cloudwego/eino/components/embedding"
|
||
"github.com/cloudwego/eino/components/retriever"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
"github.com/pgvector/pgvector-go"
|
||
)
|
||
|
||
type PGVectorRetrieverConfig struct {
|
||
Embedder embedding.Embedder
|
||
DefaultTopK int
|
||
DefaultIndex string
|
||
}
|
||
|
||
type PGVectorRetriever struct {
|
||
embedder embedding.Embedder
|
||
topK int
|
||
index string
|
||
}
|
||
|
||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||
if config.Embedder == nil {
|
||
return nil, errors.New("embedder is required")
|
||
}
|
||
if config.DefaultTopK <= 0 {
|
||
config.DefaultTopK = 5
|
||
}
|
||
|
||
return &PGVectorRetriever{
|
||
embedder: config.Embedder,
|
||
topK: config.DefaultTopK,
|
||
index: config.DefaultIndex,
|
||
}, nil
|
||
}
|
||
|
||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||
|
||
// 1. 处理公共 Option(官方标准写法)
|
||
options := &retriever.Options{
|
||
Index: &r.index,
|
||
TopK: &r.topK,
|
||
Embedding: r.embedder,
|
||
}
|
||
options = retriever.GetCommonOptions(options, opts...)
|
||
|
||
// 2. 回调(官方标准)
|
||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||
Query: query,
|
||
TopK: *options.TopK,
|
||
})
|
||
|
||
// 3. 执行检索
|
||
docs, err := r.doRetrieve(ctx, query, options)
|
||
if err != nil {
|
||
callbacks.OnError(ctx, err)
|
||
return nil, err
|
||
}
|
||
|
||
// 4. 完成回调
|
||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
||
Docs: docs,
|
||
})
|
||
|
||
return docs, nil
|
||
}
|
||
|
||
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||
|
||
// 1. 生成向量
|
||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(vectors) == 0 {
|
||
return nil, errors.New("empty query vector")
|
||
}
|
||
|
||
queryVec := pgvector.NewVector(vectors[0])
|
||
topK := *opts.TopK
|
||
|
||
// 2. PG 向量相似度检索 SQL
|
||
sql := `
|
||
SELECT id, content, dataset_id, document_id,
|
||
vector <-> ? AS distance
|
||
FROM document_chunk
|
||
ORDER BY distance ASC
|
||
LIMIT ?
|
||
`
|
||
|
||
// 3. 查询
|
||
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 4. 转为 Eino Document
|
||
docs := make([]*schema.Document, 0, len(rows))
|
||
for _, row := range rows {
|
||
docs = append(docs, &schema.Document{
|
||
ID: gconv.String(row["id"]),
|
||
Content: gconv.String(row["content"]),
|
||
Metadata: map[string]any{
|
||
"dataset_id": row["dataset_id"],
|
||
"document_id": row["document_id"],
|
||
"distance": row["distance"],
|
||
},
|
||
})
|
||
}
|
||
|
||
return docs, nil
|
||
}
|