package service import ( "context" "fmt" "rag/common/eino" "rag/consts/task" "rag/dao" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/beans" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/retriever" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) var DocumentVector = new(documentVectorService) type documentVectorService struct{} // Query 执行RAG查询 func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) { if req.TopK <= 0 { req.TopK = 5 } // 4. 使用向量检索器进行查询 r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{ Embedder: eino.EmbedderDashscope, DefaultTopK: req.TopK, }) if err != nil { g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err) return nil, fmt.Errorf("初始化向量检索器失败: %w", err) } // 5. 执行向量检索 docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{ "dataset_ids": req.DatasetIds, })) if err != nil { g.Log().Errorf(ctx, "向量检索失败: %v", err) return nil, fmt.Errorf("向量检索失败: %w", err) } messages := make([]*schema.Message, 0) err = gconv.Struct(req.History, &messages) if err != nil { g.Log().Errorf(ctx, "转换历史消息失败: %v", err) return nil, fmt.Errorf("转换历史消息失败: %w", err) } replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages) if err != nil { g.Log().Errorf(ctx, "向量检索失败: %v", err) return nil, fmt.Errorf("向量检索失败: %w", err) } return &dto.RAGQueryRes{ Answer: replyMsg.Content, }, nil } // Update 更新文件块 func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) { _, err = dao.DocumentVector.Update(ctx, req) return } // List 获取文件块列表 func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) { list, total, err := dao.DocumentVector.List(ctx, req) if err != nil { return } res = &dto.ListDocumentVectorRes{ Total: total, } err = gconv.Struct(list, &res.List) return } func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) { var docs = make([]*schema.Document, 0) msgMap := gconv.Map(msg) if err = gconv.Structs(msgMap["data"], &docs); err != nil { g.Log().Error(ctx, "DocsChunkMsg err:", err) return } if len(docs) == 0 { g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty") return } ctx = context.WithValue(ctx, "user", &beans.User{ TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]), UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]), }) idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{ BatchSize: 10, }) documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId]) rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope)) if err != nil || rows == 0 { g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err) // 写入任务进度失败 任务类型为sql存储 remark := " 向量存储数量: " + gconv.String(rows) if err != nil { remark = "向量存储失败: " + err.Error() } err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ TaskId: documentId, TaskType: task.TaskTypeGenerateVector, Status: task.TaskStatusFailed, Remark: remark, }) return } // 写入任务进度成功 任务类型为sql存储 err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ TaskId: documentId, TaskType: task.TaskTypeGenerateVector, Status: task.TaskStatusCompleted, Remark: "向量生成完成", }) return }