826 lines
24 KiB
Go
826 lines
24 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"rag/common/eino"
|
||
"rag/consts/document"
|
||
"rag/consts/keyword"
|
||
"rag/consts/model"
|
||
"rag/consts/public"
|
||
"rag/consts/task"
|
||
"rag/dao"
|
||
"rag/model/dto"
|
||
"rag/model/entity"
|
||
"strings"
|
||
"time"
|
||
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||
"gitea.com/red-future/common/utils"
|
||
gmq "github.com/bjang03/gmq/core/gmq"
|
||
"github.com/bjang03/gmq/mq"
|
||
"github.com/bjang03/gmq/types"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/gogf/gf/v2/crypto/gmd5"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/os/grpool"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var Document = new(documentService)
|
||
|
||
type documentService struct{}
|
||
|
||
// Create 创建文件
|
||
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
|
||
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||
doc, err := dao.Document.Get(ctx, &dto.GetDocumentReq{
|
||
DatasetId: req.DatasetId,
|
||
Title: req.Title,
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
if !g.IsEmpty(doc) && doc.Id > 0 {
|
||
_, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
|
||
DocumentId: doc.Id,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
_, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
|
||
DocumentId: doc.Id,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
_, err = dao.Document.Delete(ctx, &dto.DeleteDocumentReq{
|
||
Id: doc.Id,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
var id int64
|
||
id, err = dao.Document.Insert(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
datasetReq := &dto.UpdateDatasetReq{
|
||
Id: req.DatasetId,
|
||
DocumentCount: 1,
|
||
DocumentSize: req.FileSize,
|
||
}
|
||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||
if err != nil {
|
||
return
|
||
}
|
||
res = &dto.CreateDocumentRes{Id: id}
|
||
// 写入任务进度进行中 任务类型为文档解析
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: id,
|
||
TaskType: task.TaskTypeDocParse,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "文档上传完成",
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
return
|
||
})
|
||
|
||
return
|
||
}
|
||
|
||
// Update 更新文件
|
||
func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq) (err error) {
|
||
_, err = dao.Document.Update(ctx, req)
|
||
return
|
||
}
|
||
|
||
// Delete 删除文件
|
||
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
|
||
docs, err := dao.Document.Get(ctx, &dto.GetDocumentReq{Id: req.Id})
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
|
||
datasetReq := &dto.UpdateDatasetReq{
|
||
Id: docs.DatasetId,
|
||
DocumentCount: -1,
|
||
DocumentSize: -docs.FileSize,
|
||
}
|
||
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
|
||
return
|
||
}
|
||
|
||
if _, err = dao.Document.Delete(ctx, req); err != nil {
|
||
return
|
||
}
|
||
|
||
if _, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
|
||
DocumentId: docs.Id,
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
|
||
if _, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
|
||
DocumentId: docs.Id,
|
||
}); err != nil {
|
||
return err
|
||
}
|
||
|
||
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||
TaskId: docs.Id,
|
||
}); err != nil {
|
||
return
|
||
}
|
||
|
||
return
|
||
})
|
||
|
||
return
|
||
}
|
||
|
||
// Get 获取文件详情
|
||
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
|
||
r, err := dao.Document.Get(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
res = &dto.GetDocumentRes{}
|
||
err = gconv.Struct(r, &res.DocumentVO)
|
||
if err != nil {
|
||
return
|
||
}
|
||
res.ImgAddressPrefix, err = utils.GetFileAddressPrefix(ctx)
|
||
return
|
||
}
|
||
|
||
// List 文件列表
|
||
func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto.ListDocumentRes, err error) {
|
||
list, total, err := dao.Document.List(ctx, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
res = &dto.ListDocumentRes{
|
||
Total: total,
|
||
}
|
||
err = gconv.Struct(list, &res.List)
|
||
return
|
||
}
|
||
|
||
func (s *documentService) VectorSemanticSplit(ctx context.Context, req *dto.VectorSemanticSplitReq) (err error) {
|
||
// 1. 查询文件信息
|
||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||
doc, err := dao.Document.Get(ctx, &documentReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if g.IsEmpty(doc) {
|
||
return errors.New("document not found")
|
||
}
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: req.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusRunning,
|
||
Remark: "向量化执行中",
|
||
})
|
||
return s.semanticSplitDocument(ctx, doc)
|
||
}
|
||
|
||
func (s *documentService) SearchRecursiveSplit(ctx context.Context, req *dto.SearchRecursiveSplitReq) (err error) {
|
||
// 1. 查询文件信息
|
||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||
doc, err := dao.Document.Get(ctx, &documentReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if g.IsEmpty(doc) {
|
||
return errors.New("document not found")
|
||
}
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: req.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusRunning,
|
||
Remark: "全文检索执行中",
|
||
})
|
||
return s.recursiveSplitDocument(ctx, doc)
|
||
}
|
||
|
||
func (s *documentService) KeywordExtract(ctx context.Context, req *dto.KeywordExtractReq) (err error) {
|
||
// 1. 查询文件信息
|
||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||
doc, err := dao.Document.Get(ctx, &documentReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if g.IsEmpty(doc) {
|
||
return errors.New("document not found")
|
||
}
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: req.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusRunning,
|
||
Remark: "提取关键词执行中",
|
||
})
|
||
return s.extractDocument(ctx, doc)
|
||
}
|
||
|
||
// Vector 处理文件(使用eino框架切分和向量化)
|
||
func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq) (err error) {
|
||
// 更新文档状态为处理中
|
||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||
updateDocumentReq.Id = req.Id
|
||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||
// 写入任务进度失败 任务类型为文档解析
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: req.Id,
|
||
TaskType: task.TaskTypeDocParse,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "更新文档状态失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
user, err := utils.GetUserInfo(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 使用带超时的background context,避免HTTP请求完成后context被取消
|
||
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||
taskCtx = context.WithValue(taskCtx, "user", user)
|
||
// 任务1: 语义 切分文档
|
||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||
g.TryCatch(ctx, func(ctx context.Context) {
|
||
if innerErr := s.VectorSemanticSplit(ctx, &dto.VectorSemanticSplitReq{Id: req.Id}); innerErr != nil {
|
||
cancel()
|
||
}
|
||
}, func(ctx context.Context, err error) {
|
||
cancel()
|
||
})
|
||
})
|
||
|
||
// 任务2: 递归 切分文档
|
||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||
g.TryCatch(ctx, func(ctx context.Context) {
|
||
if innerErr := s.SearchRecursiveSplit(ctx, &dto.SearchRecursiveSplitReq{Id: req.Id}); innerErr != nil {
|
||
cancel()
|
||
}
|
||
}, func(ctx context.Context, err error) {
|
||
cancel()
|
||
})
|
||
})
|
||
|
||
// 任务3: 提取文档
|
||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||
g.TryCatch(ctx, func(ctx context.Context) {
|
||
if innerErr := s.KeywordExtract(ctx, &dto.KeywordExtractReq{Id: req.Id}); innerErr != nil {
|
||
cancel()
|
||
}
|
||
}, func(ctx context.Context, err error) {
|
||
cancel()
|
||
})
|
||
})
|
||
|
||
return nil
|
||
}
|
||
|
||
// extractDocument 关键词提取(支持取消)
|
||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||
// ========== 取消检查 1:方法入口 ==========
|
||
if ctx.Err() != nil {
|
||
// 写入任务进度失败 任务类型为关键字存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
// 1. 加载文件
|
||
docs, err := s.loadDocument(ctx, doc)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为关键字存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "加载文件失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
var words []utils.Keyword
|
||
if len(docs[0].Content) < 500 {
|
||
words = utils.GseTool.Extract(docs[0].Content, 4)
|
||
} else if len(docs[0].Content) < 2000 {
|
||
words = utils.GseTool.Extract(docs[0].Content, 8)
|
||
} else if len(docs[0].Content) < 5000 {
|
||
words = utils.GseTool.Extract(docs[0].Content, 13)
|
||
} else {
|
||
var docsSplit []*schema.Document
|
||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||
if err != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "递归分割文档失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// ========== 取消检查 2:循环内部 ==========
|
||
for _, t := range docsSplit {
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
|
||
}
|
||
}
|
||
|
||
// ========== 取消检查 3:批量操作前 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||
for _, word := range words {
|
||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||
DatasetId: doc.DatasetId,
|
||
DocumentId: doc.Id,
|
||
Word: word.Word,
|
||
Weight: gconv.Int16(word.Score),
|
||
KeywordType: keyword.KeywordTypeInitial.Code(),
|
||
})
|
||
}
|
||
if len(keywordReqs) > 0 {
|
||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为关键字存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "关键字存储失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// 写入任务进度已完成 任务类型为关键字存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "关键字提取完成",
|
||
})
|
||
} else {
|
||
// 写入任务进度已完成 任务类型为关键字存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeExtractKeywords,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "没有提取到关键词,关键字提取完成",
|
||
})
|
||
}
|
||
return
|
||
}
|
||
|
||
// semanticSplitDocument 语义切分
|
||
func (s *documentService) semanticSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||
// ========== 取消检查 1:方法入口 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
// 1. 加载文件
|
||
docs, err := s.loadDocument(ctx, doc)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "加载文件失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 2. 语义切分文件
|
||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "文档切分失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 2. 获取历史数据
|
||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "获取历史数据失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 3. 组装向量文档
|
||
var docsChunk = make([]*schema.Document, 0)
|
||
for i, t := range docsSplit {
|
||
// ========== 取消检查 2:循环内部 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
contentHash := gmd5.MustEncryptString(t.Content)
|
||
var isNew, needCopy bool
|
||
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashSqlKey, contentHash, doc.Id)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "检查重复数据失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
if !isNew && !needCopy {
|
||
continue
|
||
}
|
||
var metaData = make(map[string]any)
|
||
metaData[entity.DocumentCol.TenantId] = doc.TenantId
|
||
metaData[entity.DocumentCol.Creator] = doc.Creator
|
||
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
|
||
metaData[entity.DocumentVectorCol.DocumentId] = doc.Id
|
||
metaData[entity.DocumentVectorCol.ContentHash] = contentHash
|
||
metaData[entity.DocumentVectorCol.ChunkIndex] = gconv.Int64(i + 1)
|
||
if isNew {
|
||
metaData["isNew"] = true
|
||
}
|
||
if needCopy {
|
||
metaData["isNew"] = false
|
||
}
|
||
t.MetaData = metaData
|
||
docsChunk = append(docsChunk, t)
|
||
}
|
||
|
||
// ========== 取消检查 3:批量发送前 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
// 4. 发送消息到队列
|
||
if len(docsChunk) > 0 {
|
||
err = gmq.GetGmq(public.GmqMsgPluginsName).GmqPublish(ctx, &mq.RedisPubMessage{
|
||
PubMessage: types.PubMessage{
|
||
Topic: public.KnowledgeDocumentVectorTopic,
|
||
Data: docsChunk,
|
||
},
|
||
})
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "发送消息到队列失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// 写入任务进度进行中 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusRunning,
|
||
Remark: "向量生成任务已提交到队列",
|
||
})
|
||
} else {
|
||
// 写入任务进度已完成 任务类型为sql存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeGenerateVector,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "无需生成向量,任务完成",
|
||
})
|
||
}
|
||
return
|
||
}
|
||
|
||
// recursiveSplitDocument 递归切分
|
||
func (s *documentService) recursiveSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||
// ========== 取消检查 1:方法入口 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
// 1. 加载文件
|
||
docs, err := s.loadDocument(ctx, doc)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为es存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "加载文件失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 2. 递归切分文件
|
||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为es存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "文档切分失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 2. 获取历史数据
|
||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为es存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "获取历史数据失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
|
||
// 3. 组装向量文档并同时构建meilisearch文档
|
||
var meiliDocs = make([]interface{}, 0)
|
||
for i, t := range docsSplit {
|
||
// ========== 取消检查 2:循环内部 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
contentHash := gmd5.MustEncryptString(t.Content)
|
||
var isNew, needCopy bool
|
||
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashEsKey, contentHash, doc.Id)
|
||
if err != nil {
|
||
// 写入任务进度失败 任务类型为es存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "检查重复数据失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
if !isNew && !needCopy {
|
||
continue
|
||
}
|
||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||
entity.DocumentVectorCol.Id: contentHash,
|
||
entity.DocumentVectorCol.DatasetId: doc.DatasetId,
|
||
entity.DocumentVectorCol.DocumentId: doc.Id,
|
||
entity.DocumentVectorCol.Content: t.Content,
|
||
entity.DocumentVectorCol.ContentHash: contentHash,
|
||
entity.DocumentVectorCol.ChunkIndex: i + 1,
|
||
})
|
||
}
|
||
|
||
// ========== 取消检查 3:批量写入前 ==========
|
||
if ctx.Err() != nil {
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||
})
|
||
return ctx.Err()
|
||
}
|
||
|
||
// 4. 写入到meilisearch数据库中
|
||
if len(meiliDocs) > 0 {
|
||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||
// 写入任务进度失败 任务类型为meilisearch存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusFailed,
|
||
Remark: "写入meilisearch失败: " + err.Error(),
|
||
})
|
||
return
|
||
}
|
||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "全文检索数据写入完成",
|
||
})
|
||
} else {
|
||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||
TaskId: doc.Id,
|
||
TaskType: task.TaskTypeFullTextSearch,
|
||
Status: task.TaskStatusCompleted,
|
||
Remark: "无需生成全文检索数据,任务完成",
|
||
})
|
||
}
|
||
return
|
||
}
|
||
|
||
// loadDocument 加载文件
|
||
func (s *documentService) loadDocument(ctx context.Context, doc *entity.Document) (docs []*schema.Document, err error) {
|
||
return eino.LoadDocument(ctx, doc.FilePath, doc.Format)
|
||
}
|
||
|
||
// getHistoryData 获取历史数据
|
||
func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Document, lockKey, contentKey string) (err error) {
|
||
docsLockKey := fmt.Sprintf(lockKey, doc.DatasetId)
|
||
success, err := utils.Lock(ctx, docsLockKey, int64(60), func(ctx context.Context) error {
|
||
// 1. 扫描 Redis 中所有 前缀为 rag:knowledge:xxx:contentHash 的 key
|
||
pattern := fmt.Sprintf(contentKey, "*")
|
||
keys, err := g.Redis().Keys(ctx, pattern)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 2. Redis 有数据:只刷新过期时间,不查库
|
||
if len(keys) > 0 {
|
||
// 批量刷新过期时间为 60s
|
||
for _, key := range keys {
|
||
_, err = g.Redis().Expire(ctx, key, 600)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
|
||
var dictData = make([]*dto.DocumentVectorRPC, 0)
|
||
|
||
if public.KnowledgeContentHashSqlKey == contentKey {
|
||
// SQL 方式:调用 HTTP 接口查询
|
||
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
|
||
} else {
|
||
// ES 方式:查询 meilisearch
|
||
dictData, err = s.getHistoryDataFromMeilisearch(ctx, doc)
|
||
}
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, item := range dictData {
|
||
contentHash := strings.Trim(item.ContentHash, `"`)
|
||
key := fmt.Sprintf(contentKey, contentHash)
|
||
// SAdd:把文档ID加入集合(自动去重,可存多个)
|
||
_, err = g.Redis().SAdd(ctx, key, item.DocumentId)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 设置过期时间
|
||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||
}
|
||
|
||
return nil
|
||
})
|
||
if err != nil && !success {
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
|
||
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
|
||
// 调用接口获取数据
|
||
res, _, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
|
||
DatasetId: doc.DatasetId,
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = gconv.Struct(res, &dictData)
|
||
return
|
||
}
|
||
|
||
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
|
||
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
|
||
// 构建 meilisearch 查询参数
|
||
searchParams := &meilisearch.SearchParams{
|
||
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
|
||
Limit: 10000,
|
||
}
|
||
|
||
// 执行搜索
|
||
var hits []map[string]interface{}
|
||
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 转换查询结果
|
||
dictData = make([]*dto.DocumentVectorRPC, 0)
|
||
for _, hit := range hits {
|
||
item := &dto.DocumentVectorRPC{}
|
||
if err = gconv.Struct(hit, item); err != nil {
|
||
return
|
||
}
|
||
dictData = append(dictData, item)
|
||
}
|
||
return
|
||
}
|
||
|
||
// checkRepeatWithDocId 正确版:检查当前文档是否已存在该分片
|
||
// 返回:isNew(是否需要生成向量)、isCrossDoc(是否跨文档需拷贝)、err
|
||
func (s *documentService) checkRepeatWithDocId(ctx context.Context, contentKey string, contentHash string, currentDocId int64) (isNew bool, needCopy bool, err error) {
|
||
key := fmt.Sprintf(contentKey, contentHash)
|
||
|
||
// 1. 检查当前文档ID是否在集合中
|
||
exists, err := g.Redis().SIsMember(ctx, key, currentDocId)
|
||
if err != nil {
|
||
return false, false, err
|
||
}
|
||
|
||
// 情况1:当前文档已存在 → 完全跳过,不生成、不拷贝
|
||
if !g.IsEmpty(exists) {
|
||
return false, false, nil
|
||
}
|
||
|
||
// 2. 检查 key 是否存在(是否有任何文档拥有该分片)
|
||
keyExists, err := g.Redis().Exists(ctx, key)
|
||
if err != nil {
|
||
return false, false, err
|
||
}
|
||
|
||
// 情况2:key 不存在 = 全新数据 → 需要生成向量
|
||
if g.IsEmpty(keyExists) {
|
||
// 把当前文档ID加入集合
|
||
_, err = g.Redis().SAdd(ctx, key, currentDocId)
|
||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||
return true, false, err
|
||
}
|
||
|
||
// 情况3:key 存在,但当前文档不在集合中 = 跨文档重复 → 不生成,需拷贝
|
||
// 把当前文档ID加入集合(记录归属关系)
|
||
_, err = g.Redis().SAdd(ctx, key, currentDocId)
|
||
_, _ = g.Redis().Expire(ctx, key, 600)
|
||
return false, true, err
|
||
}
|