186 lines
4.5 KiB
Go
186 lines
4.5 KiB
Go
package eino
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"rag/consts/model"
|
||
"rag/dao"
|
||
"rag/model/dto"
|
||
"rag/model/entity"
|
||
|
||
"gitea.com/red-future/common/beans"
|
||
"github.com/cloudwego/eino/callbacks"
|
||
"github.com/cloudwego/eino/components/indexer"
|
||
"github.com/cloudwego/eino/schema"
|
||
"github.com/gogf/gf/v2/os/glog"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
"github.com/pgvector/pgvector-go"
|
||
)
|
||
|
||
type PGVectorIndexerOptions struct {
|
||
BatchSize int // 每批处理多少条
|
||
}
|
||
|
||
type PGVectorIndexer struct {
|
||
opts *PGVectorIndexerOptions
|
||
}
|
||
|
||
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
|
||
// 默认值
|
||
if opts.BatchSize <= 0 {
|
||
opts.BatchSize = 5
|
||
}
|
||
return &PGVectorIndexer{opts: opts}
|
||
}
|
||
|
||
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) {
|
||
|
||
embedderByType, err := GetTenantEmbedderByType(ctx, configType)
|
||
if err != nil {
|
||
return
|
||
}
|
||
indexer.WithEmbedding(embedderByType)
|
||
|
||
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
|
||
|
||
if commonOpts.Embedding == nil {
|
||
return 0, errors.New("embedding model not set")
|
||
}
|
||
|
||
// 回调
|
||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||
|
||
rows, err = i.bulkStore(ctx, docs, commonOpts)
|
||
if err != nil {
|
||
callbacks.OnError(ctx, err)
|
||
return
|
||
}
|
||
|
||
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
|
||
return
|
||
}
|
||
|
||
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||
var batchDocs []*schema.Document
|
||
|
||
// 官方ES同款逻辑:满 BatchSize 就处理一批
|
||
for _, doc := range docs {
|
||
batchDocs = append(batchDocs, doc)
|
||
|
||
// 满了 → 处理
|
||
if len(batchDocs) >= i.opts.BatchSize {
|
||
var r int64
|
||
r, err = i.doStore(ctx, batchDocs, opts)
|
||
if err != nil {
|
||
return
|
||
}
|
||
rows = rows + r
|
||
batchDocs = nil
|
||
}
|
||
}
|
||
|
||
// 最后一批
|
||
if len(batchDocs) > 0 {
|
||
var r int64
|
||
r, err = i.doStore(ctx, batchDocs, opts)
|
||
if err != nil {
|
||
return
|
||
}
|
||
rows = rows + r
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||
|
||
texts := make([]string, len(docs))
|
||
for i, d := range docs {
|
||
texts[i] = d.Content
|
||
}
|
||
|
||
// 向量化(官方ES也没有重试!)
|
||
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 转成业务实体
|
||
var chunks []*dto.VectorDocumentVectorMsg
|
||
for idx, doc := range docs {
|
||
ck := new(dto.VectorDocumentVectorMsg)
|
||
err = gconv.Struct(doc.MetaData, ck)
|
||
if err != nil {
|
||
glog.Errorf(ctx, "doStore err: %v", err)
|
||
continue
|
||
}
|
||
ck.Content = doc.Content
|
||
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
|
||
ck.VectorStatus = gconv.PtrInt8(1)
|
||
ck.Status = gconv.PtrInt8(1)
|
||
chunks = append(chunks, ck)
|
||
}
|
||
if len(chunks) == 0 {
|
||
return
|
||
}
|
||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||
TenantId: chunks[0].TenantId,
|
||
UserName: chunks[0].Creator,
|
||
})
|
||
// 创建索引
|
||
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
|
||
return
|
||
}
|
||
// 入库
|
||
rows, err = dao.DocumentVector.BatchInsert(ctx, chunks)
|
||
|
||
return
|
||
}
|
||
|
||
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
|
||
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||
return err
|
||
}
|
||
if exist != nil {
|
||
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
|
||
return nil
|
||
}
|
||
|
||
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
|
||
idx := &entity.DatasetIndex{
|
||
DatasetId: datasetId,
|
||
Name: indexName,
|
||
Dimension: dimension,
|
||
FieldType: "float",
|
||
MetricType: "COSINE",
|
||
Status: gconv.PtrInt8(1),
|
||
VectorCount: vectorCount,
|
||
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||
}
|
||
_, err = dao.DatasetIndex.Insert(ctx, idx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return i.createRealPGVectorIndex(ctx, indexName)
|
||
}
|
||
|
||
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
|
||
glog.Errorf(ctx, "create vector index failed: %v", err)
|
||
return err
|
||
}
|
||
glog.Infof(ctx, "created pgvector index: %s", indexName)
|
||
return nil
|
||
}
|
||
|
||
func (i *PGVectorIndexer) GetType() string {
|
||
return "pgvector_indexer"
|
||
}
|
||
|
||
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
|
||
return true
|
||
}
|