From b00d544fb7a81fdf8da380e840320d61dc571994 Mon Sep 17 00:00:00 2001 From: qhd <1766646056@qq.com> Date: Fri, 3 Apr 2026 11:14:44 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20rag=E5=88=9D=E5=A7=8B=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/eino/a.go | 177 ++++++++++++++++++++ common/eino/b.go | 107 ++++++++++++ common/eino/base_task.go | 49 ++++++ common/eino/c.go | 94 +++++++++++ common/eino/consts.go | 8 + common/eino/document_loader.go | 51 ++++++ common/eino/document_semantic.go | 64 ++++++++ common/eino/embedding.go | 69 ++++++++ common/eino/embedding_batch.go | 47 ++++++ common/eino/embedding_qwen.go | 273 +++++++++++++++++++++++++++++++ common/eino/priority_enum.go | 11 ++ common/eino/status_enum.go | 12 ++ common/eino/task_type.go | 14 ++ common/gse/utils.go | 114 +++++++++++++ controller/dataset_index.go | 5 - go.mod | 32 ++-- go.sum | 2 + model/dto/dataset_index.go | 1 - service/dataset_index.go | 5 - service/document.go | 43 +++-- service/document_chunk.go | 191 ++++++++++----------- 21 files changed, 1228 insertions(+), 141 deletions(-) create mode 100644 common/eino/a.go create mode 100644 common/eino/b.go create mode 100644 common/eino/base_task.go create mode 100644 common/eino/c.go create mode 100644 common/eino/consts.go create mode 100644 common/eino/document_loader.go create mode 100644 common/eino/document_semantic.go create mode 100644 common/eino/embedding.go create mode 100644 common/eino/embedding_batch.go create mode 100644 common/eino/embedding_qwen.go create mode 100644 common/eino/priority_enum.go create mode 100644 common/eino/status_enum.go create mode 100644 common/eino/task_type.go create mode 100644 common/gse/utils.go delete mode 100644 controller/dataset_index.go delete mode 100644 model/dto/dataset_index.go delete mode 100644 service/dataset_index.go diff --git a/common/eino/a.go b/common/eino/a.go new file mode 100644 index 0000000..e9bd1ea --- /dev/null +++ b/common/eino/a.go @@ -0,0 +1,177 @@ +package eino + +import ( + "context" + "database/sql" + "errors" + "fmt" + "rag/dao" + "rag/model/dto" + "rag/model/entity" + + "gitea.com/red-future/common/beans" + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/util/gconv" + "github.com/pgvector/pgvector-go" +) + +type PGVectorIndexerOptions struct { + BatchSize int // 每批处理多少条 +} + +type PGVectorIndexer struct { + opts *PGVectorIndexerOptions +} + +func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer { + // 默认值 + if opts.BatchSize <= 0 { + opts.BatchSize = 5 + } + return &PGVectorIndexer{opts: opts} +} + +func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) { + commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...) + + if commonOpts.Embedding == nil { + return 0, errors.New("embedding model not set") + } + + // 回调 + ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs}) + + ids, err := i.bulkStore(ctx, docs, commonOpts) + if err != nil { + callbacks.OnError(ctx, err) + return + } + + callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(ids)}) + return +} + +func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) { + var batchDocs []*schema.Document + + // 官方ES同款逻辑:满 BatchSize 就处理一批 + for _, doc := range docs { + batchDocs = append(batchDocs, doc) + + // 满了 → 处理 + if len(batchDocs) >= i.opts.BatchSize { + var r int64 + r, err = i.doStore(ctx, batchDocs, opts) + if err != nil { + return + } + rows = rows + r + batchDocs = nil + } + } + + // 最后一批 + if len(batchDocs) > 0 { + var r int64 + r, err = i.doStore(ctx, batchDocs, opts) + if err != nil { + return + } + rows = rows + r + } + + return +} + +func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) { + + texts := make([]string, len(docs)) + for i, d := range docs { + texts[i] = d.Content + } + + // 向量化(官方ES也没有重试!) + vectors, err := opts.Embedding.EmbedStrings(ctx, texts) + if err != nil { + return + } + + // 转成业务实体 + var chunks []*dto.VectorDocumentChunkMsg + for idx, doc := range docs { + ck := new(dto.VectorDocumentChunkMsg) + err = gconv.Struct(doc.MetaData, ck) + if err != nil { + glog.Errorf(ctx, "doStore err: %v", err) + continue + } + ck.Content = doc.Content + ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx])) + ck.VectorStatus = gconv.PtrInt8(1) + ck.Status = gconv.PtrInt8(1) + chunks = append(chunks, ck) + } + if len(chunks) == 0 { + return + } + ctx = context.WithValue(ctx, "user", &beans.User{ + TenantId: chunks[0].TenantId, + UserName: chunks[0].Creator, + }) + // 创建索引 + if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil { + return + } + // 入库 + rows, err = dao.DocumentChunk.BatchInsert(ctx, chunks) + + return +} + +func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error { + exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + if exist != nil { + _ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount) + return nil + } + + indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) + idx := &entity.DatasetIndex{ + DatasetId: datasetId, + Name: indexName, + Dimension: dimension, + FieldType: "float", + MetricType: "COSINE", + Status: gconv.PtrInt8(1), + VectorCount: vectorCount, + Description: fmt.Sprintf("数据集%d向量索引", datasetId), + } + _, err = dao.DatasetIndex.Insert(ctx, idx) + if err != nil { + return err + } + return i.createRealPGVectorIndex(ctx, indexName) +} + +func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error { + if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil { + glog.Errorf(ctx, "create vector index failed: %v", err) + return err + } + glog.Infof(ctx, "created pgvector index: %s", indexName) + return nil +} + +func (i *PGVectorIndexer) GetType() string { + return "pgvector_indexer" +} + +func (i *PGVectorIndexer) IsCallbacksEnabled() bool { + return true +} diff --git a/common/eino/b.go b/common/eino/b.go new file mode 100644 index 0000000..a1f17dc --- /dev/null +++ b/common/eino/b.go @@ -0,0 +1,107 @@ +package eino + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8" + + "github.com/cloudwego/eino-ext/components/indexer/es8" +) + +const ( + indexName = "eino_example" + fieldContent = "content" + fieldContentVector = "content_vector" + fieldExtraLocation = "location" + docExtraLocation = "location" +) + +func TestIndexer() { + ctx := context.Background() + + // 1. 创建 ES 客户端 + client, err := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{"http://localhost:9200"}, + }) + + if err != nil { + fmt.Printf("create client error: %v\n", err) + return + } + + // 2. 定义 Index Spec(选填:如果索引不存在,将自动创建) + indexSpec := &es8.IndexSpec{ + Settings: map[string]any{ + "number_of_shards": 1, + "number_of_replicas": 0, + }, + Mappings: map[string]any{ + "properties": map[string]any{ + fieldContentVector: map[string]any{ + "type": "dense_vector", + "dims": 1024, + "index": true, + "similarity": "l2_norm", + }, + }, + }, + } + + // 4. 准备文档 + // 文档通常包含 ID 和 Content + // 也可以包含额外的 Metadata 用于过滤或其他用途 + docs := []*schema.Document{ + { + ID: "1", + Content: "Eiffel Tower: Located in Paris, France.", + MetaData: map[string]any{ + docExtraLocation: "France", + }, + }, + { + ID: "2", + Content: "The Great Wall: Located in China.", + MetaData: map[string]any{ + docExtraLocation: "China", + }, + }, + } + + // 5. 创建 ES 索引器组件 + indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{ + Client: client, + Index: indexName, + IndexSpec: indexSpec, // 添加此项以启用自动索引创建 + BatchSize: 10, + // DocumentToFields 指定如何将文档字段映射到 ES 字段 + DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) { + return map[string]es8.FieldValue{ + fieldContent: { + Value: doc.Content, + EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段 + }, + fieldExtraLocation: { + // 额外的 metadata 字段 + Value: doc.MetaData[docExtraLocation], + }, + }, nil + }, + // 提供 embedding 组件用于向量化 + Embedding: EmbedderDashscope, + }) + + if err != nil { + fmt.Printf("create indexer error: %v\n", err) + return + } + + // 6. 索引文档 + ids, err := indexer.Store(ctx, docs) + if err != nil { + fmt.Printf("index error: %v\n", err) + return + } + fmt.Println("indexed ids:", ids) +} diff --git a/common/eino/base_task.go b/common/eino/base_task.go new file mode 100644 index 0000000..b883341 --- /dev/null +++ b/common/eino/base_task.go @@ -0,0 +1,49 @@ +package eino + +import ( + "time" + + "gitea.com/red-future/common/beans" +) + +// BaseTask 任务基类 - MongoDB版本 +type BaseTask struct { + beans.MongoBaseDO `bson:",inline"` + // 任务信息 + TaskType TaskType `bson:"taskType" json:"taskType"` + Status TaskStatus `bson:"status" json:"status"` + Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"` + // 进度 + TotalItems int64 `bson:"totalItems" json:"totalItems"` + ProcessedItems int64 `bson:"processedItems" json:"processedItems"` + Progress float64 `bson:"progress" json:"progress"` + // 结果 + StartTime *time.Time `bson:"startTime" json:"startTime"` + EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"` + Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"` + SuccessCount int64 `bson:"successCount" json:"successCount"` + FailCount int64 `bson:"failCount" json:"failCount"` + // 其他 + Executor string `bson:"executor,omitempty" json:"executor,omitempty"` +} + +// SQLBaseTask 任务基类 - SQL版本 +type SQLBaseTask struct { + beans.SQLBaseDO + // 任务信息 + TaskType TaskType `json:"taskType"` + Status TaskStatus `json:"status"` + Priority TaskPriority `json:"priority,omitempty"` + // 进度 + TotalItems int64 `json:"totalItems"` + ProcessedItems int64 `json:"processedItems"` + Progress float64 `json:"progress"` + // 结果 + StartTime *time.Time `json:"startTime"` + EndTime *time.Time `json:"endTime,omitempty"` + Duration int64 `json:"duration,omitempty"` + SuccessCount int64 `json:"successCount"` + FailCount int64 `json:"failCount"` + // 其他 + Executor string `json:"executor,omitempty"` +} diff --git a/common/eino/c.go b/common/eino/c.go new file mode 100644 index 0000000..ffc634a --- /dev/null +++ b/common/eino/c.go @@ -0,0 +1,94 @@ +package eino + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/cloudwego/eino/schema" + "github.com/elastic/go-elasticsearch/v8" + "github.com/elastic/go-elasticsearch/v8/typedapi/types" + + "github.com/cloudwego/eino-ext/components/retriever/es8" + "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" +) + +func TestRetriever() { + ctx := context.Background() + + client, _ := elasticsearch.NewClient(elasticsearch.Config{ + Addresses: []string{"http://localhost:9200"}, + }) + + // 创建 retriever 组件 + retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{ + Client: client, + Index: indexName, + TopK: 5, + SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{ + QueryFieldName: fieldContent, + VectorFieldName: fieldContentVector, + Hybrid: false, + // RRF 仅在特定许可证下可用 + // 参见: https://www.elastic.co/subscriptions + RRF: false, + RRFRankConstant: nil, + RRFWindowSize: nil, + }), + ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { + doc = &schema.Document{ + ID: *hit.Id_, + Content: "", + MetaData: map[string]any{}, + } + + var src map[string]any + if err = json.Unmarshal(hit.Source_, &src); err != nil { + return nil, err + } + + for field, val := range src { + switch field { + case fieldContent: + doc.Content = val.(string) + case fieldContentVector: + var v []float64 + for _, item := range val.([]interface{}) { + v = append(v, item.(float64)) + } + doc.WithDenseVector(v) + case fieldExtraLocation: + doc.MetaData[docExtraLocation] = val.(string) + } + } + + if hit.Score_ != nil { + doc.WithScore(float64(*hit.Score_)) + } + + return doc, nil + }, + Embedding: EmbedderDashscope, + }) + + // 不带过滤器的搜索 + docs, _ := retriever.Retrieve(ctx, "tourist attraction") + + // 带过滤器的搜索 + docs, _ = retriever.Retrieve(ctx, "tourist attraction", + es8.WithFilters([]types.Query{{ + Term: map[string]types.TermQuery{ + fieldExtraLocation: { + CaseInsensitive: of(true), + Value: "China", + }, + }, + }}), + ) + + fmt.Printf("retrieved docs: %+v\n", docs) +} + +func of[T any](v T) *T { + return &v +} diff --git a/common/eino/consts.go b/common/eino/consts.go new file mode 100644 index 0000000..e766233 --- /dev/null +++ b/common/eino/consts.go @@ -0,0 +1,8 @@ +package eino + +const ( + providerArk = "ark" + providerOpenai = "openai" + providerQianfan = "qianfan" + providerDashscope = "dashscope" +) diff --git a/common/eino/document_loader.go b/common/eino/document_loader.go new file mode 100644 index 0000000..c5b7f65 --- /dev/null +++ b/common/eino/document_loader.go @@ -0,0 +1,51 @@ +package eino + +import ( + "context" + "fmt" + + "gitea.com/red-future/common/utils" + "github.com/cloudwego/eino-ext/components/document/loader/url" + "github.com/cloudwego/eino-ext/components/document/parser/docx" + "github.com/cloudwego/eino-ext/components/document/parser/pdf" + "github.com/cloudwego/eino-ext/components/document/parser/xlsx" + "github.com/cloudwego/eino/components/document" + "github.com/cloudwego/eino/components/document/parser" + "github.com/cloudwego/eino/schema" +) + +// LoadDocument 业务函数:加载文件 +func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) { + p, err := docsParser(ctx, fileFormat) + if err != nil { + return + } + loader, err := url.NewLoader(ctx, &url.LoaderConfig{ + Parser: p, + }) + imageUrl, err := utils.GetFileAddressPrefix(ctx) + if err != nil { + return + } + docs, err = loader.Load(context.Background(), document.Source{ + URI: fmt.Sprintf("%s%s", imageUrl, filePath), + }) + return +} + +func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) { + switch fileFormat { + case "docx": + p, err = docx.NewDocxParser(ctx, &docx.Config{ + ToSections: true, + IncludeHeaders: true, + IncludeFooters: true, + IncludeTables: true, + }) + case "pdf": + p, err = pdf.NewPDFParser(ctx, &pdf.Config{}) + case "xlsx": + p, err = xlsx.NewXlsxParser(ctx, &xlsx.Config{}) + } + return +} diff --git a/common/eino/document_semantic.go b/common/eino/document_semantic.go new file mode 100644 index 0000000..2f39a1a --- /dev/null +++ b/common/eino/document_semantic.go @@ -0,0 +1,64 @@ +package eino + +import ( + "context" + + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive" + "github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic" + "github.com/cloudwego/eino/schema" + "github.com/gogf/gf/v2/frame/g" +) + +// SemanticSplitDocument 语义分割文档 +func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) { + // 默认分隔符(支持中英文) + separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"} + // 读取配置,使用合理的默认值 + bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int() + minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int() + percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64() + batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int() + if batchSize <= 0 { + batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个 + } + + // 使用批量包装器 + var batchEmbedder *BatchEmbedder + provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String() + switch provider { + case providerArk: + batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize) + case providerOpenai: + batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize) + case providerDashscope: + batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize) + } + + splitter, err := semantic.NewSplitter(ctx, &semantic.Config{ + Embedding: batchEmbedder, + BufferSize: bufferSize, + MinChunkSize: minChunkSize, + Percentile: percentile, + Separators: separators, + }) + if err != nil { + return + } + return splitter.Transform(ctx, docs) +} + +// RecursiveSplitDocument 递归分割文档 +func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) { + // 默认分隔符(支持中英文) + separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"} + splitter, err := recursive.NewSplitter(ctx, &recursive.Config{ + ChunkSize: 512, + OverlapSize: 100, + KeepType: recursive.KeepTypeNone, + Separators: separators, + }) + if err != nil { + return + } + return splitter.Transform(ctx, docs) +} diff --git a/common/eino/embedding.go b/common/eino/embedding.go new file mode 100644 index 0000000..7af67e3 --- /dev/null +++ b/common/eino/embedding.go @@ -0,0 +1,69 @@ +package eino + +import ( + "context" + "fmt" + + "github.com/cloudwego/eino-ext/components/embedding/ark" + "github.com/cloudwego/eino-ext/components/embedding/dashscope" + "github.com/cloudwego/eino-ext/components/embedding/openai" + "github.com/gogf/gf/v2/frame/g" + "github.com/golang/glog" +) + +// 全局只初始化一次 +var ( + EmbedderArk *ark.Embedder + EmbedderDashscope *dashscope.Embedder + EmbedderOpenAI *openai.Embedder +) + +func init() { + ctx := context.Background() + if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() { + var err error + provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String() + switch provider { + case providerArk: + cfg := &ark.EmbeddingConfig{ + APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(), + Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(), + } + if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" { + apiTypeVal := ark.APIType(apiType) + cfg.APIType = &apiTypeVal + } + EmbedderArk, err = ark.NewEmbedder(ctx, cfg) + case providerOpenai: + chatModelConfig := &openai.EmbeddingConfig{ + APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(), + Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(), + } + EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig) + case providerDashscope: + cfg := &dashscope.EmbeddingConfig{ + APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(), + Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(), + } + EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg) + } + if err != nil { + glog.Fatalf("NewEmbedder of %v error: %v", provider, err) + } + } + + return +} + +func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) { + provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String() + switch provider { + case providerArk: + return EmbedderArk.EmbedStrings(ctx, texts) + case providerOpenai: + return EmbedderOpenAI.EmbedStrings(ctx, texts) + case providerDashscope: + return EmbedderDashscope.EmbedStrings(ctx, texts) + } + return nil, fmt.Errorf("unsupported provider: %v", provider) +} diff --git a/common/eino/embedding_batch.go b/common/eino/embedding_batch.go new file mode 100644 index 0000000..83e0606 --- /dev/null +++ b/common/eino/embedding_batch.go @@ -0,0 +1,47 @@ +package eino + +import ( + "context" + + "github.com/cloudwego/eino/components/embedding" +) + +// BatchEmbedder 包装器,支持批量限制 +type BatchEmbedder struct { + embedder embedding.Embedder + batchSize int +} + +// NewBatchEmbedder 创建支持批量限制的 embedding 包装器 +func NewBatchEmbedder(embedder embedding.Embedder, batchSize int) *BatchEmbedder { + if batchSize <= 0 { + batchSize = 10 // 默认每批 10 个 + } + return &BatchEmbedder{ + embedder: embedder, + batchSize: batchSize, + } +} + +// EmbedStrings 分批调用 embedding +func (b *BatchEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) { + if len(texts) <= b.batchSize { + return b.embedder.EmbedStrings(ctx, texts, opts...) + } + + var allEmbeddings [][]float64 + for i := 0; i < len(texts); i += b.batchSize { + end := i + b.batchSize + if end > len(texts) { + end = len(texts) + } + + batch := texts[i:end] + embeddings, err := b.embedder.EmbedStrings(ctx, batch, opts...) + if err != nil { + return nil, err + } + allEmbeddings = append(allEmbeddings, embeddings...) + } + return allEmbeddings, nil +} diff --git a/common/eino/embedding_qwen.go b/common/eino/embedding_qwen.go new file mode 100644 index 0000000..9496874 --- /dev/null +++ b/common/eino/embedding_qwen.go @@ -0,0 +1,273 @@ +/* + * Copyright 2024 Red Future Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package eino + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/embedding" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/gclient" + "github.com/gogf/gf/v2/util/gconv" +) + +var ( + // 千问API默认配置 + defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" + defaultTimeout = 10 * time.Minute + defaultRetryTimes = 2 +) + +type QwenEmbeddingConfig struct { + // Timeout specifies the maximum duration to wait for API responses + // Optional. Default: 10 minutes + Timeout *time.Duration `json:"timeout"` + + // HTTPClient specifies the client to send HTTP requests. + // Optional. Default &http.Client{Timeout: Timeout} + HTTPClient *http.Client `json:"http_client"` + + // RetryTimes specifies the number of retry attempts for failed API calls + // Optional. Default: 2 + RetryTimes *int `json:"retry_times"` + + // BaseURL specifies the base URL for Qwen DashScope service + // Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" + BaseURL string `json:"base_url"` + + // APIKey specifies the API Key for authentication + // Required + APIKey string `json:"api_key"` + + // Model specifies the model name for Qwen embedding + // Required. Examples: "text-embedding-v2", "text-embedding-v3" + Model string `json:"model"` + + // TextType specifies the type of text: "document" or "query" + // Optional. Default: "document" + TextType string `json:"text_type"` + + // MaxConcurrentRequests specifies the maximum number of concurrent requests allowed + // Optional. Default: 5 + MaxConcurrentRequests *int `json:"max_concurrent_requests"` +} + +type QwenEmbedder struct { + client *gclient.Client + conf *QwenEmbeddingConfig +} + +// EmbeddingRequest 千问embedding请求结构 +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +// EmbeddingResponse 千问embedding响应结构 +type EmbeddingResponse struct { + Output struct { + Embeddings []struct { + TextIndex int `json:"text_index"` + Embedding []float64 `json:"embedding"` + } `json:"embeddings"` + } `json:"output"` + Usage struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + RequestID string `json:"request_id"` +} + +type APIError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +func (e *APIError) Error() string { + return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID) +} + +func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client { + if len(config.BaseURL) == 0 { + config.BaseURL = defaultBaseURL + } + if config.Timeout == nil { + config.Timeout = &defaultTimeout + } + if config.RetryTimes == nil { + defaultRetryTimes := 2 + config.RetryTimes = &defaultRetryTimes + } + if len(config.TextType) == 0 { + config.TextType = "document" + } + if config.MaxConcurrentRequests == nil { + defaultMaxConcurrentRequests := 5 + config.MaxConcurrentRequests = &defaultMaxConcurrentRequests + } + + client := g.Client() + client.SetTimeout(*config.Timeout) + + return client +} + +func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) { + if len(config.APIKey) == 0 { + return nil, fmt.Errorf("[Qwen] APIKey is required") + } + if len(config.Model) == 0 { + return nil, fmt.Errorf("[Qwen] Model is required") + } + + client := buildQwenClient(config) + + return &QwenEmbedder{ + client: client, + conf: config, + }, nil +} + +func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ( + [][]float64, error) { + + if len(texts) == 0 { + return nil, fmt.Errorf("[Qwen] texts cannot be empty") + } + + options := embedding.GetCommonOptions(&embedding.Options{ + Model: &e.conf.Model, + }, opts...) + + conf := &embedding.Config{ + Model: dereferenceOrZero(options.Model), + } + + ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding) + ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{ + Texts: texts, + Config: conf, + }) + defer func() { + if err := recover(); err != nil { + callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err)) + } + }() + + var usage *embedding.TokenUsage + var embeddings [][]float64 + var err error + + // 调用千问API获取embedding + embeddings, usage, err = e.callEmbeddingAPI(ctx, texts) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + callbacks.OnEnd(ctx, &embedding.CallbackOutput{ + Embeddings: embeddings, + Config: conf, + TokenUsage: usage, + }) + + return embeddings, nil +} + +func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) { + // 构建请求 + var req EmbeddingRequest + req.Model = e.conf.Model + req.Input.Texts = texts + req.Parameters.TextType = e.conf.TextType + + // 调用API + client := e.client.Clone() + client.SetHeader("Authorization", "Bearer "+e.conf.APIKey) + client.SetHeader("Content-Type", "application/json") + client.SetTimeout(*e.conf.Timeout) + + resp, err := client.Post(ctx, e.conf.BaseURL, req) + if err != nil { + return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err) + } + + defer resp.Close() + + // 检查状态码 + if resp.StatusCode != http.StatusOK { + var errResp APIError + result := resp.ReadAll() + if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" { + return nil, nil, &errResp + } + return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode) + } + + // 解析响应 + var apiResp EmbeddingResponse + result := resp.ReadAll() + if err = gconv.Struct(result, &apiResp); err != nil { + return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err) + } + + // 解析响应结果 + embeddings := make([][]float64, len(texts)) + for _, emb := range apiResp.Output.Embeddings { + if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) { + embeddings[emb.TextIndex] = emb.Embedding + } + } + + usage := &embedding.TokenUsage{ + TotalTokens: apiResp.Usage.TotalTokens, + } + + g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens) + + return embeddings, usage, nil +} + +func (e *QwenEmbedder) GetType() string { + return getType() +} + +func (e *QwenEmbedder) IsCallbacksEnabled() bool { + return true +} + +func getType() string { + return "Qwen" +} + +func dereferenceOrZero[T any](v *T) T { + if v == nil { + var t T + return t + } + return *v +} diff --git a/common/eino/priority_enum.go b/common/eino/priority_enum.go new file mode 100644 index 0000000..371706b --- /dev/null +++ b/common/eino/priority_enum.go @@ -0,0 +1,11 @@ +package eino + +// TaskPriority 任务优先级 +type TaskPriority string + +const ( + TaskPriorityLow TaskPriority = "low" // 低优先级 + TaskPriorityMedium TaskPriority = "medium" // 中优先级 + TaskPriorityHigh TaskPriority = "high" // 高优先级 + TaskPriorityUrgent TaskPriority = "urgent" // 紧急 +) diff --git a/common/eino/status_enum.go b/common/eino/status_enum.go new file mode 100644 index 0000000..6e12daf --- /dev/null +++ b/common/eino/status_enum.go @@ -0,0 +1,12 @@ +package eino + +// TaskStatus 任务状态 +type TaskStatus string + +const ( + TaskStatusPending TaskStatus = "pending" // 待处理 + TaskStatusRunning TaskStatus = "running" // 运行中 + TaskStatusCompleted TaskStatus = "completed" // 已完成 + TaskStatusFailed TaskStatus = "failed" // 失败 + TaskStatusCancelled TaskStatus = "cancelled" // 已取消 +) diff --git a/common/eino/task_type.go b/common/eino/task_type.go new file mode 100644 index 0000000..0ba5a64 --- /dev/null +++ b/common/eino/task_type.go @@ -0,0 +1,14 @@ +package eino + +// TaskType 任务类型 +type TaskType string + +const ( + TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务 + TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务 + TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务 + TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务 + TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务 + TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务 + TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务 +) diff --git a/common/gse/utils.go b/common/gse/utils.go new file mode 100644 index 0000000..aea4b38 --- /dev/null +++ b/common/gse/utils.go @@ -0,0 +1,114 @@ +package gse + +import ( + "context" + "sort" + + "github.com/go-ego/gse" + "github.com/go-ego/gse/hmm/extracker" + "github.com/go-ego/gse/hmm/segment" + "github.com/gogf/gf/v2/os/glog" +) + +var GseTool *gseTool + +// 初始化函数:程序启动时执行一次 +func init() { + var err error + GseTool, err = newGseTool() + if err != nil { + glog.Error(context.Background(), err) + } +} + +// gseTool 关键词提取工具(gse v1.0.2 标准) +type gseTool struct { + seg gse.Segmenter + tfidf *extracker.TagExtracter + tr *extracker.TextRanker +} + +// newGseTool 初始化工具(内置词典 + 停用词) +func newGseTool() (tool *gseTool, err error) { + // 1. 初始化分词器 + var seg gse.Segmenter + // 内置词典(无外部文件) + err = seg.LoadDictEmbed() + if err != nil { + return + } + // 内置停用词(v1.0.2 标准) + err = seg.LoadStopEmbed() + if err != nil { + return + } + + // 2. 初始化 TF-IDF 提取器 + tfidf := &extracker.TagExtracter{} + tfidf.WithGse(seg) + err = tfidf.LoadIdf() + if err != nil { + return + } + + // 3. 初始化 TextRank 提取器 + tr := &extracker.TextRanker{} + tr.WithGse(seg) + + tool = &gseTool{ + seg: seg, + tfidf: tfidf, + tr: tr, + } + return +} + +// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM) +func (k *gseTool) Cut(text string) []string { + return k.seg.Cut(text, true) +} + +// Keyword 最终输出:关键词 + 权重 +type Keyword struct { + Word string `json:"word"` + Score float64 `json:"score"` +} + +func (k *gseTool) Extract(text string, topN int) []Keyword { + // 1. 提取 TF-IDF + tfTags := k.extractTFIDF(text, topN) + + // 2. 提取 TextRank + trTags := k.extractTextRank(text, topN) + + // 3. 合并成最终关键词(业务最常用) + scoreMap := make(map[string]float64) + for _, tag := range tfTags { + scoreMap[tag.Text] = tag.Weight + } + for _, tag := range trTags { + scoreMap[tag.Text] = tag.Weight + } + + // 转成切片并排序(高分在前) + res := make([]Keyword, 0, len(scoreMap)) + for word, score := range scoreMap { + res = append(res, Keyword{Word: word, Score: score}) + } + + sort.Slice(res, func(i, j int) bool { + return res[i].Score > res[j].Score + }) + + return res +} + +// ExtractTFIDF TF-IDF 关键词(带权重)90% 业务:文章标签、搜索、关键词 +func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments { + return k.tfidf.ExtractTags(text, topN) +} + +// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解 +func (k *gseTool) extractTextRank(text string, topN int) segment.Segments { + return k.tr.TextRank(text, topN) +} diff --git a/controller/dataset_index.go b/controller/dataset_index.go deleted file mode 100644 index e89dc16..0000000 --- a/controller/dataset_index.go +++ /dev/null @@ -1,5 +0,0 @@ -package controller - -type datasetIndex struct{} - -var DatasetIndex = new(datasetIndex) diff --git a/go.mod b/go.mod index 960be24..dd457bd 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,29 @@ module rag go 1.26.0 require ( - gitea.com/red-future/common v0.0.6 + gitea.com/red-future/common v0.0.11 github.com/bjang03/gmq v0.0.0-00010101000000-000000000000 github.com/cloudwego/eino v0.8.6 + github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1 + github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 + github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 + github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 + github.com/elastic/go-elasticsearch/v8 v8.16.0 + github.com/go-ego/gse v1.0.2 github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 github.com/gogf/gf/v2 v2.10.0 + github.com/golang/glog v1.2.5 github.com/pgvector/pgvector-go v0.3.0 ) -replace gitea.com/red-future/common v0.0.6 => ../common +replace gitea.com/red-future/common v0.0.11 => ../common replace github.com/bjang03/gmq => ../gmq @@ -35,18 +49,7 @@ require ( github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419 // indirect github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect - github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1 // indirect - github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 // indirect - github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect - github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect github.com/dgraph-io/badger/v4 v4.2.0 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect @@ -56,13 +59,11 @@ require ( github.com/eino-contrib/docx2md v0.0.1 // indirect github.com/eino-contrib/jsonschema v1.0.3 // indirect github.com/elastic/elastic-transport-go/v8 v8.10.0 // indirect - github.com/elastic/go-elasticsearch/v8 v8.16.0 // indirect github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect github.com/evanphx/json-patch v0.5.2 // indirect github.com/fatih/color v1.19.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect - github.com/go-ego/gse v1.0.2 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -74,7 +75,6 @@ require ( github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v5 v5.3.1 // indirect - github.com/golang/glog v1.2.5 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v1.0.0 // indirect diff --git a/go.sum b/go.sum index e7369ed..d31f601 100644 --- a/go.sum +++ b/go.sum @@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ= entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM= +gitea.com/red-future/common v0.0.11 h1:AV7W3G0uZ8aPpHHSHd4ZHmLWe5+2STPKe/AYPoPCWVc= +gitea.com/red-future/common v0.0.11/go.mod h1:B8syUI4XbLCDQSeRHURYxEwnWw8mEFgmqCxjC+lM+NU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= diff --git a/model/dto/dataset_index.go b/model/dto/dataset_index.go deleted file mode 100644 index 76d3a17..0000000 --- a/model/dto/dataset_index.go +++ /dev/null @@ -1 +0,0 @@ -package dto diff --git a/service/dataset_index.go b/service/dataset_index.go deleted file mode 100644 index 5ebe407..0000000 --- a/service/dataset_index.go +++ /dev/null @@ -1,5 +0,0 @@ -package service - -var DatasetIndex = new(datasetIndexService) - -type datasetIndexService struct{} diff --git a/service/document.go b/service/document.go index 48d7d11..47af42f 100644 --- a/service/document.go +++ b/service/document.go @@ -3,6 +3,8 @@ package service import ( "context" "fmt" + "rag/common/eino" + "rag/common/gse" "rag/consts/document" "rag/consts/public" "rag/dao" @@ -16,8 +18,6 @@ import ( "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" "gitea.com/red-future/common/http" - "gitea.com/red-future/common/rag/eino" - "gitea.com/red-future/common/rag/gse" "gitea.com/red-future/common/utils" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" @@ -251,7 +251,7 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu return } // 3. 组装向量文档 - var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0) + var docsChunk = make([]*schema.Document, 0) for i, t := range docsSplit { contentHash := gmd5.MustEncryptString(t.Content) // 检查是否重复 @@ -263,27 +263,26 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu if !success { continue } - vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{ - TenantId: doc.TenantId, - Creator: doc.Creator, - DatasetId: doc.DatasetId, - DocumentId: doc.Id, - Content: t.Content, - ContentHash: contentHash, - ChunkIndex: gconv.Int64(i), - }) - + var metaData = make(map[string]any) + metaData[entity.DocumentCol.TenantId] = doc.TenantId + metaData[entity.DocumentCol.Creator] = doc.Creator + metaData[entity.DocumentCol.DatasetId] = doc.DatasetId + metaData[entity.DocumentChunkCol.DocumentId] = doc.Id + metaData[entity.DocumentChunkCol.ContentHash] = contentHash + metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i) + t.MetaData = metaData + docsChunk = append(docsChunk, t) } // 4. 发送消息到队列 - if len(vectorDocs) > 0 { + if len(docsChunk) > 0 { err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{ PubMessage: types.PubMessage{ Topic: public.KnowledgeDocumentChunkTopic, - Data: vectorDocs, + Data: docsChunk, }, }) } - vectorDocsCount = gconv.Int64(len(vectorDocs)) + vectorDocsCount = gconv.Int64(len(docsChunk)) return } @@ -318,12 +317,12 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum } // 构建Meilisearch文档 meiliDocs = append(meiliDocs, map[string]interface{}{ - "id": contentHash, - "datasetId": doc.DatasetId, - "documentId": doc.Id, - "content": t.Content, - "contentHash": contentHash, - "chunkIndex": i, + entity.DocumentChunkCol.Id: contentHash, + entity.DocumentChunkCol.DatasetId: doc.DatasetId, + entity.DocumentChunkCol.DocumentId: doc.Id, + entity.DocumentChunkCol.Content: t.Content, + entity.DocumentChunkCol.ContentHash: contentHash, + entity.DocumentChunkCol.ChunkIndex: i, }) } // 4. 写入到meilisearch数据库中 diff --git a/service/document_chunk.go b/service/document_chunk.go index 4e207de..103308d 100644 --- a/service/document_chunk.go +++ b/service/document_chunk.go @@ -2,23 +2,20 @@ package service import ( "context" - "database/sql" - "errors" - "fmt" + "rag/common/eino" "rag/consts/document" "rag/consts/public" "rag/dao" "rag/model/dto" "rag/model/entity" - "gitea.com/red-future/common/beans" - "gitea.com/red-future/common/rag/eino" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" "github.com/bjang03/gmq/types" + "github.com/cloudwego/eino/components/indexer" + "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" - "github.com/pgvector/pgvector-go" ) var DocumentChunk = new(documentChunkService) @@ -49,114 +46,124 @@ func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentCh } func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) { - var req = make([]*dto.VectorDocumentChunkMsg, 0) + var docs = make([]*schema.Document, 0) msgMap := gconv.Map(msg) - if err = gconv.Structs(msgMap["data"], &req); err != nil { + if err = gconv.Structs(msgMap["data"], &docs); err != nil { g.Log().Error(ctx, "DocsChunkMsg err:", err) return } - if len(req) == 0 { + if len(docs) == 0 { g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty") return } - ctx = context.WithValue(ctx, "user", &beans.User{ - TenantId: req[0].TenantId, - UserName: req[0].Creator, - }) + //ctx = context.WithValue(ctx, "user", &beans.User{ + // TenantId: req[0].TenantId, + // UserName: req[0].Creator, + //}) // 调用eino接口获取向量 - var vectorDocsStr = make([]string, 0, len(req)) - for _, t := range req { - vectorDocsStr = append(vectorDocsStr, t.Content) - } - embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr) - if err != nil { - g.Log().Error(ctx, "DocsChunkMsg err:", err) - err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) - return - } + //var vectorDocsStr = make([]string, 0, len(req)) + //for _, t := range req { + // vectorDocsStr = append(vectorDocsStr, t.Content) + //} + //embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr) + //if err != nil { + // g.Log().Error(ctx, "DocsChunkMsg err:", err) + // err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) + // return + //} // 获取向量维度 - dimension := 0 - if len(embeddings) > 0 { - dimension = len(embeddings[0]) - } + //dimension := 0 + //if len(embeddings) > 0 { + // dimension = len(embeddings[0]) + //} // 创建或更新DatasetIndex - err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req))) - if err != nil { - g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err) - err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) - return - } + //err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req))) + //if err != nil { + // g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err) + // err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) + // return + //} // 更新向量文档 - for i, embedding := range embeddings { - req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding)) - req[i].VectorStatus = document.VectorStatusCompleted.Code() - req[i].Status = document.StatusEnable.Code() - } - _, err = dao.DocumentChunk.BatchInsert(ctx, req) - if err != nil { - g.Log().Error(ctx, "DocsChunkMsg err:", err) - err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) + //for i, embedding := range embeddings { + // req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding)) + // req[i].VectorStatus = document.VectorStatusCompleted.Code() + // req[i].Status = document.StatusEnable.Code() + //} + //_, err = dao.DocumentChunk.BatchInsert(ctx, req) + //if err != nil { + // g.Log().Error(ctx, "DocsChunkMsg err:", err) + // err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code()) + // return + //} + idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{ + BatchSize: 10, + }) + rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope)) + if err != nil || rows == 0 { + g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err) return } - - err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusCompleted.Code()) + tenantId := docs[0].MetaData[entity.DocumentChunkCol.TenantId].(uint64) + creator := docs[0].MetaData[entity.DocumentChunkCol.Creator].(string) + documentId := docs[0].MetaData[entity.DocumentChunkCol.DocumentId].(int64) + err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code()) return } -// createOrUpdateDatasetIndex 创建或更新数据集索引 -func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) { - // 查询数据集是否已有索引 - existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return err - } - - // 已有索引 → 只更新数量 - if existIndex != nil { - _ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount) - return nil - } - - // ====================== 创建新索引 ====================== - indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名 - // 1. 插入索引配置 - index := &entity.DatasetIndex{ - DatasetId: datasetId, - Name: indexName, - Dimension: dimension, - FieldType: "float", - MetricType: "COSINE", - Status: gconv.PtrInt8(1), - VectorCount: vectorCount, - Description: fmt.Sprintf("数据集%d向量索引", datasetId), - } - _, err = dao.DatasetIndex.Insert(ctx, index) - if err != nil { - return err - } - - // 2. 真正创建 PGVector 索引(唯一真实索引!) - err = s.createRealPGVectorIndex(ctx, indexName) - return err -} - -// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用) -func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error { - // 执行真实建索引语句 - err := dao.DatasetIndex.InsertIndex(ctx, indexName) - if err != nil { - g.Log().Error(ctx, "创建向量索引失败:", err) - return err - } - g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName) - return nil -} +//// createOrUpdateDatasetIndex 创建或更新数据集索引 +//func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) { +// // 查询数据集是否已有索引 +// existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId) +// if err != nil && !errors.Is(err, sql.ErrNoRows) { +// return err +// } +// +// // 已有索引 → 只更新数量 +// if existIndex != nil { +// _ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount) +// return nil +// } +// +// // ====================== 创建新索引 ====================== +// indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名 +// // 1. 插入索引配置 +// index := &entity.DatasetIndex{ +// DatasetId: datasetId, +// Name: indexName, +// Dimension: dimension, +// FieldType: "float", +// MetricType: "COSINE", +// Status: gconv.PtrInt8(1), +// VectorCount: vectorCount, +// Description: fmt.Sprintf("数据集%d向量索引", datasetId), +// } +// _, err = dao.DatasetIndex.Insert(ctx, index) +// if err != nil { +// return err +// } +// +// // 2. 真正创建 PGVector 索引(唯一真实索引!) +// err = s.createRealPGVectorIndex(ctx, indexName) +// return err +//} +// +//// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用) +//func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error { +// // 执行真实建索引语句 +// err := dao.DatasetIndex.InsertIndex(ctx, indexName) +// if err != nil { +// g.Log().Error(ctx, "创建向量索引失败:", err) +// return err +// } +// g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName) +// return nil +//} // publishKnowledgeDocumentMsg 发布消息 func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {