Files
rag/service/document.go
2026-04-09 15:58:05 +08:00

719 lines
21 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"fmt"
"rag/common/eino"
"rag/common/task"
"rag/consts/document"
"rag/consts/public"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"strings"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/crypto/gmd5"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
)
var Document = new(documentService)
type documentService struct{}
// Create 创建文件
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
var id int64
id, err = dao.Document.Insert(ctx, req)
if err != nil {
return
}
datasetReq := &dto.UpdateDatasetReq{
Id: req.DatasetId,
DocumentCount: 1,
DocumentSize: req.FileSize,
}
_, err = dao.Dataset.Update(ctx, datasetReq)
if err != nil {
return
}
res = &dto.CreateDocumentRes{Id: id}
// 写入任务进度待处理 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusPending,
Remark: "文档上传成功待解析: " + req.Title,
})
return
})
return
}
// Update 更新文件
func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq) (err error) {
_, err = dao.Document.Update(ctx, req)
return
}
// Delete 删除文件
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id})
if err != nil {
return
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
datasetReq := &dto.UpdateDatasetReq{
Id: docs.DatasetId,
DocumentCount: -1,
DocumentSize: -docs.FileSize,
}
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
return
}
if _, err = dao.Document.Delete(ctx, req); err != nil {
return
}
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
TaskId: docs.Id,
}); err != nil {
return
}
return
})
return
}
// Get 获取文件详情
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
r, err := dao.Document.GetByID(ctx, req)
err = gconv.Struct(r, &res)
return
}
// List 文件列表
func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto.ListDocumentRes, err error) {
list, total, err := dao.Document.List(ctx, req)
if err != nil {
return nil, err
}
res = &dto.ListDocumentRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
// Process 处理文件(使用eino框架切分和向量化)
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.GetByID(ctx, &documentReq)
if err != nil {
return err
}
if g.IsEmpty(doc) {
return errors.New("document not found")
}
// 2. 更新文档状态为处理中
updateDocumentReq := new(dto.UpdateDocumentReq)
updateDocumentReq.Id = req.Id
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
// 写入任务进度失败 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusFailed,
Remark: "更新文档状态失败: " + err.Error(),
})
return
}
// 写入任务进度进行中 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusRunning,
Remark: "文档解析开始",
})
if err != nil {
return
}
// ======================
// 核心grpool + g.Try 最佳实践
// ======================
taskCtx, cancel := context.WithCancel(ctx)
// 任务1: SQL 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
// 任务2: ES 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
// 任务3: 提取文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
return nil
}
// extractDocument 关键词提取(支持取消)
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
var words []utils.Keyword
if len(docs[0].Content) < 500 {
words = utils.GseTool.Extract(docs[0].Content, 4)
} else if len(docs[0].Content) < 2000 {
words = utils.GseTool.Extract(docs[0].Content, 8)
} else if len(docs[0].Content) < 5000 {
words = utils.GseTool.Extract(docs[0].Content, 13)
} else {
var docsSplit []*schema.Document
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "递归分割文档失败: " + err.Error(),
})
return
}
// ========== 取消检查 2循环内部 ==========
for _, t := range docsSplit {
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
}
}
// ========== 取消检查 3批量操作前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
for _, word := range words {
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
DatasetId: doc.DatasetId,
DocumentId: doc.Id,
Word: word.Word,
Weight: gconv.Int16(word.Score),
})
}
if len(keywordReqs) > 0 {
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "关键字存储失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "关键字提取完成",
})
} else {
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "没有提取到关键词,关键字提取完成",
})
}
return
}
// sqlSplitDocument SQL切分支持取消
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 语义切分文件
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档
var docsChunk = make([]*schema.Document, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
continue
}
var metaData = make(map[string]any)
metaData[entity.DocumentCol.TenantId] = doc.TenantId
metaData[entity.DocumentCol.Creator] = doc.Creator
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
metaData[entity.DocumentChunkCol.DocumentId] = doc.Id
metaData[entity.DocumentChunkCol.ContentHash] = contentHash
metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i)
t.MetaData = metaData
docsChunk = append(docsChunk, t)
}
// ========== 取消检查 3批量发送前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 发送消息到队列
if len(docsChunk) > 0 {
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentChunkTopic,
Data: docsChunk,
},
})
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "发送消息到队列失败: " + err.Error(),
})
return
}
// 写入任务进度进行中 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusRunning,
Remark: "向量生成任务已提交到队列",
})
} else {
// 写入任务进度已完成 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "无需生成向量,任务完成",
})
}
return
}
// esSplitDocument ES切分支持取消
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 递归切分文件
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档并同时构建meilisearch文档
var meiliDocs = make([]interface{}, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
continue
}
meiliDocs = append(meiliDocs, map[string]interface{}{
entity.DocumentChunkCol.Id: contentHash,
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
entity.DocumentChunkCol.DocumentId: doc.Id,
entity.DocumentChunkCol.Content: t.Content,
entity.DocumentChunkCol.ContentHash: contentHash,
entity.DocumentChunkCol.ChunkIndex: i,
})
}
// ========== 取消检查 3批量写入前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 写入到meilisearch数据库中
if len(meiliDocs) > 0 {
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
// 写入任务进度失败 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "写入meilisearch失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "全文检索数据写入完成",
})
} else {
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "无需生成全文检索数据,任务完成",
})
}
return
}
// loadDocument 加载文件
func (s *documentService) loadDocument(ctx context.Context, doc *entity.Document) (docs []*schema.Document, err error) {
return eino.LoadDocument(ctx, doc.FilePath, doc.Format)
}
// getHistoryData 获取历史数据
func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Document, lockKey, contentKey string) (err error) {
docsLockKey := fmt.Sprintf(lockKey, doc.DatasetId)
success, err := utils.Lock(ctx, docsLockKey, int64(60), func(ctx context.Context) error {
// 1. 扫描 Redis 中所有 前缀为 rag_binary:knowledge:xxx:contentHash 的 key
pattern := fmt.Sprintf(contentKey, "*")
keys, err := g.Redis().Keys(ctx, pattern)
if err != nil {
return err
}
// 2. Redis 有数据:只刷新过期时间,不查库
if len(keys) > 0 {
// 批量刷新过期时间为 60s
for _, key := range keys {
_, err = g.Redis().Expire(ctx, key, 600)
if err != nil {
return err
}
}
return nil
}
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
var dictData = make([]*dto.DocumentChunkRPC, 0)
if public.KnowledgeContentHashSqlKey == contentKey {
// SQL 方式:调用 HTTP 接口查询
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
} else {
// ES 方式:查询 meilisearch
dictData, err = s.getHistoryDataFromMeilisearch(ctx, doc)
}
if err != nil {
return err
}
// 4. 把查询到的数据写入 Redis600s过期
for _, item := range dictData {
// 去除可能的 JSON 引号
contentHash := strings.Trim(item.ContentHash, `"`)
key := fmt.Sprintf(contentKey, contentHash)
_, err = g.Redis().Set(ctx, key, true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
})
if err != nil {
return err
}
}
return nil
})
if err != nil && !success {
return
}
return
}
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
// 调用接口获取数据
d := &dto.ListDocumentChunkRPC{}
if err = http.Get(ctx, "rag_binary-vector/document/chunk/listDocumentChunk", headers, &d,
"datasetId", gconv.String(doc.DatasetId),
"status", 1); err != nil {
return
}
dictData = d.List
return
}
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
Limit: 10000,
}
// 执行搜索
var hits []map[string]interface{}
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
if err != nil {
return
}
// 转换查询结果
dictData = make([]*dto.DocumentChunkRPC, 0)
for _, hit := range hits {
item := &dto.DocumentChunkRPC{}
if err = gconv.Struct(hit, item); err != nil {
return
}
dictData = append(dictData, item)
}
return
}
// checkRepeat 检查是否重复
func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) {
var val *gvar.Var
if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
}); err != nil {
return
}
success = val.Bool()
return
}