package eino import ( "context" "database/sql" "errors" "fmt" "rag/consts/model" "rag/dao" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/beans" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/util/gconv" "github.com/pgvector/pgvector-go" ) type PGVectorIndexerOptions struct { BatchSize int // 每批处理多少条 } type PGVectorIndexer struct { opts *PGVectorIndexerOptions } func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer { // 默认值 if opts.BatchSize <= 0 { opts.BatchSize = 5 } return &PGVectorIndexer{opts: opts} } func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) { embedderByType, err := GetTenantEmbedderByType(ctx, configType) if err != nil { return } indexer.WithEmbedding(embedderByType) commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...) if commonOpts.Embedding == nil { return 0, errors.New("embedding model not set") } // 回调 ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) rows, err = i.bulkStore(ctx, docs, commonOpts) if err != nil { callbacks.OnError(ctx, err) return } callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)}) return } func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) { var batchDocs []*schema.Document // 官方ES同款逻辑:满 BatchSize 就处理一批 for _, doc := range docs { batchDocs = append(batchDocs, doc) // 满了 → 处理 if len(batchDocs) >= i.opts.BatchSize { var r int64 r, err = i.doStore(ctx, batchDocs, opts) if err != nil { return } rows = rows + r batchDocs = nil } } // 最后一批 if len(batchDocs) > 0 { var r int64 r, err = i.doStore(ctx, batchDocs, opts) if err != nil { return } rows = rows + r } return } func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) { texts := make([]string, len(docs)) for i, d := range docs { texts[i] = d.Content } // 向量化(官方ES也没有重试!) vectors, err := opts.Embedding.EmbedStrings(ctx, texts) if err != nil { return } // 转成业务实体 var chunks []*dto.VectorDocumentVectorMsg for idx, doc := range docs { ck := new(dto.VectorDocumentVectorMsg) err = gconv.Struct(doc.MetaData, ck) if err != nil { glog.Errorf(ctx, "doStore err: %v", err) continue } ck.Content = doc.Content ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx])) ck.VectorStatus = gconv.PtrInt8(1) ck.Status = gconv.PtrInt8(1) chunks = append(chunks, ck) } if len(chunks) == 0 { return } ctx = context.WithValue(ctx, "user", &beans.User{ TenantId: chunks[0].TenantId, UserName: chunks[0].Creator, }) // 创建索引 if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil { return } // 入库 rows, err = dao.DocumentVector.BatchInsert(ctx, chunks) return } func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error { exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId) if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } if exist != nil { _ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount) return nil } indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) idx := &entity.DatasetIndex{ DatasetId: datasetId, Name: indexName, Dimension: dimension, FieldType: "float", MetricType: "COSINE", Status: gconv.PtrInt8(1), VectorCount: vectorCount, Description: fmt.Sprintf("数据集%d向量索引", datasetId), } _, err = dao.DatasetIndex.Insert(ctx, idx) if err != nil { return err } return i.createRealPGVectorIndex(ctx, indexName) } func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error { if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil { glog.Errorf(ctx, "create vector index failed: %v", err) return err } glog.Infof(ctx, "created pgvector index: %s", indexName) return nil } func (i *PGVectorIndexer) GetType() string { return "pgvector_indexer" } func (i *PGVectorIndexer) IsCallbacksEnabled() bool { return true }