222 lines
6.4 KiB
Go
222 lines
6.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"rag/common/eino"
|
||
"rag/consts/model"
|
||
"rag/consts/task"
|
||
"rag/dao"
|
||
"rag/model/dto"
|
||
"rag/model/entity"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
"github.com/cloudwego/eino/components/retriever"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
"github.com/pgvector/pgvector-go"
|
||
)
|
||
|
||
var DocumentVector = new(documentVectorService)
|
||
|
||
type documentVectorService struct{}
|
||
|
||
// Query 执行RAG查询
|
||
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||
|
||
modelInfo, err := dao.Model.Get(ctx, &dto.GetModelReq{
|
||
ModelType: model.ModelTypeChat.Code(),
|
||
})
|
||
if err != nil {
|
||
g.Log().Errorf(ctx, "获取模型失败: %v", err)
|
||
return nil, fmt.Errorf("获取模型失败: %w", err)
|
||
}
|
||
if modelInfo == nil {
|
||
g.Log().Errorf(ctx, "模型不存在: %v", model.ModelTypeChat.Code())
|
||
return nil, fmt.Errorf("模型不存在: %w", err)
|
||
}
|
||
|
||
// 4. 使用向量检索器进行查询
|
||
r, err := eino.NewPGVectorRetriever(ctx, &eino.PGVectorRetrieverConfig{
|
||
DefaultTopK: req.TopK,
|
||
}, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||
if err != nil {
|
||
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||
}
|
||
|
||
// 5. 执行向量检索
|
||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithDSLInfo(map[string]any{
|
||
"dataset_ids": req.DatasetIds,
|
||
"document_ids": req.DocumentIds,
|
||
}))
|
||
if err != nil {
|
||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||
}
|
||
|
||
messages := make([]*schema.Message, 0)
|
||
err = gconv.Struct(req.History, &messages)
|
||
if err != nil {
|
||
g.Log().Errorf(ctx, "转换历史消息失败: %v", err)
|
||
return nil, fmt.Errorf("转换历史消息失败: %w", err)
|
||
}
|
||
|
||
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages, modelInfo.ConfigType)
|
||
if err != nil {
|
||
g.Log().Errorf(ctx, "向量检索失败: %v", err)
|
||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||
}
|
||
|
||
return &dto.RAGQueryRes{
|
||
Answer: replyMsg.Content,
|
||
}, nil
|
||
}
|
||
|
||
// Update 更新文件块
|
||
func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) {
|
||
_, err = dao.DocumentVector.Update(ctx, req)
|
||
return
|
||
}
|
||
|
||
// List 获取文件块列表
|
||
func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
|
||
list, total, err := dao.DocumentVector.List(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
res = &dto.ListDocumentVectorRes{
|
||
Total: total,
|
||
}
|
||
err = gconv.Struct(list, &res.List)
|
||
return
|
||
}
|
||
|
||
func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
||
var docs = make([]*schema.Document, 0)
|
||
msgMap := gconv.Map(msg)
|
||
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
|
||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||
return
|
||
}
|
||
if len(docs) == 0 {
|
||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||
return
|
||
}
|
||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
|
||
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
|
||
})
|
||
|
||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
|
||
|
||
var docsStore = make([]*schema.Document, 0)
|
||
var docsInsert = make([]*dto.VectorDocumentVectorMsg, 0)
|
||
for _, doc := range docs {
|
||
if gconv.Bool(doc.MetaData["isNew"]) {
|
||
docsStore = append(docsStore, doc)
|
||
} else {
|
||
ck := new(dto.VectorDocumentVectorMsg)
|
||
err = gconv.Struct(doc.MetaData, ck)
|
||
ck.Content = doc.Content
|
||
ck.VectorStatus = gconv.PtrInt8(1)
|
||
ck.Status = gconv.PtrInt8(1)
|
||
docsInsert = append(docsInsert, ck)
|
||
}
|
||
}
|
||
|
||
if !g.IsEmpty(docsStore) {
|
||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||
BatchSize: 10,
|
||
})
|
||
var rows int64
|
||
rows, err = idx.Store(ctx, docsStore, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||
|
||
if err != nil || rows == 0 {
|
||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
remark := " 向量存储数量: " + gconv.String(rows)
|
||
if err != nil {
|
||
remark = "向量存储失败: " + err.Error()
|
||
}
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: documentId,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: remark,
|
||
})
|
||
return
|
||
}
|
||
}
|
||
|
||
if !g.IsEmpty(docsInsert) {
|
||
// 1. 提取所有 contentHash
|
||
contentHashs := make([]string, 0, len(docsInsert))
|
||
for _, d := range docsInsert {
|
||
contentHashs = append(contentHashs, d.ContentHash)
|
||
}
|
||
|
||
// 2. 分页查询已存在的向量(一页1000,避免大查询)
|
||
var existVectors []*entity.DocumentVector
|
||
for page := 1; ; page++ {
|
||
res, total, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
|
||
Page: &beans.Page{PageSize: 1000, PageNum: int64(page)},
|
||
ContentHashs: contentHashs,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(res) == 0 {
|
||
break
|
||
}
|
||
existVectors = append(existVectors, res...)
|
||
if len(existVectors) >= total {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 3. 构建哈希 -> 向量 的映射表(O(1) 查找,性能提升巨大)
|
||
vectorMap := make(map[string]pgvector.Vector, len(existVectors))
|
||
for _, v := range existVectors {
|
||
vectorMap[v.ContentHash] = v.Vector
|
||
}
|
||
|
||
// 4. 回填向量 + 过滤掉数据库已存在的数据(避免重复插入)
|
||
for _, d := range docsInsert {
|
||
// 回填已有向量
|
||
if vec, ok := vectorMap[d.ContentHash]; ok {
|
||
d.Vector = vec
|
||
}
|
||
}
|
||
|
||
var rows int64
|
||
rows, err = dao.DocumentVector.BatchInsert(ctx, docsInsert)
|
||
|
||
if err != nil || rows == 0 {
|
||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
remark := " 向量存储数量: " + gconv.String(rows)
|
||
if err != nil {
|
||
remark = "向量存储失败: " + err.Error()
|
||
}
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: documentId,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: remark,
|
||
})
|
||
return
|
||
}
|
||
}
|
||
|
||
// 写入任务进度成功 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: documentId,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "向量生成完成",
|
||
})
|
||
return
|
||
}
|