diff --git a/config.yml b/config.yml index 9c1aef1..808c664 100644 --- a/config.yml +++ b/config.yml @@ -1,7 +1,6 @@ server: address: :3006 name: rag - workerId: 1 # Database. database: diff --git a/consts/keyword/type.go b/consts/keyword/type.go new file mode 100644 index 0000000..5844f5c --- /dev/null +++ b/consts/keyword/type.go @@ -0,0 +1,26 @@ +package keyword + +import "github.com/gogf/gf/v2/util/gconv" + +var ( + KeywordTypeDefined = newKeywordType(gconv.PtrInt8(1), "自定义") + KeywordTypeInitial = newKeywordType(gconv.PtrInt8(2), "初始化") +) + +type KeywordType *int8 + +type keywordType struct { + code KeywordType + desc string +} + +func (s keywordType) Code() KeywordType { + return s.code +} +func (s keywordType) Desc() string { + return s.desc +} + +func newKeywordType(code KeywordType, desc string) keywordType { + return keywordType{code: code, desc: desc} +} diff --git a/controller/document.go b/controller/document.go index 9da0aea..49b49d6 100644 --- a/controller/document.go +++ b/controller/document.go @@ -32,7 +32,7 @@ func (c *document) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (res } // Get 获取文件详情 -func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) { +func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) { res, err = service.Document.Get(ctx, req) return } diff --git a/dao/document_vector.go b/dao/document_vector.go index e51370e..96ecd5b 100644 --- a/dao/document_vector.go +++ b/dao/document_vector.go @@ -10,6 +10,7 @@ import ( "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" @@ -34,7 +35,7 @@ func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDo // Update 更新文件块 func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) { - model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector) + model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty() r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update() if err != nil { return @@ -48,8 +49,11 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto Where(entity.DocumentVectorCol.DatasetId, req.DatasetId). Where(entity.DocumentVectorCol.DocumentId, req.DocumentId). Where(entity.DocumentVectorCol.Status, req.Status). - Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus). - OrderDesc(entity.DocumentVectorCol.CreatedAt) + Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus) + if !g.IsEmpty(req.Keyword) { + model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%") + } + model.OrderDesc(entity.DocumentVectorCol.CreatedAt) if req.Page != nil { model.Page(int(req.Page.PageNum), int(req.Page.PageSize)) } @@ -62,13 +66,27 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto } func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) { - result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector). - Fields("id, content, dataset_id, document_id, vector <=> ? AS distance", vector). - WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds). - WhereNotNull(entity.DocumentVectorCol.Vector). - OrderAsc("distance"). - Limit(topK). - All() + //result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector). + // Fields("id, content, dataset_id, document_id, vector <=> ? AS distance"). + // WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds). + // WhereNotNull(entity.DocumentVectorCol.Vector). + // OrderAsc("distance"). + // Limit(topK). + // All() + //if err != nil { + // return nil, err + //} + sql := ` + SELECT id, content, dataset_id, document_id, + vector <=> ? AS distance + FROM rag_vector_document_vector + WHERE dataset_id IN (?) + AND vector IS NOT NULL + ORDER BY distance ASC + LIMIT ? +` + // 顺序:vector, dataset_id, topK + result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, vector, datasetIds, topK) if err != nil { return nil, err } diff --git a/dao/keyword.go b/dao/keyword.go index a0cc8fe..589799f 100644 --- a/dao/keyword.go +++ b/dao/keyword.go @@ -44,7 +44,7 @@ func (d *keywordDao) BatchSaveOrUpdate(ctx context.Context, req []*dto.CreateKey } func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (rows int64, err error) { - model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword) + model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).OmitEmpty() r, err := model.Data(&req).Where(entity.KeywordCol.Id, req.Id).Update() if err != nil { return diff --git a/dao/task.go b/dao/task.go index acf08c1..023ccb2 100644 --- a/dao/task.go +++ b/dao/task.go @@ -29,8 +29,8 @@ func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, // Update 更新任务 func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) { - model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask) - r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update() + model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty() + r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).Update() if err != nil { return } diff --git a/main.go b/main.go index 9d793bf..d9853e4 100644 --- a/main.go +++ b/main.go @@ -33,8 +33,7 @@ func main() { controller.Keyword, }) - err := utils.InitGseTool(ctx) - if err != nil { + if err := utils.InitGseTool(ctx); err != nil { g.Log().Error(ctx, "gse 分词工具初始化失败:", err) } @@ -46,7 +45,7 @@ func main() { Port: redisAddressList[1], }, }) - if err = gmq.GetGmq(public.GmqMsgPluginsName).GmqSubscribe(ctx, &mq.RedisSubMessage{ + if err := gmq.GetGmq(public.GmqMsgPluginsName).GmqSubscribe(ctx, &mq.RedisSubMessage{ SubMessage: types.SubMessage{ Topic: public.KnowledgeDocumentVectorTopic, ConsumerName: public.KnowledgeDocumentVectorConsumer, diff --git a/model/dto/dataset.go b/model/dto/dataset.go index a2aa297..c634e73 100644 --- a/model/dto/dataset.go +++ b/model/dto/dataset.go @@ -49,6 +49,7 @@ type ListDatasetReq struct { g.Meta `path:"/list" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)列表" dc:"分页查询知识库(数据集)列表,支持多条件筛选"` Page *beans.Page `json:"page"` + Ids []int64 `json:"ids" dc:"数据集ID列表"` Keyword string `json:"keyword" dc:"关键词搜索"` } diff --git a/model/dto/document.go b/model/dto/document.go index d01f348..c597a33 100644 --- a/model/dto/document.go +++ b/model/dto/document.go @@ -48,6 +48,11 @@ type GetDocumentReq struct { Id int64 `json:"id" v:"required#ID不能为空"` } +type GetDocumentRes struct { + *DocumentVO + ImgAddressPrefix string `json:"imgAddressPrefix"` +} + // ListDocumentReq 文件列表请求 type ListDocumentReq struct { g.Meta `path:"/list" method:"get" tags:"文件管理" summary:"获取文件列表" dc:"分页查询文件列表,支持多条件筛选"` @@ -68,6 +73,8 @@ type DocumentVO struct { Id int64 `json:"id,string" dc:"id"` DatasetId int64 `json:"datasetId,string"` Title string `json:"title" dc:"文件标题"` + Format string `orm:"format" json:"format" dc:"文件格式"` + FilePath string `orm:"file_path" json:"filePath" dc:"文件存储路径"` Status document.Status `json:"status" dc:"状态1启用/0停用"` VectorStatus document.VectorStatus `json:"vectorStatus" dc:"向量化状态 状态: 1 待定, 2 处理, 3 完成, 4 失败"` ChunkCount int64 `json:"chunkCount" dc:"分块数"` diff --git a/model/dto/document_vector.go b/model/dto/document_vector.go index 27724ee..31c2acb 100644 --- a/model/dto/document_vector.go +++ b/model/dto/document_vector.go @@ -42,6 +42,7 @@ type ListDocumentVectorReq struct { g.Meta `path:"/list" method:"get" tags:"文件块向量管理" summary:"获取文件块向量列表" dc:"分页查询文件块向量列表,支持多条件筛选"` Page *beans.Page `json:"page"` + Keyword string `json:"keyword" dc:"关键词搜索"` DatasetId int64 `json:"datasetId"` DocumentId int64 `json:"documentId"` Status document.Status `json:"status"` diff --git a/model/dto/keyword.go b/model/dto/keyword.go index 5fd2a72..cd2ecce 100644 --- a/model/dto/keyword.go +++ b/model/dto/keyword.go @@ -1,6 +1,8 @@ package dto import ( + "rag/consts/keyword" + "gitea.com/red-future/common/beans" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gtime" @@ -10,10 +12,11 @@ import ( type CreateKeywordReq struct { g.Meta `path:"/create" method:"post" tags:"关键词管理" summary:"创建关键词" dc:"创建关键词"` - DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"` - DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"` - Word string `json:"word" v:"required#名称不能为空"` - Weight int16 `json:"weight" v:"required#权重不能为空"` + DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"` + DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"` + Word string `json:"word" v:"required#名称不能为空"` + Weight int16 `json:"weight" v:"required#权重不能为空"` + KeywordType keyword.KeywordType `json:"keywordType" v:"required#类型不能为空"` } // CreateKeywordRes 创建关键词响应 @@ -48,12 +51,13 @@ type GetKeywordReq struct { type ListKeywordReq struct { g.Meta `path:"/list" method:"get" tags:"关键词管理" summary:"获取关键词列表" dc:"分页查询关键词列表,支持多条件筛选"` - Page *beans.Page `json:"page"` - DatasetId int64 `json:"datasetId"` - DocumentId int64 `json:"documentId"` - Word string `json:"word"` - Words []string `json:"words"` - Keyword string `json:"keyword" dc:"关键词搜索"` + Page *beans.Page `json:"page"` + DatasetId int64 `json:"datasetId"` + DocumentId int64 `json:"documentId"` + Word string `json:"word"` + Words []string `json:"words"` + Keyword string `json:"keyword" dc:"关键词搜索"` + KeywordType keyword.KeywordType `json:"keywordType"` } // ListKeywordRes 关键词列表响应 @@ -63,11 +67,12 @@ type ListKeywordRes struct { } type KeywordVO struct { - Id int64 `json:"id,string" dc:"id"` - Word string `json:"word" dc:"关键词名称"` - Weight int16 `json:"weight" dc:"权重"` - DatasetId int64 `json:"datasetId,string" dc:"数据集ID"` - DocumentId int64 `json:"documentId,string" dc:"文档ID"` - CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"` - UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"` + Id int64 `json:"id,string" dc:"id"` + Word string `json:"word" dc:"关键词名称"` + Weight int16 `json:"weight" dc:"权重"` + KeywordType keyword.KeywordType `json:"keywordType" dc:"类型"` + DatasetId int64 `json:"datasetId,string" dc:"数据集ID"` + DocumentId int64 `json:"documentId,string" dc:"文档ID"` + CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"` + UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"` } diff --git a/model/entity/keyword.go b/model/entity/keyword.go index 9c7c4f7..e949ce2 100644 --- a/model/entity/keyword.go +++ b/model/entity/keyword.go @@ -1,27 +1,34 @@ package entity -import "gitea.com/red-future/common/beans" +import ( + "rag/consts/keyword" + + "gitea.com/red-future/common/beans" +) type keywordCol struct { beans.SQLBaseCol - DatasetId string - DocumentId string - Word string - Weight string + DatasetId string + DocumentId string + Word string + Weight string + KeywordType string } var KeywordCol = keywordCol{ - SQLBaseCol: beans.DefSQLBaseCol, - DatasetId: "dataset_id", - DocumentId: "document_id", - Word: "word", - Weight: "weight", + SQLBaseCol: beans.DefSQLBaseCol, + DatasetId: "dataset_id", + DocumentId: "document_id", + Word: "word", + Weight: "weight", + KeywordType: "keyword_type", } type Keyword struct { beans.SQLBaseDO `orm:",inline"` - DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"` - DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"` - Word string `orm:"word" json:"word" dc:"关键词"` - Weight int16 `orm:"weight" json:"weight" dc:"权重"` + DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"` + DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"` + Word string `orm:"word" json:"word" dc:"关键词"` + Weight int16 `orm:"weight" json:"weight" dc:"权重"` + KeywordType keyword.KeywordType `orm:"keyword_type" json:"keywordType" dc:"类型"` } diff --git a/service/dataset.go b/service/dataset.go index fe7437b..67c8a9d 100644 --- a/service/dataset.go +++ b/service/dataset.go @@ -45,43 +45,3 @@ func (s *datasetService) List(ctx context.Context, req *dto.ListDatasetReq) (res err = gconv.Struct(list, &res.List) return } - -//// Search 搜索(示例,实际需要调用向量库) -//func (s *datasetService) Search(ctx context.Context, req *dto.SearchReq) (res *dto.SearchRes, err error) { -// // 1. 获取数据集信息 -// kb, err := dao.Dataset.GetByID(ctx, req) -// if err != nil { -// return nil, err -// } -// -// // 2. 获取文件块 -// chunks, err := dao.Chunk.FindChunksByKBIDWithLimit(ctx, req.KBID, 0, req.TopK) -// if err != nil { -// return nil, err -// } -// -// // 3. TODO: 使用向量检索(需要集成向量库) -// // 暂时使用简单的关键词匹配 -// results := make([]dto.SearchResult, 0) -// for _, chunk := range chunks { -// results = append(results, dto.SearchResult{ -// Content: chunk.Content, -// Score: 0.8, // TODO: 计算实际向量相似度 -// DocumentID: chunk.DocumentID, -// ChunkIndex: chunk.Index, -// }) -// } -// -// g.Log().Infof(ctx, "数据集[%s]搜索完成,查询:%s,结果数:%d", kb.Name, req.Query, len(results)) -// -// return &dto.SearchRes{Results: results}, nil -//} -// -//// formatChunks 格式化文件块为上下文 -//func (s *datasetService) formatChunks(chunks []*entity.DocumentChunk) string { -// var sb strings.Builder -// for i, chunk := range chunks { -// sb.WriteString(fmt.Sprintf("[%d] %s\n\n", i+1, chunk.Content)) -// } -// return sb.String() -//} diff --git a/service/document.go b/service/document.go index 6caac1f..1328eee 100644 --- a/service/document.go +++ b/service/document.go @@ -6,6 +6,7 @@ import ( "fmt" "rag/common/eino" "rag/consts/document" + "rag/consts/keyword" "rag/consts/public" "rag/consts/task" "rag/dao" @@ -104,9 +105,17 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq } // Get 获取文件详情 -func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) { +func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) { r, err := dao.Document.GetByID(ctx, req) - err = gconv.Struct(r, &res) + if err != nil { + return + } + res = &dto.GetDocumentRes{} + err = gconv.Struct(r, &res.DocumentVO) + if err != nil { + return + } + res.ImgAddressPrefix, err = utils.GetFileAddressPrefix(ctx) return } @@ -280,10 +289,11 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum var keywordReqs = make([]*dto.CreateKeywordReq, 0) for _, word := range words { keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{ - DatasetId: doc.DatasetId, - DocumentId: doc.Id, - Word: word.Word, - Weight: gconv.Int16(word.Score), + DatasetId: doc.DatasetId, + DocumentId: doc.Id, + Word: word.Word, + Weight: gconv.Int16(word.Score), + KeywordType: keyword.KeywordTypeInitial.Code(), }) } if len(keywordReqs) > 0 {