Files
rag/dao/document_vector.go

112 lines
3.5 KiB
Go

package dao
import (
"context"
"fmt"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
var DocumentVector = new(documentVectorDao)
type documentVectorDao struct{}
// BatchInsert 批量插入文件块
func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentVectorMsg) (rows int64, err error) {
var res []*entity.DocumentVector
if err = gconv.Structs(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Data(&res).Insert()
if err != nil {
return
}
return r.RowsAffected()
}
// Update 更新文件块
func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector)
r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update()
if err != nil {
return
}
return r.RowsAffected()
}
// List 文件块列表
func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVectorReq, fields ...string) (res []*entity.DocumentVector, total int, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Fields(fields).OmitEmpty().
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
Where(entity.DocumentVectorCol.Status, req.Status).
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
OrderDesc(entity.DocumentVectorCol.CreatedAt)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).
Fields("id, content, dataset_id, document_id, vector <=> ? AS distance", vector).
WhereIn(entity.DocumentVectorCol.DatasetId, datasetIds).
WhereNotNull(entity.DocumentVectorCol.Vector).
OrderAsc("distance").
Limit(topK).
All()
if err != nil {
return nil, err
}
return result.List(), nil
}
// SearchByKeywords 通过关键词全文检索文档块
func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Query: query,
Limit: int64(topK),
ShowRankingScore: true,
}
// 构建 datasetIds 过滤条件
if len(datasetIds) > 0 {
datasetIdStrs := gconv.Strings(datasetIds)
quotedIds := make([]string, len(datasetIdStrs))
for i, id := range datasetIdStrs {
quotedIds[i] = fmt.Sprintf("%s", id)
}
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
}
// 执行搜索
var hits []map[string]interface{}
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
if err != nil {
return nil, err
}
// 转换查询结果为 gdb.List
resultList := make(gdb.List, 0, len(hits))
for _, hit := range hits {
resultList = append(resultList, hit)
}
return resultList, nil
}