refactor: 重构文档向量相关代码结构
This commit is contained in:
129
service/document_vector.go
Normal file
129
service/document_vector.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var DocumentVector = new(documentVectorService)
|
||||
|
||||
type documentVectorService struct{}
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||
if req.TopK <= 0 {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
// 4. 使用向量检索器进行查询
|
||||
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||
Embedder: eino.EmbedderDashscope,
|
||||
DefaultTopK: req.TopK,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 执行向量检索
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
}))
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
messages := make([]*schema.Message, 0)
|
||||
err = gconv.Struct(req.History, &messages)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "转换历史消息失败: %v", err)
|
||||
return nil, fmt.Errorf("转换历史消息失败: %w", err)
|
||||
}
|
||||
|
||||
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.RAGQueryRes{
|
||||
Answer: replyMsg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Update 更新文件块
|
||||
func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) {
|
||||
_, err = dao.DocumentVector.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// List 获取文件块列表
|
||||
func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
|
||||
list, total, err := dao.DocumentVector.List(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.ListDocumentVectorRes{
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
||||
var docs = make([]*schema.Document, 0)
|
||||
msgMap := gconv.Map(msg)
|
||||
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
return
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
|
||||
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
|
||||
})
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
|
||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度成功 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "向量生成完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user