Files
rag/common/eino/indexer.go

186 lines
4.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"database/sql"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/beans"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
type PGVectorIndexerOptions struct {
BatchSize int // 每批处理多少条
}
type PGVectorIndexer struct {
opts *PGVectorIndexerOptions
}
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
// 默认值
if opts.BatchSize <= 0 {
opts.BatchSize = 5
}
return &PGVectorIndexer{opts: opts}
}
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) {
embedderByType, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return
}
indexer.WithEmbedding(embedderByType)
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
if commonOpts.Embedding == nil {
return 0, errors.New("embedding model not set")
}
// 回调
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
rows, err = i.bulkStore(ctx, docs, commonOpts)
if err != nil {
callbacks.OnError(ctx, err)
return
}
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
return
}
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
var batchDocs []*schema.Document
// 官方ES同款逻辑满 BatchSize 就处理一批
for _, doc := range docs {
batchDocs = append(batchDocs, doc)
// 满了 → 处理
if len(batchDocs) >= i.opts.BatchSize {
var r int64
r, err = i.doStore(ctx, batchDocs, opts)
if err != nil {
return
}
rows = rows + r
batchDocs = nil
}
}
// 最后一批
if len(batchDocs) > 0 {
var r int64
r, err = i.doStore(ctx, batchDocs, opts)
if err != nil {
return
}
rows = rows + r
}
return
}
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
texts := make([]string, len(docs))
for i, d := range docs {
texts[i] = d.Content
}
// 向量化官方ES也没有重试
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
if err != nil {
return
}
// 转成业务实体
var chunks []*dto.VectorDocumentVectorMsg
for idx, doc := range docs {
ck := new(dto.VectorDocumentVectorMsg)
err = gconv.Struct(doc.MetaData, ck)
if err != nil {
glog.Errorf(ctx, "doStore err: %v", err)
continue
}
ck.Content = doc.Content
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
ck.VectorStatus = gconv.PtrInt8(1)
ck.Status = gconv.PtrInt8(1)
chunks = append(chunks, ck)
}
if len(chunks) == 0 {
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: chunks[0].TenantId,
UserName: chunks[0].Creator,
})
// 创建索引
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
return
}
// 入库
rows, err = dao.DocumentVector.BatchInsert(ctx, chunks)
return
}
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
if exist != nil {
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
return nil
}
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
idx := &entity.DatasetIndex{
DatasetId: datasetId,
Name: indexName,
Dimension: dimension,
FieldType: "float",
MetricType: "COSINE",
Status: gconv.PtrInt8(1),
VectorCount: vectorCount,
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
}
_, err = dao.DatasetIndex.Insert(ctx, idx)
if err != nil {
return err
}
return i.createRealPGVectorIndex(ctx, indexName)
}
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
glog.Errorf(ctx, "create vector index failed: %v", err)
return err
}
glog.Infof(ctx, "created pgvector index: %s", indexName)
return nil
}
func (i *PGVectorIndexer) GetType() string {
return "pgvector_indexer"
}
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
return true
}