85 lines
2.4 KiB
Go
85 lines
2.4 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"rag/common/eino"
|
|
"rag/common/task"
|
|
"rag/dao"
|
|
"rag/model/dto"
|
|
"rag/model/entity"
|
|
|
|
"gitea.com/red-future/common/beans"
|
|
"github.com/cloudwego/eino/components/indexer"
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
)
|
|
|
|
var DocumentChunk = new(documentChunkService)
|
|
|
|
type documentChunkService struct{}
|
|
|
|
// Update 更新文件块
|
|
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
|
_, err = dao.DocumentChunk.Update(ctx, req)
|
|
return
|
|
}
|
|
|
|
// List 获取文件块列表
|
|
func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentChunkReq) (res *dto.ListDocumentChunkRes, err error) {
|
|
list, total, err := dao.DocumentChunk.List(ctx, req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
res = &dto.ListDocumentChunkRes{
|
|
Total: total,
|
|
}
|
|
err = gconv.Struct(list, &res.List)
|
|
return
|
|
}
|
|
|
|
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
|
var docs = make([]*schema.Document, 0)
|
|
msgMap := gconv.Map(msg)
|
|
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
|
|
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
|
return
|
|
}
|
|
if len(docs) == 0 {
|
|
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
|
return
|
|
}
|
|
ctx = context.WithValue(ctx, "user", &beans.User{
|
|
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]),
|
|
UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]),
|
|
})
|
|
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
|
BatchSize: 10,
|
|
})
|
|
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
|
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
|
if err != nil || rows == 0 {
|
|
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
|
// 写入任务进度失败 任务类型为sql存储
|
|
remark := " 向量存储数量: " + gconv.String(rows)
|
|
if err != nil {
|
|
remark = "向量存储失败: " + err.Error()
|
|
}
|
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
|
TaskId: documentId,
|
|
TaskType: task.TaskTypeGenerateVector,
|
|
Status: task.TaskStatusFailed,
|
|
Remark: remark,
|
|
})
|
|
return
|
|
}
|
|
// 写入任务进度成功 任务类型为sql存储
|
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
|
TaskId: documentId,
|
|
TaskType: task.TaskTypeGenerateVector,
|
|
Status: task.TaskStatusCompleted,
|
|
Remark: "向量生成完成",
|
|
})
|
|
return
|
|
}
|