feat: 新增关键词类型及优化查询逻辑
支持关键词类型区分,优化文件向量查询SQL及DAO更新逻辑,移除冗余配置和注释代码。
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
server:
|
||||
address: :3006
|
||||
name: rag
|
||||
workerId: 1
|
||||
|
||||
# Database.
|
||||
database:
|
||||
|
||||
26
consts/keyword/type.go
Normal file
26
consts/keyword/type.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package keyword
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
var (
|
||||
KeywordTypeDefined = newKeywordType(gconv.PtrInt8(1), "自定义")
|
||||
KeywordTypeInitial = newKeywordType(gconv.PtrInt8(2), "初始化")
|
||||
)
|
||||
|
||||
type KeywordType *int8
|
||||
|
||||
type keywordType struct {
|
||||
code KeywordType
|
||||
desc string
|
||||
}
|
||||
|
||||
func (s keywordType) Code() KeywordType {
|
||||
return s.code
|
||||
}
|
||||
func (s keywordType) Desc() string {
|
||||
return s.desc
|
||||
}
|
||||
|
||||
func newKeywordType(code KeywordType, desc string) keywordType {
|
||||
return keywordType{code: code, desc: desc}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (c *document) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (res
|
||||
}
|
||||
|
||||
// Get 获取文件详情
|
||||
func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
|
||||
func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
|
||||
res, err = service.Document.Get(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
@@ -34,7 +35,7 @@ func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDo
|
||||
|
||||
// Update 更新文件块
|
||||
func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector)
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty()
|
||||
r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -48,8 +49,11 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto
|
||||
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
|
||||
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
|
||||
Where(entity.DocumentVectorCol.Status, req.Status).
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
|
||||
OrderDesc(entity.DocumentVectorCol.CreatedAt)
|
||||
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus)
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%")
|
||||
}
|
||||
model.OrderDesc(entity.DocumentVectorCol.CreatedAt)
|
||||
if req.Page != nil {
|
||||
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
|
||||
}
|
||||
@@ -62,13 +66,27 @@ func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVecto
|
||||
}
|
||||
|
||||
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
|
||||
Fields("id, content, dataset_id, document_id, vector <=> ? AS distance", vector).
|
||||
WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
|
||||
WhereNotNull(entity.DocumentVectorCol.Vector).
|
||||
OrderAsc("distance").
|
||||
Limit(topK).
|
||||
All()
|
||||
//result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
|
||||
// Fields("id, content, dataset_id, document_id, vector <=> ? AS distance").
|
||||
// WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
|
||||
// WhereNotNull(entity.DocumentVectorCol.Vector).
|
||||
// OrderAsc("distance").
|
||||
// Limit(topK).
|
||||
// All()
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <=> ? AS distance
|
||||
FROM rag_vector_document_vector
|
||||
WHERE dataset_id IN (?)
|
||||
AND vector IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
// 顺序:vector, dataset_id, topK
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, vector, datasetIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func (d *keywordDao) BatchSaveOrUpdate(ctx context.Context, req []*dto.CreateKey
|
||||
}
|
||||
|
||||
func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword)
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).OmitEmpty()
|
||||
r, err := model.Data(&req).Where(entity.KeywordCol.Id, req.Id).Update()
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
@@ -29,8 +29,8 @@ func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64,
|
||||
|
||||
// Update 更新任务
|
||||
func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask)
|
||||
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty()
|
||||
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).Update()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
5
main.go
5
main.go
@@ -33,8 +33,7 @@ func main() {
|
||||
controller.Keyword,
|
||||
})
|
||||
|
||||
err := utils.InitGseTool(ctx)
|
||||
if err != nil {
|
||||
if err := utils.InitGseTool(ctx); err != nil {
|
||||
g.Log().Error(ctx, "gse 分词工具初始化失败:", err)
|
||||
}
|
||||
|
||||
@@ -46,7 +45,7 @@ func main() {
|
||||
Port: redisAddressList[1],
|
||||
},
|
||||
})
|
||||
if err = gmq.GetGmq(public.GmqMsgPluginsName).GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
if err := gmq.GetGmq(public.GmqMsgPluginsName).GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
SubMessage: types.SubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorTopic,
|
||||
ConsumerName: public.KnowledgeDocumentVectorConsumer,
|
||||
|
||||
@@ -49,6 +49,7 @@ type ListDatasetReq struct {
|
||||
g.Meta `path:"/list" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)列表" dc:"分页查询知识库(数据集)列表,支持多条件筛选"`
|
||||
|
||||
Page *beans.Page `json:"page"`
|
||||
Ids []int64 `json:"ids" dc:"数据集ID列表"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,11 @@ type GetDocumentReq struct {
|
||||
Id int64 `json:"id" v:"required#ID不能为空"`
|
||||
}
|
||||
|
||||
type GetDocumentRes struct {
|
||||
*DocumentVO
|
||||
ImgAddressPrefix string `json:"imgAddressPrefix"`
|
||||
}
|
||||
|
||||
// ListDocumentReq 文件列表请求
|
||||
type ListDocumentReq struct {
|
||||
g.Meta `path:"/list" method:"get" tags:"文件管理" summary:"获取文件列表" dc:"分页查询文件列表,支持多条件筛选"`
|
||||
@@ -68,6 +73,8 @@ type DocumentVO struct {
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
DatasetId int64 `json:"datasetId,string"`
|
||||
Title string `json:"title" dc:"文件标题"`
|
||||
Format string `orm:"format" json:"format" dc:"文件格式"`
|
||||
FilePath string `orm:"file_path" json:"filePath" dc:"文件存储路径"`
|
||||
Status document.Status `json:"status" dc:"状态1启用/0停用"`
|
||||
VectorStatus document.VectorStatus `json:"vectorStatus" dc:"向量化状态 状态: 1 待定, 2 处理, 3 完成, 4 失败"`
|
||||
ChunkCount int64 `json:"chunkCount" dc:"分块数"`
|
||||
|
||||
@@ -42,6 +42,7 @@ type ListDocumentVectorReq struct {
|
||||
g.Meta `path:"/list" method:"get" tags:"文件块向量管理" summary:"获取文件块向量列表" dc:"分页查询文件块向量列表,支持多条件筛选"`
|
||||
|
||||
Page *beans.Page `json:"page"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
DatasetId int64 `json:"datasetId"`
|
||||
DocumentId int64 `json:"documentId"`
|
||||
Status document.Status `json:"status"`
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"rag/consts/keyword"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
@@ -10,10 +12,11 @@ import (
|
||||
type CreateKeywordReq struct {
|
||||
g.Meta `path:"/create" method:"post" tags:"关键词管理" summary:"创建关键词" dc:"创建关键词"`
|
||||
|
||||
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
||||
DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"`
|
||||
Word string `json:"word" v:"required#名称不能为空"`
|
||||
Weight int16 `json:"weight" v:"required#权重不能为空"`
|
||||
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
||||
DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"`
|
||||
Word string `json:"word" v:"required#名称不能为空"`
|
||||
Weight int16 `json:"weight" v:"required#权重不能为空"`
|
||||
KeywordType keyword.KeywordType `json:"keywordType" v:"required#类型不能为空"`
|
||||
}
|
||||
|
||||
// CreateKeywordRes 创建关键词响应
|
||||
@@ -48,12 +51,13 @@ type GetKeywordReq struct {
|
||||
type ListKeywordReq struct {
|
||||
g.Meta `path:"/list" method:"get" tags:"关键词管理" summary:"获取关键词列表" dc:"分页查询关键词列表,支持多条件筛选"`
|
||||
|
||||
Page *beans.Page `json:"page"`
|
||||
DatasetId int64 `json:"datasetId"`
|
||||
DocumentId int64 `json:"documentId"`
|
||||
Word string `json:"word"`
|
||||
Words []string `json:"words"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
Page *beans.Page `json:"page"`
|
||||
DatasetId int64 `json:"datasetId"`
|
||||
DocumentId int64 `json:"documentId"`
|
||||
Word string `json:"word"`
|
||||
Words []string `json:"words"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
KeywordType keyword.KeywordType `json:"keywordType"`
|
||||
}
|
||||
|
||||
// ListKeywordRes 关键词列表响应
|
||||
@@ -63,11 +67,12 @@ type ListKeywordRes struct {
|
||||
}
|
||||
|
||||
type KeywordVO struct {
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
|
||||
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
KeywordType keyword.KeywordType `json:"keywordType" dc:"类型"`
|
||||
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
|
||||
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
|
||||
@@ -1,27 +1,34 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
import (
|
||||
"rag/consts/keyword"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type keywordCol struct {
|
||||
beans.SQLBaseCol
|
||||
DatasetId string
|
||||
DocumentId string
|
||||
Word string
|
||||
Weight string
|
||||
DatasetId string
|
||||
DocumentId string
|
||||
Word string
|
||||
Weight string
|
||||
KeywordType string
|
||||
}
|
||||
|
||||
var KeywordCol = keywordCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
DatasetId: "dataset_id",
|
||||
DocumentId: "document_id",
|
||||
Word: "word",
|
||||
Weight: "weight",
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
DatasetId: "dataset_id",
|
||||
DocumentId: "document_id",
|
||||
Word: "word",
|
||||
Weight: "weight",
|
||||
KeywordType: "keyword_type",
|
||||
}
|
||||
|
||||
type Keyword struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"`
|
||||
DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"`
|
||||
Word string `orm:"word" json:"word" dc:"关键词"`
|
||||
Weight int16 `orm:"weight" json:"weight" dc:"权重"`
|
||||
DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"`
|
||||
DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"`
|
||||
Word string `orm:"word" json:"word" dc:"关键词"`
|
||||
Weight int16 `orm:"weight" json:"weight" dc:"权重"`
|
||||
KeywordType keyword.KeywordType `orm:"keyword_type" json:"keywordType" dc:"类型"`
|
||||
}
|
||||
|
||||
@@ -45,43 +45,3 @@ func (s *datasetService) List(ctx context.Context, req *dto.ListDatasetReq) (res
|
||||
err = gconv.Struct(list, &res.List)
|
||||
return
|
||||
}
|
||||
|
||||
//// Search 搜索(示例,实际需要调用向量库)
|
||||
//func (s *datasetService) Search(ctx context.Context, req *dto.SearchReq) (res *dto.SearchRes, err error) {
|
||||
// // 1. 获取数据集信息
|
||||
// kb, err := dao.Dataset.GetByID(ctx, req)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // 2. 获取文件块
|
||||
// chunks, err := dao.Chunk.FindChunksByKBIDWithLimit(ctx, req.KBID, 0, req.TopK)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
//
|
||||
// // 3. TODO: 使用向量检索(需要集成向量库)
|
||||
// // 暂时使用简单的关键词匹配
|
||||
// results := make([]dto.SearchResult, 0)
|
||||
// for _, chunk := range chunks {
|
||||
// results = append(results, dto.SearchResult{
|
||||
// Content: chunk.Content,
|
||||
// Score: 0.8, // TODO: 计算实际向量相似度
|
||||
// DocumentID: chunk.DocumentID,
|
||||
// ChunkIndex: chunk.Index,
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// g.Log().Infof(ctx, "数据集[%s]搜索完成,查询:%s,结果数:%d", kb.Name, req.Query, len(results))
|
||||
//
|
||||
// return &dto.SearchRes{Results: results}, nil
|
||||
//}
|
||||
//
|
||||
//// formatChunks 格式化文件块为上下文
|
||||
//func (s *datasetService) formatChunks(chunks []*entity.DocumentChunk) string {
|
||||
// var sb strings.Builder
|
||||
// for i, chunk := range chunks {
|
||||
// sb.WriteString(fmt.Sprintf("[%d] %s\n\n", i+1, chunk.Content))
|
||||
// }
|
||||
// return sb.String()
|
||||
//}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/consts/document"
|
||||
"rag/consts/keyword"
|
||||
"rag/consts/public"
|
||||
"rag/consts/task"
|
||||
"rag/dao"
|
||||
@@ -104,9 +105,17 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
}
|
||||
|
||||
// Get 获取文件详情
|
||||
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
|
||||
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
|
||||
r, err := dao.Document.GetByID(ctx, req)
|
||||
err = gconv.Struct(r, &res)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res = &dto.GetDocumentRes{}
|
||||
err = gconv.Struct(r, &res.DocumentVO)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res.ImgAddressPrefix, err = utils.GetFileAddressPrefix(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -280,10 +289,11 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||
for _, word := range words {
|
||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||
DatasetId: doc.DatasetId,
|
||||
DocumentId: doc.Id,
|
||||
Word: word.Word,
|
||||
Weight: gconv.Int16(word.Score),
|
||||
DatasetId: doc.DatasetId,
|
||||
DocumentId: doc.Id,
|
||||
Word: word.Word,
|
||||
Weight: gconv.Int16(word.Score),
|
||||
KeywordType: keyword.KeywordTypeInitial.Code(),
|
||||
})
|
||||
}
|
||||
if len(keywordReqs) > 0 {
|
||||
|
||||
Reference in New Issue
Block a user