117 lines
3.5 KiB
Go
117 lines
3.5 KiB
Go
package dao
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"rag/consts/public"
|
||
"rag/model/dto"
|
||
"rag/model/entity"
|
||
|
||
"gitea.com/red-future/common/db/gfdb"
|
||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/text/gstr"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
"github.com/pgvector/pgvector-go"
|
||
)
|
||
|
||
var DocumentChunk = new(documentChunkDao)
|
||
|
||
type documentChunkDao struct{}
|
||
|
||
// BatchInsert 批量插入文件块
|
||
func (d *documentChunkDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentChunkMsg) (rows int64, err error) {
|
||
var res []*entity.DocumentChunk
|
||
if err = gconv.Structs(req, &res); err != nil {
|
||
return
|
||
}
|
||
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Data(&res).Insert()
|
||
if err != nil {
|
||
return
|
||
}
|
||
return r.RowsAffected()
|
||
}
|
||
|
||
// Update 更新文件块
|
||
func (d *documentChunkDao) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (rows int64, err error) {
|
||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk)
|
||
r, err := model.Data(&req).Where(entity.DocumentChunkCol.Id, req.Id).Update()
|
||
if err != nil {
|
||
return
|
||
}
|
||
return r.RowsAffected()
|
||
}
|
||
|
||
// List 文件块列表
|
||
func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkReq, fields ...string) (res []*entity.DocumentChunk, total int, err error) {
|
||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Fields(fields).OmitEmpty().
|
||
Where(entity.DocumentChunkCol.DatasetId, req.DatasetId).
|
||
Where(entity.DocumentChunkCol.DocumentId, req.DocumentId).
|
||
Where(entity.DocumentChunkCol.Status, req.Status).
|
||
Where(entity.DocumentChunkCol.VectorStatus, req.VectorStatus).
|
||
OrderDesc(entity.DocumentChunkCol.CreatedAt)
|
||
if req.Page != nil {
|
||
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
|
||
}
|
||
r, total, err := model.AllAndCount(false)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = r.Structs(&res)
|
||
return
|
||
}
|
||
|
||
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
||
sql := `
|
||
SELECT id, content, dataset_id, document_id,
|
||
vector <=> ? AS distance
|
||
FROM rag_vector_document_chunk
|
||
WHERE dataset_id IN (?)
|
||
AND vector IS NOT NULL
|
||
ORDER BY distance ASC
|
||
LIMIT ?
|
||
`
|
||
// 顺序:vector, dataset_id, topK
|
||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryVec, datasetId, topK)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return result.List(), nil
|
||
}
|
||
|
||
// SearchByKeywords 通过关键词全文检索文档块
|
||
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||
// 构建 meilisearch 查询参数
|
||
searchParams := &meilisearch.SearchParams{
|
||
Query: query,
|
||
Limit: int64(topK),
|
||
ShowRankingScore: true,
|
||
}
|
||
|
||
// 构建 datasetIds 过滤条件
|
||
if len(datasetIds) > 0 {
|
||
datasetIdStrs := gconv.Strings(datasetIds)
|
||
quotedIds := make([]string, len(datasetIdStrs))
|
||
for i, id := range datasetIdStrs {
|
||
quotedIds[i] = fmt.Sprintf("%s", id)
|
||
}
|
||
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
|
||
}
|
||
|
||
// 执行搜索
|
||
var hits []map[string]interface{}
|
||
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 转换查询结果为 gdb.List
|
||
resultList := make(gdb.List, 0, len(hits))
|
||
for _, hit := range hits {
|
||
resultList = append(resultList, hit)
|
||
}
|
||
|
||
return resultList, nil
|
||
}
|