53 lines
1.3 KiB
Go
53 lines
1.3 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"rag/common/eino"
|
|
"rag/model/dto"
|
|
|
|
"github.com/cloudwego/eino/components/retriever"
|
|
"github.com/gogf/gf/v2/os/glog"
|
|
)
|
|
|
|
var RAGQuery = new(ragQueryService)
|
|
|
|
type ragQueryService struct{}
|
|
|
|
// Query 执行RAG查询
|
|
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
|
if req.TopK <= 0 {
|
|
req.TopK = 5
|
|
}
|
|
|
|
// 4. 使用向量检索器进行查询
|
|
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
|
Embedder: eino.EmbedderDashscope,
|
|
DefaultTopK: req.TopK,
|
|
})
|
|
if err != nil {
|
|
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
|
|
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
|
}
|
|
|
|
// 5. 执行向量检索
|
|
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
|
"dataset_ids": req.DatasetIds,
|
|
}))
|
|
if err != nil {
|
|
glog.Errorf(ctx, "向量检索失败: %v", err)
|
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
|
}
|
|
|
|
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
|
|
if err != nil {
|
|
glog.Errorf(ctx, "向量检索失败: %v", err)
|
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
|
}
|
|
|
|
return &dto.RAGQueryRes{
|
|
Answer: replyMsg.Content,
|
|
Sources: sources,
|
|
}, nil
|
|
}
|