refactor: 重构文档向量相关代码结构

This commit is contained in:
2026-04-10 13:12:19 +08:00
parent a7b8713e26
commit 94df015aa9
30 changed files with 335 additions and 506 deletions

View File

@@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"rag/common/eino"
"rag/common/task"
"rag/consts/document"
"rag/consts/public"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
@@ -123,8 +123,8 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
return
}
// Process 处理文件(使用eino框架切分和向量化)
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
// Vector 处理文件(使用eino框架切分和向量化)
func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.GetByID(ctx, &documentReq)
@@ -403,9 +403,9 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
metaData[entity.DocumentCol.TenantId] = doc.TenantId
metaData[entity.DocumentCol.Creator] = doc.Creator
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
metaData[entity.DocumentChunkCol.DocumentId] = doc.Id
metaData[entity.DocumentChunkCol.ContentHash] = contentHash
metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i)
metaData[entity.DocumentVectorCol.DocumentId] = doc.Id
metaData[entity.DocumentVectorCol.ContentHash] = contentHash
metaData[entity.DocumentVectorCol.ChunkIndex] = gconv.Int64(i)
t.MetaData = metaData
docsChunk = append(docsChunk, t)
}
@@ -423,9 +423,9 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
// 4. 发送消息到队列
if len(docsChunk) > 0 {
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
err = gmq.GetGmq(public.GmqMsgPluginsName).GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentChunkTopic,
Topic: public.KnowledgeDocumentVectorTopic,
Data: docsChunk,
},
})
@@ -541,12 +541,12 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
continue
}
meiliDocs = append(meiliDocs, map[string]interface{}{
entity.DocumentChunkCol.Id: contentHash,
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
entity.DocumentChunkCol.DocumentId: doc.Id,
entity.DocumentChunkCol.Content: t.Content,
entity.DocumentChunkCol.ContentHash: contentHash,
entity.DocumentChunkCol.ChunkIndex: i,
entity.DocumentVectorCol.Id: contentHash,
entity.DocumentVectorCol.DatasetId: doc.DatasetId,
entity.DocumentVectorCol.DocumentId: doc.Id,
entity.DocumentVectorCol.Content: t.Content,
entity.DocumentVectorCol.ContentHash: contentHash,
entity.DocumentVectorCol.ChunkIndex: i,
})
}
@@ -621,7 +621,7 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
}
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
var dictData = make([]*dto.DocumentChunkRPC, 0)
var dictData = make([]*dto.DocumentVectorRPC, 0)
if public.KnowledgeContentHashSqlKey == contentKey {
// SQL 方式:调用 HTTP 接口查询
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
@@ -658,9 +658,9 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
}
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
// 调用接口获取数据
res, _, err := dao.DocumentChunk.List(ctx, &dto.ListDocumentChunkReq{
res, _, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
DatasetId: doc.DatasetId,
Status: gconv.PtrInt8(1),
})
@@ -669,7 +669,7 @@ func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entit
}
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
@@ -684,9 +684,9 @@ func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc
}
// 转换查询结果
dictData = make([]*dto.DocumentChunkRPC, 0)
dictData = make([]*dto.DocumentVectorRPC, 0)
for _, hit := range hits {
item := &dto.DocumentChunkRPC{}
item := &dto.DocumentVectorRPC{}
if err = gconv.Struct(hit, item); err != nil {
return
}

View File

@@ -1,84 +0,0 @@
package service
import (
"context"
"rag/common/eino"
"rag/common/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/beans"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var DocumentChunk = new(documentChunkService)
type documentChunkService struct{}
// Update 更新文件块
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
_, err = dao.DocumentChunk.Update(ctx, req)
return
}
// List 获取文件块列表
func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentChunkReq) (res *dto.ListDocumentChunkRes, err error) {
list, total, err := dao.DocumentChunk.List(ctx, req)
if err != nil {
return
}
res = &dto.ListDocumentChunkRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
var docs = make([]*schema.Document, 0)
msgMap := gconv.Map(msg)
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
g.Log().Error(ctx, "DocsChunkMsg err:", err)
return
}
if len(docs) == 0 {
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]),
UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]),
})
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
// 写入任务进度成功 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "向量生成完成",
})
return
}

129
service/document_vector.go Normal file
View File

@@ -0,0 +1,129 @@
package service
import (
"context"
"fmt"
"rag/common/eino"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/beans"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var DocumentVector = new(documentVectorService)
type documentVectorService struct{}
// Query 执行RAG查询
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
if req.TopK <= 0 {
req.TopK = 5
}
// 4. 使用向量检索器进行查询
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
Embedder: eino.EmbedderDashscope,
DefaultTopK: req.TopK,
})
if err != nil {
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
}
// 5. 执行向量检索
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
"dataset_ids": req.DatasetIds,
}))
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
messages := make([]*schema.Message, 0)
err = gconv.Struct(req.History, &messages)
if err != nil {
g.Log().Errorf(ctx, "转换历史消息失败: %v", err)
return nil, fmt.Errorf("转换历史消息失败: %w", err)
}
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
}, nil
}
// Update 更新文件块
func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) {
_, err = dao.DocumentVector.Update(ctx, req)
return
}
// List 获取文件块列表
func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
list, total, err := dao.DocumentVector.List(ctx, req)
if err != nil {
return
}
res = &dto.ListDocumentVectorRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
var docs = make([]*schema.Document, 0)
msgMap := gconv.Map(msg)
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
g.Log().Error(ctx, "DocsChunkMsg err:", err)
return
}
if len(docs) == 0 {
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
})
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
// 写入任务进度成功 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "向量生成完成",
})
return
}

View File

@@ -1,60 +0,0 @@
package service
import (
"context"
"fmt"
"rag/common/eino"
"rag/model/dto"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
var RAGQuery = new(ragQueryService)
type ragQueryService struct{}
// Query 执行RAG查询
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
if req.TopK <= 0 {
req.TopK = 5
}
// 4. 使用向量检索器进行查询
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
Embedder: eino.EmbedderDashscope,
DefaultTopK: req.TopK,
})
if err != nil {
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
}
// 5. 执行向量检索
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
"dataset_ids": req.DatasetIds,
}))
if err != nil {
glog.Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
messages := make([]*schema.Message, 0)
err = gconv.Struct(req.History, &messages)
if err != nil {
glog.Errorf(ctx, "转换历史消息失败: %v", err)
return nil, fmt.Errorf("转换历史消息失败: %w", err)
}
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
if err != nil {
glog.Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
}, nil
}

View File

@@ -2,11 +2,10 @@ package service
import (
"context"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/common/task"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"